-
Notifications
You must be signed in to change notification settings - Fork 37
vae implementation #122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vae implementation #122
Conversation
|
Hi @eari100 . Thank you for the nice PR! I can take a look tomorrow. In the meantime, could you rebase onto main? It looks like there are branch conflicts. We recently added some formatting tools that may have caused this. Apologies for the inconvenience. |
# Conflicts: # bonsai/models/vae/README.md # bonsai/models/vae/modeling.py # bonsai/models/vae/params.py # bonsai/models/vae/tests/VAE_segmentation_example.ipynb # bonsai/models/vae/tests/run_model.py # pyproject.toml # Conflicts: # bonsai/models/vae/tests/VAE_segmentation_example.ipynb # bonsai/models/vae/tests/run_model.py
b0a759e to
1dc82f3
Compare
|
@chapman20j |
| jy = nnx_model(jx) | ||
| with torch.no_grad(): | ||
| ty = dif_model(tx).sample | ||
| np.testing.assert_allclose(jy, ty.permute(0, 2, 3, 1).cpu().detach().numpy(), atol=9e-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a few more tests on the intermediate layers? An error of 9e-1 may be a little high.
| return jnp.array(image[None, ...]) | ||
| ``` | ||
|
|
||
| ## **Image Postproessing** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should be "Image Postprocessing"
| ## **Set-up** | ||
|
|
||
| ```{code-cell} | ||
| !pip install -q git+https://github.com/eari100/bonsai@vae-weights-and-tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you update this to pip install from the main bonsai repo? I know this will temporarily break the notebook when you're testing but it will work when the PR is approved.
bonsai/models/vae/modeling.py
Outdated
| def forward(model, x, key): | ||
| return model(x, key) | ||
| def __init__(self, rngs: nnx.Rngs): | ||
| block_out_channels = [128, 256, 512, 512] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of these constants are things that would better fit in a model config dataclass. Could you refactor the __init__ methods to accept a config? This would also improve the consistency of this implementation with other models on the repo. ViT or Qwen3 may be a good reference for this. If you have any other questions, feel free to ask here.
bonsai/models/vae/params.py
Outdated
| TO_JAX_CONV_2D_KERNEL = (2, 3, 1, 0) # (C_out, C_in, kH, kW) -> (kH, kW, C_in, C_out) | ||
| TO_JAX_LINEAR_KERNEL = (1, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you update this to use the Transform enum like in other implementations (e.g. resnet).
|
Hi @eari100 . Just gave the PR a first pass and it looks good. I think there are a few stylistic things (e.g. adding enums) which would improve consistency with the rest of the repo and make it easier to add more variants of this architecture. One of the most helpful changes would be adding some intermediate tests. This helps us ensure high model quality. Thanks again for addressing this issue! |
|
@chapman20j Although I haven't modified the model itself, I’ve found that the tests now pass with |
|
@chapman20j |
* vae implementation # Conflicts: # bonsai/models/vae/README.md # bonsai/models/vae/modeling.py # bonsai/models/vae/params.py # bonsai/models/vae/tests/VAE_segmentation_example.ipynb # bonsai/models/vae/tests/run_model.py # pyproject.toml # Conflicts: # bonsai/models/vae/tests/VAE_segmentation_example.ipynb # bonsai/models/vae/tests/run_model.py * run a pre-commit hook * Modify code style * Add intermediate tests
* vae implementation # Conflicts: # bonsai/models/vae/README.md # bonsai/models/vae/modeling.py # bonsai/models/vae/params.py # bonsai/models/vae/tests/VAE_segmentation_example.ipynb # bonsai/models/vae/tests/run_model.py # pyproject.toml # Conflicts: # bonsai/models/vae/tests/VAE_segmentation_example.ipynb # bonsai/models/vae/tests/run_model.py * run a pre-commit hook * Modify code style * Add intermediate tests
@chapman20j
My apologies for the late PR. I’ve finalized the changes, and I would appreciate it if you could take a look when you have a moment.
Conflicts:
bonsai/models/vae/README.md
bonsai/models/vae/modeling.py
bonsai/models/vae/params.py
bonsai/models/vae/tests/VAE_segmentation_example.ipynb
bonsai/models/vae/tests/run_model.py
pyproject.toml
Resolves #46
Reference
Checklist
run_model.pyfor model usage,test_outputs.pyand/ormodel_validation_colab.ipynbfor quality).