Skip to content

Commit 18d3000

Browse files
Juarez BochiLaurentMazare
Juarez Bochi
andauthored
Add support to UL2 model family (huggingface#1300)
* Add support to UL2 model family * Update docs with UL2 * Create ActivationWithOptionalGating to avoid polluting activations * Also refactor quantized t5 * Remove useless conversion * Revert Activation::NewGelu name change * Remove useless return * Apply rustfmt and clippy recommendations * Reuse t5::ActivationWithOptionalGating in quantized version * (cosmetic change) use a match rather than ifs + avoid early returns. --------- Co-authored-by: Laurent <[email protected]>
1 parent 6958384 commit 18d3000

File tree

9 files changed

+71
-15
lines changed

9 files changed

+71
-15
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ rayon = "1.7.0"
5050
rusttype = { version = "0.9", default-features = false }
5151
safetensors = "0.3.1"
5252
serde = { version = "1.0.171", features = ["derive"] }
53+
serde_plain = "1.0.2"
5354
serde_json = "1.0.99"
5455
thiserror = "1"
5556
tokenizers = { version = "0.13.4", default-features = false }

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ If you have an addition to this list, please submit a pull request.
175175
- Replit-code-v1.5-3B.
176176
- Bert.
177177
- Text to text.
178-
- T5 and its variants: FlanT5, MADLAD400 (translation), CoEdit (Grammar correction).
178+
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
179179
- Marian MT (Machine Translation).
180180
- Whisper (multi-lingual support).
181181
- Text to image.

candle-core/src/op.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,8 @@ unary_op!(Recip, "recip", v, v.recip());
551551
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
552552
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
553553

554-
/// `gelu` operation
554+
/// Tanh based approximation of the `gelu` operation
555+
/// GeluErf is the more precise one.
555556
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
556557
impl UnaryOpT for Gelu {
557558
const NAME: &'static str = "gelu";

candle-examples/examples/t5/README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ $ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate
99
9 tokens generated (2.42 token/s)
1010
```
1111

12+
Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported.
13+
1214
## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)
1315

1416
MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
@@ -22,7 +24,7 @@ cargo run --example t5 --release -- \
2224
Wie geht es dir, mein Freund?
2325
```
2426

25-
## Sentence embedding example:
27+
## Sentence embedding example
2628

2729
```bash
2830
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."

candle-examples/examples/t5/main.rs

+11
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,17 @@ impl T5ModelBuilder {
104104
api.get("model-00004-of-00005.safetensors")?,
105105
api.get("model-00005-of-00005.safetensors")?,
106106
]
107+
} else if model_id == "google/flan-ul2" {
108+
vec![
109+
api.get("model-00001-of-00008.safetensors")?,
110+
api.get("model-00002-of-00008.safetensors")?,
111+
api.get("model-00003-of-00008.safetensors")?,
112+
api.get("model-00004-of-00008.safetensors")?,
113+
api.get("model-00005-of-00008.safetensors")?,
114+
api.get("model-00006-of-00008.safetensors")?,
115+
api.get("model-00007-of-00008.safetensors")?,
116+
api.get("model-00008-of-00008.safetensors")?,
117+
]
107118
} else {
108119
vec![api.get("model.safetensors")?]
109120
};

candle-nn/src/activation.rs

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use serde::Deserialize;
66
pub enum Activation {
77
#[default]
88
Gelu,
9-
#[serde(rename = "gated-gelu")]
109
NewGelu,
1110
Relu,
1211
Relu2,

candle-transformers/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ rand = { workspace = true }
2121
rayon = { workspace = true }
2222
serde = { workspace = true }
2323
serde_json = { workspace = true }
24+
serde_plain = { workspace = true }
2425
tracing = { workspace = true }
2526
wav = { workspace = true }
2627

candle-transformers/src/models/quantized_t5.rs

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// T5 Text Model, quantized version
22
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
33

4+
use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating};
45
use crate::models::with_tracing::QMatMul;
56
use crate::quantized_nn::Embedding;
67
pub use crate::quantized_var_builder::VarBuilder;
@@ -54,8 +55,8 @@ pub struct Config {
5455
dropout_rate: f64,
5556
layer_norm_epsilon: f64,
5657
initializer_factor: f64,
57-
#[serde(default)]
58-
feed_forward_proj: Activation,
58+
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
59+
pub feed_forward_proj: ActivationWithOptionalGating,
5960
#[serde(default = "default_tie_word_embeddings")]
6061
tie_word_embeddings: bool,
6162
#[serde(default = "default_is_decoder")]
@@ -83,7 +84,10 @@ impl Default for Config {
8384
dropout_rate: 0.1,
8485
layer_norm_epsilon: 1e-6,
8586
initializer_factor: 1.0,
86-
feed_forward_proj: Activation::Relu,
87+
feed_forward_proj: ActivationWithOptionalGating {
88+
gated: false,
89+
activation: Activation::Relu,
90+
},
8791
tie_word_embeddings: true,
8892
is_decoder: false,
8993
is_encoder_decoder: true,
@@ -176,7 +180,7 @@ impl T5DenseGatedActDense {
176180
wi_0,
177181
wi_1,
178182
wo,
179-
act: Activation::NewGelu,
183+
act: cfg.feed_forward_proj.activation,
180184
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
181185
})
182186
}
@@ -205,7 +209,7 @@ impl T5LayerFF {
205209
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
206210
let layer_norm =
207211
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
208-
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
212+
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
209213
(
210214
None,
211215
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),

candle-transformers/src/models/t5.rs

+43-6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,37 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
3737
Ok(m)
3838
}
3939

40+
#[derive(Debug, Deserialize, Default, Clone, PartialEq)]
41+
pub struct ActivationWithOptionalGating {
42+
pub gated: bool,
43+
pub activation: candle_nn::Activation,
44+
}
45+
46+
pub fn deserialize_feed_forward_proj_activation<'de, D>(
47+
deserializer: D,
48+
) -> std::result::Result<ActivationWithOptionalGating, D::Error>
49+
where
50+
D: serde::de::Deserializer<'de>,
51+
{
52+
match String::deserialize(deserializer)?.as_str() {
53+
"gated-gelu" => Ok(ActivationWithOptionalGating {
54+
gated: true,
55+
activation: candle_nn::Activation::NewGelu,
56+
}),
57+
"gated-silu" => Ok(ActivationWithOptionalGating {
58+
gated: true,
59+
activation: candle_nn::Activation::Silu,
60+
}),
61+
buf => {
62+
let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?;
63+
Ok(ActivationWithOptionalGating {
64+
gated: false,
65+
activation,
66+
})
67+
}
68+
}
69+
}
70+
4071
#[derive(Debug, Clone, PartialEq, Deserialize)]
4172
pub struct Config {
4273
vocab_size: usize,
@@ -52,8 +83,8 @@ pub struct Config {
5283
dropout_rate: f64,
5384
layer_norm_epsilon: f64,
5485
initializer_factor: f64,
55-
#[serde(default)]
56-
feed_forward_proj: Activation,
86+
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
87+
feed_forward_proj: ActivationWithOptionalGating,
5788
#[serde(default = "default_tie_word_embeddings")]
5889
tie_word_embeddings: bool,
5990
#[serde(default = "default_is_decoder")]
@@ -81,7 +112,10 @@ impl Default for Config {
81112
dropout_rate: 0.1,
82113
layer_norm_epsilon: 1e-6,
83114
initializer_factor: 1.0,
84-
feed_forward_proj: Activation::Relu,
115+
feed_forward_proj: ActivationWithOptionalGating {
116+
gated: false,
117+
activation: Activation::Relu,
118+
},
85119
tie_word_embeddings: true,
86120
is_decoder: false,
87121
is_encoder_decoder: true,
@@ -102,7 +136,10 @@ impl Config {
102136
d_model: 768,
103137
dropout_rate: 0.1,
104138
eos_token_id: 1,
105-
feed_forward_proj: Activation::Relu,
139+
feed_forward_proj: ActivationWithOptionalGating {
140+
gated: false,
141+
activation: Activation::Relu,
142+
},
106143
tie_word_embeddings: true,
107144
initializer_factor: 1.0,
108145
is_decoder: false,
@@ -202,7 +239,7 @@ impl T5DenseGatedActDense {
202239
wi_0,
203240
wi_1,
204241
wo,
205-
act: Activation::NewGelu,
242+
act: cfg.feed_forward_proj.activation,
206243
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
207244
})
208245
}
@@ -231,7 +268,7 @@ impl T5LayerFF {
231268
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
232269
let layer_norm =
233270
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
234-
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
271+
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
235272
(
236273
None,
237274
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),

0 commit comments

Comments
 (0)