-
Notifications
You must be signed in to change notification settings - Fork 28
About lm_loss for distillm #16
Copy link
Copy link
Open
Description
Hi authors, I found that there may be a problem in your code when calculating lm_loss for distillm. As shown in the following code, lm_loss is calculated with logits and labels in no_model_batch. However, sometimes model_batch and no_model_batch are from SGO, and calculating lm_loss on this data will push the student model to learn the outputs generated by itself.
Lines 321 to 339 in d47e77f
| elif "adaptive" in args.type and r < adaptive_threshold: | |
| model_batch, no_model_batch = replay_buffer.sample() | |
| model_batch, no_model_batch = replay_buffer.move_to_device(model_batch, no_model_batch, device) | |
| model.train() | |
| outputs = model(**model_batch, use_cache=False) | |
| logits = outputs.logits | |
| if args.model_parallel: | |
| raise NotImplementedError | |
| else: | |
| lm_loss = loss_func(logits.float().view(-1, logits.shape[-1]), no_model_batch["label"].view(-1)) | |
| if teacher_model is not None: | |
| distil_loss = get_distil_loss(args, tokenizer, model, teacher_model, model_batch, no_model_batch, logits) | |
| loss = (1 - args.kd_ratio) * lm_loss + args.kd_ratio * distil_loss | |
| else: | |
| loss = lm_loss |
In our experiments, we found this will lead to model collapse when the threshold increases (after several epochs).
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels