Skip to content

Commit 135dfbc

Browse files
committed
Overhaul generate.py
1 parent b63d97c commit 135dfbc

File tree

6 files changed

+71
-90
lines changed

6 files changed

+71
-90
lines changed

generate.py

+20-82
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,25 @@
1-
import argparse, pathlib, uuid, subprocess
1+
import argparse, uuid, subprocess
2+
import torch
23
from model import MusicTransformer
34
from preprocess import SequenceEncoder
4-
import torch
5-
import torch.nn.functional as F
6-
import numpy as np
7-
from helpers import one_hot
8-
from pretty_midi import PrettyMIDI, Instrument
5+
from helpers import sample, write_midi
96
import midi_input
10-
import pdb
7+
import yaml
118

129
class GeneratorError(Exception):
1310
pass
1411

15-
def write_midi(note_sequence, output_dir, filename):
16-
17-
#make output directory
18-
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
19-
20-
#generate midi
21-
midi = PrettyMIDI()
22-
piano_track = Instrument(program=0, is_drum=False, name=filename)
23-
piano_track.notes = note_sequence
24-
midi.instruments.append(piano_track)
25-
output_name = output_dir + f"{filename}.midi"
26-
midi.write(output_name)
27-
28-
def sample(model, sample_length, prime_sequence=[], temperature=1,
29-
topk=None):
30-
"""
31-
Generate a MIDI event sequence of a fixed length by randomly sampling from a model's distribution of sequences. Optionally, "seed" the sequence with a
32-
prime. A well-trained model will create music that responds to the prime
33-
and develops upon it.
34-
"""
35-
#deactivate training mode
36-
model.eval()
37-
if len(prime_sequence) == 0:
38-
#if no prime is provided, randomly select a starting event
39-
input_sequence = [np.random.randint(model.n_tokens)]
40-
else:
41-
input_sequence = prime_sequence
42-
43-
for i in range(sample_length):
44-
if torch.cuda.is_available():
45-
input_tensor = torch.LongTensor(input_sequence).cuda()
46-
else:
47-
input_tensor = torch.LongTensor(input_sequence)
48-
#add singleton dimension for the batch
49-
input_tensor = input_tensor.unsqueeze(0)
50-
out = model(input_tensor)
51-
probs = F.softmax(out / temperature, dim=-1)
52-
#keep the probability distribution for the *next* state only
53-
probs = probs[:, -1, :]
54-
55-
if topk is not None:
56-
#sample from only the top k most probable states
57-
values, indices = probs.topk(topk)
58-
if torch.cuda.is_available():
59-
zeros = torch.zeros(model.n_tokens).cuda()
60-
else:
61-
zeros = torch.zeros(model.n_tokens)
62-
probs = torch.scatter(zeros, 0, indices, values)
63-
64-
next_char_ix = torch.multinomial(probs,1).item()
65-
66-
input_sequence.append(next_char_ix)
67-
68-
return input_sequence
69-
7012
def main():
7113
parser = argparse.ArgumentParser("Script to generate MIDI tracks by sampling from a trained model.")
7214

