Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions encodeval/eval_tasks/abstract_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class EvalConfig:
task_type: Task type identifier (SC, SR, TC, or IR).
loss_fn: Optional callable for the training loss function.
loss_kwargs: Optional keyword arguments for the loss function.
finetuning_learning_rate: Learning rate used during fine-tuning, used to
build result directories.
"""

model_class: Callable = None
Expand All @@ -40,6 +42,7 @@ class EvalConfig:
task_type: Literal["SC", "SR", "TC", "IR"] = None
loss_fn: Callable = None
loss_kwargs: Dict = None
finetuning_learning_rate: str = None

def __post_init__(self):
"""
Expand All @@ -53,10 +56,19 @@ def __post_init__(self):
self.model_dtype = self.model_kwargs.pop("dtype")
self.device = self.model_kwargs.pop("device")

base_output_dir = self.tr_args_kwargs.get("output_dir", "")
base_model_name = os.environ["EVAL_MODEL_PATH"].split("/")[-1]
ft_lr = None
if self.finetuning_learning_rate is not None:
ft_lr = str(self.finetuning_learning_rate).replace(".", "_").replace("-", "_").replace("+", "p")
model_name = f"{base_model_name}_ftlr_{ft_lr}" if ft_lr else base_model_name

# Handle loading fine-tuned model from disk if specified
ft_model_config_dir = self.model_kwargs.pop("ft_model_config_dir", None)
if ft_model_config_dir is not None:
ft_model_path = f"{os.environ['EVAL_MODEL_PATH']}/evaluation/weights/{self.task_type}/{ft_model_config_dir}"
ft_model_path = (
f"{base_output_dir}/{model_name}/evaluation/weights/{self.task_type}/{ft_model_config_dir}"
)
print(f"Loading fine-tuned model at {ft_model_path}")
if "pretrained_model_name_or_path" in self.model_kwargs:
self.model_kwargs["pretrained_model_name_or_path"] = ft_model_path
Expand Down Expand Up @@ -144,15 +156,15 @@ def __post_init__(self):
self.dataset_name = ""

# Prepare output/log directories
model_name = os.environ["EVAL_MODEL_PATH"].split("/")[-1]
output_dir = self.tr_args.output_dir
output_dir = base_output_dir
output_subdir = (
f"{self.task_type}/{self.dataset_name}/{ft_model_config_dir.replace('/', '_')}/{output_subdir}"
if ft_model_config_dir is not None else f"{self.task_type}/{self.dataset_name}/{output_subdir}"
)
self.tr_args.output_dir = f"{os.environ['EVAL_MODEL_PATH']}/evaluation/weights/{output_subdir}"
self.tr_args.logging_dir = f"{os.environ['EVAL_MODEL_PATH']}/evaluation/logs/{output_subdir}"
self.results_dir = f"{output_dir}/{model_name}/{output_subdir}"
model_results_root = f"{output_dir}/{model_name}"
self.tr_args.output_dir = f"{model_results_root}/evaluation/weights/{output_subdir}"
self.tr_args.logging_dir = f"{model_results_root}/evaluation/logs/{output_subdir}"
self.results_dir = f"{model_results_root}/{output_subdir}"

# Clear old logs if logging directory is not empty
if os.path.exists(self.tr_args.logging_dir) and len(os.listdir(self.tr_args.logging_dir)) > 0:
Expand Down