🐞 Describe the Bug
Loss mask is currently only applied when sequence_first is False. It needs to be enabled for all cases and spans need to be clipped based on sequence_k and sequence_q. This only affects cases when sampling.use_loss_masking_spans=True
|
if batch.loss_masking_spans is not None: |
🎯 Expected Behavior
Loss mask is applied
📝 Additional Context
Loss masking were introduced in #113, but the implementation did not consider sequence_k/sequence_q