@@ -11,6 +11,7 @@ class ModelConfig:
1111 patch_size : tuple [int , int ]
1212 num_channels : int
1313 hidden_dim : int
14+ attn_dropout_prob : float
1415 dropout_prob : float
1516 num_heads : int
1617 mlp_dim : int
@@ -25,6 +26,7 @@ def vit_p16_224(cls):
2526 patch_size = (16 , 16 ),
2627 num_channels = 3 ,
2728 hidden_dim = 768 ,
29+ attn_dropout_prob = 0.0 ,
2830 dropout_prob = 0.0 ,
2931 num_heads = 12 ,
3032 mlp_dim = 3072 ,
@@ -66,9 +68,9 @@ def __init__(self, cfg: ModelConfig, *, rngs: nnx.Rngs):
6668 )
6769 self .cls_token = nnx .Variable (jax .random .normal (rngs .params (), (1 , 1 , cfg .hidden_dim )))
6870 self .pos_embeddings = nnx .Variable (jax .random .normal (rngs .params (), (1 , num_patches + 1 , cfg .hidden_dim )))
69- self .dropout = nnx .Dropout (cfg .dropout_prob , rngs = rngs )
71+ self .dropout = nnx .Dropout (cfg .dropout_prob )
7072
71- def __call__ (self , pixel_values : jnp .ndarray ) -> jnp .ndarray :
73+ def __call__ (self , pixel_values : jnp .ndarray , * , rngs : nnx . Rngs | None ) -> jnp .ndarray :
7274 embeddings = self .projection (pixel_values )
7375 b , h , w , c = embeddings .shape
7476 embeddings = embeddings .reshape (b , h * w , c )
@@ -89,49 +91,54 @@ def __call__(self, pixel_values: jnp.ndarray) -> jnp.ndarray:
8991 embeddings = jnp .concatenate ((cls_tokens , embeddings ), axis = 1 )
9092
9193 embeddings = embeddings + current_pos_embeddings
92- embeddings = self .dropout (embeddings )
94+ embeddings = self .dropout (embeddings , rngs = rngs )
9395 return embeddings
9496
9597
9698class TransformerEncoder (nnx .Module ):
9799 def __init__ (self , cfg : ModelConfig , * , rngs : nnx .Rngs ):
98100 self .attention = nnx .MultiHeadAttention (
99- num_heads = cfg .num_heads , in_features = cfg .hidden_dim , decode = False , rngs = rngs
101+ num_heads = cfg .num_heads ,
102+ in_features = cfg .hidden_dim ,
103+ dropout_rate = cfg .attn_dropout_prob ,
104+ decode = False ,
105+ rngs = rngs ,
100106 )
101107 self .linear1 = nnx .Linear (cfg .hidden_dim , cfg .mlp_dim , rngs = rngs )
102108 self .linear2 = nnx .Linear (cfg .mlp_dim , cfg .hidden_dim , rngs = rngs )
103- self .dropout = nnx .Dropout (cfg .dropout_prob , rngs = rngs )
109+ self .dropout = nnx .Dropout (cfg .dropout_prob )
104110 self .layernorm_before = nnx .LayerNorm (cfg .hidden_dim , epsilon = cfg .eps , rngs = rngs )
105111 self .layernorm_after = nnx .LayerNorm (cfg .hidden_dim , epsilon = cfg .eps , rngs = rngs )
106112
107- def __call__ (self , hidden_states , head_mask = None ):
113+ def __call__ (self , hidden_states , head_mask = None , * , rngs : nnx . Rngs | None ):
108114 hidden_states_norm = self .layernorm_before (hidden_states )
109- attention_output = self .attention (hidden_states_norm , head_mask )
115+ attention_output = self .attention (hidden_states_norm , head_mask , rngs = rngs )
110116 hidden_states = attention_output + hidden_states
111117 layer_output = self .layernorm_after (hidden_states )
112118 layer_output = jax .nn .gelu (self .linear1 (layer_output ))
113119 layer_output = self .linear2 (layer_output )
114- layer_output = self .dropout (layer_output )
120+ layer_output = self .dropout (layer_output , rngs = rngs )
115121 layer_output += hidden_states
116122 return layer_output
117123
118124
119125class ViTClassificationModel (nnx .Module ):
120126 def __init__ (self , cfg : ModelConfig , * , rngs : nnx .Rngs ):
121127 self .pos_embeddings = Embeddings (cfg , rngs = rngs )
122- self .layers = nnx .Sequential ( * [TransformerEncoder (cfg , rngs = rngs ) for _ in range (cfg .num_layers )])
128+ self .layers = nnx .List ( [TransformerEncoder (cfg , rngs = rngs ) for _ in range (cfg .num_layers )])
123129 self .ln = nnx .LayerNorm (cfg .hidden_dim , epsilon = cfg .eps , rngs = rngs )
124130 self .classifier = nnx .Linear (cfg .hidden_dim , cfg .num_labels , rngs = rngs )
125131
126- def __call__ (self , x ):
127- x = self .pos_embeddings (x )
128- x = self .layers (x )
132+ def __call__ (self , x , * , rngs : nnx .Rngs | None ):
133+ x = self .pos_embeddings (x , rngs = rngs )
134+ for layer in self .layers :
135+ x = layer (x , rngs = rngs )
129136 x = self .ln (x )
130137 x = self .classifier (x [:, 0 , :])
131138 return x
132139
133140
134141@jax .jit
135- def forward (graphdef : nnx .GraphDef [nnx .Module ], state : nnx .State , x : jax .Array ) -> jax .Array :
142+ def forward (graphdef : nnx .GraphDef [nnx .Module ], state : nnx .State , x : jax .Array , rngs : nnx . Rngs ) -> jax .Array :
136143 model = nnx .merge (graphdef , state )
137- return model (x )
144+ return model (x , rngs = rngs )
0 commit comments