8
8
from preprocess import PreprocessingPipeline
9
9
from train import train
10
10
from model import MusicTransformer
11
+ import argparse
11
12
12
13
def main ():
14
+ parser = argparse .ArgumentParser ("Script to train model on a GPU" )
15
+ parser .add_argument ("--checkpoint" , type = str , default = None ,
16
+ help = "Optional path to saved model, if none provided, the model is trained from scratch." )
17
+ parser .add_argument ("--n_epochs" , type = int , default = 5 ,
18
+ help = "Number of training epochs." )
19
+ args = parser .parse_args ()
20
+
13
21
sampling_rate = 125
14
22
n_velocity_bins = 32
15
23
seq_length = 1024
24
+ n_tokens = 256 + sampling_rate + n_velocity_bins
25
+ transformer = MusicTransformer (n_tokens , seq_length ,
26
+ d_model = 64 , n_heads = 8 , d_feedforward = 256 ,
27
+ depth = 4 , positional_encoding = True , relative_pos = True )
28
+
29
+ if args .checkpoint is not None :
30
+ state = torch .load (args .checkpoint )
31
+ transformer .load_state_dict (state )
32
+ print (f"Successfully loaded checkpoint at { args .checkpoint } " )
16
33
#rule of thumb: 1 minute is roughly 2k tokens
34
+
17
35
pipeline = PreprocessingPipeline (input_dir = "data" , stretch_factors = [0.975 , 1 , 1.025 ],
18
36
split_size = 30 , sampling_rate = sampling_rate , n_velocity_bins = n_velocity_bins ,
19
37
transpositions = range (- 2 ,3 ), training_val_split = 0.9 , max_encoded_length = seq_length + 1 ,
@@ -28,15 +46,14 @@ def main():
28
46
29
47
training_sequences = pipeline .encoded_sequences ['training' ]
30
48
validation_sequences = pipeline .encoded_sequences ['validation' ]
31
- n_tokens = 256 + 125 + 32
32
49
33
50
batch_size = 16
34
- transformer = MusicTransformer (n_tokens , seq_length , d_model = 64 , n_heads = 8 ,
35
- d_feedforward = 256 , depth = 4 , positional_encoding = True , relative_pos = True )
36
51
37
52
train (transformer , training_sequences , validation_sequences ,
38
- epochs = 5 , evaluate_per = 1 ,
53
+ epochs = args . n_epochs , evaluate_per = 1 ,
39
54
batch_size = batch_size , batches_per_print = 100 ,
40
55
padding_index = 0 , checkpoint_path = checkpoint )
41
56
42
57
58
+ if __name__ == "__main__" :
59
+ main ()
0 commit comments