Skip to content

Conversation

@Aatman09
Copy link
Contributor

@Aatman09 Aatman09 commented Nov 4, 2025

Resolves #62

Please check issues for any pending model implementations. Consider opening issue if none exists.

Added the ConvNext model in jax

Reference

  • Paper: https://arxiv.org/abs/2201.03545
    
  • Model code: https://github.com/facebookresearch/ConvNeXt
    
  • Model weights: https://huggingface.co/facebook/convnext-large-224
    

Checklist

  • [ x] I have read contribution guidelines.
  • [ x] I have added all the necessary unit tests for my change. (run_model.py for usage, test_outputs.py and model_validation_colab.ipynb (if applicable) for quality).
  • [ x] I have verified that my change does not break existing code and all unit tests pass.
  • [ x] I have added all appropriate doc-strings/documentation.
  • [ x] My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • [ x] I have signed the Contributor License Agreement.

from flax import nnx
from huggingface_hub import snapshot_download

# Import your ConvNeXt-specific modules
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please simplify some of these comments.

@chapman20j
Copy link
Collaborator

Hi @Aatman09. Thank you for the nice commit! I left some basic comments involving simplifying comments in the code and adding test cases. I can take another look after these are added.

@Aatman09
Copy link
Contributor Author

Aatman09 commented Nov 5, 2025

Thank you, @chapman20j. I will apply the necessary changes as soon as possible. Could you please clarify what you mean by simplifying the comments? Should I remove them, or should I add 2–3 lines explaining the steps?

@chapman20j
Copy link
Collaborator

For the comments, please remove them when it is clear from context what is happening (e.g. Import your ConvNeXt-specific modules). For the other comments, please refer to one of the other implementations (e.g. resnet) for how much to comment the code (or ask again here).

@Aatman09
Copy link
Contributor Author

Aatman09 commented Nov 8, 2025

Hi, I’ve made the required changes. Please review them when you get a chance.
Apologies for the multiple commits — I’m still learning the process.
Thank you for your time and patience!

@@ -0,0 +1,56 @@
import jax
import jax.numpy as jnp
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is left. I will do it right now.

Copy link
Contributor Author

@Aatman09 Aatman09 Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey , I wanted to ask what should be the value of tolernace ( rtol , atol )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This depends on the layer, the dtype, and other hyperparameters. Generally an individual operation in float64 should have very small error (e.g. atol=1e-5 or 1e-7). Among the layers, layernorm tends to have larger error (typically atol=1e-3). If you put multiple layers after one another, the errors tend to grow. Testing individual layers lets you see that those parts are implemented correctly. What numbers are you seeing right now for convnext?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also as a side note, I think you forgot to put an assert_close in the test_full function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I found out that I forgot assert_close in the test_full so I added it and checked again and the output came was

Mismatched elements: 31911 / 32000 (99.7%)
Greatest absolute difference: 0.009215220808982849 at index (27, 131) (up to 1e-05 allowed)
Greatest relative difference: 89.54647064208984 at index (2, 433) (up to 1e-05 allowed)

So i cheched the values manually and these are the results

Torch output (first 5x5 block):
tensor([[-0.3498, -0.0746,  0.0529, -0.0059,  0.0205],
        [-0.3477, -0.0707,  0.0386, -0.0194,  0.0054],
        [-0.3414, -0.0668,  0.0493, -0.0090,  0.0202],
        [-0.3532, -0.0726,  0.0502, -0.0033,  0.0186],
        [-0.3521, -0.0766,  0.0479, -0.0065,  0.0193]])

JAX output (first 5x5 block):
[[-0.3550264  -0.07526195  0.05261551 -0.00813212  0.01877142]
 [-0.3530529  -0.07142937  0.03878181 -0.02106674  0.00416216]
 [-0.34643137 -0.06731458  0.04927637 -0.0108891   0.01859591]
 [-0.35840794 -0.07314007  0.05010794 -0.00517488  0.01698814]
 [-0.35734856 -0.07718338  0.04766323 -0.00877798  0.01744177]]

/home/aries/bonsai/bonsai/models/ConvNext/tests/test_outputs_ConvNext.py:53: FutureWarning: `torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. Please use `torch.testing.assert_close()` instead. You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.
  torch.testing.assert_allclose(ty, jy, rtol=1e-5, atol=1e-5)
[  FAILED  ] TestModuleForwardPasses.test_full
======================================================================
FAIL: test_full (__main__.TestModuleForwardPasses.test_full)
TestModuleForwardPasses.test_full
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/aries/bonsai/bonsai/models/ConvNext/tests/test_outputs_ConvNext.py", line 53, in test_full
    torch.testing.assert_allclose(ty, jy, rtol=1e-5, atol=1e-5)
  File "/home/aries/bonsai/env/lib/python3.12/site-packages/typing_extensions.py", line 3004, in wrapper
    return arg(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/aries/bonsai/env/lib/python3.12/site-packages/torch/testing/_comparison.py", line 1629, in assert_allclose
    torch.testing.assert_close(
  File "/home/aries/bonsai/env/lib/python3.12/site-packages/torch/testing/_comparison.py", line 1589, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 31911 / 32000 (99.7%)
Greatest absolute difference: 0.009215220808982849 at index (27, 131) (up to 1e-05 allowed)
Greatest relative difference: 89.54647064208984 at index (2, 433) (up to 1e-05 allowed)

I was trying to debug this whole night but was unable to find out the problem.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. With 36 layers, we expect to see errors like this. To know if the implementation is correct, it is helpful to compare against individual layers. Could you compare the numerics for a single Block layer? If the implementation is correct, that will likely have an error around 1e-5.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this is promising! If the individual block layer tests pass then we can accept the PR. Nice work on this!

@Aatman09
Copy link
Contributor Author

Aatman09 commented Nov 20, 2025

Done — I added the individual block checker. Let me know if everything looks good.

If this part is completed and approved, I was thinking about working on YOLO in a Bonsai-style implementation next. After that, I can create a pull request.

chapman20j
chapman20j previously approved these changes Nov 24, 2025
Copy link
Member

@jenriver jenriver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve, please update test dependencies

@chapman20j chapman20j merged commit a11f3f4 into jax-ml:main Nov 25, 2025
3 of 5 checks passed
@chapman20j
Copy link
Collaborator

@Aatman09 thank you for the nice commit! Just merged into the repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Model Request: ConvNext

3 participants