-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_sup.sh
74 lines (64 loc) · 2.08 KB
/
run_sup.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/bin/bash
# In this example, we show how to train SimCSE on unsupervised Wikipedia data.
# If you want to train it with multiple GPU cards, see "run_sup_example.sh"
# about how to use PyTorch's distributed data parallel.
pretrain="cyclone/simcse-chinese-roberta-wwm-ext"
date=''
epoch=6
bs=128
pooler="cls"
max_seq_length=64
comment=''
model_name="${comment}_${date}_ep${epoch}_bs${bs}_${pooler}_max_seq_length${max_seq_length}"
dir_path="./result/${model_name}"
drive_result="/content/drive/MyDrive/competition/simcse-mini/result"
# 删除当前的文件夹
rm -rf $dir_path
# sup train
python main.py \
--model_name_or_path $pretrain \
--train_file "./data/simple_X_train.csv" \
--validation_file "./data/simple_X_val.csv" \
--output_dir $dir_path \
--num_train_epochs $epoch \
--per_device_train_batch_size $bs \
--per_device_eval_batch_size $bs \
--optim adamw_apex_fused \
--learning_rate 3e-5 \
--max_seq_length $max_seq_length \
--save_total_limit 5 \
--evaluation_strategy steps \
--greater_is_better False \
--metric_for_best_model eval_loss \
--load_best_model_at_end \
--eval_steps 100 \
--save_steps 500 \
--logging_steps 100 \
--pooler_type $pooler \
--overwrite_output_dir \
--temp 0.05 \
--do_train \
--do_eval \
--fp16 \
--do_fgm \
--do_ema \
--lr_scheduler_type cosine_with_restarts \
--label_smoothing_factor 0.1 \
&& { echo "train finished!"; } || { echo 'train failed'; exit 1; }
# 保留原始脚本
cp ./run_sup.sh $dir_path/run_backup.sh
# 提取embedding
python get_embedding.py \
--dir_path $dir_path \
--pooler_type $pooler \
--temp 0.05 \
--batchsize 1024 \
--max_seq_length $max_seq_length \
&& { echo "get embedding finished!"; } || { echo "get embedding failed"; exit 1; }
# 检查embedding文件
python data_check.py
# 打包embedding。放入result文件夹
tar zcvf foo.tar.gz query_embedding doc_embedding && mv foo.tar.gz $dir_path/${model_name}.tar.gz
# 将文件移回云盘保存
cp -r ./result/* $drive_result
echo "Finished!"