Skip to content

About lm_loss for distillm #16

@songmzhang

Description

@songmzhang

Hi authors, I found that there may be a problem in your code when calculating lm_loss for distillm. As shown in the following code, lm_loss is calculated with logits and labels in no_model_batch. However, sometimes model_batch and no_model_batch are from SGO, and calculating lm_loss on this data will push the student model to learn the outputs generated by itself.

distillm/finetune.py

Lines 321 to 339 in d47e77f

elif "adaptive" in args.type and r < adaptive_threshold:
model_batch, no_model_batch = replay_buffer.sample()
model_batch, no_model_batch = replay_buffer.move_to_device(model_batch, no_model_batch, device)
model.train()
outputs = model(**model_batch, use_cache=False)
logits = outputs.logits
if args.model_parallel:
raise NotImplementedError
else:
lm_loss = loss_func(logits.float().view(-1, logits.shape[-1]), no_model_batch["label"].view(-1))
if teacher_model is not None:
distil_loss = get_distil_loss(args, tokenizer, model, teacher_model, model_batch, no_model_batch, logits)
loss = (1 - args.kd_ratio) * lm_loss + args.kd_ratio * distil_loss
else:
loss = lm_loss

In our experiments, we found this will lead to model collapse when the threshold increases (after several epochs).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions