-
Notifications
You must be signed in to change notification settings - Fork 34
Added ConvNext model #69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| from flax import nnx | ||
| from huggingface_hub import snapshot_download | ||
|
|
||
| # Import your ConvNeXt-specific modules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please simplify some of these comments.
|
Hi @Aatman09. Thank you for the nice commit! I left some basic comments involving simplifying comments in the code and adding test cases. I can take another look after these are added. |
|
Thank you, @chapman20j. I will apply the necessary changes as soon as possible. Could you please clarify what you mean by simplifying the comments? Should I remove them, or should I add 2–3 lines explaining the steps? |
|
For the comments, please remove them when it is clear from context what is happening (e.g. Import your ConvNeXt-specific modules). For the other comments, please refer to one of the other implementations (e.g. resnet) for how much to comment the code (or ask again here). |
|
Hi, I’ve made the required changes. Please review them when you get a chance. |
| @@ -0,0 +1,56 @@ | |||
| import jax | |||
| import jax.numpy as jnp | |||
| import torch | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you installed the pre-commit hooks and run them?
https://github.com/jenriver/bonsai/blob/main/CONTRIBUTING.md#bonsai-pull-request-checklist
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it is left. I will do it right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey , I wanted to ask what should be the value of tolernace ( rtol , atol )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This depends on the layer, the dtype, and other hyperparameters. Generally an individual operation in float64 should have very small error (e.g. atol=1e-5 or 1e-7). Among the layers, layernorm tends to have larger error (typically atol=1e-3). If you put multiple layers after one another, the errors tend to grow. Testing individual layers lets you see that those parts are implemented correctly. What numbers are you seeing right now for convnext?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also as a side note, I think you forgot to put an assert_close in the test_full function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I found out that I forgot assert_close in the test_full so I added it and checked again and the output came was
Mismatched elements: 31911 / 32000 (99.7%)
Greatest absolute difference: 0.009215220808982849 at index (27, 131) (up to 1e-05 allowed)
Greatest relative difference: 89.54647064208984 at index (2, 433) (up to 1e-05 allowed)
So i cheched the values manually and these are the results
Torch output (first 5x5 block):
tensor([[-0.3498, -0.0746, 0.0529, -0.0059, 0.0205],
[-0.3477, -0.0707, 0.0386, -0.0194, 0.0054],
[-0.3414, -0.0668, 0.0493, -0.0090, 0.0202],
[-0.3532, -0.0726, 0.0502, -0.0033, 0.0186],
[-0.3521, -0.0766, 0.0479, -0.0065, 0.0193]])
JAX output (first 5x5 block):
[[-0.3550264 -0.07526195 0.05261551 -0.00813212 0.01877142]
[-0.3530529 -0.07142937 0.03878181 -0.02106674 0.00416216]
[-0.34643137 -0.06731458 0.04927637 -0.0108891 0.01859591]
[-0.35840794 -0.07314007 0.05010794 -0.00517488 0.01698814]
[-0.35734856 -0.07718338 0.04766323 -0.00877798 0.01744177]]
/home/aries/bonsai/bonsai/models/ConvNext/tests/test_outputs_ConvNext.py:53: FutureWarning: `torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. Please use `torch.testing.assert_close()` instead. You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.
torch.testing.assert_allclose(ty, jy, rtol=1e-5, atol=1e-5)
[ FAILED ] TestModuleForwardPasses.test_full
======================================================================
FAIL: test_full (__main__.TestModuleForwardPasses.test_full)
TestModuleForwardPasses.test_full
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/aries/bonsai/bonsai/models/ConvNext/tests/test_outputs_ConvNext.py", line 53, in test_full
torch.testing.assert_allclose(ty, jy, rtol=1e-5, atol=1e-5)
File "/home/aries/bonsai/env/lib/python3.12/site-packages/typing_extensions.py", line 3004, in wrapper
return arg(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/aries/bonsai/env/lib/python3.12/site-packages/torch/testing/_comparison.py", line 1629, in assert_allclose
torch.testing.assert_close(
File "/home/aries/bonsai/env/lib/python3.12/site-packages/torch/testing/_comparison.py", line 1589, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!
Mismatched elements: 31911 / 32000 (99.7%)
Greatest absolute difference: 0.009215220808982849 at index (27, 131) (up to 1e-05 allowed)
Greatest relative difference: 89.54647064208984 at index (2, 433) (up to 1e-05 allowed)
I was trying to debug this whole night but was unable to find out the problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi. With 36 layers, we expect to see errors like this. To know if the implementation is correct, it is helpful to compare against individual layers. Could you compare the numerics for a single Block layer? If the implementation is correct, that will likely have an error around 1e-5.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this is promising! If the individual block layer tests pass then we can accept the PR. Nice work on this!
|
Done — I added the individual block checker. Let me know if everything looks good. If this part is completed and approved, I was thinking about working on YOLO in a Bonsai-style implementation next. After that, I can create a pull request. |
jenriver
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approve, please update test dependencies
|
@Aatman09 thank you for the nice commit! Just merged into the repo. |
Resolves #62
Added the ConvNext model in jax
Reference
Checklist
run_model.pyfor usage,test_outputs.pyandmodel_validation_colab.ipynb(if applicable) for quality).