diff --git a/dragonfly/src/env/environment.py b/dragonfly/src/env/environment.py index e4b4de49..c2404daa 100755 --- a/dragonfly/src/env/environment.py +++ b/dragonfly/src/env/environment.py @@ -21,7 +21,7 @@ def __init__(self, path, pms): if hasattr(pms, "args"): self.args = pms.args # Generate workers - self.worker = worker(self.name, self.args, mpi.rank, path) + self.worker = worker(self.name, mpi.rank, path, self.args) # Set all slaves to wait for instructions if (mpi.rank != 0): self.worker.work() diff --git a/dragonfly/src/env/utils.py b/dragonfly/src/env/utils.py new file mode 100644 index 00000000..017abd40 --- /dev/null +++ b/dragonfly/src/env/utils.py @@ -0,0 +1,66 @@ +import os +import ast +import importlib.util + +def find_class_in_folder(folder_path: str, class_name: str): + """ + Searches for a class definition with the given name in all .py files within folder_path + + Args: + folder_path: path to the folder containing Python files + class_name: name of the class to search for + + Returns: + matching_files: list of file paths where the class is defined + """ + matching_files = [] + + for root, _, files in os.walk(folder_path): + for file in files: + if file.endswith(".py"): # Look only for Python files + file_path = os.path.join(root, file) + try: + with open(file_path, "r", encoding="utf-8") as f: + tree = ast.parse(f.read(), filename=file_path) + + # Check for class definitions in the parsed AST + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + matching_files.append(file_path) + break # No need to check further in this file + + except (SyntaxError, UnicodeDecodeError) as e: + print(f"Skipping {file_path} due to error: {e}") + + return matching_files + +def import_class_from_file(file_path: str, class_name: str): + """ + Dynamically imports a class from a given Python file path + + Args: + file_path: absolute or relative path to the Python file + class_name: name of the class to import + + Returns: + the class object if found, else None + """ + if not os.path.isfile(file_path) or not file_path.endswith(".py"): + raise ValueError(f"Invalid Python file: {file_path}") + + module_name = os.path.splitext(os.path.basename(file_path))[ + 0 + ] # Extract module name + + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) # Load the module + + # Retrieve the class from the module + if hasattr(module, class_name): + return getattr(module, class_name) + else: + raise ImportError(f"Class '{class_name}' not found in '{file_path}'") + + raise ImportError(f"Could not load module from '{file_path}'") diff --git a/dragonfly/src/env/worker.py b/dragonfly/src/env/worker.py index 2a5517e4..841cdafd 100755 --- a/dragonfly/src/env/worker.py +++ b/dragonfly/src/env/worker.py @@ -4,14 +4,20 @@ # Custom imports from dragonfly.src.env.mpi import mpi +from dragonfly.src.env.utils import find_class_in_folder, import_class_from_file +from dragonfly.src.utils.error import error ############################################### # Worker class for slave processes class worker(): - def __init__(self, env_name, args, cpu, path): + def __init__(self, + env_name, + cpu, + path_hint, + args): # Build environment - try: + try: # Test if this is a gym environment if args is not None: self.env = gym.make(env_name, render_mode="rgb_array", @@ -19,20 +25,24 @@ def __init__(self, env_name, args, cpu, path): else: self.env = gym.make(env_name, render_mode="rgb_array") - except: - sys.path.append(path) - module = __import__(env_name) - env_build = getattr(module, env_name) + except: # Othwerise, look for env_name class in all files of path_hint folder + files = find_class_in_folder(path_hint, env_name) + + if len(files) > 1: + error("worker", "init", + f"Found more than one file containing class: {env_name}") + + builder = import_class_from_file(files[0], env_name) try: if args is not None: - self.env = env_build(cpu, **args.__dict__) + self.env = builder(cpu, **args.__dict__) else: - self.env = env_build(cpu) + self.env = builder(cpu) except: if args is not None: - self.env = env_build(**args.__dict__) + self.env = builder(**args.__dict__) else: - self.env = env_build() + self.env = builder() # Working function for slaves def work(self):