Skip to content

Commit 5893492

Browse files
committed
CHG: Add L2 normalization layer
1 parent cd08dcd commit 5893492

File tree

4 files changed

+15
-1
lines changed

4 files changed

+15
-1
lines changed

kaffe/layers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
'MemoryData': shape_mem_data,
3636
'MultinomialLogisticLoss': shape_scalar,
3737
'MVN': shape_not_implemented,
38+
'Normalize': shape_identity,
3839
'Pooling': shape_pool,
3940
'Power': shape_identity,
4041
'ReLU': shape_identity,

kaffe/tensorflow/network.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,13 @@ def batch_normalization(self, input, name, scale_offset=True, relu=False):
253253
def dropout(self, input, keep_prob, name):
254254
keep = 1 - self.use_dropout + (self.use_dropout * keep_prob)
255255
return tf.nn.dropout(input, keep, name=name)
256+
257+
@layer
258+
def l2_normalize(self, input):
259+
# NOTE: Currently, only inference is supported
260+
with tf.variable_scope(name) as scope:
261+
shp = input.get_shape().as_list()
262+
outputs = tf.nn.l2_normalize(x=input, axis=-1)
263+
alpha = self.make_var('alpha', shape=[-1:])
264+
outputs = tf.multiply(outputs, alpha)
265+
return outputs

kaffe/tensorflow/transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def map_eltwise(self, node):
160160
return TensorFlowNode(operations[op_code])
161161
except KeyError:
162162
raise KaffeError('Unknown elementwise operation: {}'.format(op_code))
163+
164+
def map_normalize(self, node):
165+
return TensorFlowNode('l2_normalize')
163166

164167
def commit(self, chains):
165168
return chains

kaffe/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def __call__(self, graph):
282282
names = ('mean', 'variance')
283283
if len(node.data) == 4:
284284
names += ('scale', 'offset')
285-
elif node.kind == NodeKind.PReLU:
285+
elif node.kind == NodeKind.PReLU or node.kind == NodeKind.Normalize:
286286
names = ('alpha',)
287287
else:
288288
print_stderr('WARNING: Unhandled parameters: {}'.format(node.kind))

0 commit comments

Comments
 (0)