Skip to content

Commit b63d97c

Browse files
author
DeepLearning VM
committed
Fixed argparser in run.py
2 parents bf77ac2 + 501836b commit b63d97c

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

run.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,30 @@
88
from preprocess import PreprocessingPipeline
99
from train import train
1010
from model import MusicTransformer
11+
import argparse
1112

1213
def main():
14+
parser = argparse.ArgumentParser("Script to train model on a GPU")
15+
parser.add_argument("--checkpoint", type=str, default=None,
16+
help="Optional path to saved model, if none provided, the model is trained from scratch.")
17+
parser.add_argument("--n_epochs", type=int, default=5,
18+
help="Number of training epochs.")
19+
args = parser.parse_args()
20+
1321
sampling_rate = 125
1422
n_velocity_bins = 32
1523
seq_length = 1024
24+
n_tokens = 256 + sampling_rate + n_velocity_bins
25+
transformer = MusicTransformer(n_tokens, seq_length,
26+
d_model = 64, n_heads = 8, d_feedforward=256,
27+
depth = 4, positional_encoding=True, relative_pos=True)
28+
29+
if args.checkpoint is not None:
30+
state = torch.load(args.checkpoint)
31+
transformer.load_state_dict(state)
32+
print(f"Successfully loaded checkpoint at {args.checkpoint}")
1633
#rule of thumb: 1 minute is roughly 2k tokens
34+
1735
pipeline = PreprocessingPipeline(input_dir="data", stretch_factors=[0.975, 1, 1.025],
1836
split_size=30, sampling_rate=sampling_rate, n_velocity_bins=n_velocity_bins,
1937
transpositions=range(-2,3), training_val_split=0.9, max_encoded_length=seq_length+1,
@@ -28,15 +46,14 @@ def main():
2846

2947
training_sequences = pipeline.encoded_sequences['training']
3048
validation_sequences = pipeline.encoded_sequences['validation']
31-
n_tokens = 256 + 125 + 32
3249

3350
batch_size = 16
34-
transformer = MusicTransformer(n_tokens, seq_length, d_model = 64, n_heads = 8,
35-
d_feedforward=256, depth = 4, positional_encoding=True, relative_pos=True)
3651

3752
train(transformer, training_sequences, validation_sequences,
38-
epochs = 5, evaluate_per = 1,
53+
epochs = args.n_epochs, evaluate_per = 1,
3954
batch_size = batch_size, batches_per_print=100,
4055
padding_index=0, checkpoint_path=checkpoint)
4156

4257

58+
if __name__=="__main__":
59+
main()

0 commit comments

Comments
 (0)