Skip to content

Commit 85c1912

Browse files
committed
Always use multiple processes for pre-processing
1 parent 8f30c0f commit 85c1912

File tree

3 files changed

+10
-84
lines changed

3 files changed

+10
-84
lines changed

main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
flags.DEFINE_string("output_dir", "result", "Name of test output directory [result]")
1919
flags.DEFINE_string("data_dir", "Train", "Name of data directory to train on [FastTrain]")
2020
flags.DEFINE_boolean("train", True, "True for training, false for testing [True]")
21-
flags.DEFINE_integer("threads", 1, "Number of processes to pre-process data with [1]")
2221
flags.DEFINE_boolean("distort", False, "Distort some images with JPEG compression artifacts after downscaling [False]")
2322
flags.DEFINE_boolean("params", False, "Save weight and bias parameters [False]")
2423

model.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from utils import (
2-
thread_train_setup,
3-
train_input_setup,
2+
multiprocess_train_setup,
43
test_input_setup,
54
save_params,
65
merge,
@@ -31,7 +30,6 @@ def __init__(self, sess, config):
3130
self.radius = config.radius
3231
self.batch_size = config.batch_size
3332
self.learning_rate = config.learning_rate
34-
self.threads = config.threads
3533
self.distort = config.distort
3634
self.params = config.params
3735

@@ -94,11 +92,8 @@ def run(self):
9492
def run_train(self):
9593
start_time = time.time()
9694
print("Beginning training setup...")
97-
if self.threads == 1:
98-
train_data, train_label = train_input_setup(self)
99-
else:
100-
train_data, train_label = thread_train_setup(self)
101-
print("Training setup took {} seconds with {} threads".format(time.time() - start_time, self.threads))
95+
train_data, train_label = multiprocess_train_setup(self)
96+
print("Training setup took {} seconds".format(time.time() - start_time))
10297

10398
print("Training...")
10499
start_time = time.time()

utils.py

Lines changed: 7 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tensorflow as tf
1313
from PIL import Image
1414
import numpy as np
15-
from multiprocessing import Pool, Lock, active_children
15+
import multiprocessing
1616

1717
FLAGS = tf.app.flags.FLAGS
1818

@@ -145,46 +145,19 @@ def train_input_worker(args):
145145

146146
return [single_input_sequence, single_label_sequence]
147147

148-
149-
def thread_train_setup(config):
148+
def multiprocess_train_setup(config):
150149
"""
151-
Spawns |config.threads| worker processes to pre-process the data
152-
153-
This has not been extensively tested so use at your own risk.
154-
Also this is technically multiprocessing not threading, I just say thread
155-
because it's shorter to type.
150+
Spawns several processes to pre-process the data
156151
"""
157152
if downsample == False:
158153
import sys
159154
sys.exit()
160155

161-
sess = config.sess
162-
163-
# Load data path
164-
data = prepare_data(sess, dataset=config.data_dir)
165-
166-
# Initialize multiprocessing pool with # of processes = config.threads
167-
pool = Pool(config.threads)
168-
169-
# Distribute |images_per_thread| images across each worker process
170-
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding // 2, config.distort]
171-
images_per_thread = len(data) // config.threads
172-
workers = []
173-
for thread in range(config.threads):
174-
args_list = [(data[i], config_values) for i in range(thread * images_per_thread, (thread + 1) * images_per_thread)]
175-
worker = pool.map_async(train_input_worker, args_list)
176-
workers.append(worker)
177-
print("{} worker processes created".format(config.threads))
178-
179-
pool.close()
156+
data = prepare_data(config.sess, dataset=config.data_dir)
180157

181-
results = []
182-
for i in range(len(workers)):
183-
print("Waiting for worker process {}".format(i))
184-
results.extend(workers[i].get(timeout=240))
185-
print("Worker process {} done".format(i))
186-
187-
print("All worker processes done!")
158+
with multiprocessing.Pool(max(multiprocessing.cpu_count() - 1, 1)) as pool:
159+
config_values = [config.image_size, config.label_size, config.stride, config.scale, config.padding // 2, config.distort]
160+
results = pool.map(train_input_worker, [(data[i], config_values) for i in range(len(data))])
188161

189162
sub_input_sequence, sub_label_sequence = [], []
190163

@@ -198,47 +171,6 @@ def thread_train_setup(config):
198171

199172
return (arrdata, arrlabel)
200173

201-
def train_input_setup(config):
202-
"""
203-
Read image files, make their sub-images, and save them as a h5 file format.
204-
"""
205-
if downsample == False:
206-
import sys
207-
sys.exit()
208-
209-
sess = config.sess
210-
image_size, label_size, stride, scale, padding = config.image_size, config.label_size, config.stride, config.scale, config.padding // 2
211-
212-
# Load data path
213-
data = prepare_data(sess, dataset=config.data_dir)
214-
215-
sub_input_sequence, sub_label_sequence = [], []
216-
217-
for i in range(len(data)):
218-
input_, label_ = preprocess(data[i], scale, distort=config.distort)
219-
220-
if len(input_.shape) == 3:
221-
h, w, _ = input_.shape
222-
else:
223-
h, w = input_.shape
224-
225-
for x in range(0, h - image_size + 1, stride):
226-
for y in range(0, w - image_size + 1, stride):
227-
sub_input = input_[x : x + image_size, y : y + image_size]
228-
x_loc, y_loc = x + padding, y + padding
229-
sub_label = label_[x_loc * scale : x_loc * scale + label_size, y_loc * scale : y_loc * scale + label_size]
230-
231-
sub_input = sub_input.reshape([image_size, image_size, 1])
232-
sub_label = sub_label.reshape([label_size, label_size, 1])
233-
234-
sub_input_sequence.append(sub_input)
235-
sub_label_sequence.append(sub_label)
236-
237-
arrdata = np.asarray(sub_input_sequence)
238-
arrlabel = np.asarray(sub_label_sequence)
239-
240-
return (arrdata, arrlabel)
241-
242174
def test_input_setup(config):
243175
sess = config.sess
244176

0 commit comments

Comments
 (0)