-
Notifications
You must be signed in to change notification settings - Fork 36
VJEPA2 model, classifier works well! need some help with fm though #121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…sine similarity in tests
…ue to gated repo.
|
Can anyone review my PR @jenriver @chapman20j |
|
Hi @coder0143 . Thanks for the nice PR! This looks great. Could you please address a few comments. |
bonsai/models/vjepa2/modeling.py
Outdated
| from jax import Array | ||
|
|
||
|
|
||
| @dataclasses.dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These dataclasses for the outputs don't add much to the implementation. Could you refactor these out to give a more minimal implementation?
| from absl.testing import absltest | ||
| from huggingface_hub import snapshot_download | ||
| from transformers import AutoVideoProcessor, VJEPA2ForVideoClassification, VJEPA2Model | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you move the tests in this file into the previous file?
bonsai/models/vjepa2/modeling.py
Outdated
|
|
||
| def __init__(self, config: VJEPA2FlaxConfig, hidden_size: int, rngs: nnx.Rngs): | ||
| super().__init__() | ||
| self.config = config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The config is not used later in the class. Could you remove this here and in other places where it is unused?
bonsai/models/vjepa2/modeling.py
Outdated
| # Conv3d expects (B, T, H, W, C) which is what we have | ||
| x = self.proj(pixel_values_videos) # (B, T', H', W', hidden_size) | ||
|
|
||
| # Flatten spatial and temporal dimensions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comments here aren't necessary since you describe the behavior in the docstring. Could you remove these?
bonsai/models/vjepa2/modeling.py
Outdated
| self.d_dim = int(2 * ((self.attention_head_size // 3) // 2)) | ||
| self.h_dim = int(2 * ((self.attention_head_size // 3) // 2)) | ||
| self.w_dim = int(2 * ((self.attention_head_size // 3) // 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines are identical (up to naming). Could you condense this?
| ty_bonsai = torch.tensor(np_y, dtype=torch.float32) | ||
|
|
||
| torch.testing.assert_close(ty_bonsai, ty, rtol=1e-5, atol=1e-3) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a test corresponding to VJEPA2RopeAttention. These intermediate tests help us ensure high numerical accuracy of the implementations.
| # PyTorch format: (B, T, C, H, W) | ||
| torch_shape = (self.batch_size, self.num_frames, self.channels, self.height, self.width) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might make sense to have a self.torch_input_shape in the setUp method to simplify this. We can also remove this comment since the variable names make it clear.
|
Also, you mentioned in the PR title that you "need some help with fm". Do you have a question in particular about it? |
| torch.testing.assert_close( | ||
| torch.tensor(flax_hidden), | ||
| torch.tensor(torch_hidden), | ||
| rtol=float("inf"), # Its a mess | ||
| atol=20.0, # Allow for accumulated FP error over 24 layers (only 0.05% elements mismatch) | ||
| ) | ||
|
|
||
| # Check predictor output | ||
| self.assertEqual(torch_predictor.shape, flax_predictor.shape) | ||
|
|
||
| torch.testing.assert_close( | ||
| torch.tensor(flax_predictor), | ||
| torch.tensor(torch_predictor), | ||
| rtol=float("inf"), | ||
| atol=20.0, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This currently sets no parity tolerance at all. Is there a missing implementation of component - do you have an idea on why there's such a large numeric diff?
bonsai/models/vjepa2/modeling.py
Outdated
| import dataclasses | ||
| from typing import Optional, Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's modernize to Python 3.10+ syntax by using Array | None rather than Optional and lowercase tuple[] to eliminate unnecessary typing imports.
| def rotate_queries_or_keys(x: Array, pos: Array, dim: int) -> Array: | ||
| """Apply rotary position embeddings to queries or keys. | ||
|
|
||
| Args: | ||
| x: Input tensor of shape (B, num_heads, N, head_dim) | ||
| pos: Position indices of shape (B, num_heads, N) or (N,) for broadcasting | ||
| dim: Dimension size for this component (d_dim, h_dim, or w_dim) | ||
|
|
||
| Returns: | ||
| Rotated tensor of same shape as input but only for first `dim` dimensions | ||
| """ | ||
| _, _, _, D = x.shape | ||
|
|
||
| # Compute frequencies - use input dtype like PyTorch | ||
| omega = jnp.arange(D // 2, dtype=x.dtype) | ||
| omega = omega / (D / 2.0) | ||
| omega = 1.0 / (10000**omega) # (D/2,) | ||
|
|
||
| # Compute angles: pos * omega | ||
| freq = pos[..., None] * omega # (..., N, D/2) | ||
|
|
||
| # Build rotation matrix | ||
| emb_sin = jnp.sin(freq) # (..., N, D/2) | ||
| emb_cos = jnp.cos(freq) # (..., N, D/2) | ||
|
|
||
| # Repeat for full dimension | ||
| emb_sin = jnp.tile(emb_sin, (1, 1, 1, 2)) # (..., N, D) | ||
| emb_cos = jnp.tile(emb_cos, (1, 1, 1, 2)) # (..., N, D) | ||
|
|
||
| # Split into pairs and rotate like PyTorch | ||
| y = x.reshape(*x.shape[:-1], -1, 2) # (..., N, D/2, 2) | ||
| y1, y2 = y[..., 0], y[..., 1] # Each (..., N, D/2) | ||
|
|
||
| # Stack as (-y2, y1) and flatten | ||
| y_rotated = jnp.stack((-y2, y1), axis=-1) # (..., N, D/2, 2) | ||
| y_rotated = y_rotated.reshape(x.shape) # (..., N, D) | ||
|
|
||
| # Apply rotation: x * cos + rotated * sin | ||
| rotated = (x * emb_cos) + (y_rotated * emb_sin) | ||
|
|
||
| return rotated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are several redundancies here that span over multiple lines -- can this be rewritten into something like this?
def rotate_half(x: Array) -> Array:
"""Rotates half the hidden dims of the input."""
x1, x2 = jnp.split(x, 2, axis=-1)
return jnp.concatenate([-x2, x1], axis=-1)
def apply_rope(x: Array, pos: Array, dim: int) -> Array:
"""
Apply 3D rotary positional embeddings.
Simplified half-rotate implementation.
"""
# x: (B, H, N, head_dim_fraction)
# pos: (B, H, N)
D = x.shape[-1]
omega = 1.0 / (10000**(jnp.arange(0, D, 2, dtype=x.dtype) / D))
angle = pos[..., None] * omega # (B, H, N, D/2)
cos = jnp.cos(angle).repeat(2, axis=-1)
sin = jnp.sin(angle).repeat(2, axis=-1)
return (x * cos) + (rotate_half(x) * sin)
|
I have made the following changes:
The issue with foundation model was the nnx.gelu with approximation, that's why im using jax.nn.gelu. The video processor, normalizes inputs to a narrow range that puts 54% of MLP activations in GELU's sensitive transition zone, which got carried over 24 layers causing very high atol. The classifiers work perfectly fine, and when using the fwm for downstream tasks with video processor's output, there is 0.9987 correlation with final output (~0.5 atol), to which the further tasks will be robust to. |
|
Can anyone review the changes @chapman20j @jenriver |
bonsai/models/vjepa2/params.py
Outdated
| else: | ||
| return { | ||
| # Encoder Embeddings | ||
| r"vjepa2\.encoder\.embeddings\.patch_embeddings\.proj\.weight": ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This dict has a lot of overlap with the previous one. In particular, it appears that the keys starting with vjepa2 are the same as the ones in the previous dict just without the vjepa2\.. To simplify this, could you create the first dict as a local variable key_mapping. Then you could check
if classifier:
key_mapping = {r"vjepa2\." + k: r"vjepa2\." + v for k, v in key_mapping.items()}
return key_mapping + {...}where {...} denotes the additional key mappings for the pooler and classifier.
|
|
||
| def setUp(self): | ||
| super().setUp() | ||
| self.save_dir = constants.default_cache_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you use the tempfile library to create a temporary file. This way we don't save random weights after the execution of the program? e.g. something like
with tempfile.NamedTemporaryFile(mode='w+t', delete=True) as temp:
# include saving and loading into bonsai model here| atol=1.0, | ||
| ) | ||
|
|
||
| def test_get_vision_features(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this tests the same thing as the previous test since get_vision_features runs the encoder. Could you remove this test?
bonsai/models/vjepa2/modeling.py
Outdated
| self.d_dim = rope_dim | ||
| self.h_dim = rope_dim | ||
| self.w_dim = rope_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are identical. Can we just update this to self.rop_dim?
bonsai/models/vjepa2/modeling.py
Outdated
| last_hidden_state = outputs.last_hidden_state | ||
| pooler_output = self.pooler(last_hidden_state) | ||
| logits = self.classifier(pooler_output) | ||
| return VJEPA2ClassificationOutput(logits=logits, last_hidden_state=last_hidden_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add something like the following to test if jit works with this implementation:
@jax.jit
def forward(model, inputs):
return model(inputs)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be a stand-alone function
|
|
||
| from bonsai.models.vjepa2.modeling import VJEPA2FlaxConfig | ||
| from bonsai.models.vjepa2.params import create_model_from_safe_tensors | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a comment in the modeling.py file about using a jitted forward method. Could you use this function in this file to demonstrate that jax.jit works with the current implementation?
| self.num_frames = self.config.frames_per_clip | ||
| self.height = self.config.crop_size | ||
| self.width = self.config.crop_size | ||
| self.channels = self.config.in_chans |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you set the torch seed here so the tests are reproducible? Using torch.manual_seed(0) should be sufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Along with this, once the tests are reproducible could you double check the atol and rtol values. I think some of these should be a little smaller. An easy way to test this is to just remove all the rtol and atol kwargs and then just add back the values slightly above the threshold at which they would pass. This gives us an idea of how numerically consistent the layers are with the reference implementation.
|
Hi @coder0143 . I gave this another pass. The code looks nice! Could you please implement the suggested changes? Most of the changes involve simplifying some of the code to reduce the total number of lines. Also, the jit forward method is important as it lets us get lots of performance out of the code. This also improves compatibility with other libraries that leverage jit for expensive workloads. Thanks for your hard work on this! |
|
Thankyou for reviewing my PR @chapman20j , I have made the necessary changes:
|
Thanks for the great implementation! Everything looks solid. Thanks again for the contribution! |
Resolves #117
Reference
Link: https://github.com/huggingface/transformers/blob/main/src/transformers/models/vjepa2/modeling_vjepa2.py
Checklist
run_model.pyfor model usage,test_outputs.pyand/ormodel_validation_colab.ipynbfor quality).