Skip to content

Conversation

@coder0143
Copy link
Contributor

Resolves #117

Reference
Link: https://github.com/huggingface/transformers/blob/main/src/transformers/models/vjepa2/modeling_vjepa2.py

Checklist

  • I have read the Contribution Guidelines and used pre-commit hooks to format this commit.
  • I have added all the necessary unit tests for my change. (run_model.py for model usage, test_outputs.py and/or model_validation_colab.ipynb for quality).
  • (If using an LLM) I have carefully reviewed and removed all superfluous comments or unneeded, commented-out code. Only necessary and functional code remains.
  • I have signed the Contributor License Agreement (CLA).

@coder0143
Copy link
Contributor Author

Can anyone review my PR @jenriver @chapman20j

@chapman20j
Copy link
Collaborator

Hi @coder0143 . Thanks for the nice PR! This looks great. Could you please address a few comments.

from jax import Array


@dataclasses.dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These dataclasses for the outputs don't add much to the implementation. Could you refactor these out to give a more minimal implementation?

from absl.testing import absltest
from huggingface_hub import snapshot_download
from transformers import AutoVideoProcessor, VJEPA2ForVideoClassification, VJEPA2Model

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move the tests in this file into the previous file?


def __init__(self, config: VJEPA2FlaxConfig, hidden_size: int, rngs: nnx.Rngs):
super().__init__()
self.config = config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config is not used later in the class. Could you remove this here and in other places where it is unused?

Comment on lines 188 to 191
# Conv3d expects (B, T, H, W, C) which is what we have
x = self.proj(pixel_values_videos) # (B, T', H', W', hidden_size)

# Flatten spatial and temporal dimensions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments here aren't necessary since you describe the behavior in the docstring. Could you remove these?

