Skip to content

Commit 24ba21f

Browse files
authored
Merge pull request #2 from igv/master
Pull recent igv commits (2019-10-04)
2 parents e30e472 + 85c1912 commit 24ba21f

File tree

5 files changed

+47
-116
lines changed

5 files changed

+47
-116
lines changed

FSRCNN.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def model(self):
7272
deconv_biases = tf.get_variable('deconv_b', initializer=tf.zeros([self.scale**2]))
7373
deconv = tf.nn.conv2d(conv, deconv_weights, strides=[1,1,1,1], padding='SAME', data_format='NHWC')
7474
deconv = tf.nn.bias_add(deconv, deconv_biases, data_format='NHWC')
75-
deconv = tf.depth_to_space(deconv, self.scale, name='pixel_shuffle', data_format='NHWC')
75+
if self.scale > 1:
76+
deconv = tf.depth_to_space(deconv, self.scale, name='pixel_shuffle', data_format='NHWC')
7677

7778
return deconv
7879

gen.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def format_weights(weights, n, length=4):
2828

2929
def base_header(file):
3030
file.write('//!HOOK LUMA\n')
31-
file.write('//!WHEN OUTPUT.w LUMA.w / {0}.400 > OUTPUT.h LUMA.h / {0}.400 > *\n'.format(scale - 1))
31+
if scale > 1:
32+
file.write('//!WHEN OUTPUT.w LUMA.w / {0}.400 > OUTPUT.h LUMA.h / {0}.400 > *\n'.format(scale - 1))
3233