73-
# parser.add_argument("--model_key", type=str,
74-
# help="key to MODEL_DICT, allowing access to the path of a saved model & its params")
15+
parser.add_argument("--model_key", type=str,
16+
help="Key in saved_models/model.yaml, helps look up model arguments and path to saved checkpoint.")
7517
parser.add_argument("--sample_length", type=int, default=512,
7618
help="number of events to generate")
7719
parser.add_argument("--temps", nargs="+", type=float,
7820
default=[1.0],
7921
help="space-separated list of temperatures to use when sampling")
80-
parser.add_argument("--topks", nargs="+", type=int,
81-
help="space-separated list of topks to use when sampling")
82-
parser.add_argument("--n_trials", type=int, default=5,
22+
parser.add_argument("--n_trials", type=int, default=3,
8323
help="number of MIDI samples to generate per experiment")
8424
parser.add_argument("--live_input", action='store_true', default = False,
8525
help="if true, take in a seed from a MIDI input controller")
@@ -91,13 +31,15 @@ def main():
9131

9232
args=parser.parse_args()
9333

94-
# model_key = args.model_key
95-
# if MODEL_DICT.get(model_key) is None:
96-
# raise GeneratorError("model key not supplied or not recognized!")
97-
model_path = "saved_models/tf_20200124"
98-
model_key = "tf_20200124"
99-
model_args = {"n_states": 413, "d_model": 64,
100-
"dim_feedforward": 512, "n_heads": 4, "n_layers": 3}
34+
model_key = args.model_key
35+
36+
try:
37+
model_dict = yaml.safe_load(open('saved_models/model.yaml'))[model_key]
38+
except:
39+
raise GeneratorError(f"could not find yaml information for key {model_key}")
40+
41+
model_path = model_dict["path"]
42+
model_args = model_dict["args"]
10143
try:
10244
state = torch.load(model_path)
10345
except RuntimeError:
@@ -106,7 +48,8 @@ def main():
10648
n_velocity_events = 32
10749
n_time_shift_events = 125
10850

109-
decoder = SequenceEncoder(n_time_shift_events, n_velocity_events)
51+
decoder = SequenceEncoder(n_time_shift_events, n_velocity_events,
52+
min_events=0)
11053

11154
if args.live_input:
11255
print("Expecting a midi input...")
@@ -117,14 +60,10 @@ def main():
11760
prime_sequence = []
11861

11962
model = MusicTransformer(**model_args)
120-
model.load_state_dict(state)
63+
model.load_state_dict(state, strict=False)
12164

12265
temps = args.temps
12366

124-
topks = args.topks
125-
if topks is None:
126-
topks = [None]
127-
12867
trial_key = str(uuid.uuid4())[:6]
12968
n_trials = args.n_trials
13069

@@ -136,8 +75,7 @@ def main():
13675
note_sequence = []
13776
for i in range(n_trials):
13877
print("generating sequence")
139-
output_sequence = sample(model, prime_sequence = prime_sequence,
140-
sample_length=args.sample_length, temperature=temp)
78+
output_sequence = sample(model, prime_sequence = prime_sequence, sample_length=args.sample_length, temperature=temp)
14179
note_sequence = decoder.decode_sequence(output_sequence,
14280
verbose=True, stuck_note_duration=None)
14381

helpers.py

+45-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
22
import numpy as np
3-
from pretty_midi import Note
3+
from pretty_midi import Note, PrettyMIDI, Instrument
44
import torch.nn.functional as F
5-
import copy
5+
import copy, pathlib
6+
import pdb
67

78
def vectorize(sequence):
89
"""
@@ -66,3 +67,45 @@ def d(tensor=None):
6667
if tensor is None:
6768
return 'cuda' if torch.cuda.is_available() else 'cpu'
6869
return 'cuda' if tensor.is_cuda else 'cpu'
70+
71+
def write_midi(note_sequence, output_dir, filename):
72+
73+
#make output directory
74+
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
75+
76+
#generate midi
77+
midi = PrettyMIDI()
78+
piano_track = Instrument(program=0, is_drum=False, name=filename)
79+
piano_track.notes = note_sequence
80+
midi.instruments.append(piano_track)
81+
output_name = output_dir + f"{filename}.midi"
82+
midi.write(output_name)
83+
84+
def sample(model, sample_length, prime_sequence=[], temperature=1):
85+
"""
86+
Generate a MIDI event sequence of a fixed length by randomly sampling from a model's distribution of sequences. Optionally, "seed" the sequence with a prime. A well-trained model will create music that responds to the prime and develops upon it.
87+
"""
88+
#deactivate training mode
89+
model.eval()
90+
if len(prime_sequence) == 0:
91+
#if no prime is provided, randomly select a starting event
92+
input_sequence = [np.random.randint(model.n_tokens)]
93+
else:
94+
input_sequence = prime_sequence.copy()
95+
96+
#add singleton dimension for the batch
97+
input_tensor = torch.LongTensor(input_sequence).unsqueeze(0)
98+
99+
for i in range(sample_length):
100+
#select probabilities of *next* token
101+
out = model(input_tensor)[0, -1, :]
102+
#out is a 1d tensor of shape (n_tokens)
103+
probs = F.softmax(out / temperature, dim=0)
104+
#sample prob distribution for next character
105+
pdb.set_trace()
106+
c = torch.multinomial(probs,1)
107+
input_tensor = torch.cat([input_tensor[:,1:], c[None]], dim=1)
108+
input_sequence.append(c.item())
109+
110+
return input_sequence
111+

midi_input.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def read(n_velocity_events=32, n_time_shift_events=125):
6969
i += 1
7070

7171
note_sequence = quantize(note_sequence, n_velocity_events, n_time_shift_events)
72-
#sigh bad practice
72+
7373
note_sequence = vectorize(note_sequence)
7474
return note_sequence
7575

@@ -84,10 +84,8 @@ def quantize(note_sequence, n_velocity_events, n_time_shift_events):
8484

8585
note.velocity = (note.velocity // velocity_step) * velocity_step + 1
8686

87-
return note_sequence
88-
89-
9087

88+
return note_sequence
9189

9290
if __name__ == "__main__":
9391
read()

model/attention.py

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def forward(self, x, mask):
4747
for w, x in zip(self.linears, (x,x,x))]
4848
if self.relative_pos:
4949
#apply same position embeddings across the batch
50+
#Is it possible to apply positional self-attention over
51+
#only half of all relative distances?
5052
Er = self.Er[:, embedding_start:, :].unsqueeze(0)
5153
QEr = torch.matmul(queries, Er.transpose(-1,-2))
5254
QEr = self._mask_positions(QEr)

tests/transformer_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from preprocess import PreprocessingPipeline
44
from train import train
55
from model import MusicTransformer
6-
from generate import sample
6+
from helpers import sample
77

88
def main():
99

train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def train(model, training_data, validation_data,
103103
averaged_loss = 0
104104
batch_num += 1
105105

106-
print(f"epoch: {e+1}/{epochs} | time: {time.time() - batch_start_time:.0f}s")
106+
print(f"epoch: {e+1}/{epochs} | time: {(time.time() - batch_start_time) / 60:,.0f}m")
107107
shuffle(training_data)
108108

109109
if (e + 1) % evaluate_per == 0:

0 commit comments

Comments
 (0)