1
- import argparse , pathlib , uuid , subprocess
1
+ import argparse , uuid , subprocess
2
+ import torch
2
3
from model import MusicTransformer
3
4
from preprocess import SequenceEncoder
4
- import torch
5
- import torch .nn .functional as F
6
- import numpy as np
7
- from helpers import one_hot
8
- from pretty_midi import PrettyMIDI , Instrument
5
+ from helpers import sample , write_midi
9
6
import midi_input
10
- import pdb
7
+ import yaml
11
8
12
9
class GeneratorError (Exception ):
13
10
pass
14
11
15
- def write_midi (note_sequence , output_dir , filename ):
16
-
17
- #make output directory
18
- pathlib .Path (output_dir ).mkdir (parents = True , exist_ok = True )
19
-
20
- #generate midi
21
- midi = PrettyMIDI ()
22
- piano_track = Instrument (program = 0 , is_drum = False , name = filename )
23
- piano_track .notes = note_sequence
24
- midi .instruments .append (piano_track )
25
- output_name = output_dir + f"{ filename } .midi"
26
- midi .write (output_name )
27
-
28
- def sample (model , sample_length , prime_sequence = [], temperature = 1 ,
29
- topk = None ):
30
- """
31
- Generate a MIDI event sequence of a fixed length by randomly sampling from a model's distribution of sequences. Optionally, "seed" the sequence with a
32
- prime. A well-trained model will create music that responds to the prime
33
- and develops upon it.
34
- """
35
- #deactivate training mode
36
- model .eval ()
37
- if len (prime_sequence ) == 0 :
38
- #if no prime is provided, randomly select a starting event
39
- input_sequence = [np .random .randint (model .n_tokens )]
40
- else :
41
- input_sequence = prime_sequence
42
-
43
- for i in range (sample_length ):
44
- if torch .cuda .is_available ():
45
- input_tensor = torch .LongTensor (input_sequence ).cuda ()
46
- else :
47
- input_tensor = torch .LongTensor (input_sequence )
48
- #add singleton dimension for the batch
49
- input_tensor = input_tensor .unsqueeze (0 )
50
- out = model (input_tensor )
51
- probs = F .softmax (out / temperature , dim = - 1 )
52
- #keep the probability distribution for the *next* state only
53
- probs = probs [:, - 1 , :]
54
-
55
- if topk is not None :
56
- #sample from only the top k most probable states
57
- values , indices = probs .topk (topk )
58
- if torch .cuda .is_available ():
59
- zeros = torch .zeros (model .n_tokens ).cuda ()
60
- else :
61
- zeros = torch .zeros (model .n_tokens )
62
- probs = torch .scatter (zeros , 0 , indices , values )
63
-
64
- next_char_ix = torch .multinomial (probs ,1 ).item ()
65
-
66
- input_sequence .append (next_char_ix )
67
-
68
- return input_sequence
69
-
70
12
def main ():
71
13
parser = argparse .ArgumentParser ("Script to generate MIDI tracks by sampling from a trained model." )
72
14
73
- # parser.add_argument("--model_key", type=str,
74
- # help="key to MODEL_DICT, allowing access to the path of a saved model & its params ")
15
+ parser .add_argument ("--model_key" , type = str ,
16
+ help = "Key in saved_models/model.yaml, helps look up model arguments and path to saved checkpoint. " )
75
17
parser .add_argument ("--sample_length" , type = int , default = 512 ,
76
18
help = "number of events to generate" )
77
19
parser .add_argument ("--temps" , nargs = "+" , type = float ,
78
20
default = [1.0 ],
79
21
help = "space-separated list of temperatures to use when sampling" )
80
- parser .add_argument ("--topks" , nargs = "+" , type = int ,
81
- help = "space-separated list of topks to use when sampling" )
82
- parser .add_argument ("--n_trials" , type = int , default = 5 ,
22
+ parser .add_argument ("--n_trials" , type = int , default = 3 ,
83
23
help = "number of MIDI samples to generate per experiment" )
84
24
parser .add_argument ("--live_input" , action = 'store_true' , default = False ,
85
25
help = "if true, take in a seed from a MIDI input controller" )
@@ -91,13 +31,15 @@ def main():
91
31
92
32
args = parser .parse_args ()
93
33
94
- # model_key = args.model_key
95
- # if MODEL_DICT.get(model_key) is None:
96
- # raise GeneratorError("model key not supplied or not recognized!")
97
- model_path = "saved_models/tf_20200124"
98
- model_key = "tf_20200124"
99
- model_args = {"n_states" : 413 , "d_model" : 64 ,
100
- "dim_feedforward" : 512 , "n_heads" : 4 , "n_layers" : 3 }
34
+ model_key = args .model_key
35
+
36
+ try :
37
+ model_dict = yaml .safe_load (open ('saved_models/model.yaml' ))[model_key ]
38
+ except :
39
+ raise GeneratorError (f"could not find yaml information for key { model_key } " )
40
+
41
+ model_path = model_dict ["path" ]
42
+ model_args = model_dict ["args" ]
101
43
try :
102
44
state = torch .load (model_path )
103
45
except RuntimeError :
@@ -106,7 +48,8 @@ def main():
106
48
n_velocity_events = 32
107
49
n_time_shift_events = 125
108
50
109
- decoder = SequenceEncoder (n_time_shift_events , n_velocity_events )
51
+ decoder = SequenceEncoder (n_time_shift_events , n_velocity_events ,
52
+ min_events = 0 )
110
53
111
54
if args .live_input :
112
55
print ("Expecting a midi input..." )
@@ -117,14 +60,10 @@ def main():
117
60
prime_sequence = []
118
61
119
62
model = MusicTransformer (** model_args )
120
- model .load_state_dict (state )
63
+ model .load_state_dict (state , strict = False )
121
64
122
65
temps = args .temps
123
66
124
- topks = args .topks
125
- if topks is None :
126
- topks = [None ]
127
-
128
67
trial_key = str (uuid .uuid4 ())[:6 ]
129
68
n_trials = args .n_trials
130
69
@@ -136,8 +75,7 @@ def main():
136
75
note_sequence = []
137
76
for i in range (n_trials ):
138
77
print ("generating sequence" )
139
- output_sequence = sample (model , prime_sequence = prime_sequence ,
140
- sample_length = args .sample_length , temperature = temp )
78
+ output_sequence = sample (model , prime_sequence = prime_sequence , sample_length = args .sample_length , temperature = temp )
141
79
note_sequence = decoder .decode_sequence (output_sequence ,
142
80
verbose = True , stuck_note_duration = None )
143
81
0 commit comments