Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dragonfly/src/env/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, path, pms):
if hasattr(pms, "args"): self.args = pms.args

# Generate workers
self.worker = worker(self.name, self.args, mpi.rank, path)
self.worker = worker(self.name, mpi.rank, path, self.args)

# Set all slaves to wait for instructions
if (mpi.rank != 0): self.worker.work()
Expand Down
66 changes: 66 additions & 0 deletions dragonfly/src/env/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import ast
import importlib.util

def find_class_in_folder(folder_path: str, class_name: str):
"""
Searches for a class definition with the given name in all .py files within folder_path

Args:
folder_path: path to the folder containing Python files
class_name: name of the class to search for

Returns:
matching_files: list of file paths where the class is defined
"""
matching_files = []

for root, _, files in os.walk(folder_path):
for file in files:
if file.endswith(".py"): # Look only for Python files
file_path = os.path.join(root, file)
try:
with open(file_path, "r", encoding="utf-8") as f:
tree = ast.parse(f.read(), filename=file_path)

# Check for class definitions in the parsed AST
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == class_name:
matching_files.append(file_path)
break # No need to check further in this file

except (SyntaxError, UnicodeDecodeError) as e:
print(f"Skipping {file_path} due to error: {e}")

return matching_files

def import_class_from_file(file_path: str, class_name: str):
"""
Dynamically imports a class from a given Python file path

Args:
file_path: absolute or relative path to the Python file
class_name: name of the class to import

Returns:
the class object if found, else None
"""
if not os.path.isfile(file_path) or not file_path.endswith(".py"):
raise ValueError(f"Invalid Python file: {file_path}")

module_name = os.path.splitext(os.path.basename(file_path))[
0
] # Extract module name

spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # Load the module

# Retrieve the class from the module
if hasattr(module, class_name):
return getattr(module, class_name)
else:
raise ImportError(f"Class '{class_name}' not found in '{file_path}'")

raise ImportError(f"Could not load module from '{file_path}'")
30 changes: 20 additions & 10 deletions dragonfly/src/env/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,45 @@

# Custom imports
from dragonfly.src.env.mpi import mpi
from dragonfly.src.env.utils import find_class_in_folder, import_class_from_file
from dragonfly.src.utils.error import error

###############################################
# Worker class for slave processes
class worker():
def __init__(self, env_name, args, cpu, path):
def __init__(self,
env_name,
cpu,
path_hint,
args):

# Build environment
try:
try: # Test if this is a gym environment
if args is not None:
self.env = gym.make(env_name,
render_mode="rgb_array",
**args.__dict__)
else:
self.env = gym.make(env_name,
render_mode="rgb_array")
except:
sys.path.append(path)
module = __import__(env_name)
env_build = getattr(module, env_name)
except: # Othwerise, look for env_name class in all files of path_hint folder
files = find_class_in_folder(path_hint, env_name)

if len(files) > 1:
error("worker", "init",
f"Found more than one file containing class: {env_name}")

builder = import_class_from_file(files[0], env_name)
try:
if args is not None:
self.env = env_build(cpu, **args.__dict__)
self.env = builder(cpu, **args.__dict__)
else:
self.env = env_build(cpu)
self.env = builder(cpu)
except:
if args is not None:
self.env = env_build(**args.__dict__)
self.env = builder(**args.__dict__)
else:
self.env = env_build()
self.env = builder()

# Working function for slaves
def work(self):
Expand Down