-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathget_batch_images.py
69 lines (59 loc) · 1.81 KB
/
get_batch_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/qcfs/jackjack5/3party-software/anaconda/kdft/anaconda2/bin/python
import glob
import argparse
import os
import numpy as np
import pickle
import sys
from tqdm import tqdm
import random
def makebatch(args):
filelist = glob.glob(args.files+'/*.json')
batchsize = args.batchsize
savedir = args.savedir
batch_iter = 0
batch_images = []
batch_channels = []
random.seed(777) # You can also change random seed value
random.shuffle(filelist)
tt = np.array(filelist)
if not os.path.isdir(savedir):
os.mkdir(savedir)
sum_nc = 0
for filename in tqdm(filelist):
if not '.npy' in filename:
continue
images = np.array(json.load(open(filename))['image'])
dim = images.shape[0]
nc = images.shape[-1]
sum_nc += nc
for c in range(nc):
batch_images.append(images[:,:,:,c].reshape(dim,dim,dim,1))
if len(batch_images)==batchsize:
print 'saving batch #',batch_iter
print sum_nc
batch_savefilename1 = savedir+'/'+str(batch_iter)+'_images.npy'
batch_savefilename2 = savedir+'/'+str(batch_iter)+'_pvals.npy'
fin_batch = np.array(batch_images)
(p,q,r,s,t) = np.where(fin_batch >= 0.02)
pvals = fin_batch[p,q,r,s,t]
fin_batch2 = np.array([p,q,r,s,t],np.int32)
np.save(batch_savefilename1, fin_batch2)
np.save(batch_savefilename2, np.array(pvals))
batch_iter += 1
batch_images = []
batch_channels = []
return 1
def main():
parser = argparse.ArgumentParser(description='script for making single image batches')
parser.add_argument('--savedir',type=str,default='batch_images/',
help = 'save destination for batch images')
parser.add_argument('--files',type=str,
help = 'input image files, xx.npy')
parser.add_argument('--batchsize',type=int,default=20,
help = 'the size of batches')
args = parser.parse_args()
makebatch(args)
return
if __name__=='__main__':
main()