diff --git a/encodeval/eval_tasks/abstract_eval.py b/encodeval/eval_tasks/abstract_eval.py index 1d2e05b..3081f23 100644 --- a/encodeval/eval_tasks/abstract_eval.py +++ b/encodeval/eval_tasks/abstract_eval.py @@ -27,6 +27,8 @@ class EvalConfig: task_type: Task type identifier (SC, SR, TC, or IR). loss_fn: Optional callable for the training loss function. loss_kwargs: Optional keyword arguments for the loss function. + finetuning_learning_rate: Learning rate used during fine-tuning, used to + build result directories. """ model_class: Callable = None @@ -40,6 +42,7 @@ class EvalConfig: task_type: Literal["SC", "SR", "TC", "IR"] = None loss_fn: Callable = None loss_kwargs: Dict = None + finetuning_learning_rate: str = None def __post_init__(self): """ @@ -53,10 +56,19 @@ def __post_init__(self): self.model_dtype = self.model_kwargs.pop("dtype") self.device = self.model_kwargs.pop("device") + base_output_dir = self.tr_args_kwargs.get("output_dir", "") + base_model_name = os.environ["EVAL_MODEL_PATH"].split("/")[-1] + ft_lr = None + if self.finetuning_learning_rate is not None: + ft_lr = str(self.finetuning_learning_rate).replace(".", "_").replace("-", "_").replace("+", "p") + model_name = f"{base_model_name}_ftlr_{ft_lr}" if ft_lr else base_model_name + # Handle loading fine-tuned model from disk if specified ft_model_config_dir = self.model_kwargs.pop("ft_model_config_dir", None) if ft_model_config_dir is not None: - ft_model_path = f"{os.environ['EVAL_MODEL_PATH']}/evaluation/weights/{self.task_type}/{ft_model_config_dir}" + ft_model_path = ( + f"{base_output_dir}/{model_name}/evaluation/weights/{self.task_type}/{ft_model_config_dir}" + ) print(f"Loading fine-tuned model at {ft_model_path}") if "pretrained_model_name_or_path" in self.model_kwargs: self.model_kwargs["pretrained_model_name_or_path"] = ft_model_path @@ -144,15 +156,15 @@ def __post_init__(self): self.dataset_name = "" # Prepare output/log directories - model_name = os.environ["EVAL_MODEL_PATH"].split("/")[-1] - output_dir = self.tr_args.output_dir + output_dir = base_output_dir output_subdir = ( f"{self.task_type}/{self.dataset_name}/{ft_model_config_dir.replace('/', '_')}/{output_subdir}" if ft_model_config_dir is not None else f"{self.task_type}/{self.dataset_name}/{output_subdir}" ) - self.tr_args.output_dir = f"{os.environ['EVAL_MODEL_PATH']}/evaluation/weights/{output_subdir}" - self.tr_args.logging_dir = f"{os.environ['EVAL_MODEL_PATH']}/evaluation/logs/{output_subdir}" - self.results_dir = f"{output_dir}/{model_name}/{output_subdir}" + model_results_root = f"{output_dir}/{model_name}" + self.tr_args.output_dir = f"{model_results_root}/evaluation/weights/{output_subdir}" + self.tr_args.logging_dir = f"{model_results_root}/evaluation/logs/{output_subdir}" + self.results_dir = f"{model_results_root}/{output_subdir}" # Clear old logs if logging directory is not empty if os.path.exists(self.tr_args.logging_dir) and len(os.listdir(self.tr_args.logging_dir)) > 0: