|
| 1 | +""" |
| 2 | +modified from pixelcnn++ |
| 3 | +Various tensorflow utilities |
| 4 | +""" |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import tensorflow as tf |
| 8 | +from tensorflow.contrib.framework.python.ops import add_arg_scope |
| 9 | + |
| 10 | + |
| 11 | +def int_shape(x): |
| 12 | + return x.shape.as_list() |
| 13 | + |
| 14 | + |
| 15 | +def get_name(layer_name, counters): |
| 16 | + ''' utlity for keeping track of layer names ''' |
| 17 | + if not layer_name in counters: |
| 18 | + counters[layer_name] = 0 |
| 19 | + name = layer_name + '_' + str(counters[layer_name]) |
| 20 | + counters[layer_name] += 1 |
| 21 | + return name |
| 22 | + |
| 23 | + |
| 24 | +@add_arg_scope |
| 25 | +def dense(x, num_units, init_scale=1., counters={}, init=False, **kwargs): |
| 26 | + ''' fully connected layer ''' |
| 27 | + name = get_name('dense', counters) |
| 28 | + with tf.variable_scope(name): |
| 29 | + if init: |
| 30 | + xs = x.shape.as_list() |
| 31 | + # data based initialization of parameters |
| 32 | + V = tf.get_variable('V', [xs[1], num_units], tf.float32, tf.random_normal_initializer(0, 0.05)) |
| 33 | + V_norm = tf.nn.l2_normalize(V.initialized_value(), [0]) |
| 34 | + x_init = tf.matmul(x, V_norm) |
| 35 | + m_init, v_init = tf.nn.moments(x_init, [0]) |
| 36 | + scale_init = init_scale / tf.sqrt(v_init + 1e-10) |
| 37 | + g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init) |
| 38 | + b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init * scale_init) |
| 39 | + x_init = tf.reshape(scale_init, [1, num_units]) * (x_init - tf.reshape(m_init, [1, num_units])) |
| 40 | + |
| 41 | + return x_init |
| 42 | + else: |
| 43 | + V = tf.get_variable("V") |
| 44 | + g = tf.get_variable("g") |
| 45 | + b = tf.get_variable("b") |
| 46 | + with tf.control_dependencies([tf.assert_variables_initialized([V, g, b])]): |
| 47 | + # use weight normalization (Salimans & Kingma, 2016) |
| 48 | + x = tf.matmul(x, V) |
| 49 | + scaler = g / tf.sqrt(tf.reduce_sum(tf.square(V), [0])) |
| 50 | + x = tf.reshape(scaler, [1, num_units]) * x + tf.reshape(b, [1, num_units]) |
| 51 | + |
| 52 | + return x |
| 53 | + |
| 54 | + |
| 55 | +@add_arg_scope |
| 56 | +def conv2d(x, num_filters, filter_size=[3, 3], stride=[1, 1], pad='SAME', init_scale=1., counters={}, init=False, **kwargs): |
| 57 | + ''' convolutional layer ''' |
| 58 | + num_filters = int(num_filters) |
| 59 | + strides = [1] + stride + [1] |
| 60 | + name = get_name('conv2d', counters) |
| 61 | + with tf.variable_scope(name): |
| 62 | + if init: |
| 63 | + xs = x.shape.as_list() |
| 64 | + # data based initialization of parameters |
| 65 | + V = tf.get_variable('V', filter_size + [xs[-1], num_filters], |
| 66 | + tf.float32, tf.random_normal_initializer(0, 0.05)) |
| 67 | + V_norm = tf.nn.l2_normalize(V.initialized_value(), [0, 1, 2]) |
| 68 | + x_init = tf.nn.conv2d(x, V_norm, strides, pad) |
| 69 | + m_init, v_init = tf.nn.moments(x_init, [0, 1, 2]) |
| 70 | + scale_init = init_scale / tf.sqrt(v_init + 1e-8) |
| 71 | + g = tf.get_variable('g', dtype=tf.float32, initializer = scale_init) |
| 72 | + b = tf.get_variable('b', dtype=tf.float32, initializer = -m_init * scale_init) |
| 73 | + x_init = tf.reshape(scale_init, [1, 1, 1, num_filters]) * (x_init - tf.reshape(m_init, [1, 1, 1, num_filters])) |
| 74 | + |
| 75 | + return x_init |
| 76 | + else: |
| 77 | + V = tf.get_variable("V") |
| 78 | + g = tf.get_variable("g") |
| 79 | + b = tf.get_variable("b") |
| 80 | + with tf.control_dependencies([tf.assert_variables_initialized([V, g, b])]): |
| 81 | + # use weight normalization (Salimans & Kingma, 2016) |
| 82 | + W = tf.reshape(g, [1, 1, 1, num_filters]) * tf.nn.l2_normalize(V, [0, 1, 2]) |
| 83 | + |
| 84 | + # calculate convolutional layer output |
| 85 | + x = tf.nn.bias_add(tf.nn.conv2d(x, W, strides, pad), b) |
| 86 | + |
| 87 | + return x |
| 88 | + |
| 89 | + |
| 90 | +@add_arg_scope |
| 91 | +def deconv2d(x, num_filters, filter_size=[3, 3], stride=[1, 1], pad='SAME', init_scale=1., counters={}, init=False, **kwargs): |
| 92 | + ''' transposed convolutional layer ''' |
| 93 | + num_filters = int(num_filters) |
| 94 | + name = get_name('deconv2d', counters) |
| 95 | + xs = int_shape(x) |
| 96 | + strides = [1] + stride + [1] |
| 97 | + if pad == 'SAME': |
| 98 | + target_shape = [xs[0], xs[1] * stride[0], |
| 99 | + xs[2] * stride[1], num_filters] |
| 100 | + else: |
| 101 | + target_shape = [xs[0], xs[1] * stride[0] + filter_size[0] - |
| 102 | + 1, xs[2] * stride[1] + filter_size[1] - 1, num_filters] |
| 103 | + with tf.variable_scope(name): |
| 104 | + if init: |
| 105 | + # data based initialization of parameters |
| 106 | + V = tf.get_variable('V', filter_size + [num_filters, xs[-1]], tf.float32, tf.random_normal_initializer(0, 0.05)) |
| 107 | + V_norm = tf.nn.l2_normalize(V.initialized_value(), [0, 1, 3]) |
| 108 | + x_init = tf.nn.conv2d_transpose(x, V_norm, target_shape, strides, padding=pad) |
| 109 | + m_init, v_init = tf.nn.moments(x_init, [0, 1, 2]) |
| 110 | + scale_init = init_scale / tf.sqrt(v_init + 1e-8) |
| 111 | + g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init) |
| 112 | + b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init * scale_init) |
| 113 | + x_init = tf.reshape(scale_init, [1, 1, 1, num_filters]) * (x_init - tf.reshape(m_init, [1, 1, 1, num_filters])) |
| 114 | + |
| 115 | + return x_init |
| 116 | + else: |
| 117 | + V = tf.get_variable("V") |
| 118 | + g = tf.get_variable("g") |
| 119 | + b = tf.get_variable("b") |
| 120 | + with tf.control_dependencies([tf.assert_variables_initialized([V, g, b])]): |
| 121 | + # use weight normalization (Salimans & Kingma, 2016) |
| 122 | + W = tf.reshape(g, [1, 1, num_filters, 1]) * tf.nn.l2_normalize(V, [0, 1, 3]) |
| 123 | + |
| 124 | + # calculate convolutional layer output |
| 125 | + x = tf.nn.conv2d_transpose(x, W, target_shape, strides, padding=pad) |
| 126 | + x = tf.nn.bias_add(x, b) |
| 127 | + |
| 128 | + return x |
| 129 | + |
| 130 | + |
| 131 | +@add_arg_scope |
| 132 | +def activate(x, activation, **kwargs): |
| 133 | + if activation == None: |
| 134 | + return x |
| 135 | + elif activation == "elu": |
| 136 | + return tf.nn.elu(x) |
| 137 | + else: |
| 138 | + raise NotImplemented(activation) |
| 139 | + |
| 140 | + |
| 141 | +def nin(x, num_units): |
| 142 | + """ a network in network layer (1x1 CONV) """ |
| 143 | + s = int_shape(x) |
| 144 | + x = tf.reshape(x, [np.prod(s[:-1]), s[-1]]) |
| 145 | + x = dense(x, num_units) |
| 146 | + return tf.reshape(x, s[:-1] + [num_units]) |
| 147 | + |
| 148 | + |
| 149 | +def downsample(x, num_units): |
| 150 | + return conv2d(x, num_units, stride = [2, 2]) |
| 151 | + |
| 152 | + |
| 153 | +def upsample(x, num_units, method = "subpixel"): |
| 154 | + if method == "conv_transposed": |
| 155 | + return deconv2d(x, num_units, stride = [2, 2]) |
| 156 | + elif method == "subpixel": |
| 157 | + x = conv2d(x, 4*num_units) |
| 158 | + x = tf.depth_to_space(x, 2) |
| 159 | + return x |
| 160 | + |
| 161 | + |
| 162 | +@add_arg_scope |
| 163 | +def residual_block(x, a = None, conv=conv2d, init=False, dropout_p=0.0, gated = False, **kwargs): |
| 164 | + """Slight variation of original.""" |
| 165 | + xs = int_shape(x) |
| 166 | + num_filters = xs[-1] |
| 167 | + |
| 168 | + residual = x |
| 169 | + if a is not None: |
| 170 | + a = nin(activate(a), num_filters) |
| 171 | + residual = tf.concat([residual, a], axis = -1) |
| 172 | + residual = activate(residual) |
| 173 | + residual = tf.nn.dropout(residual, keep_prob = 1.0 - dropout_p) |
| 174 | + residual = conv(residual, num_filters) |
| 175 | + if gated: |
| 176 | + residual = activate(residual) |
| 177 | + residual = tf.nn.dropout(residual, keep_prob = 1.0 - dropout_p) |
| 178 | + residual = conv(residual, 2*num_filters) |
| 179 | + a, b = tf.split(residual, 2, 3) |
| 180 | + residual = a * tf.nn.sigmoid(b) |
| 181 | + |
| 182 | + return x + residual |
| 183 | + |
| 184 | + |
| 185 | +def make_linear_var( |
| 186 | + step, |
| 187 | + start, end, |
| 188 | + start_value, end_value, |
| 189 | + clip_min = 0.0, clip_max = 1.0): |
| 190 | + """linear from (a, alpha) to (b, beta), i.e. |
| 191 | + (beta - alpha)/(b - a) * (x - a) + alpha""" |
| 192 | + linear = ( |
| 193 | + (end_value - start_value) / |
| 194 | + (end - start) * |
| 195 | + (tf.cast(step, tf.float32) - start) + start_value) |
| 196 | + return tf.clip_by_value(linear, clip_min, clip_max) |
| 197 | + |
| 198 | + |
| 199 | +def split_groups(x, bs = 2): |
| 200 | + return tf.split(tf.space_to_depth(x, bs), bs**2, axis = 3) |
| 201 | + |
| 202 | + |
| 203 | +def merge_groups(xs, bs = 2): |
| 204 | + return tf.depth_to_space(tf.concat(xs, axis = 3), bs) |
0 commit comments