Skip to content

Conversation

@eari100
Copy link
Contributor

@eari100 eari100 commented Jan 7, 2026

@chapman20j
My apologies for the late PR. I’ve finalized the changes, and I would appreciate it if you could take a look when you have a moment.

Conflicts:
bonsai/models/vae/README.md
bonsai/models/vae/modeling.py
bonsai/models/vae/params.py
bonsai/models/vae/tests/VAE_segmentation_example.ipynb
bonsai/models/vae/tests/run_model.py
pyproject.toml

Resolves #46

Reference

Checklist

  • I have read the Contribution Guidelines and used pre-commit hooks to format this commit.
  • I have added all the necessary unit tests for my change. (run_model.py for model usage, test_outputs.py and/or model_validation_colab.ipynb for quality).
  • (If using an LLM) I have carefully reviewed and removed all superfluous comments or unneeded, commented-out code. Only necessary and functional code remains.
  • I have signed the Contributor License Agreement (CLA).

@chapman20j
Copy link
Collaborator

Hi @eari100 . Thank you for the nice PR! I can take a look tomorrow. In the meantime, could you rebase onto main? It looks like there are branch conflicts. We recently added some formatting tools that may have caused this. Apologies for the inconvenience.

# Conflicts:
#	bonsai/models/vae/README.md
#	bonsai/models/vae/modeling.py
#	bonsai/models/vae/params.py
#	bonsai/models/vae/tests/VAE_segmentation_example.ipynb
#	bonsai/models/vae/tests/run_model.py
#	pyproject.toml

# Conflicts:
#	bonsai/models/vae/tests/VAE_segmentation_example.ipynb
#	bonsai/models/vae/tests/run_model.py
@eari100 eari100 force-pushed the vae-weights-and-tests branch from b0a759e to 1dc82f3 Compare January 8, 2026 06:04
@eari100
Copy link
Contributor Author

eari100 commented Jan 8, 2026

@chapman20j
Rebase completed. I appreciate your assistance with this.

jy = nnx_model(jx)
with torch.no_grad():
ty = dif_model(tx).sample
np.testing.assert_allclose(jy, ty.permute(0, 2, 3, 1).cpu().detach().numpy(), atol=9e-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a few more tests on the intermediate layers? An error of 9e-1 may be a little high.

return jnp.array(image[None, ...])
```

## **Image Postproessing**
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should be "Image Postprocessing"

## **Set-up**

```{code-cell}
!pip install -q git+https://github.com/eari100/bonsai@vae-weights-and-tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update this to pip install from the main bonsai repo? I know this will temporarily break the notebook when you're testing but it will work when the PR is approved.

def forward(model, x, key):
return model(x, key)
def __init__(self, rngs: nnx.Rngs):
block_out_channels = [128, 256, 512, 512]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of these constants are things that would better fit in a model config dataclass. Could you refactor the __init__ methods to accept a config? This would also improve the consistency of this implementation with other models on the repo. ViT or Qwen3 may be a good reference for this. If you have any other questions, feel free to ask here.

Comment on lines 25 to 26
TO_JAX_CONV_2D_KERNEL = (2, 3, 1, 0) # (C_out, C_in, kH, kW) -> (kH, kW, C_in, C_out)
TO_JAX_LINEAR_KERNEL = (1, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update this to use the Transform enum like in other implementations (e.g. resnet).

@chapman20j
Copy link
Collaborator

Hi @eari100 . Just gave the PR a first pass and it looks good. I think there are a few stylistic things (e.g. adding enums) which would improve consistency with the rest of the repo and make it easier to add more variants of this architecture. One of the most helpful changes would be adding some intermediate tests. This helps us ensure high model quality. Thanks again for addressing this issue!

@eari100
Copy link
Contributor Author

eari100 commented Jan 9, 2026

@chapman20j
I have updated the code reflecting your feedback.

Although I haven't modified the model itself, I’ve found that the tests now pass with atol=5e-3 after several attempts. Previously, it failed multiple times, which had forced me to set the tolerance to 9e-1.I'll keep an eye on it to see if the discrepancy occurs again.

@chapman20j chapman20j merged commit a493b03 into jax-ml:main Jan 12, 2026
3 checks passed
@eari100
Copy link
Contributor Author

eari100 commented Jan 13, 2026

@chapman20j
Thank you for the detailed feedback. It has been a valuable learning experience for me.

Aatman09 pushed a commit to Aatman09/bonsai that referenced this pull request Jan 13, 2026
* vae implementation

# Conflicts:
#	bonsai/models/vae/README.md
#	bonsai/models/vae/modeling.py
#	bonsai/models/vae/params.py
#	bonsai/models/vae/tests/VAE_segmentation_example.ipynb
#	bonsai/models/vae/tests/run_model.py
#	pyproject.toml

# Conflicts:
#	bonsai/models/vae/tests/VAE_segmentation_example.ipynb
#	bonsai/models/vae/tests/run_model.py

* run a pre-commit hook

* Modify code style

* Add intermediate tests
coder0143 pushed a commit to coder0143/bonsai that referenced this pull request Jan 19, 2026
* vae implementation

# Conflicts:
#	bonsai/models/vae/README.md
#	bonsai/models/vae/modeling.py
#	bonsai/models/vae/params.py
#	bonsai/models/vae/tests/VAE_segmentation_example.ipynb
#	bonsai/models/vae/tests/run_model.py
#	pyproject.toml

# Conflicts:
#	bonsai/models/vae/tests/VAE_segmentation_example.ipynb
#	bonsai/models/vae/tests/run_model.py

* run a pre-commit hook

* Modify code style

* Add intermediate tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

VAE: Add proper weight loading and quality tests.

2 participants