3334
def header1(file, n, d):
3435
base_header(file)
@@ -75,8 +76,9 @@ def header5(file, n, d, inp):
7576
file.write('//!DESC sub-pixel convolution {}\n'.format((n//comps) + 1))
7677
for i in range(d//4):
7778
file.write('//!BIND {}{}\n'.format(inp, i + 1))
78-
file.write('//!SAVE SUBCONV{}\n'.format((n//comps) + 1))
79-
file.write('//!COMPONENTS {}\n'.format(comps))
79+
if scale > 1:
80+
file.write('//!SAVE SUBCONV{}\n'.format((n//comps) + 1))
81+
file.write('//!COMPONENTS {}\n'.format(comps))
8082

8183
def header6(file):
8284
base_header(file)
@@ -219,45 +221,47 @@ def main():
219221
ln = get_line_number("deconv_b", fname)
220222
biases = read_weights(fname, ln)
221223
inp = "EXPANDED" if shrinking else "RES"
222-
comps = 3 if scale == 3 else 4
224+
comps = scale if scale % 2 == 1 else 4
223225
for n in range(0, scale**2, comps):
224226
header5(file, n, d, inp)
225227
file.write('vec4 hook()\n')
226228
file.write('{\n')
227-
file.write('vec{0} res = vec{0}({1});\n'.format(comps, format_weights(biases[0], n, length=comps)))
229+
if scale == 1:
230+
file.write('float res = {};\n'.format(format_weights(biases[0], n, length=comps)))
231+
else:
232+
file.write('vec{0} res = vec{0}({1});\n'.format(comps, format_weights(biases[0], n, length=comps)))
228233
p = 0
229234
for l in range(0, len(weights), 4):
230235
if l % d == 0:
231236
y, x = p%(radius*2+1)-radius, p//(radius*2+1)-radius
232237
p += 1
233238
idx = (l//4)%(d//4)
234-
file.write('res += mat4x{}({},{},{},{}) * {}{}_texOff(vec2({},{}));\n'.format(
235-
comps, format_weights(weights[l], n, length=comps), format_weights(weights[l+1], n, length=comps),
239+
file.write('res += {}{}({},{},{},{}){} {}{}_texOff(vec2({},{})){};\n'.format(
240+
"mat4x" if scale > 1 else "dot(", comps if scale > 1 else "vec4",
241+
format_weights(weights[l], n, length=comps), format_weights(weights[l+1], n, length=comps),
236242
format_weights(weights[l+2], n, length=comps), format_weights(weights[l+3], n, length=comps),
237-
inp, idx + 1, x, y))
238-
if comps == 4:
239-
file.write('return res;\n')
240-
else:
241-
file.write('return vec4(res, 0);\n')
243+
" *" if scale > 1 else ",", inp, idx + 1, x, y, "" if scale > 1 else ")"))
244+
file.write('return vec4(res{});\n'.format(", 0" * (4 - comps)))
242245
file.write('}\n\n')
243246

244-
# Aggregation
245-
header6(file)
246-
file.write('vec4 hook()\n')
247-
file.write('{\n')
248-
file.write('vec2 fcoord = fract(SUBCONV1_pos * SUBCONV1_size);\n')
249-
file.write('vec2 base = SUBCONV1_pos + (vec2(0.5) - fcoord) * SUBCONV1_pt;\n')
250-
file.write('ivec2 index = ivec2(fcoord * vec2({}));\n'.format(scale))
251-
if scale > 2:
252-
file.write('mat{0} res = mat{0}(SUBCONV1_tex(base).{1}'.format(scale, "rgba"[:comps]))
253-
for i in range(scale-1):
254-
file.write(',SUBCONV{}_tex(base).{}'.format(i + 2, "rgba"[:comps]))
255-
file.write(');\n')
256-
file.write('return vec4(res[index.x][index.y], 0, 0, 1);\n')
257-
else:
258-
file.write('vec4 res = SUBCONV1_tex(base);\n')
259-
file.write('return vec4(res[index.x * {} + index.y], 0, 0, 1);\n'.format(scale))
260-
file.write('}\n')
247+
if scale > 1:
248+
# Aggregation
249+
header6(file)
250+
file.write('vec4 hook()\n')
251+
file.write('{\n')
252+
file.write('vec2 fcoord = fract(SUBCONV1_pos * SUBCONV1_size);\n')
253+
file.write('vec2 base = SUBCONV1_pos + (vec2(0.5) - fcoord) * SUBCONV1_pt;\n')
254+
file.write('ivec2 index = ivec2(fcoord * vec2({}));\n'.format(scale))
255+
if scale > 2:
256+
file.write('mat{0} res = mat{0}(SUBCONV1_tex(base).{1}'.format(scale, "rgba"[:comps]))
257+
for i in range(scale-1):
258+
file.write(',SUBCONV{}_tex(base).{}'.format(i + 2, "rgba"[:comps]))
259+
file.write(');\n')
260+
file.write('return vec4(res[index.x][index.y], 0, 0, 1);\n')
261+
else:
262+
file.write('vec4 res = SUBCONV1_tex(base);\n')
263+
file.write('return vec4(res[index.x * {} + index.y], 0, 0, 1);\n'.format(scale))
264+
file.write('}\n')
261265

262266
else:
263267
print("Missing argument: You must specify a file name")

main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]")
2424
flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]")
2525
flags.DEFINE_boolean("train", True, "True for training, false for testing [True]")
26-
flags.DEFINE_integer("threads", 1, "Number of processes to pre-process data with [1]")
2726
flags.DEFINE_boolean("distort", False, "Distort some images with JPEG compression artifacts after downscaling [False]")
2827
flags.DEFINE_boolean("params", False, "Save weight and bias parameters [False]")
2928

model.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from utils import (
2-
thread_train_setup,
3-
train_input_setup,
2+
multiprocess_train_setup,
43
test_input_setup,
54
save_params,
65
merge,
@@ -31,14 +30,13 @@ def __init__(self, sess, config):
3130
self.radius = config.radius
3231
self.batch_size = config.batch_size
3332
self.learning_rate = config.learning_rate
34-
self.threads = config.threads
3533
self.distort = config.distort
3634
self.params = config.params
3735

3836
self.padding = 4
3937
# Different image/label sub-sizes for different scaling factors x2, x3, x4
40-
scale_factors = [[20 + self.padding, 40], [14 + self.padding, 42], [12 + self.padding, 48]]
41-
self.image_size, self.label_size = scale_factors[self.scale - 2]
38+
scale_factors = [[40 + self.padding, 40], [20 + self.padding, 40], [14 + self.padding, 42], [12 + self.padding, 48]]
39+
self.image_size, self.label_size = scale_factors[self.scale - 1]
4240

4341
self.stride = self.image_size - self.padding
4442

@@ -94,11 +92,8 @@ def run(self):
9492
def run_train(self):
9593
start_time = time.time()
9694
print("Beginning training setup...")
97-
if self.threads == 1:
98-
train_data, train_label = train_input_setup(self)
99-
else:
100-
train_data, train_label = thread_train_setup(self)
101-
print("Training setup took {} seconds with {} threads".format(time.time() - start_time, self.threads))
95+
train_data, train_label = multiprocess_train_setup(self)
96+
print("Training setup took {} seconds".format(time.time() - start_time))
10297

10398
print("Training...")
10499
start_time = time.time()

utils.py

Lines changed: 7 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tensorflow as tf
1313
from PIL import Image
1414
import numpy as np
15-
from multiprocessing import Pool, Lock, active_children
15+
import multiprocessing
1616

1717
FLAGS = tf.app.flags.FLAGS
1818

@@ -145,46 +145,19 @@ def train_input_worker(args):
145145

146146
return [single_input_sequence, single_label_sequence]
147147

148-
149-
def thread_train_setup(config):
148+
def multiprocess_train_setup(config):
150149
"""
151-
Spawns |config.threads| worker processes to pre-process the data
152-
153-
This has not been extensively tested so use at your own risk.
154-
Also this is technically multiprocessing not threading, I just say thread
155-
because it's shorter to type.
150+
Spawns several processes to pre-process the data
156151
"""
157152
if downsample == False:
158153
import sys
159154
sys.exit()
160155

161-
sess = config.sess
162-
163-
# Load data path
164-
data = prepare_data(sess, dataset=config.data_dir)
165-
166-
# Initialize multiprocessing pool with # of processes = config.threads
167-
pool = Pool(config.threads)
168-
169-
# Distribute |images_per_thread| images across each worker process
170-
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding // 2, config.distort]
171-
images_per_thread = len(data) // config.threads
172-
workers = []
173-
for thread in range(config.threads):
174-
args_list = [(data[i], config_values) for i in range(thread * images_per_thread, (thread + 1) * images_per_thread)]
175-
worker = pool.map_async(train_input_worker, args_list)
176-
workers.append(worker)
177-
print("{} worker processes created".format(config.threads))
178-
179-
pool.close()
156+
data = prepare_data(config.sess, dataset=config.data_dir)
180157

181-
results = []
182-
for i in range(len(workers)):
183-
print("Waiting for worker process {}".format(i))
184-
results.extend(workers[i].get(timeout=240))
185-
print("Worker process {} done".format(i))
186-
187-
print("All worker processes done!")
158+
with multiprocessing.Pool(max(multiprocessing.cpu_count() - 1, 1)) as pool:
159+
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding // 2, config.distort]
160+
results = pool.map(train_input_worker, [(data[i], config_values) for i in range(len(data))])
188161

189162
sub_input_sequence, sub_label_sequence = [], []
190163

@@ -198,47 +171,6 @@ def thread_train_setup(config):
198171

199172
return (arrdata, arrlabel)
200173

201-
def train_input_setup(config):
202-
"""
203-
Read image files, make their sub-images, and save them as a h5 file format.
204-
"""
205-
if downsample == False:
206-
import sys
207-
sys.exit()
208-
209-
sess = config.sess
210-
image_size, label_size, stride, scale, padding = config.image_size, config.label_size, config.stride, config.scale, config.padding // 2
211-
212-
# Load data path
213-
data = prepare_data(sess, dataset=config.data_dir)
214-
215-
sub_input_sequence, sub_label_sequence = [], []
216-
217-
for i in range(len(data)):
218-
input_, label_ = preprocess(data[i], scale, distort=config.distort)
219-
220-
if len(input_.shape) == 3:
221-
h, w, _ = input_.shape
222-
else:
223-
h, w = input_.shape
224-
225-
for x in range(0, h - image_size + 1, stride):
226-
for y in range(0, w - image_size + 1, stride):
227-
sub_input = input_[x : x + image_size, y : y + image_size]
228-
x_loc, y_loc = x + padding, y + padding
229-
sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]
230-
231-
sub_input = sub_input.reshape([image_size, image_size, 1])
232-
sub_label = sub_label.reshape([label_size, label_size, 1])
233-
234-
sub_input_sequence.append(sub_input)
235-
sub_label_sequence.append(sub_label)
236-
237-
arrdata = np.asarray(sub_input_sequence)
238-
arrlabel = np.asarray(sub_label_sequence)
239-
240-
return (arrdata, arrlabel)
241-
242174
def test_input_setup(config):
243175
sess = config.sess
244176

0 commit comments

Comments
 (0)