-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest.py
More file actions
30 lines (25 loc) · 1.1 KB
/
Copy pathtest.py
File metadata and controls
30 lines (25 loc) · 1.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from util import load,load_lookup
from Models import SimpleCNN
import csv
import numpy as np
model = SimpleCNN()
model.load_weights("FKD_weights.h5")
lookup = load_lookup('IdLookupTable.csv')
feature2kpId = np.load('feature2kpId.npy').item()
print('Reading Test Data')
X_test, _,_ = load(test=True)
with open('submission_results.csv','w') as fw:
myFields = ['RowId', 'Location']
writer = csv.DictWriter(fw, fieldnames=myFields)
writer.writeheader()
for idx,img in enumerate(X_test,1):
print('Predict Keypoints of Image {0}'.format(idx))
keypoints = model.predict(img.reshape((1,) + img.shape))[0]
row_ids,feature_names = lookup[idx]['RowId'],lookup[idx]['FeatureName']
for RowId,FeatureName in zip(row_ids,feature_names):
location = keypoints[feature2kpId[FeatureName]]
writer.writerow({myFields[0]: RowId, myFields[1]: location * 48 + 48})
# for kp_idx,location in enumerate(keypoints[0],1):
# row_id = idx*30 + kp_idx
# writer.writerow({myFields[0]: row_id,myFields[1]:location*48+48})
print('Done')