Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions examples/vae/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,20 @@ This code follows [pytorch/examples/vae](https://github.com/pytorch/examples/blo

```bash
pip install -r requirements.txt
python main.py --workdir=/tmp/mnist --config=configs/default.py
python main.py --workdir=/tmp/mnist
```

## Overriding Hyperparameter configurations
## Configuring hyperparameters

This VAE example allows specifying a hyperparameter configuration by the means of
setting `--config` flag. Configuration flag is defined using
[config_flags](https://github.com/google/ml_collections/tree/master#config-flags).
`config_flags` allows overriding configuration fields. This can be done as
follows:
The VAE example uses simple command line arguments for configuration. You can override the default values as follows:

```shell
python main.py \
--workdir=/tmp/mnist --config=configs/default.py \
--config.learning_rate=0.01 --config.num_epochs=10
--workdir=/tmp/mnist \
--learning_rate=0.01 \
--num_epochs=10 \
--batch_size=128 \
--latents=20
```


Expand Down
14 changes: 14 additions & 0 deletions examples/vae/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Simple configuration using dataclasses instead of ml_collections."""
from dataclasses import dataclass

@dataclass
class TrainingConfig:
"""Training configuration parameters."""
learning_rate: float = 0.001
latents: int = 20
batch_size: int = 128
num_epochs: int = 30

def get_default_config() -> TrainingConfig:
"""Get the default configuration."""
return TrainingConfig()
28 changes: 0 additions & 28 deletions examples/vae/configs/default.py

This file was deleted.

75 changes: 47 additions & 28 deletions examples/vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,46 +18,65 @@
that can be easily tested and imported in Colab.
"""

from absl import app
from absl import flags
from absl import logging
from clu import platform
import argparse
import logging
import jax
from ml_collections import config_flags
import tensorflow as tf

import time
import train
from config import TrainingConfig, get_default_config
import os

def setup_training_args():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove argparse

"""Setup training arguments with defaults from config."""
parser = argparse.ArgumentParser(description='VAE Training Script')
config = get_default_config()

# Add all config parameters as arguments
parser.add_argument('--learning_rate', type=float, default=config.learning_rate,
help='Learning rate for training')
parser.add_argument('--latents', type=int, default=config.latents,
help='Number of latent dimensions')
parser.add_argument('--batch_size', type=int, default=config.batch_size,
help='Batch size for training')
parser.add_argument('--num_epochs', type=int, default=config.num_epochs,
help='Number of training epochs')
parser.add_argument('--workdir', type=str, default='/tmp/vae',
help='Working directory for checkpoints and logs')

FLAGS = flags.FLAGS
args = parser.parse_args()

config_flags.DEFINE_config_file(
'config',
None,
'File path to the training hyperparameter configuration.',
lock_config=True,
)
# Convert args to TrainingConfig
return TrainingConfig(
learning_rate=args.learning_rate,
latents=args.latents,
batch_size=args.batch_size,
num_epochs=args.num_epochs
), args.workdir

def main():
# Configure logging
logging.basicConfig(level=logging.INFO)

def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
# Parse arguments and get config
config, workdir = setup_training_args()

# Make sure tf does not allocate gpu memory.
tf.config.experimental.set_visible_devices([], 'GPU')
# Create workdir if it doesn't exist
os.makedirs(workdir, exist_ok=True)

logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
logging.info('JAX local devices: %r', jax.local_devices())
# Make sure tf does not allocate gpu memory.
tf.config.experimental.set_visible_devices([], 'GPU')

# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
)
logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
logging.info('JAX local devices: %r', jax.local_devices())

train.train_and_evaluate(FLAGS.config)
# Simple process logging
logging.info('Starting training process %d/%d',
jax.process_index(), jax.process_count())

start = time.perf_counter()
train.train_and_evaluate(config)
logging.info('Total training time: %.2f seconds', time.perf_counter() - start)

if __name__ == '__main__':
app.run(main)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanepunk why do you remove abseil app and the usage of config file?

main()
47 changes: 25 additions & 22 deletions examples/vae/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,47 @@

"""VAE model definitions."""

from flax import linen as nn
from flax import nnx
from jax import random
import jax.numpy as jnp


class Encoder(nn.Module):
class Encoder(nnx.Module):
"""VAE Encoder."""

latents: int
def __init__(self, input_features: int, latents: int, *, rngs: nnx.Rngs):
self.fc1 = nnx.Linear(input_features, 500, rngs=rngs)
self.fc2_mean = nnx.Linear(500, latents, rngs=rngs)
self.fc2_logvar = nnx.Linear(500, latents, rngs=rngs)

@nn.compact
def __call__(self, x):
x = nn.Dense(500, name='fc1')(x)
x = nn.relu(x)
mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
x = self.fc1(x)
x = nnx.relu(x)
mean_x = self.fc2_mean(x)
logvar_x = self.fc2_logvar(x)
return mean_x, logvar_x


class Decoder(nn.Module):
class Decoder(nnx.Module):
"""VAE Decoder."""

@nn.compact
def __init__(self, latents: int, output_features: int, *, rngs: nnx.Rngs):
self.fc1 = nnx.Linear(latents, 500, rngs=rngs)
self.fc2 = nnx.Linear(500, output_features, rngs=rngs)

def __call__(self, z):
z = nn.Dense(500, name='fc1')(z)
z = nn.relu(z)
z = nn.Dense(784, name='fc2')(z)
z = self.fc1(z)
z = nnx.relu(z)
z = self.fc2(z)
return z


class VAE(nn.Module):
class VAE(nnx.Module):
"""Full VAE model."""

latents: int = 20

def setup(self):
self.encoder = Encoder(self.latents)
self.decoder = Decoder()
def __init__(self, input_features: int, latents: int, rngs: nnx.Rngs):
self.encoder = Encoder(input_features=input_features, latents=latents, rngs=rngs)
self.decoder = Decoder(latents=latents, output_features=input_features, rngs=rngs)

def __call__(self, x, z_rng):
mean, logvar = self.encoder(x)
Expand All @@ -60,7 +63,7 @@ def __call__(self, x, z_rng):
return recon_x, mean, logvar

def generate(self, z):
return nn.sigmoid(self.decoder(z))
return nnx.sigmoid(self.decoder(z))


def reparameterize(rng, mean, logvar):
Expand All @@ -69,5 +72,5 @@ def reparameterize(rng, mean, logvar):
return mean + eps * std


def model(latents):
return VAE(latents=latents)
def model(input_features: int, latents: int, rngs: nnx.Rngs):
return VAE(input_features=input_features, latents=latents, rngs=rngs)
6 changes: 3 additions & 3 deletions examples/vae/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
absl-py==1.4.0
flax==0.6.9
numpy==1.23.5
flax~=0.12
numpy>=1.26.4
optax==0.1.5
Pillow==10.2.0
tensorflow==2.12.0
tensorflow-cpu~=2.18.0
tensorflow-datasets==4.9.2
67 changes: 29 additions & 38 deletions examples/vae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
"""Training and evaluation logic."""

from absl import logging
from flax import linen as nn
from flax import nnx
import input_pipeline
import models
import utils as vae_utils
from flax.training import train_state
import jax
from jax import random
import jax.numpy as jnp
import ml_collections
from config import TrainingConfig
import optax
import tensorflow_datasets as tfds

Expand All @@ -34,7 +33,7 @@ def kl_divergence(mean, logvar):

@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
logits = nn.log_sigmoid(logits)
logits = nnx.log_sigmoid(logits)
return -jnp.sum(
labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits))
)
Expand All @@ -45,39 +44,37 @@ def compute_metrics(recon_x, x, mean, logvar):
kld_loss = kl_divergence(mean, logvar).mean()
return {'bce': bce_loss, 'kld': kld_loss, 'loss': bce_loss + kld_loss}


