-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
144 lines (103 loc) · 3.84 KB
/
main.py
File metadata and controls
144 lines (103 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Created on Thu Nov 26 15:12:46 2020
@author: eych
"""
import os
import sys
from sys import platform
from Models_1B import *
from Build_fit import *
from Models_2B import *
from Buid_2B import *
from LoadData import LoadData1B
from LoadData import LoadData2B
parent_path= sys.path.append(os.path.dirname(__file__))
print('Executable env: '+ str(sys.executable))
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
data_path = "D:/Tiles/Final_datasets/Sentinel_2_tiles_test/*/"
print(data_path)
print(platform)
###################Experiments with 1Branch-Models###############
###Data Loading
ld=LoadData1B()
ld.load_data(data_path)
#rasterio
training_labels=ld.labels_forTr(10)
tr_set = ld.training_set10_Tr(10)
###Arcpy
#training_labels=ld.labels_forTr_Arcpy(10)
#tr_set=ld.training_set_arcpy(10)
print(training_labels.shape)
print(tr_set.shape)
####Parameters def
nb_labels = training_labels.shape[3]
print(nb_labels)
# The dimensions of the input images
nb_rows = training_labels.shape[1]
print(nb_rows)
nb_cols = training_labels.shape[2]
print(nb_cols)
###class Weights definition for the weighted loss function
ws = ld.weights_prep(training_labels, nb_rows, nb_cols)
print(ws)
training_images_nor = ld.norm(tr_set)
print(len(training_images_nor))
del tr_set
training_images_nor_aug, training_labels_aug = ld.rot(training_images_nor, training_labels, 0, 5)
del training_labels
training_images_nor_aug, training_labels_aug = ld.flip(training_images_nor_aug, training_labels_aug, 5, 10)
####Parameters def
channels = training_images_nor_aug.shape[3]
print(channels)
input_shape = (nb_rows, nb_cols, channels)
#### Model building + Training
model=DeepForestM2(input_shape, nb_labels)
model.summary()
name='DeepForestM2'
loss='categorical_crossentropy'
m=Buid_1B(nb_labels,model)
print(m.nb_labels)
m.build()
#m.train_model(training_images_nor_aug,training_labels_aug,ws,name,loss,epochs=1)
###################Experiments with 2Branch-Models###############
'''ld=LoadData2B()
ld.load_data(data_path)
training_labels=ld.labels_forTr(10)
#training_labels=ld.labels_forTr_Arcpy(10)
print(training_labels.shape)
nb_labels = training_labels.shape[3]
print(nb_labels)
# The dimensions of the input images
nb_rows = training_labels.shape[1]
print(nb_rows)
nb_cols = training_labels.shape[2]
print(nb_cols)
training_images_s1=ld.training_s1(10)
training_images_s2=ld.training_s2(10)
timestamps = training_images_s1.shape[1]
print(timestamps)
channels1= training_images_s1.shape[4]
print(channels1)
channels2= training_images_s2.shape[3]
print(channels2)
input_shape1=(timestamps,nb_rows,nb_cols, channels1)
print(input_shape1)
input_shape2= (nb_rows,nb_cols,channels2)
print(input_shape2)
ws=ld.weights_prep(training_labels,nb_rows,nb_cols)
print(ws)
training_images_s1_nor_aug,training_images_s2_nor_aug,training_labels_aug=ld.rot(training_images_s1,training_images_s2,training_labels,0,5)
del training_labels
training_images_s1_nor_aug,training_images_s2_nor_aug,training_labels_aug=ld.flip(training_images_s1,training_images_s2,training_labels_aug,5,10)
# Model building
model_glob, model1, model2 = two_branches(input_shape1, input_shape2, nb_labels)
model_glob.summary()
name = 'two_branch M'
loss = 'categorical_crossentropy'
m = Buid_2B(nb_labels, model_glob)
print(m.nb_labels)
m.build()'''
#m.train_model(training_images_s1, training_images_s2, training_labels, weights=weights, valstart=15, name=name,loss=loss, epochs=1, batch_size=2)
print('End Main')
# See PyCharm help at https://www.jetbrains.com/help/pycharm/