Skip to content

Commit b79f641

Browse files
No public description
PiperOrigin-RevId: 673432468
1 parent 5d3b4df commit b79f641

File tree

3 files changed

+65
-26
lines changed

3 files changed

+65
-26
lines changed

official/vision/configs/backbones.py

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class VisionTransformer(hyperparams.Config):
5252
layer_scale_init_value: float = 0.0
5353
# Transformer encoder spatial partition dimensions.
5454
transformer_partition_dims: Optional[Tuple[int, int, int, int]] = None
55+
# If True, output attention scores.
56+
output_attention_scores: bool = False
5557

5658

5759
@dataclasses.dataclass

official/vision/modeling/backbones/vit.py

+52-26
Original file line numberDiff line numberDiff line change
@@ -108,22 +108,25 @@ def call(self, inputs):
108108
class Encoder(layers.Layer):
109109
"""Transformer Encoder."""
110110

111-
def __init__(self,
112-
num_layers,
113-
mlp_dim,
114-
num_heads,
115-
dropout_rate=0.1,
116-
attention_dropout_rate=0.1,
117-
kernel_regularizer=None,
118-
inputs_positions=None,
119-
init_stochastic_depth_rate=0.0,
120-
kernel_initializer='glorot_uniform',
121-
add_pos_embed=True,
122-
pos_embed_origin_shape=None,
123-
pos_embed_target_shape=None,
124-
layer_scale_init_value=0.0,
125-
transformer_partition_dims=None,
126-
**kwargs):
111+
def __init__(
112+
self,
113+
num_layers,
114+
mlp_dim,
115+
num_heads,
116+
dropout_rate=0.1,
117+
attention_dropout_rate=0.1,
118+
kernel_regularizer=None,
119+
inputs_positions=None,
120+
init_stochastic_depth_rate=0.0,
121+
kernel_initializer='glorot_uniform',
122+
add_pos_embed=True,
123+
pos_embed_origin_shape=None,
124+
pos_embed_target_shape=None,
125+
layer_scale_init_value=0.0,
126+
transformer_partition_dims=None,
127+
output_attention_scores=False,
128+
**kwargs,
129+
):
127130
super().__init__(**kwargs)
128131
self._num_layers = num_layers
129132
self._mlp_dim = mlp_dim
@@ -139,6 +142,7 @@ def __init__(self,
139142
self._pos_embed_target_shape = pos_embed_target_shape
140143
self._layer_scale_init_value = layer_scale_init_value
141144
self._transformer_partition_dims = transformer_partition_dims
145+
self._output_attention_scores = output_attention_scores
142146

143147
def build(self, input_shape):
144148
if self._add_pos_embed:
@@ -163,10 +167,13 @@ def build(self, input_shape):
163167
kernel_initializer=self._kernel_initializer,
164168
norm_first=True,
165169
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
166-
self._init_stochastic_depth_rate, i + 1, self._num_layers),
170+
self._init_stochastic_depth_rate, i + 1, self._num_layers
171+
),
167172
norm_epsilon=1e-6,
168173
layer_scale_init_value=self._layer_scale_init_value,
169-
transformer_partition_dims=self._transformer_partition_dims)
174+
transformer_partition_dims=self._transformer_partition_dims,
175+
return_attention_scores=self._output_attention_scores,
176+
)
170177
self._encoder_layers.append(encoder_layer)
171178
self._norm = layers.LayerNormalization(epsilon=1e-6)
172179
super().build(input_shape)
@@ -177,9 +184,16 @@ def call(self, inputs, training=None):
177184
x = self._pos_embed(x, inputs_positions=self._inputs_positions)
178185
x = self._dropout(x, training=training)
179186

187+
attention_scores = None # Needed to suppress undefined-variable warning.
180188
for encoder_layer in self._encoder_layers:
181-
x = encoder_layer(x, training=training)
189+
if self._output_attention_scores:
190+
x, attention_scores = encoder_layer(x, training=training)
191+
else:
192+
x = encoder_layer(x, training=training)
182193
x = self._norm(x)
194+
195+
if self._output_attention_scores:
196+
return x, attention_scores
183197
return x
184198

