Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions bonsai/models/clip/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## CLIP in JAX

This directory contains a pure JAX implementation of the **CLIP (Contrastive Language–Image Pretraining)** model, implemented using **Flax**.

The model consists of:
- A Vision Transformer (ViT-style) image encoder
- A Transformer-based text encoder
- A shared embedding space trained with a contrastive objective

This implementation focuses on correctness, modularity, and testability, and is designed to integrate cleanly with the rest of the Bonsai model zoo.

---

## Tested on

| Model Name | Config | CPU | GPU A100 (1x) | GPU H100 (1x) | GPU A100 (8x) | GPU H100 (8x) | TPU v2 (8x) | TPU v5e (1x) |
| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |
| CLIP (ViT + Text Transformer) | ✅ Supported | ✅ Runs | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check |

> **Note**
> This model is tested and supported on **Python 3.11**.
> Python 3.13 is currently **not supported** due to upstream JAX/Flax incompatibilities.

---

## Running this model

### Forward pass test (recommended)

You can verify that the model runs correctly by executing the pytest forward test:

```sh
python -m pytest bonsai/models/clip/tests/test_model.py -vv
4 changes: 4 additions & 0 deletions bonsai/models/clip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .params import CLIPConfig
from .modeling import CLIP

__all__ = ["CLIP", "CLIPConfig"]
261 changes: 261 additions & 0 deletions bonsai/models/clip/clip_model_jax.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "FeIEqVH9r3dn"
},
"outputs": [],
"source": [
"!pip install -q flax einops optax"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "3aRJQOHwr4kN"
},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import flax.linen as nn\n",
"from dataclasses import dataclass\n",
"from einops import rearrange"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "iYj4uuHmr9Ep"
},
"outputs": [],
"source": [
"@dataclass\n",
"class CLIPConfig:\n",
" embed_dim: int = 512\n",
"\n",
" image_size: int = 224\n",
" patch_size: int = 32\n",
" vision_width: int = 768\n",
" vision_layers: int = 12\n",
" vision_heads: int = 12\n",
"\n",
" vocab_size: int = 49408\n",
" context_length: int = 77\n",
" text_width: int = 512\n",
" text_layers: int = 12\n",
" text_heads: int = 8"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "pATk-Ozgr-PX"
},
"outputs": [],
"source": [
"class VisionTransformer(nn.Module):\n",
" cfg: CLIPConfig\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" B = x.shape[0]\n",
"\n",
" x = nn.Conv(\n",
" features=self.cfg.vision_width,\n",
" kernel_size=(self.cfg.patch_size, self.cfg.patch_size),\n",
" strides=(self.cfg.patch_size, self.cfg.patch_size),\n",
" padding=\"VALID\"\n",
" )(x)\n",
"\n",
" x = rearrange(x, \"b h w c -> b (h w) c\")\n",
"\n",
" cls = self.param(\n",
" \"cls_token\",\n",
" nn.initializers.zeros,\n",
" (1, 1, self.cfg.vision_width)\n",
" )\n",
" cls = jnp.tile(cls, (B, 1, 1))\n",
" x = jnp.concatenate([cls, x], axis=1)\n",
"\n",
" pos = self.param(\n",
" \"pos_embed\",\n",
" nn.initializers.normal(stddev=0.01),\n",
" (1, x.shape[1], self.cfg.vision_width)\n",
" )\n",
" x = x + pos\n",
"\n",
" for _ in range(self.cfg.vision_layers):\n",
" h = nn.LayerNorm()(x)\n",
" h = nn.SelfAttention(\n",
" num_heads=self.cfg.vision_heads,\n",
" qkv_features=self.cfg.vision_width\n",
" )(h)\n",
" x = x + h\n",
" x = nn.LayerNorm()(x[:, 0])\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9AyT2fC4sA07"
},
"outputs": [],
"source": [
"class TextTransformer(nn.Module):\n",
" cfg: CLIPConfig\n",
"\n",
" @nn.compact\n",
" def __call__(self, tokens):\n",
" x = nn.Embed(self.cfg.vocab_size, self.cfg.text_width)(tokens)\n",
"\n",
" pos = self.param(\n",
" \"pos_embed\",\n",
" nn.initializers.normal(stddev=0.01),\n",
" (1, self.cfg.context_length, self.cfg.text_width)\n",
" )\n",
" x = x + pos\n",
"\n",
" causal_mask = nn.attention.make_causal_mask(tokens)\n",
"\n",
" for _ in range(self.cfg.text_layers):\n",
" h = nn.LayerNorm()(x)\n",
" h = nn.SelfAttention(\n",
" num_heads=self.cfg.text_heads,\n",
" qkv_features=self.cfg.text_width,\n",
" )(h, mask=causal_mask)\n",
" x = x + h\n",
"\n",
" x = nn.LayerNorm()(x[:, -1])\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "Mb-K3VBfsB_j"
},
"outputs": [],
"source": [
"class CLIP(nn.Module):\n",
" cfg: CLIPConfig\n",
"\n",
" @nn.compact\n",
" def __call__(self, images, texts):\n",
"\n",
" image_feat = VisionTransformer(self.cfg)(images)\n",
" text_feat = TextTransformer(self.cfg)(texts)\n",
"\n",
" image_emb = nn.Dense(self.cfg.embed_dim, name=\"image_proj\")(image_feat)\n",
" text_emb = nn.Dense(self.cfg.embed_dim, name=\"text_proj\")(text_feat)\n",
"\n",
" image_emb = image_emb / jnp.linalg.norm(image_emb, axis=-1, keepdims=True)\n",
" text_emb = text_emb / jnp.linalg.norm(text_emb, axis=-1, keepdims=True)\n",
"\n",
" logit_scale = self.param(\"logit_scale\", nn.initializers.zeros, ())\n",
" logit_scale = jnp.exp(logit_scale)\n",
"\n",
" logits = logit_scale * image_emb @ text_emb.T\n",
" return logits"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "Ge5nGFQ8sDKz"
},
"outputs": [],
"source": [
"cfg = CLIPConfig()\n",
"model = CLIP(cfg)\n",
"\n",
"key = jax.random.PRNGKey(0)\n",
"\n",
"dummy_images = jnp.ones((2, 224, 224, 3))\n",
"dummy_texts = jnp.ones((2, cfg.context_length), dtype=jnp.int32)\n",
"\n",
"params = model.init(key, dummy_images, dummy_texts)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ETy-ECrjsEMq",
"outputId": "7693f191-8806-4059-a0ff-bb77465f529d"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[0.01149538, 0.01149538],\n",
" [0.01149538, 0.01149538]], dtype=float32)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"logits = model.apply(params, dummy_images, dummy_texts)\n",
"logits"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "W4t2GKepsFh5",
"outputId": "afca5309-7173-447a-d138-604bfba16a56"
},
"outputs": [
{
"data": {
"text/plain": [
"69381121"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def count_params(p):\n",
" return sum(x.size for x in jax.tree_util.tree_leaves(p))\n",
"\n",
"count_params(params)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading