@@ -108,22 +108,25 @@ def call(self, inputs):
108
108
class Encoder (layers .Layer ):
109
109
"""Transformer Encoder."""
110
110
111
- def __init__ (self ,
112
- num_layers ,
113
- mlp_dim ,
114
- num_heads ,
115
- dropout_rate = 0.1 ,
116
- attention_dropout_rate = 0.1 ,
117
- kernel_regularizer = None ,
118
- inputs_positions = None ,
119
- init_stochastic_depth_rate = 0.0 ,
120
- kernel_initializer = 'glorot_uniform' ,
121
- add_pos_embed = True ,
122
- pos_embed_origin_shape = None ,
123
- pos_embed_target_shape = None ,
124
- layer_scale_init_value = 0.0 ,
125
- transformer_partition_dims = None ,
126
- ** kwargs ):
111
+ def __init__ (
112
+ self ,
113
+ num_layers ,
114
+ mlp_dim ,
115
+ num_heads ,
116
+ dropout_rate = 0.1 ,
117
+ attention_dropout_rate = 0.1 ,
118
+ kernel_regularizer = None ,
119
+ inputs_positions = None ,
120
+ init_stochastic_depth_rate = 0.0 ,
121
+ kernel_initializer = 'glorot_uniform' ,
122
+ add_pos_embed = True ,
123
+ pos_embed_origin_shape = None ,
124
+ pos_embed_target_shape = None ,
125
+ layer_scale_init_value = 0.0 ,
126
+ transformer_partition_dims = None ,
127
+ output_attention_scores = False ,
128
+ ** kwargs ,
129
+ ):
127
130
super ().__init__ (** kwargs )
128
131
self ._num_layers = num_layers
129
132
self ._mlp_dim = mlp_dim
@@ -139,6 +142,7 @@ def __init__(self,
139
142
self ._pos_embed_target_shape = pos_embed_target_shape
140
143
self ._layer_scale_init_value = layer_scale_init_value
141
144
self ._transformer_partition_dims = transformer_partition_dims
145
+ self ._output_attention_scores = output_attention_scores
142
146
143
147
def build (self , input_shape ):
144
148
if self ._add_pos_embed :
@@ -163,10 +167,13 @@ def build(self, input_shape):
163
167
kernel_initializer = self ._kernel_initializer ,
164
168
norm_first = True ,
165
169
stochastic_depth_drop_rate = nn_layers .get_stochastic_depth_rate (
166
- self ._init_stochastic_depth_rate , i + 1 , self ._num_layers ),
170
+ self ._init_stochastic_depth_rate , i + 1 , self ._num_layers
171
+ ),
167
172
norm_epsilon = 1e-6 ,
168
173
layer_scale_init_value = self ._layer_scale_init_value ,
169
- transformer_partition_dims = self ._transformer_partition_dims )
174
+ transformer_partition_dims = self ._transformer_partition_dims ,
175
+ return_attention_scores = self ._output_attention_scores ,
176
+ )
170
177
self ._encoder_layers .append (encoder_layer )
171
178
self ._norm = layers .LayerNormalization (epsilon = 1e-6 )
172
179
super ().build (input_shape )
@@ -177,9 +184,16 @@ def call(self, inputs, training=None):
177
184
x = self ._pos_embed (x , inputs_positions = self ._inputs_positions )
178
185
x = self ._dropout (x , training = training )
179
186
187
+ attention_scores = None # Needed to suppress undefined-variable warning.
180
188
for encoder_layer in self ._encoder_layers :
181
- x = encoder_layer (x , training = training )
189
+ if self ._output_attention_scores :
190
+ x , attention_scores = encoder_layer (x , training = training )
191
+ else :
192
+ x = encoder_layer (x , training = training )
182
193
x = self ._norm (x )
194
+
195
+ if self ._output_attention_scores :
196
+ return x , attention_scores
183
197
return x
184
198
185
199
def get_config (self ):
@@ -199,6 +213,7 @@ def get_config(self):
199
213
'pos_embed_target_shape' : self ._pos_embed_target_shape ,
200
214
'layer_scale_init_value' : self ._layer_scale_init_value ,
201
215
'transformer_partition_dims' : self ._transformer_partition_dims ,
216
+ 'output_attention_scores' : self ._output_attention_scores ,
202
217
}
203
218
config .update (updates )
204
219
return config
@@ -227,6 +242,7 @@ def __init__(
227
242
pos_embed_shape : Optional [Tuple [int , int ]] = None ,
228
243
layer_scale_init_value : float = 0.0 ,
229
244
transformer_partition_dims : Optional [Tuple [int , int , int , int ]] = None ,
245
+ output_attention_scores : bool = False ,
230
246
):
231
247
"""VisionTransformer initialization function."""
232
248
self ._mlp_dim = mlp_dim
@@ -265,20 +281,29 @@ def __init__(
265
281
if pooler == 'token' :
266
282
x = TokenLayer (name = 'cls' )(x )
267
283
268
- x = Encoder (
284
+ encoder_output = Encoder (
269
285
num_layers = num_layers ,
270
286
mlp_dim = mlp_dim ,
271
287
num_heads = num_heads ,
272
288
dropout_rate = dropout_rate ,
273
289
attention_dropout_rate = attention_dropout_rate ,
274
290
kernel_regularizer = kernel_regularizer ,
275
- kernel_initializer = 'glorot_uniform' if original_init else dict (
276
- class_name = 'TruncatedNormal' , config = dict (stddev = .02 )),
291
+ kernel_initializer = 'glorot_uniform'
292
+ if original_init
293
+ else dict (class_name = 'TruncatedNormal' , config = dict (stddev = 0.02 )),
277
294
init_stochastic_depth_rate = init_stochastic_depth_rate ,
278
295
pos_embed_origin_shape = pos_embed_shape ,
279
296
pos_embed_target_shape = pos_embed_target_shape ,
280
- layer_scale_init_value = layer_scale_init_value )(
281
- x )
297
+ layer_scale_init_value = layer_scale_init_value ,
298
+ output_attention_scores = output_attention_scores ,
299
+ )(x )
300
+
301
+ endpoints = {}
302
+ if output_attention_scores :
303
+ x , attention_scores = encoder_output
304
+ endpoints ['attention_scores' ] = attention_scores
305
+ else :
306
+ x = encoder_output
282
307
283
308
if pooler == 'token' :
284
309
output_feature = x [:, 1 :]
@@ -292,7 +317,6 @@ def __init__(
292
317
else :
293
318
raise ValueError (f'unrecognized pooler type: { pooler } ' )
294
319
295
- endpoints = {}
296
320
if output_2d_feature_maps :
297
321
# Use the closest feature level.
298
322
feat_level = round (math .log2 (patch_size ))
@@ -376,4 +400,6 @@ def build_vit(input_specs,
376
400
output_2d_feature_maps = backbone_cfg .output_2d_feature_maps ,
377
401
layer_scale_init_value = backbone_cfg .layer_scale_init_value ,
378
402
pos_embed_shape = backbone_cfg .pos_embed_shape ,
379
- transformer_partition_dims = backbone_cfg .transformer_partition_dims )
403
+ transformer_partition_dims = backbone_cfg .transformer_partition_dims ,
404
+ output_attention_scores = backbone_cfg .output_attention_scores ,
405
+ )
0 commit comments