Skip to content
Cade Stocker edited this page Nov 28, 2025 · 2 revisions

Generators

Models

The following files define the different model types you can create for generators:

  • generator_gru.py
  • generator_lstm.py
  • generator_transformer.py

Checkpoints

There are subdirectories for each tokenization of the dataset used (naive, miditok, miditok augmented). The trained model weights (checkpoints) are stored in these.

Generator Factory:

Allows you to easily create different types of generators (options are lstm, gru, and transformer). The factory returns a generator model instance.

Arguments for get_generator are:

  • model_type (lstm, gru, transformer)
  • vocab_size
  • kwargs (other architecture specific params):
    • common kwargs:
      • embed_size
      • hidden_size
      • num_layers
      • dropout
    • for transformers:
      • d_model (model dimension)
      • nhead (number of attention heads)
      • dim_feedforward (feedforward dimension)
      • max_seq_length

get_default_config gets the default hyperparams for a given model type.

Discriminators

Models

The following files define the different model types you can create for discriminators:

  • discriminators_lstm.py
  • discriminators_mlp.py
  • discriminator_transformer.py

Checkpoints

There are subdirectories for each tokenization of the dataset used (naive, miditok, miditok augmented). The trained model weights (checkpoints) are stored in these.

Discriminator Factory:

Allows you to easily create different types of discriminators (options are lstm, mlp, and transformer). The factory returns a discriminator model instance.

Arguments for get_discriminator are:

  • model_type (lstm, mlp, transformer)
  • hidden1 and hidden2 (size of first and second hidden layer)
  • pitch_dim
  • context_measures
  • hidden1 and hidden2: size of hidden layers
  • pool
  • dropout
  • lstm and mlp
    • embed_size
    • hidden_size
    • num_layers
    • dropout
  • transformer
    • embed_size
    • num_heads
    • num_layers

get_default_config gets the default hyperparams for a given model type.

Clone this wiki locally