-
Notifications
You must be signed in to change notification settings - Fork 763
Migrate VAE example to Flax NNX #5077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b4627cb
6fd002a
f959a6d
76971ef
ed731fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| """Simple configuration using dataclasses instead of ml_collections.""" | ||
| from dataclasses import dataclass | ||
|
|
||
| @dataclass | ||
| class TrainingConfig: | ||
| """Training configuration parameters.""" | ||
| learning_rate: float = 0.001 | ||
| latents: int = 20 | ||
| batch_size: int = 128 | ||
| num_epochs: int = 30 | ||
|
|
||
| def get_default_config() -> TrainingConfig: | ||
| """Get the default configuration.""" | ||
| return TrainingConfig() |
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,46 +18,65 @@ | |
| that can be easily tested and imported in Colab. | ||
| """ | ||
|
|
||
| from absl import app | ||
| from absl import flags | ||
| from absl import logging | ||
| from clu import platform | ||
| import argparse | ||
| import logging | ||
| import jax | ||
| from ml_collections import config_flags | ||
| import tensorflow as tf | ||
|
|
||
| import time | ||
| import train | ||
| from config import TrainingConfig, get_default_config | ||
| import os | ||
|
|
||
| def setup_training_args(): | ||
| """Setup training arguments with defaults from config.""" | ||
| parser = argparse.ArgumentParser(description='VAE Training Script') | ||
| config = get_default_config() | ||
|
|
||
| # Add all config parameters as arguments | ||
| parser.add_argument('--learning_rate', type=float, default=config.learning_rate, | ||
| help='Learning rate for training') | ||
| parser.add_argument('--latents', type=int, default=config.latents, | ||
| help='Number of latent dimensions') | ||
| parser.add_argument('--batch_size', type=int, default=config.batch_size, | ||
| help='Batch size for training') | ||
| parser.add_argument('--num_epochs', type=int, default=config.num_epochs, | ||
| help='Number of training epochs') | ||
| parser.add_argument('--workdir', type=str, default='/tmp/vae', | ||
| help='Working directory for checkpoints and logs') | ||
|
|
||
| FLAGS = flags.FLAGS | ||
| args = parser.parse_args() | ||
|
|
||
| config_flags.DEFINE_config_file( | ||
| 'config', | ||
| None, | ||
| 'File path to the training hyperparameter configuration.', | ||
| lock_config=True, | ||
| ) | ||
| # Convert args to TrainingConfig | ||
| return TrainingConfig( | ||
| learning_rate=args.learning_rate, | ||
| latents=args.latents, | ||
| batch_size=args.batch_size, | ||
| num_epochs=args.num_epochs | ||
| ), args.workdir | ||
|
|
||
| def main(): | ||
| # Configure logging | ||
| logging.basicConfig(level=logging.INFO) | ||
|
|
||
| def main(argv): | ||
| if len(argv) > 1: | ||
| raise app.UsageError('Too many command-line arguments.') | ||
| # Parse arguments and get config | ||
| config, workdir = setup_training_args() | ||
|
|
||
| # Make sure tf does not allocate gpu memory. | ||
| tf.config.experimental.set_visible_devices([], 'GPU') | ||
| # Create workdir if it doesn't exist | ||
| os.makedirs(workdir, exist_ok=True) | ||
|
|
||
| logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) | ||
| logging.info('JAX local devices: %r', jax.local_devices()) | ||
| # Make sure tf does not allocate gpu memory. | ||
| tf.config.experimental.set_visible_devices([], 'GPU') | ||
|
|
||
| # Add a note so that we can tell which task is which JAX host. | ||
| # (Depending on the platform task 0 is not guaranteed to be host 0) | ||
| platform.work_unit().set_task_status( | ||
| f'process_index: {jax.process_index()}, ' | ||
| f'process_count: {jax.process_count()}' | ||
| ) | ||
| logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) | ||
| logging.info('JAX local devices: %r', jax.local_devices()) | ||
|
|
||
| train.train_and_evaluate(FLAGS.config) | ||
| # Simple process logging | ||
| logging.info('Starting training process %d/%d', | ||
| jax.process_index(), jax.process_count()) | ||
|
|
||
| start = time.perf_counter() | ||
| train.train_and_evaluate(config) | ||
| logging.info('Total training time: %.2f seconds', time.perf_counter() - start) | ||
|
|
||
| if __name__ == '__main__': | ||
| app.run(main) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sanepunk why do you remove abseil app and the usage of config file? |
||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| absl-py==1.4.0 | ||
| flax==0.6.9 | ||
| numpy==1.23.5 | ||
| flax~=0.12 | ||
| numpy>=1.26.4 | ||
| optax==0.1.5 | ||
| Pillow==10.2.0 | ||
| tensorflow==2.12.0 | ||
| tensorflow-cpu~=2.18.0 | ||
| tensorflow-datasets==4.9.2 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,15 +14,14 @@ | |
| """Training and evaluation logic.""" | ||
|
|
||
| from absl import logging | ||
| from flax import linen as nn | ||
| from flax import nnx | ||
| import input_pipeline | ||
| import models | ||
| import utils as vae_utils | ||
| from flax.training import train_state | ||
| import jax | ||
| from jax import random | ||
| import jax.numpy as jnp | ||
| import ml_collections | ||
| from config import TrainingConfig | ||
| import optax | ||
| import tensorflow_datasets as tfds | ||
|
|
||
|
|
@@ -34,7 +33,7 @@ def kl_divergence(mean, logvar): | |
|
|
||
| @jax.vmap | ||
| def binary_cross_entropy_with_logits(logits, labels): | ||
| logits = nn.log_sigmoid(logits) | ||
| logits = nnx.log_sigmoid(logits) | ||
| return -jnp.sum( | ||
| labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits)) | ||
| ) | ||
|
|
@@ -45,39 +44,37 @@ def compute_metrics(recon_x, x, mean, logvar): | |
| kld_loss = kl_divergence(mean, logvar).mean() | ||
| return {'bce': bce_loss, 'kld': kld_loss, 'loss': bce_loss + kld_loss} | ||
|
|
||
|
|
||
| def train_step(state, batch, z_rng, latents): | ||
| def loss_fn(params): | ||
| recon_x, mean, logvar = models.model(latents).apply( | ||
| {'params': params}, batch, z_rng | ||
| ) | ||
|
|
||
| @nnx.jit | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use donate args to donate model and optimizer to reduce GPU memory usage.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried adding donate_argnums to nnx.jit in the train_step, but was getting NaN loss and kl divergence. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What to do about this? |
||
| def train_step(model: nnx.Module, optimizer: nnx.Optimizer, batch, z_rng): | ||
| """Single training step for the VAE model.""" | ||
| def loss_fn(model): | ||
| recon_x, mean, logvar = model(batch, z_rng) | ||
| bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean() | ||
| kld_loss = kl_divergence(mean, logvar).mean() | ||
| loss = bce_loss + kld_loss | ||
| return loss | ||
|
|
||
| grads = jax.grad(loss_fn)(state.params) | ||
| return state.apply_gradients(grads=grads) | ||
|
|
||
| loss, grads = nnx.value_and_grad(loss_fn)(model) | ||
| optimizer.update(model, grads) | ||
| return loss | ||
|
|
||
| def eval_f(params, images, z, z_rng, latents): | ||
| def eval_model(vae): | ||
| recon_images, mean, logvar = vae(images, z_rng) | ||
| comparison = jnp.concatenate([ | ||
| images[:8].reshape(-1, 28, 28, 1), | ||
| recon_images[:8].reshape(-1, 28, 28, 1), | ||
| ]) | ||
|
|
||
| generate_images = vae.generate(z) | ||
| generate_images = generate_images.reshape(-1, 28, 28, 1) | ||
| metrics = compute_metrics(recon_images, images, mean, logvar) | ||
| return metrics, comparison, generate_images | ||
| @nnx.jit | ||
| def eval_f(model: nnx.Module, images, z, z_rng): | ||
| """Evaluation function for the VAE model.""" | ||
| recon_images, mean, logvar = model(images, z_rng) | ||
| comparison = jnp.concatenate([ | ||
| images[:8].reshape(-1, 28, 28, 1), | ||
| recon_images[:8].reshape(-1, 28, 28, 1), | ||
| ]) | ||
| generate_images = model.generate(z) | ||
| generate_images = generate_images.reshape(-1, 28, 28, 1) | ||
| metrics = compute_metrics(recon_images, images, mean, logvar) | ||
| return metrics, comparison, generate_images | ||
|
|
||
| return nn.apply(eval_model, models.model(latents))({'params': params}) | ||
|
|
||
|
|
||
| def train_and_evaluate(config: ml_collections.ConfigDict): | ||
| def train_and_evaluate(config: TrainingConfig): | ||
| """Train and evaulate pipeline.""" | ||
| rng = random.key(0) | ||
| rng, key = random.split(rng) | ||
|
|
@@ -90,14 +87,9 @@ def train_and_evaluate(config: ml_collections.ConfigDict): | |
| test_ds = input_pipeline.build_test_set(ds_builder) | ||
|
|
||
| logging.info('Initializing model.') | ||
| init_data = jnp.ones((config.batch_size, 784), jnp.float32) | ||
| params = models.model(config.latents).init(key, init_data, rng)['params'] | ||
|
|
||
| state = train_state.TrainState.create( | ||
| apply_fn=models.model(config.latents).apply, | ||
| params=params, | ||
| tx=optax.adam(config.learning_rate), | ||
| ) | ||
| rngs = nnx.Rngs(0) | ||
| model = models.model(784, config.latents, rngs=rngs) | ||
| optimizer = nnx.Optimizer(model, optax.adam(config.learning_rate), wrt=nnx.Param) | ||
|
|
||
| rng, z_key, eval_rng = random.split(rng, 3) | ||
| z = random.normal(z_key, (64, config.latents)) | ||
|
|
@@ -110,11 +102,10 @@ def train_and_evaluate(config: ml_collections.ConfigDict): | |
| for _ in range(steps_per_epoch): | ||
| batch = next(train_ds) | ||
| rng, key = random.split(rng) | ||
| state = train_step(state, batch, key, config.latents) | ||
| loss_val = train_step(model, optimizer, batch, key) | ||
|
|
||
| metrics, comparison, sample = eval_f( | ||
| state.params, test_ds, z, eval_rng, config.latents | ||
| ) | ||
| model, test_ds, z, eval_rng) | ||
| vae_utils.save_image( | ||
| comparison, f'results/reconstruction_{epoch}.png', nrow=8 | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove argparse