Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unintuitive Logic in masked_fill Function of DistilBERT Model Implementation #2721

Open
sondalex opened this issue Jan 16, 2025 · 1 comment

Comments

@sondalex
Copy link

sondalex commented Jan 16, 2025

masked_fill function of distilbert model implementation has currently unintuitive logic

fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}

In the current setup, the user must invert the attention mask obtained from the tokenizer before passing it to the model.forward function. This requirement can be confusing as it differs from transformers implementation.

...
let text: Vec<&str>  = vec![...];
let encoded = tokenizer.encode_batch(text.to_vec().clone(), true)?;
let input_ids = encoded.iter().map(|v| v.get_ids().to_vec()).collect::<Vec<_>>();
let input_ids = Tensor::new(input_ids, &device)?;
let attention_mask = encoded.iter().map(|encoding| encoding.get_attention_mask().to_vec()).collect::<Vec<_>>();
let attention_mask = Tensor::new(attention_mask, &device)?;

let (batch_size, feature_size) = input_ids.dims2()?;

// Invert the attention mask for correct behavior --> Counterintuitive
let attention_mask = attention_mask.eq(0 as u32)?.reshape((batch_size, 1, 1, feature_size))?;

let output = model.forward(&input_ids, &attention_mask)?;
...

Proposition:

Replace masked_fill function with:

fn masked_fill(on_true: &Tensor, mask: &Tensor, on_false: f32) -> Result<Tensor> {
    let shape = mask.shape();
    let on_false = Tensor::new(on_false, on_true.device())?.broadcast_as(shape.dims())?;
    let m = mask.where_cond(&on_true, &on_false)?;
    Ok(m)
}
@fbilhaut
Copy link

fbilhaut commented Jan 30, 2025

I second this. The masked_fill function is indeed counterintuitive (though its name isn’t that telling either ^^). In any case, the fact that the forward function actually expects an inverted mask for that reason can be quite troublesome.

I imagine this example isn’t the most critical one, as it’s not the most up-to-date, but I think many people (like me) might try it for initial tests with Candle, since it’s well known. In that sense, the example doesn’t really serve its purpose well because it’s misleading.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants