diff --git a/examples/vae/README.md b/examples/vae/README.md index 325aba8ff..51de127f0 100644 --- a/examples/vae/README.md +++ b/examples/vae/README.md @@ -5,21 +5,20 @@ This code follows [pytorch/examples/vae](https://github.com/pytorch/examples/blo ```bash pip install -r requirements.txt -python main.py --workdir=/tmp/mnist --config=configs/default.py +python main.py --workdir=/tmp/mnist ``` -## Overriding Hyperparameter configurations +## Configuring hyperparameters -This VAE example allows specifying a hyperparameter configuration by the means of -setting `--config` flag. Configuration flag is defined using -[config_flags](https://github.com/google/ml_collections/tree/master#config-flags). -`config_flags` allows overriding configuration fields. This can be done as -follows: +The VAE example uses simple command line arguments for configuration. You can override the default values as follows: ```shell python main.py \ ---workdir=/tmp/mnist --config=configs/default.py \ ---config.learning_rate=0.01 --config.num_epochs=10 +--workdir=/tmp/mnist \ +--learning_rate=0.01 \ +--num_epochs=10 \ +--batch_size=128 \ +--latents=20 ``` diff --git a/examples/vae/config.py b/examples/vae/config.py new file mode 100644 index 000000000..a0f26ebdd --- /dev/null +++ b/examples/vae/config.py @@ -0,0 +1,14 @@ +"""Simple configuration using dataclasses instead of ml_collections.""" +from dataclasses import dataclass + +@dataclass +class TrainingConfig: + """Training configuration parameters.""" + learning_rate: float = 0.001 + latents: int = 20 + batch_size: int = 128 + num_epochs: int = 30 + +def get_default_config() -> TrainingConfig: + """Get the default configuration.""" + return TrainingConfig() \ No newline at end of file diff --git a/examples/vae/configs/default.py b/examples/vae/configs/default.py deleted file mode 100644 index c18d0cf20..000000000 --- a/examples/vae/configs/default.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Default Hyperparameter configuration.""" - -import ml_collections - - -def get_config(): - """Get the default hyperparameter configuration.""" - config = ml_collections.ConfigDict() - - config.learning_rate = 0.001 - config.latents = 20 - config.batch_size = 128 - config.num_epochs = 30 - return config diff --git a/examples/vae/main.py b/examples/vae/main.py index 537ec08d6..6bd277f68 100644 --- a/examples/vae/main.py +++ b/examples/vae/main.py @@ -18,46 +18,65 @@ that can be easily tested and imported in Colab. """ -from absl import app -from absl import flags -from absl import logging -from clu import platform +import argparse +import logging import jax -from ml_collections import config_flags import tensorflow as tf - +import time import train +from config import TrainingConfig, get_default_config +import os + +def setup_training_args(): + """Setup training arguments with defaults from config.""" + parser = argparse.ArgumentParser(description='VAE Training Script') + config = get_default_config() + # Add all config parameters as arguments + parser.add_argument('--learning_rate', type=float, default=config.learning_rate, + help='Learning rate for training') + parser.add_argument('--latents', type=int, default=config.latents, + help='Number of latent dimensions') + parser.add_argument('--batch_size', type=int, default=config.batch_size, + help='Batch size for training') + parser.add_argument('--num_epochs', type=int, default=config.num_epochs, + help='Number of training epochs') + parser.add_argument('--workdir', type=str, default='/tmp/vae', + help='Working directory for checkpoints and logs') -FLAGS = flags.FLAGS + args = parser.parse_args() -config_flags.DEFINE_config_file( - 'config', - None, - 'File path to the training hyperparameter configuration.', - lock_config=True, -) + # Convert args to TrainingConfig + return TrainingConfig( + learning_rate=args.learning_rate, + latents=args.latents, + batch_size=args.batch_size, + num_epochs=args.num_epochs + ), args.workdir +def main(): + # Configure logging + logging.basicConfig(level=logging.INFO) -def main(argv): - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') + # Parse arguments and get config + config, workdir = setup_training_args() - # Make sure tf does not allocate gpu memory. - tf.config.experimental.set_visible_devices([], 'GPU') + # Create workdir if it doesn't exist + os.makedirs(workdir, exist_ok=True) - logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) - logging.info('JAX local devices: %r', jax.local_devices()) + # Make sure tf does not allocate gpu memory. + tf.config.experimental.set_visible_devices([], 'GPU') - # Add a note so that we can tell which task is which JAX host. - # (Depending on the platform task 0 is not guaranteed to be host 0) - platform.work_unit().set_task_status( - f'process_index: {jax.process_index()}, ' - f'process_count: {jax.process_count()}' - ) + logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) + logging.info('JAX local devices: %r', jax.local_devices()) - train.train_and_evaluate(FLAGS.config) + # Simple process logging + logging.info('Starting training process %d/%d', + jax.process_index(), jax.process_count()) + start = time.perf_counter() + train.train_and_evaluate(config) + logging.info('Total training time: %.2f seconds', time.perf_counter() - start) if __name__ == '__main__': - app.run(main) + main() diff --git a/examples/vae/models.py b/examples/vae/models.py index a1bbd94c3..0130d9c57 100644 --- a/examples/vae/models.py +++ b/examples/vae/models.py @@ -14,44 +14,47 @@ """VAE model definitions.""" -from flax import linen as nn +from flax import nnx from jax import random import jax.numpy as jnp -class Encoder(nn.Module): +class Encoder(nnx.Module): """VAE Encoder.""" - latents: int + def __init__(self, input_features: int, latents: int, *, rngs: nnx.Rngs): + self.fc1 = nnx.Linear(input_features, 500, rngs=rngs) + self.fc2_mean = nnx.Linear(500, latents, rngs=rngs) + self.fc2_logvar = nnx.Linear(500, latents, rngs=rngs) - @nn.compact def __call__(self, x): - x = nn.Dense(500, name='fc1')(x) - x = nn.relu(x) - mean_x = nn.Dense(self.latents, name='fc2_mean')(x) - logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x) + x = self.fc1(x) + x = nnx.relu(x) + mean_x = self.fc2_mean(x) + logvar_x = self.fc2_logvar(x) return mean_x, logvar_x -class Decoder(nn.Module): +class Decoder(nnx.Module): """VAE Decoder.""" - @nn.compact + def __init__(self, latents: int, output_features: int, *, rngs: nnx.Rngs): + self.fc1 = nnx.Linear(latents, 500, rngs=rngs) + self.fc2 = nnx.Linear(500, output_features, rngs=rngs) + def __call__(self, z): - z = nn.Dense(500, name='fc1')(z) - z = nn.relu(z) - z = nn.Dense(784, name='fc2')(z) + z = self.fc1(z) + z = nnx.relu(z) + z = self.fc2(z) return z -class VAE(nn.Module): +class VAE(nnx.Module): """Full VAE model.""" - latents: int = 20 - - def setup(self): - self.encoder = Encoder(self.latents) - self.decoder = Decoder() + def __init__(self, input_features: int, latents: int, rngs: nnx.Rngs): + self.encoder = Encoder(input_features=input_features, latents=latents, rngs=rngs) + self.decoder = Decoder(latents=latents, output_features=input_features, rngs=rngs) def __call__(self, x, z_rng): mean, logvar = self.encoder(x) @@ -60,7 +63,7 @@ def __call__(self, x, z_rng): return recon_x, mean, logvar def generate(self, z): - return nn.sigmoid(self.decoder(z)) + return nnx.sigmoid(self.decoder(z)) def reparameterize(rng, mean, logvar): @@ -69,5 +72,5 @@ def reparameterize(rng, mean, logvar): return mean + eps * std -def model(latents): - return VAE(latents=latents) +def model(input_features: int, latents: int, rngs: nnx.Rngs): + return VAE(input_features=input_features, latents=latents, rngs=rngs) diff --git a/examples/vae/requirements.txt b/examples/vae/requirements.txt index cc497ca7e..dcee66a19 100644 --- a/examples/vae/requirements.txt +++ b/examples/vae/requirements.txt @@ -1,7 +1,7 @@ absl-py==1.4.0 -flax==0.6.9 -numpy==1.23.5 +flax~=0.12 +numpy>=1.26.4 optax==0.1.5 Pillow==10.2.0 -tensorflow==2.12.0 +tensorflow-cpu~=2.18.0 tensorflow-datasets==4.9.2 \ No newline at end of file diff --git a/examples/vae/train.py b/examples/vae/train.py index 84f1b582a..4fb40546c 100644 --- a/examples/vae/train.py +++ b/examples/vae/train.py @@ -14,15 +14,14 @@ """Training and evaluation logic.""" from absl import logging -from flax import linen as nn +from flax import nnx import input_pipeline import models import utils as vae_utils -from flax.training import train_state import jax from jax import random import jax.numpy as jnp -import ml_collections +from config import TrainingConfig import optax import tensorflow_datasets as tfds @@ -34,7 +33,7 @@ def kl_divergence(mean, logvar): @jax.vmap def binary_cross_entropy_with_logits(logits, labels): - logits = nn.log_sigmoid(logits) + logits = nnx.log_sigmoid(logits) return -jnp.sum( labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits)) ) @@ -45,39 +44,37 @@ def compute_metrics(recon_x, x, mean, logvar): kld_loss = kl_divergence(mean, logvar).mean() return {'bce': bce_loss, 'kld': kld_loss, 'loss': bce_loss + kld_loss} - -def train_step(state, batch, z_rng, latents): - def loss_fn(params): - recon_x, mean, logvar = models.model(latents).apply( - {'params': params}, batch, z_rng - ) - +@nnx.jit +def train_step(model: nnx.Module, optimizer: nnx.Optimizer, batch, z_rng): + """Single training step for the VAE model.""" + def loss_fn(model): + recon_x, mean, logvar = model(batch, z_rng) bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean() kld_loss = kl_divergence(mean, logvar).mean() loss = bce_loss + kld_loss return loss - grads = jax.grad(loss_fn)(state.params) - return state.apply_gradients(grads=grads) - + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(model, grads) + return loss -def eval_f(params, images, z, z_rng, latents): - def eval_model(vae): - recon_images, mean, logvar = vae(images, z_rng) - comparison = jnp.concatenate([ - images[:8].reshape(-1, 28, 28, 1), - recon_images[:8].reshape(-1, 28, 28, 1), - ]) - generate_images = vae.generate(z) - generate_images = generate_images.reshape(-1, 28, 28, 1) - metrics = compute_metrics(recon_images, images, mean, logvar) - return metrics, comparison, generate_images +@nnx.jit +def eval_f(model: nnx.Module, images, z, z_rng): + """Evaluation function for the VAE model.""" + recon_images, mean, logvar = model(images, z_rng) + comparison = jnp.concatenate([ + images[:8].reshape(-1, 28, 28, 1), + recon_images[:8].reshape(-1, 28, 28, 1), + ]) + generate_images = model.generate(z) + generate_images = generate_images.reshape(-1, 28, 28, 1) + metrics = compute_metrics(recon_images, images, mean, logvar) + return metrics, comparison, generate_images - return nn.apply(eval_model, models.model(latents))({'params': params}) -def train_and_evaluate(config: ml_collections.ConfigDict): +def train_and_evaluate(config: TrainingConfig): """Train and evaulate pipeline.""" rng = random.key(0) rng, key = random.split(rng) @@ -90,14 +87,9 @@ def train_and_evaluate(config: ml_collections.ConfigDict): test_ds = input_pipeline.build_test_set(ds_builder) logging.info('Initializing model.') - init_data = jnp.ones((config.batch_size, 784), jnp.float32) - params = models.model(config.latents).init(key, init_data, rng)['params'] - - state = train_state.TrainState.create( - apply_fn=models.model(config.latents).apply, - params=params, - tx=optax.adam(config.learning_rate), - ) + rngs = nnx.Rngs(0) + model = models.model(784, config.latents, rngs=rngs) + optimizer = nnx.Optimizer(model, optax.adam(config.learning_rate), wrt=nnx.Param) rng, z_key, eval_rng = random.split(rng, 3) z = random.normal(z_key, (64, config.latents)) @@ -110,11 +102,10 @@ def train_and_evaluate(config: ml_collections.ConfigDict): for _ in range(steps_per_epoch): batch = next(train_ds) rng, key = random.split(rng) - state = train_step(state, batch, key, config.latents) + loss_val = train_step(model, optimizer, batch, key) metrics, comparison, sample = eval_f( - state.params, test_ds, z, eval_rng, config.latents - ) + model, test_ds, z, eval_rng) vae_utils.save_image( comparison, f'results/reconstruction_{epoch}.png', nrow=8 )