def train_step(state, batch, z_rng, latents):
def loss_fn(params):
recon_x, mean, logvar = models.model(latents).apply(
{'params': params}, batch, z_rng
)

@nnx.jit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use donate args to donate model and optimizer to reduce GPU memory usage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried adding donate_argnums to nnx.jit in the train_step, but was getting NaN loss and kl divergence.
what to do?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What to do about this?

def train_step(model: nnx.Module, optimizer: nnx.Optimizer, batch, z_rng):
"""Single training step for the VAE model."""
def loss_fn(model):
recon_x, mean, logvar = model(batch, z_rng)
bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
kld_loss = kl_divergence(mean, logvar).mean()
loss = bce_loss + kld_loss
return loss

grads = jax.grad(loss_fn)(state.params)
return state.apply_gradients(grads=grads)

loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss

def eval_f(params, images, z, z_rng, latents):
def eval_model(vae):
recon_images, mean, logvar = vae(images, z_rng)
comparison = jnp.concatenate([
images[:8].reshape(-1, 28, 28, 1),
recon_images[:8].reshape(-1, 28, 28, 1),
])

generate_images = vae.generate(z)
generate_images = generate_images.reshape(-1, 28, 28, 1)
metrics = compute_metrics(recon_images, images, mean, logvar)
return metrics, comparison, generate_images
@nnx.jit
def eval_f(model: nnx.Module, images, z, z_rng):
"""Evaluation function for the VAE model."""
recon_images, mean, logvar = model(images, z_rng)
comparison = jnp.concatenate([
images[:8].reshape(-1, 28, 28, 1),
recon_images[:8].reshape(-1, 28, 28, 1),
])
generate_images = model.generate(z)
generate_images = generate_images.reshape(-1, 28, 28, 1)
metrics = compute_metrics(recon_images, images, mean, logvar)
return metrics, comparison, generate_images

return nn.apply(eval_model, models.model(latents))({'params': params})


def train_and_evaluate(config: ml_collections.ConfigDict):
def train_and_evaluate(config: TrainingConfig):
"""Train and evaulate pipeline."""
rng = random.key(0)
rng, key = random.split(rng)
Expand All @@ -90,14 +87,9 @@ def train_and_evaluate(config: ml_collections.ConfigDict):
test_ds = input_pipeline.build_test_set(ds_builder)

logging.info('Initializing model.')
init_data = jnp.ones((config.batch_size, 784), jnp.float32)
params = models.model(config.latents).init(key, init_data, rng)['params']

state = train_state.TrainState.create(
apply_fn=models.model(config.latents).apply,
params=params,
tx=optax.adam(config.learning_rate),
)
rngs = nnx.Rngs(0)
model = models.model(784, config.latents, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adam(config.learning_rate), wrt=nnx.Param)

rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, config.latents))
Expand All @@ -110,11 +102,10 @@ def train_and_evaluate(config: ml_collections.ConfigDict):
for _ in range(steps_per_epoch):
batch = next(train_ds)
rng, key = random.split(rng)
state = train_step(state, batch, key, config.latents)
loss_val = train_step(model, optimizer, batch, key)

metrics, comparison, sample = eval_f(
state.params, test_ds, z, eval_rng, config.latents
)
model, test_ds, z, eval_rng)
vae_utils.save_image(
comparison, f'results/reconstruction_{epoch}.png', nrow=8
)
Expand Down