-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayers_temp.py
60 lines (48 loc) · 1.81 KB
/
layers_temp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from __future__ import absolute_import, division
import tensorflow as tf
from keras.layers import Conv2D
from keras.initializers import RandomNormal
from tf_helpers import tf_batch_map_offsets
class ConvOffset2D(Conv2D):
"""ConvOffset2D"""
def __init__(self, filters, init_normal_stddev=0.01, **kwargs):
"""Init"""
self.filters = filters
super(ConvOffset2D, self).__init__(
self.filters * 2, (3, 3), padding='same', use_bias=False,
# TODO gradients are near zero if init is zeros
kernel_initializer='zeros',
# kernel_initializer=RandomNormal(0, init_normal_stddev),
**kwargs
)
def call(self, x):
# TODO offsets probably have no nonlinearity?
x_shape = x.get_shape()
offsets = super(ConvOffset2D, self).call(x)
offsets = self._to_bc_h_w_2(offsets, x_shape)
x = self._to_bc_h_w(x, x_shape)
x_offset = tf_batch_map_offsets(x, offsets)
x_offset = self._to_b_h_w_c(x_offset, x_shape)
return x_offset
def compute_output_shape(self, input_shape):
return input_shape
@staticmethod
def _to_bc_h_w_2(x, x_shape):
"""(b, h, w, 2c) -> (b*c, h, w, 2)"""
x = tf.transpose(x, [0, 3, 1, 2])
x = tf.reshape(x, (-1, int(x_shape[1]), int(x_shape[2]), 2))
return x
@staticmethod
def _to_bc_h_w(x, x_shape):
"""(b, h, w, c) -> (b*c, h, w)"""
x = tf.transpose(x, [0, 3, 1, 2])
x = tf.reshape(x, (-1, int(x_shape[1]), int(x_shape[2])))
return x
@staticmethod
def _to_b_h_w_c(x, x_shape):
"""(b*c, h, w) -> (b, h, w, c)"""
x = tf.reshape(
x, (-1, int(x_shape[3]), int(x_shape[1]), int(x_shape[2]))
)
x = tf.transpose(x, [0, 2, 3, 1])
return x