Skip to content

Commit 414386a

Browse files
committed
add weight fusion and class train
1 parent 7ca82f9 commit 414386a

32 files changed

+3762882
-75685
lines changed

run2.sh

+19-6
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,26 @@ fi
9494
# #################### execute according to TYPE
9595
########## train
9696
if [ "$TYPE" = "train" ]; then
97-
cd src/text_predict/ocr_bert_base
98-
time python ocr_classifier.py
97+
# cd src/text_predict/ocr_bert_base
98+
# time python ocr_classifier.py
99+
# cd -
100+
cd src/text_predict/ocr_bert_class_train
101+
# time python train.py
102+
cd -
103+
104+
cd src/weight_fusion
105+
time python train.py
99106
cd -
100107
exit 0
101-
fi
102-
########## test
103-
# elif [ "$TYPE" = "test" ]; then
104108

105-
# exit 0
109+
########## test
110+
elif [ "$TYPE" = "test" ]; then
111+
cd src/text_predict/ocr_bert_class_train
112+
time python predict.py
113+
cd -
114+
cd src/weight_fusion
115+
time python predict.py
116+
cd -
117+
exit 0
106118
# ######### text predict
119+
fi

src/test_2nd.txt

+25,000
Large diffs are not rendered by default.

src/text_predict/ocr_bert_base/ocr_classifier.py

-2
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def train(run_i, model, loss_fn, config, train_loader, val_loader, root_path):
134134
num+=1
135135
loss.backward()
136136
optimizer.step()
137-
break
138137
gap = evaluate(config, val_loader, tokenizer, model)
139138
with open(root_path+'res.csv', 'a+') as out:
140139
writer = csv.writer(out)
@@ -154,7 +153,6 @@ def train(run_i, model, loss_fn, config, train_loader, val_loader, root_path):
154153
if(count>3):
155154
print("overfit stop train!")
156155
break
157-
break
158156

159157
# pre process
160158
def predict(config, model,checkpoint,predict_loader,reslut_path):

0 commit comments

Comments
 (0)