Skip to content

Commit 8def90c

Browse files
committed
CHG: Add Reshape + flatten from ethereon#147
1 parent 5893492 commit 8def90c

File tree

4 files changed

+75
-7
lines changed

4 files changed

+75
-7
lines changed

kaffe/layers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
'EuclideanLoss': shape_scalar,
2323
'Eltwise': shape_identity,
2424
'Exp': shape_identity,
25-
'Flatten': shape_not_implemented,
25+
'Flatten': flatten_shape,
2626
'HDF5Data': shape_data,
2727
'HDF5Output': shape_identity,
2828
'HingeLoss': shape_scalar,
@@ -38,7 +38,9 @@
3838
'Normalize': shape_identity,
3939
'Pooling': shape_pool,
4040
'Power': shape_identity,
41+
'PReLU': shape_identity,
4142
'ReLU': shape_identity,
43+
'Reshape': reshape_shape,
4244
'Scale': shape_identity,
4345
'Sigmoid': shape_identity,
4446
'SigmoidCrossEntropyLoss': shape_scalar,
@@ -49,8 +51,7 @@
4951
'Slice': shape_not_implemented,
5052
'TanH': shape_identity,
5153
'WindowData': shape_not_implemented,
52-
'Threshold': shape_identity,
53-
'PReLU': shape_identity,
54+
'Threshold': shape_identity
5455
}
5556

5657
LAYER_TYPES = LAYER_DESCRIPTORS.keys()

kaffe/shapes.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,36 @@ def shape_concat(node):
7070
return tuple(output_shape)
7171

7272

73+
def reshape_shape(node) :
74+
input_shape = node.get_only_parent().output_shape
75+
input_shape_pr = input_shape.channels*input_shape.height*input_shape.width
76+
input_shape_arr = [input_shape.batch_size,input_shape.channels,input_shape.height,input_shape.width]
77+
pr = 1
78+
axes = node.parameters.shape.dim
79+
new_shape = [input_shape.batch_size,1,1,1]
80+
for j in range(1,len(axes)) :
81+
if axes[j] == 0 :
82+
new_shape[j] = input_shape_arr[j]
83+
pr *= new_shape[j]
84+
elif not axes[j] == -1 :
85+
new_shape[j] = int(axes[j])
86+
pr *= new_shape[j]
87+
elif axes[j] == -1 :
88+
new_shape[j] = -1
89+
90+
for j in range(1,len(new_shape)) :
91+
if new_shape[j] == -1 :
92+
new_shape[j] = int(input_shape_pr/pr)
93+
94+
return TensorShape(new_shape[0],new_shape[1],new_shape[2],new_shape[3])
95+
96+
97+
def flatten_shape(node) :
98+
shape1 = node.get_only_parent().output_shape
99+
100+
return TensorShape(shape1.batch_size,shape1.channels*shape1.height*shape1.width,1,1)
101+
102+
73103
def shape_convolution(node):
74104
return get_strided_kernel_output_shape(node, math.floor)
75105

kaffe/tensorflow/network.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,27 @@ def dropout(self, input, keep_prob, name):
255255
return tf.nn.dropout(input, keep, name=name)
256256

257257
@layer
258-
def l2_normalize(self, input):
258+
def reshape(self,input,b,x,y,c,name,transpose = False) :
259+
if transpose :
260+
input = tf.reshape(input,[-1,c,x,y])
261+
return tf.transpose(input,(0,2,3,1))
262+
263+
return tf.reshape(input,[-1,x,y,c],name = name)
264+
265+
@layer
266+
def flatten(self,input,name):
267+
input = tf.transpose(input,(0,3,1,2))
268+
dim = 1
269+
for d in input.get_shape()[1:].as_list():
270+
dim *= d
271+
return tf.reshape(input,[-1,dim],name = name)
272+
273+
@layer
274+
def l2_normalize(self, input, name):
259275
# NOTE: Currently, only inference is supported
260276
with tf.variable_scope(name) as scope:
261277
shp = input.get_shape().as_list()
262278
outputs = tf.nn.l2_normalize(x=input, axis=-1)
263-
alpha = self.make_var('alpha', shape=[-1:])
279+
alpha = self.make_var('alpha', shape=shp[-1:])
264280
outputs = tf.multiply(outputs, alpha)
265281
return outputs

kaffe/tensorflow/transformer.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ def map_lrn(self, node):
142142
return TensorFlowNode('lrn', int(params.local_size / 2), alpha, params.beta)
143143

144144
def map_concat(self, node):
145-
axis = (2, 3, 1, 0)[node.parameters.axis]
145+
if node.parents[0].kind == 'Flatten':
146+
axis = node.parameters.axis
147+
else :
148+
axis = (2, 3, 1, 0)[node.parameters.axis]
146149
return TensorFlowNode('concat', axis)
147150

148151
def map_dropout(self, node):
@@ -161,9 +164,27 @@ def map_eltwise(self, node):
161164
except KeyError:
162165
raise KaffeError('Unknown elementwise operation: {}'.format(op_code))
163166

167+
def map_reshape(self,node) :
168+
shape = node.output_shape
169+
new_shape = [0]*4
170+
new_shape[0] = shape[0]
171+
new_shape[1] = shape[2]
172+
new_shape[2] = shape[3]
173+
new_shape[3] = shape[1]
174+
parent_shape = node.get_only_parent().output_shape
175+
176+
## we need to transpose again if a fc layer is reshaped to conv
177+
kwargs = {'transpose' : False}
178+
if parent_shape.height == 1 and parent_shape.width == 1 :
179+
kwargs['transpose'] = True
180+
return TensorFlowNode('reshape',new_shape[0],new_shape[1],new_shape[2],new_shape[3],**kwargs)
181+
182+
def map_flatten(self,node) :
183+
return TensorFlowNode('flatten')
184+
164185
def map_normalize(self, node):
165186
return TensorFlowNode('l2_normalize')
166-
187+
167188
def commit(self, chains):
168189
return chains
169190

0 commit comments

Comments
 (0)