-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy path__init__.py
90 lines (76 loc) · 2.63 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from typing import Dict, Any
from omegaconf import DictConfig
from transformers import Trainer, TrainingArguments
from trainer.base import FinetuneTrainer
from trainer.unlearn.grad_ascent import GradAscent
from trainer.unlearn.grad_diff import GradDiff
from trainer.unlearn.npo import NPO
from trainer.unlearn.dpo import DPO
from trainer.unlearn.simnpo import SimNPO
from trainer.unlearn.rmu import RMU
import logging
logger = logging.getLogger(__name__)
TRAINER_REGISTRY: Dict[str, Any] = {}
def _register_trainer(trainer_class):
TRAINER_REGISTRY[trainer_class.__name__] = trainer_class
def load_trainer_args(trainer_args: DictConfig, dataset):
trainer_args = dict(trainer_args)
warmup_epochs = trainer_args.pop("warmup_epochs", None)
if warmup_epochs:
batch_size = trainer_args["per_device_train_batch_size"]
grad_accum_steps = trainer_args["gradient_accumulation_steps"]
num_devices = torch.cuda.device_count()
dataset_len = len(dataset)
trainer_args["warmup_steps"] = int(
(warmup_epochs * dataset_len)
// (batch_size * grad_accum_steps * num_devices)
)
trainer_args = TrainingArguments(**trainer_args)
return trainer_args
def load_trainer(
trainer_cfg: DictConfig,
model,
train_dataset=None,
eval_dataset=None,
tokenizer=None,
data_collator=None,
evaluator=None,
template_args=None,
):
trainer_args = trainer_cfg.args
method_args = trainer_cfg.get("method_args", {})
trainer_args = load_trainer_args(trainer_args, train_dataset)
trainer_handler_name = trainer_cfg.get("handler")
assert trainer_handler_name is not None, ValueError(
f"{trainer_handler_name} handler not set"
)
trainer_cls = TRAINER_REGISTRY.get(trainer_handler_name, None)
assert trainer_cls is not None, NotImplementedError(
f"{trainer_handler_name} not implemented or not registered"
)
trainer = trainer_cls(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
args=trainer_args,
evaluator=evaluator,
template_args=template_args,
**method_args,
)
logger.info(
f"{trainer_handler_name} Trainer loaded, output_dir: {trainer_args.output_dir}"
)
return trainer, trainer_args
# Register Finetuning Trainer
_register_trainer(Trainer)
_register_trainer(FinetuneTrainer)
# Register Unlearning Trainer
_register_trainer(GradAscent)
_register_trainer(GradDiff)
_register_trainer(NPO)
_register_trainer(DPO)
_register_trainer(SimNPO)
_register_trainer(RMU)