Skip to content

Commit 04b12d3

Browse files
No public description
PiperOrigin-RevId: 879570384
1 parent 1b152ff commit 04b12d3

6 files changed

Lines changed: 90 additions & 10 deletions

File tree

official/modeling/activations/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from official.modeling.activations.mish import mish
1818
from official.modeling.activations.relu import relu6
1919
from official.modeling.activations.sigmoid import hard_sigmoid
20+
from official.modeling.activations.squared_relu import squared_relu
2021
from official.modeling.activations.swish import hard_swish
2122
from official.modeling.activations.swish import identity
2223
from official.modeling.activations.swish import simple_swish
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Customized Squared ReLU activation."""
16+
17+
import tensorflow as tf, tf_keras
18+
19+
20+
@tf_keras.utils.register_keras_serializable(package='Text')
21+
def squared_relu(features: tf.Tensor) -> tf.Tensor:
22+
"""Computes the Squared ReLU activation function.
23+
24+
Args:
25+
features: A `Tensor` representing preactivation values.
26+
27+
Returns:
28+
The activation value.
29+
"""
30+
features_tensor = tf.convert_to_tensor(features)
31+
return tf.math.square(tf.nn.relu(features_tensor))
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for the customized Squared ReLU activation."""
16+
17+
import numpy as np
18+
import tensorflow as tf, tf_keras
19+
20+
from official.modeling import activations
21+
22+
23+
class CustomizedSquaredReluTest(tf.test.TestCase):
24+
25+
def _squared_relu_nn(self, x):
26+
x = np.float32(x)
27+
return tf.math.square(tf.nn.relu(x))
28+
29+
def test_squared_relu(self):
30+
features = [[0.25, 0, -0.25], [-1, -2, 3]]
31+
customized_squared_relu_data = activations.squared_relu(features)
32+
squared_relu_data = self._squared_relu_nn(features)
33+
self.assertAllClose(customized_squared_relu_data, squared_relu_data)
34+
35+
36+
if __name__ == '__main__':
37+
tf.test.main()

official/modeling/tf_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def get_activation(identifier, use_keras_layer=False, **kwargs):
120120
"hard_sigmoid": activations.hard_sigmoid,
121121
"mish": activations.mish,
122122
"gelu": functools.partial(tf.nn.gelu, **kwargs),
123+
"squared_relu": activations.squared_relu,
123124
}
124125
if identifier in keras_layer_allowlist:
125126
return tf_keras.layers.Activation(keras_layer_allowlist[identifier])
@@ -131,6 +132,7 @@ def get_activation(identifier, use_keras_layer=False, **kwargs):
131132
"hard_sigmoid": activations.hard_sigmoid,
132133
"identity": activations.identity,
133134
"mish": activations.mish,
135+
"squared_relu": activations.squared_relu,
134136
}
135137
if identifier in name_to_fn:
136138
return tf_keras.activations.get(name_to_fn[identifier])

official/nlp/modeling/layers/mobile_bert_layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import tensorflow as tf, tf_keras
1717

1818
from official.modeling import tf_utils
19-
2019
from official.nlp.modeling.layers import on_device_embedding
2120
from official.nlp.modeling.layers import position_embedding
2221

@@ -288,11 +287,12 @@ def __init__(self,
288287
layer_name = layer_prefix + '/intermediate_dense'
289288
intermediate_layer = tf_keras.layers.EinsumDense(
290289
'abc,cd->abd',
291-
activation=self.intermediate_act_fn,
290+
activation=tf_utils.get_activation(self.intermediate_act_fn),
292291
output_shape=[None, self.intermediate_size],
293292
bias_axes='d',
294293
kernel_initializer=tf_utils.clone_initializer(self.initializer),
295-
name=layer_name)
294+
name=layer_name,
295+
)
296296
layer_name = layer_prefix + '/output_dense'
297297
output_layer = tf_keras.layers.EinsumDense(
298298
'abc,cd->abd',

official/nlp/modeling/layers/mobile_bert_layers_test.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ def test_embedding_layer_with_token_type(self):
4242
output = layer(input_seq, token_type)
4343
output_shape = output.shape.as_list()
4444
expected_shape = [1, 4, 16]
45-
self.assertListEqual(output_shape, expected_shape, msg=None)
45+
self.assertListEqual(output_shape, expected_shape)
4646

4747
def test_embedding_layer_without_token_type(self):
4848
layer = mobile_bert_layers.MobileBertEmbedding(10, 8, 2, 16)
4949
input_seq = tf.Variable([[2, 3, 4, 5]])
5050
output = layer(input_seq)
5151
output_shape = output.shape.as_list()
5252
expected_shape = [1, 4, 16]
53-
self.assertListEqual(output_shape, expected_shape, msg=None)
53+
self.assertListEqual(output_shape, expected_shape)
5454

5555
def test_embedding_layer_get_config(self):
5656
layer = mobile_bert_layers.MobileBertEmbedding(
@@ -72,7 +72,7 @@ def test_no_norm(self):
7272
output = layer(feature)
7373
output_shape = output.shape.as_list()
7474
expected_shape = [2, 3, 4]
75-
self.assertListEqual(output_shape, expected_shape, msg=None)
75+
self.assertListEqual(output_shape, expected_shape)
7676

7777
@parameterized.named_parameters(('with_kq_shared_bottleneck', False),
7878
('without_kq_shared_bottleneck', True))
@@ -83,7 +83,17 @@ def test_transfomer_kq_shared_bottleneck(self, is_kq_shared):
8383
output = layer(feature)
8484
output_shape = output.shape.as_list()
8585
expected_shape = [2, 3, 512]
86-
self.assertListEqual(output_shape, expected_shape, msg=None)
86+
self.assertListEqual(output_shape, expected_shape)
87+
88+
def test_transformer_with_squared_relu(self):
89+
feature = tf.random.uniform([2, 3, 512])
90+
layer = mobile_bert_layers.MobileBertTransformer(
91+
intermediate_act_fn='squared_relu'
92+
)
93+
output = layer(feature)
94+
output_shape = output.shape.as_list()
95+
expected_shape = [2, 3, 512]
96+
self.assertListEqual(output_shape, expected_shape)
8797

8898
def test_transfomer_with_mask(self):
8999
feature = tf.random.uniform([2, 3, 512])
@@ -94,7 +104,7 @@ def test_transfomer_with_mask(self):
94104
output = layer(feature, input_mask)
95105
output_shape = output.shape.as_list()
96106
expected_shape = [2, 3, 512]
97-
self.assertListEqual(output_shape, expected_shape, msg=None)
107+
self.assertListEqual(output_shape, expected_shape)
98108

99109
def test_transfomer_return_attention_score(self):
100110
sequence_length = 5
@@ -104,8 +114,7 @@ def test_transfomer_return_attention_score(self):
104114
num_attention_heads=num_attention_heads)
105115
_, attention_score = layer(feature, return_attention_scores=True)
106116
expected_shape = [2, num_attention_heads, sequence_length, sequence_length]
107-
self.assertListEqual(
108-
attention_score.shape.as_list(), expected_shape, msg=None)
117+
self.assertListEqual(attention_score.shape.as_list(), expected_shape)
109118

110119
def test_transformer_get_config(self):
111120
layer = mobile_bert_layers.MobileBertTransformer(

0 commit comments

Comments
 (0)