1212import tensorflow as tf
1313from PIL import Image
1414import numpy as np
15- from multiprocessing import Pool , Lock , active_children
15+ import multiprocessing
1616
1717FLAGS = tf .app .flags .FLAGS
1818
@@ -145,46 +145,19 @@ def train_input_worker(args):
145145
146146 return [single_input_sequence , single_label_sequence ]
147147
148-
149- def thread_train_setup (config ):
148+ def multiprocess_train_setup (config ):
150149 """
151- Spawns |config.threads| worker processes to pre-process the data
152-
153- This has not been extensively tested so use at your own risk.
154- Also this is technically multiprocessing not threading, I just say thread
155- because it's shorter to type.
150+ Spawns several processes to pre-process the data
156151 """
157152 if downsample == False :
158153 import sys
159154 sys .exit ()
160155
161- sess = config .sess
162-
163- # Load data path
164- data = prepare_data (sess , dataset = config .data_dir )
165-
166- # Initialize multiprocessing pool with # of processes = config.threads
167- pool = Pool (config .threads )
168-
169- # Distribute |images_per_thread| images across each worker process
170- config_values = [config .image_size , config .label_size , config .stride , config .scale , config .padding // 2 , config .distort ]
171- images_per_thread = len (data ) // config .threads
172- workers = []
173- for thread in range (config .threads ):
174- args_list = [(data [i ], config_values ) for i in range (thread * images_per_thread , (thread + 1 ) * images_per_thread )]
175- worker = pool .map_async (train_input_worker , args_list )
176- workers .append (worker )
177- print ("{} worker processes created" .format (config .threads ))
178-
179- pool .close ()
156+ data = prepare_data (config .sess , dataset = config .data_dir )
180157
181- results = []
182- for i in range (len (workers )):
183- print ("Waiting for worker process {}" .format (i ))
184- results .extend (workers [i ].get (timeout = 240 ))
185- print ("Worker process {} done" .format (i ))
186-
187- print ("All worker processes done!" )
158+ with multiprocessing .Pool (max (multiprocessing .cpu_count () - 1 , 1 )) as pool :
159+ config_values = [config .image_size , config .label_size , config .stride , config .scale , config .padding // 2 , config .distort ]
160+ results = pool .map (train_input_worker , [(data [i ], config_values ) for i in range (len (data ))])
188161
189162 sub_input_sequence , sub_label_sequence = [], []
190163
@@ -198,47 +171,6 @@ def thread_train_setup(config):
198171
199172 return (arrdata , arrlabel )
200173
201- def train_input_setup (config ):
202- """
203- Read image files, make their sub-images, and save them as a h5 file format.
204- """
205- if downsample == False :
206- import sys
207- sys .exit ()
208-
209- sess = config .sess
210- image_size , label_size , stride , scale , padding = config .image_size , config .label_size , config .stride , config .scale , config .padding // 2
211-
212- # Load data path
213- data = prepare_data (sess , dataset = config .data_dir )
214-
215- sub_input_sequence , sub_label_sequence = [], []
216-
217- for i in range (len (data )):
218- input_ , label_ = preprocess (data [i ], scale , distort = config .distort )
219-
220- if len (input_ .shape ) == 3 :
221- h , w , _ = input_ .shape
222- else :
223- h , w = input_ .shape
224-
225- for x in range (0 , h - image_size + 1 , stride ):
226- for y in range (0 , w - image_size + 1 , stride ):
227- sub_input = input_ [x : x + image_size , y : y + image_size ]
228- x_loc , y_loc = x + padding , y + padding
229- sub_label = label_ [x_loc * scale : x_loc * scale + label_size , y_loc * scale : y_loc * scale + label_size ]
230-
231- sub_input = sub_input .reshape ([image_size , image_size , 1 ])
232- sub_label = sub_label .reshape ([label_size , label_size , 1 ])
233-
234- sub_input_sequence .append (sub_input )
235- sub_label_sequence .append (sub_label )
236-
237- arrdata = np .asarray (sub_input_sequence )
238- arrlabel = np .asarray (sub_label_sequence )
239-
240- return (arrdata , arrlabel )
241-
242174def test_input_setup (config ):
243175 sess = config .sess
244176
0 commit comments