Comment on lines 301 to 303
self.d_dim = int(2 * ((self.attention_head_size // 3) // 2))
self.h_dim = int(2 * ((self.attention_head_size // 3) // 2))
self.w_dim = int(2 * ((self.attention_head_size // 3) // 2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines are identical (up to naming). Could you condense this?

ty_bonsai = torch.tensor(np_y, dtype=torch.float32)

torch.testing.assert_close(ty_bonsai, ty, rtol=1e-5, atol=1e-3)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a test corresponding to VJEPA2RopeAttention. These intermediate tests help us ensure high numerical accuracy of the implementations.

Comment on lines 62 to 63
# PyTorch format: (B, T, C, H, W)
torch_shape = (self.batch_size, self.num_frames, self.channels, self.height, self.width)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might make sense to have a self.torch_input_shape in the setUp method to simplify this. We can also remove this comment since the variable names make it clear.

@chapman20j
Copy link
Collaborator

Also, you mentioned in the PR title that you "need some help with fm". Do you have a question in particular about it?

Comment on lines 65 to 80
torch.testing.assert_close(
torch.tensor(flax_hidden),
torch.tensor(torch_hidden),
rtol=float("inf"), # Its a mess
atol=20.0, # Allow for accumulated FP error over 24 layers (only 0.05% elements mismatch)
)

# Check predictor output
self.assertEqual(torch_predictor.shape, flax_predictor.shape)

torch.testing.assert_close(
torch.tensor(flax_predictor),
torch.tensor(torch_predictor),
rtol=float("inf"),
atol=20.0,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This currently sets no parity tolerance at all. Is there a missing implementation of component - do you have an idea on why there's such a large numeric diff?

Comment on lines 1 to 2
import dataclasses
from typing import Optional, Tuple
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's modernize to Python 3.10+ syntax by using Array | None rather than Optional and lowercase tuple[] to eliminate unnecessary typing imports.

Comment on lines 224 to 264
def rotate_queries_or_keys(x: Array, pos: Array, dim: int) -> Array:
"""Apply rotary position embeddings to queries or keys.

Args:
x: Input tensor of shape (B, num_heads, N, head_dim)
pos: Position indices of shape (B, num_heads, N) or (N,) for broadcasting
dim: Dimension size for this component (d_dim, h_dim, or w_dim)

Returns:
Rotated tensor of same shape as input but only for first `dim` dimensions
"""
_, _, _, D = x.shape

# Compute frequencies - use input dtype like PyTorch
omega = jnp.arange(D // 2, dtype=x.dtype)
omega = omega / (D / 2.0)
omega = 1.0 / (10000**omega) # (D/2,)

# Compute angles: pos * omega
freq = pos[..., None] * omega # (..., N, D/2)

# Build rotation matrix
emb_sin = jnp.sin(freq) # (..., N, D/2)
emb_cos = jnp.cos(freq) # (..., N, D/2)

# Repeat for full dimension
emb_sin = jnp.tile(emb_sin, (1, 1, 1, 2)) # (..., N, D)
emb_cos = jnp.tile(emb_cos, (1, 1, 1, 2)) # (..., N, D)

# Split into pairs and rotate like PyTorch
y = x.reshape(*x.shape[:-1], -1, 2) # (..., N, D/2, 2)
y1, y2 = y[..., 0], y[..., 1] # Each (..., N, D/2)

# Stack as (-y2, y1) and flatten
y_rotated = jnp.stack((-y2, y1), axis=-1) # (..., N, D/2, 2)
y_rotated = y_rotated.reshape(x.shape) # (..., N, D)

# Apply rotation: x * cos + rotated * sin
rotated = (x * emb_cos) + (y_rotated * emb_sin)

return rotated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are several redundancies here that span over multiple lines -- can this be rewritten into something like this?

def rotate_half(x: Array) -> Array:
    """Rotates half the hidden dims of the input."""
    x1, x2 = jnp.split(x, 2, axis=-1)
    return jnp.concatenate([-x2, x1], axis=-1)

def apply_rope(x: Array, pos: Array, dim: int) -> Array:
    """
    Apply 3D rotary positional embeddings.
    Simplified half-rotate implementation.
    """
    # x: (B, H, N, head_dim_fraction)
    # pos: (B, H, N)
    D = x.shape[-1]
    
    omega = 1.0 / (10000**(jnp.arange(0, D, 2, dtype=x.dtype) / D))
    angle = pos[..., None] * omega  # (B, H, N, D/2)
    
    cos = jnp.cos(angle).repeat(2, axis=-1)
    sin = jnp.sin(angle).repeat(2, axis=-1)
    
    return (x * cos) + (rotate_half(x) * sin)

@coder0143
Copy link
Contributor Author

I have made the following changes:

  1. Re formatted modeling.py
  2. Fixed tests for pretrained models and classifiers.

The issue with foundation model was the nnx.gelu with approximation, that's why im using jax.nn.gelu. The video processor, normalizes inputs to a narrow range that puts 54% of MLP activations in GELU's sensitive transition zone, which got carried over 24 layers causing very high atol. The classifiers work perfectly fine, and when using the fwm for downstream tasks with video processor's output, there is 0.9987 correlation with final output (~0.5 atol), to which the further tasks will be robust to.

@coder0143
Copy link
Contributor Author

Can anyone review the changes @chapman20j @jenriver

else:
return {
# Encoder Embeddings
r"vjepa2\.encoder\.embeddings\.patch_embeddings\.proj\.weight": (
Copy link
Collaborator

@chapman20j chapman20j Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dict has a lot of overlap with the previous one. In particular, it appears that the keys starting with vjepa2 are the same as the ones in the previous dict just without the vjepa2\.. To simplify this, could you create the first dict as a local variable key_mapping. Then you could check

if classifier:
    key_mapping = {r"vjepa2\." + k: r"vjepa2\." + v for k, v in key_mapping.items()}
    return key_mapping + {...}

where {...} denotes the additional key mappings for the pooler and classifier.


def setUp(self):
super().setUp()
self.save_dir = constants.default_cache_path
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use the tempfile library to create a temporary file. This way we don't save random weights after the execution of the program? e.g. something like

with tempfile.NamedTemporaryFile(mode='w+t', delete=True) as temp:
   # include saving and loading into bonsai model here

atol=1.0,
)

def test_get_vision_features(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this tests the same thing as the previous test since get_vision_features runs the encoder. Could you remove this test?

Comment on lines 186 to 188
self.d_dim = rope_dim
self.h_dim = rope_dim
self.w_dim = rope_dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are identical. Can we just update this to self.rop_dim?

last_hidden_state = outputs.last_hidden_state
pooler_output = self.pooler(last_hidden_state)
logits = self.classifier(pooler_output)
return VJEPA2ClassificationOutput(logits=logits, last_hidden_state=last_hidden_state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add something like the following to test if jit works with this implementation:

@jax.jit
def forward(model, inputs):
    return model(inputs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a stand-alone function


from bonsai.models.vjepa2.modeling import VJEPA2FlaxConfig
from bonsai.models.vjepa2.params import create_model_from_safe_tensors

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a comment in the modeling.py file about using a jitted forward method. Could you use this function in this file to demonstrate that jax.jit works with the current implementation?

self.num_frames = self.config.frames_per_clip
self.height = self.config.crop_size
self.width = self.config.crop_size
self.channels = self.config.in_chans
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you set the torch seed here so the tests are reproducible? Using torch.manual_seed(0) should be sufficient.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Along with this, once the tests are reproducible could you double check the atol and rtol values. I think some of these should be a little smaller. An easy way to test this is to just remove all the rtol and atol kwargs and then just add back the values slightly above the threshold at which they would pass. This gives us an idea of how numerically consistent the layers are with the reference implementation.

@chapman20j
Copy link
Collaborator

Hi @coder0143 . I gave this another pass. The code looks nice! Could you please implement the suggested changes? Most of the changes involve simplifying some of the code to reduce the total number of lines. Also, the jit forward method is important as it lets us get lots of performance out of the code. This also improves compatibility with other libraries that leverage jit for expensive workloads. Thanks for your hard work on this!

@coder0143
Copy link
Contributor Author

Thankyou for reviewing my PR @chapman20j , I have made the necessary changes:

  1. changed params
  2. updated tests
  3. some parts in modeling

@jenriver
Copy link
Member

Thankyou for reviewing my PR @chapman20j , I have made the necessary changes:

  1. changed params
  2. updated tests
  3. some parts in modeling

Thanks for the great implementation! Everything looks solid.
One minor note is that we generally prefer to handle input transposition outside the JIT-compiled forward pass (and have it in modeling.py), but I don't want to hold this up over a small detail. I'll go ahead and merge this now and make that quick tweak myself.

Thanks again for the contribution!

@jenriver jenriver merged commit 2e6591f into jax-ml:main Jan 16, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

VJEPA2

3 participants