185199
def get_config(self):
@@ -199,6 +213,7 @@ def get_config(self):
199213
'pos_embed_target_shape': self._pos_embed_target_shape,
200214
'layer_scale_init_value': self._layer_scale_init_value,
201215
'transformer_partition_dims': self._transformer_partition_dims,
216+
'output_attention_scores': self._output_attention_scores,
202217
}
203218
config.update(updates)
204219
return config
@@ -227,6 +242,7 @@ def __init__(
227242
pos_embed_shape: Optional[Tuple[int, int]] = None,
228243
layer_scale_init_value: float = 0.0,
229244
transformer_partition_dims: Optional[Tuple[int, int, int, int]] = None,
245+
output_attention_scores: bool = False,
230246
):
231247
"""VisionTransformer initialization function."""
232248
self._mlp_dim = mlp_dim
@@ -265,20 +281,29 @@ def __init__(
265281
if pooler == 'token':
266282
x = TokenLayer(name='cls')(x)
267283

268-
x = Encoder(
284+
encoder_output = Encoder(
269285
num_layers=num_layers,
270286
mlp_dim=mlp_dim,
271287
num_heads=num_heads,
272288
dropout_rate=dropout_rate,
273289
attention_dropout_rate=attention_dropout_rate,
274290
kernel_regularizer=kernel_regularizer,
275-
kernel_initializer='glorot_uniform' if original_init else dict(
276-
class_name='TruncatedNormal', config=dict(stddev=.02)),
291+
kernel_initializer='glorot_uniform'
292+
if original_init
293+
else dict(class_name='TruncatedNormal', config=dict(stddev=0.02)),
277294
init_stochastic_depth_rate=init_stochastic_depth_rate,
278295
pos_embed_origin_shape=pos_embed_shape,
279296
pos_embed_target_shape=pos_embed_target_shape,
280-
layer_scale_init_value=layer_scale_init_value)(
281-
x)
297+
layer_scale_init_value=layer_scale_init_value,
298+
output_attention_scores=output_attention_scores,
299+
)(x)
300+
301+
endpoints = {}
302+
if output_attention_scores:
303+
x, attention_scores = encoder_output
304+
endpoints['attention_scores'] = attention_scores
305+
else:
306+
x = encoder_output
282307

283308
if pooler == 'token':
284309
output_feature = x[:, 1:]
@@ -292,7 +317,6 @@ def __init__(
292317
else:
293318
raise ValueError(f'unrecognized pooler type: {pooler}')
294319

295-
endpoints = {}
296320
if output_2d_feature_maps:
297321
# Use the closest feature level.
298322
feat_level = round(math.log2(patch_size))
@@ -376,4 +400,6 @@ def build_vit(input_specs,
376400
output_2d_feature_maps=backbone_cfg.output_2d_feature_maps,
377401
layer_scale_init_value=backbone_cfg.layer_scale_init_value,
378402
pos_embed_shape=backbone_cfg.pos_embed_shape,
379-
transformer_partition_dims=backbone_cfg.transformer_partition_dims)
403+
transformer_partition_dims=backbone_cfg.transformer_partition_dims,
404+
output_attention_scores=backbone_cfg.output_attention_scores,
405+
)

official/vision/modeling/backbones/vit_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,17 @@ def test_posembedding_interpolation(self):
9595
output = network(inputs)['pre_logits']
9696
self.assertEqual(output.shape, [1, 1, 1, 768])
9797

98+
def test_attention_scores(self):
99+
tf_keras.backend.set_image_data_format('channels_last')
100+
input_specs = tf_keras.layers.InputSpec(shape=[2, 224, 224, 3])
101+
network = vit.VisionTransformer(
102+
input_specs=input_specs, output_attention_scores=True
103+
)
104+
105+
inputs = tf_keras.Input(shape=(224, 224, 3), batch_size=1)
106+
outputs = network(inputs)
107+
self.assertEqual(outputs['attention_scores'].shape, [1, 12, 197, 197])
108+
98109

99110
if __name__ == '__main__':
100111
tf.test.main()

0 commit comments

Comments
 (0)