diff --git a/bonsai/models/unet/tests/UNet_segmentation_example.ipynb b/bonsai/models/unet/tests/UNet_segmentation_example.ipynb index 54a1fc37..60548937 100644 --- a/bonsai/models/unet/tests/UNet_segmentation_example.ipynb +++ b/bonsai/models/unet/tests/UNet_segmentation_example.ipynb @@ -629,7 +629,7 @@ " return state, loss\n", "\n", "\n", - "print(\"🚀 Starting training from checkpoint...\")\n", + "print(\"Starting training from checkpoint...\")\n", "train_loader, vis_loader = load_dataset()\n", "num_epochs = 100\n", "state = train_state\n", diff --git a/bonsai/models/unet/tests/UNet_segmentation_example.md b/bonsai/models/unet/tests/UNet_segmentation_example.md index 5663f428..a893efaa 100644 --- a/bonsai/models/unet/tests/UNet_segmentation_example.md +++ b/bonsai/models/unet/tests/UNet_segmentation_example.md @@ -256,7 +256,7 @@ def train_step(state: TrainState, other_vars: nnx.State, batch: tuple[jax.Array, return state, loss -print("🚀 Starting training from checkpoint...") +print("Starting training from checkpoint...") train_loader, vis_loader = load_dataset() num_epochs = 100 state = train_state diff --git a/bonsai/tutorials/JAX_machine_translation.ipynb b/bonsai/tutorials/JAX_machine_translation.ipynb new file mode 100644 index 00000000..cb478b90 --- /dev/null +++ b/bonsai/tutorials/JAX_machine_translation.ipynb @@ -0,0 +1,1253 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ee3e1116-f6cd-497e-b617-1d89d5d1f744", + "metadata": {}, + "source": [ + "# Machine Translation with encoder-decoder transformer model\n" + ] + }, + { + "cell_type": "markdown", + "id": "50f0bd58-dcc6-41f4-9dc4-3a08c8ef751b", + "metadata": {}, + "source": [ + "This tutorial is adapted from [Keras' documentation on English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/), which is itself an adaptation from the book [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition)\n", + "\n", + "We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation." + ] + }, + { + "cell_type": "markdown", + "id": "0e5066d9", + "metadata": {}, + "source": [ + "### Installing Dependencies\n", + "The current versions are configured for **Google TPU v5**. To switch to the **GPU** version, use the following command:\n", + "\n", + "```python\n", + "%pip install \"jax[cuda12]==0.8.2\" \"flax==0.12.2\" numpy tiktoken tqdm grain optax matplotlib" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bf8d50f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: numpy in /home/aries/bonsai/env/lib/python3.11/site-packages (2.4.0)\n", + "Requirement already satisfied: tiktoken in /home/aries/bonsai/env/lib/python3.11/site-packages (0.12.0)\n", + "Requirement already satisfied: flax in /home/aries/bonsai/env/lib/python3.11/site-packages (0.12.2)\n", + "Requirement already satisfied: tqdm in /home/aries/bonsai/env/lib/python3.11/site-packages (4.67.1)\n", + "Requirement already satisfied: grain in /home/aries/bonsai/env/lib/python3.11/site-packages (0.2.15)\n", + "Requirement already satisfied: optax in /home/aries/bonsai/env/lib/python3.11/site-packages (0.2.6)\n", + "Requirement already satisfied: matplotlib in /home/aries/bonsai/env/lib/python3.11/site-packages (3.10.8)\n", + "Requirement already satisfied: jax in /home/aries/bonsai/env/lib/python3.11/site-packages (0.8.2)\n", + "Requirement already satisfied: regex>=2022.1.18 in /home/aries/bonsai/env/lib/python3.11/site-packages (from tiktoken) (2025.11.3)\n", + "Requirement already satisfied: requests>=2.26.0 in /home/aries/bonsai/env/lib/python3.11/site-packages (from tiktoken) (2.32.5)\n", + "Requirement already satisfied: msgpack in /home/aries/bonsai/env/lib/python3.11/site-packages (from flax) (1.1.2)\n", + "Requirement already satisfied: orbax-checkpoint in /home/aries/bonsai/env/lib/python3.11/site-packages (from flax) (0.11.31)\n", + "Requirement already satisfied: tensorstore in /home/aries/bonsai/env/lib/python3.11/site-packages (from flax) (0.1.80)\n", + "Requirement already satisfied: rich>=11.1 in /home/aries/bonsai/env/lib/python3.11/site-packages (from flax) (14.2.0)\n", + "Requirement already satisfied: typing_extensions>=4.2 in /home/aries/bonsai/env/lib/python3.11/site-packages (from flax) (4.15.0)\n", + "Requirement already satisfied: PyYAML>=5.4.1 in /home/aries/bonsai/env/lib/python3.11/site-packages (from flax) (6.0.3)\n", + "Requirement already satisfied: treescope>=0.1.7 in /home/aries/bonsai/env/lib/python3.11/site-packages (from flax) (0.1.10)\n", + "Requirement already satisfied: absl-py in /home/aries/bonsai/env/lib/python3.11/site-packages (from grain) (2.3.1)\n", + "Requirement already satisfied: array-record>=0.8.1 in /home/aries/bonsai/env/lib/python3.11/site-packages (from grain) (0.8.3)\n", + "Requirement already satisfied: cloudpickle in /home/aries/bonsai/env/lib/python3.11/site-packages (from grain) (3.1.2)\n", + "Requirement already satisfied: etils[epath,epy] in /home/aries/bonsai/env/lib/python3.11/site-packages (from grain) (1.13.0)\n", + "Requirement already satisfied: protobuf>=5.28.3 in /home/aries/bonsai/env/lib/python3.11/site-packages (from grain) (6.33.2)\n", + "Requirement already satisfied: chex>=0.1.87 in /home/aries/bonsai/env/lib/python3.11/site-packages (from optax) (0.1.91)\n", + "Requirement already satisfied: jaxlib>=0.5.3 in /home/aries/bonsai/env/lib/python3.11/site-packages (from optax) (0.8.2)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /home/aries/bonsai/env/lib/python3.11/site-packages (from matplotlib) (1.3.3)\n", + "Requirement already satisfied: cycler>=0.10 in /home/aries/bonsai/env/lib/python3.11/site-packages (from matplotlib) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /home/aries/bonsai/env/lib/python3.11/site-packages (from matplotlib) (4.61.1)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /home/aries/bonsai/env/lib/python3.11/site-packages (from matplotlib) (1.4.9)\n", + "Requirement already satisfied: packaging>=20.0 in /home/aries/bonsai/env/lib/python3.11/site-packages (from matplotlib) (25.0)\n", + "Requirement already satisfied: pillow>=8 in /home/aries/bonsai/env/lib/python3.11/site-packages (from matplotlib) (12.1.0)\n", + "Requirement already satisfied: pyparsing>=3 in /home/aries/bonsai/env/lib/python3.11/site-packages (from matplotlib) (3.3.1)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /home/aries/bonsai/env/lib/python3.11/site-packages (from matplotlib) (2.9.0.post0)\n", + "Requirement already satisfied: ml_dtypes>=0.5.0 in /home/aries/bonsai/env/lib/python3.11/site-packages (from jax) (0.5.4)\n", + "Requirement already satisfied: opt_einsum in /home/aries/bonsai/env/lib/python3.11/site-packages (from jax) (3.4.0)\n", + "Requirement already satisfied: scipy>=1.13 in /home/aries/bonsai/env/lib/python3.11/site-packages (from jax) (1.16.3)\n", + "Requirement already satisfied: toolz>=1.0.0 in /home/aries/bonsai/env/lib/python3.11/site-packages (from chex>=0.1.87->optax) (1.1.0)\n", + "Requirement already satisfied: six>=1.5 in /home/aries/bonsai/env/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib) (1.17.0)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /home/aries/bonsai/env/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.4.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/aries/bonsai/env/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.11)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/aries/bonsai/env/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (2.6.2)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/aries/bonsai/env/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (2026.1.4)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /home/aries/bonsai/env/lib/python3.11/site-packages (from rich>=11.1->flax) (4.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/aries/bonsai/env/lib/python3.11/site-packages (from rich>=11.1->flax) (2.19.2)\n", + "Requirement already satisfied: mdurl~=0.1 in /home/aries/bonsai/env/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax) (0.1.2)\n", + "Requirement already satisfied: fsspec in /home/aries/bonsai/env/lib/python3.11/site-packages (from etils[epath,epy]->grain) (2025.12.0)\n", + "Requirement already satisfied: importlib_resources in /home/aries/bonsai/env/lib/python3.11/site-packages (from etils[epath,epy]->grain) (6.5.2)\n", + "Requirement already satisfied: zipp in /home/aries/bonsai/env/lib/python3.11/site-packages (from etils[epath,epy]->grain) (3.23.0)\n", + "Requirement already satisfied: nest_asyncio in /home/aries/bonsai/env/lib/python3.11/site-packages (from orbax-checkpoint->flax) (1.6.0)\n", + "Requirement already satisfied: aiofiles in /home/aries/bonsai/env/lib/python3.11/site-packages (from orbax-checkpoint->flax) (25.1.0)\n", + "Requirement already satisfied: humanize in /home/aries/bonsai/env/lib/python3.11/site-packages (from orbax-checkpoint->flax) (4.15.0)\n", + "Requirement already satisfied: simplejson>=3.16.0 in /home/aries/bonsai/env/lib/python3.11/site-packages (from orbax-checkpoint->flax) (3.20.2)\n", + "Requirement already satisfied: psutil in /home/aries/bonsai/env/lib/python3.11/site-packages (from orbax-checkpoint->flax) (7.2.1)\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install numpy tiktoken tqdm grain optax matplotlib \"jax[tpu]==0.8.2\" \"flax == 0.12.2\"" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "dd506ffa-3b91-44f1-92d1-a08ed933e78e", + "metadata": {}, + "outputs": [], + "source": [ + "import dataclasses\n", + "import pathlib\n", + "import random\n", + "import re\n", + "import string\n", + "\n", + "import grain.python as grain\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import optax\n", + "import tiktoken\n", + "import tqdm\n", + "from flax import nnx" + ] + }, + { + "cell_type": "markdown", + "id": "e1f324b0-140a-48fa-9fcb-d6308f098343", + "metadata": {}, + "source": [ + "## Pull down data to temp and extract into memory\n", + "\n", + "There are lots of ways to get this done, but for simplicity and clear visibility into what's happening this is downloaded to a temporary directory, extracted there, and read into a python object with processing.\n", + "\n", + "### Libraries Used:\n", + "* **tempfile**: Used to create temporary directories to store the dataset during processing.\n", + "* **zipfile**: Used to extract the contents of the downloaded zip file.\n", + "* **requests**: Used to fetch the raw dataset file from the URL.\n", + "\n", + "### Process Overview:\n", + "We extract the zip data into a folder on the local environment. During this process, we apply a critical formatting step to the Spanish target data:\n", + "\n", + "> `\"[start] \" + spa + \" [end]\"`\n", + "\n", + "### Why do we do this? (Teacher Forcing)\n", + "This is essential for the **Decoder** (the part of the AI generating the Spanish translation). It needs explicit instructions on when to begin generating text and when to stop.\n", + "\n", + "* **[start]:** Tells the model, \"Begin generating the first word now.\"\n", + "* **[end]:** Tells the model, \"The sentence is complete; stop generating.\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "102943a5-8724-48e0-8d6a-f56069f03426", + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "import zipfile\n", + "\n", + "import requests\n", + "\n", + "url = \"http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip\"\n", + "\n", + "with tempfile.TemporaryDirectory() as temp_dir:\n", + " temp_path = pathlib.Path(temp_dir)\n", + " zip_file_path = temp_path / \"spa-eng.zip\"\n", + "\n", + " response = requests.get(url)\n", + " zip_file_path.write_bytes(response.content)\n", + "\n", + " with zipfile.ZipFile(zip_file_path, \"r\") as zip_ref:\n", + " zip_ref.extractall(temp_path)\n", + "\n", + " text_file = temp_path / \"spa-eng\" / \"spa.txt\"\n", + "\n", + " with open(text_file) as f:\n", + " lines = f.read().split(\"\\n\")[:-1]\n", + " text_pairs = []\n", + " for line in lines:\n", + " eng, spa = line.split(\"\\t\")\n", + " spa = \"[start] \" + spa + \" [end]\"\n", + " text_pairs.append((eng, spa))" + ] + }, + { + "cell_type": "markdown", + "id": "9524904b-fa17-493f-bcfa-335963cb7c45", + "metadata": {}, + "source": [ + "## Build train/validate/test pair sets\n", + "\n", + "\n", + "We partition the dataset into three distinct subsets: **Training**, **Validation**, and **Test**.\n", + "\n", + "* **Validation Data (15%):** Used to evaluate the model during training to tune hyperparameters and prevent overfitting.\n", + "* **Test Data (15%):** Reserved for the final evaluation to check how the model performs on completely unseen data.\n", + "* **Training Data (70%):** The remaining data used to actually teach the model." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "bee9f1b0-5f74-47dc-a7e1-a4ea3be1ef7f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "118964 total pairs\n", + "83276 training pairs\n", + "17844 validation pairs\n", + "17844 test pairs\n" + ] + } + ], + "source": [ + "random.shuffle(text_pairs)\n", + "num_val_samples = int(0.15 * len(text_pairs))\n", + "num_train_samples = len(text_pairs) - 2 * num_val_samples\n", + "train_pairs = text_pairs[:num_train_samples]\n", + "val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]\n", + "test_pairs = text_pairs[num_train_samples + num_val_samples :]\n", + "\n", + "print(f\"{len(text_pairs)} total pairs\")\n", + "print(f\"{len(train_pairs)} training pairs\")\n", + "print(f\"{len(val_pairs)} validation pairs\")\n", + "print(f\"{len(test_pairs)} test pairs\")" + ] + }, + { + "cell_type": "markdown", + "id": "2442289e", + "metadata": {}, + "source": [ + "# Tokenization\n", + "\n", + "This step is crucial because computers and machines do not understand English words directly; they require a **numeric representation** to process language.\n", + "\n", + "To achieve this, we use the **Tiktoken** library developed by **OpenAI**. specifically utilizing the `cl100k_base` dictionary.\n", + "\n", + "* **Tiktoken:** A fast BPE (Byte Pair Encoding) tokenizer.\n", + "* **cl100k_base:** A vocabulary containing approximately **100,000 tokens**, where every unique word or word fragment is mapped to a specific integer." + ] + }, + { + "cell_type": "markdown", + "id": "a714c4ea-9ff6-4dab-ae9c-1a884d4857e7", + "metadata": {}, + "source": [ + "We strip out punctuation to keep things simple and in line with the original tutorial - the `[` `]` are kept in so that our `[start]` and `[end]` formatting is preserved." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "07e054d3-a20c-4aed-8f8a-fb5158df8e5b", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = tiktoken.get_encoding(\"cl100k_base\")\n", + "\n", + "strip_chars = string.punctuation + \"¿\"\n", + "strip_chars = strip_chars.replace(\"[\", \"\")\n", + "strip_chars = strip_chars.replace(\"]\", \"\")\n", + "\n", + "vocab_size = tokenizer.n_vocab\n", + "sequence_length = 20" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e2b3e5b3-8466-4c81-99da-0559c88b25ef", + "metadata": {}, + "outputs": [], + "source": [ + "def custom_standardization(input_string):\n", + " lowercase = input_string.lower()\n", + " return re.sub(f\"[{re.escape(strip_chars)}]\", \"\", lowercase)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5bdc0673-9723-45b5-8a42-2152295df69b", + "metadata": {}, + "outputs": [], + "source": [ + "def tokenize_and_pad(text, tokenizer, max_length):\n", + " tokens = tokenizer.encode(text)[:max_length]\n", + " padded = (\n", + " tokens + [0] * (max_length - len(tokens)) if len(tokens) < max_length else tokens\n", + " ) ##assumes list-like - (https://github.com/openai/tiktoken/blob/main/tiktoken/core.py#L81 current tiktoken out)\n", + " return padded" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "235b1221-e72d-4793-addd-7bb870bd8e75", + "metadata": {}, + "outputs": [], + "source": [ + "def format_dataset(eng, spa, tokenizer, sequence_length):\n", + " eng = custom_standardization(eng)\n", + " spa = custom_standardization(spa)\n", + " eng = tokenize_and_pad(eng, tokenizer, sequence_length)\n", + " spa = tokenize_and_pad(spa, tokenizer, sequence_length)\n", + " return {\n", + " \"encoder_inputs\": eng,\n", + " \"decoder_inputs\": spa[:-1],\n", + " \"target_output\": spa[1:],\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ca013d07-1504-42cc-906f-2fcacc757008", + "metadata": {}, + "outputs": [], + "source": [ + "train_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in train_pairs]\n", + "val_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in val_pairs]\n", + "test_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in test_pairs]" + ] + }, + { + "cell_type": "markdown", + "id": "90bbae98-48dd-4ae4-99bb-92336d7c0a1c", + "metadata": {}, + "source": [ + "At this point we've extracted the data, applied formatting, and tokenized the phrases with padding. The data is kept in train/validate/test sets that each have dictionary entries, which look like the following:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "dcbfa780-553f-41f6-8b3e-55955db78b2a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'encoder_inputs': [29177, 499, 6604, 264, 2697, 62896, 4587, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_inputs': [29563, 60, 6183, 4355, 11158, 24180, 4247, 4799, 510, 408, 60, 0, 0, 0, 0, 0, 0, 0, 0], 'target_output': [60, 6183, 4355, 11158, 24180, 4247, 4799, 510, 408, 60, 0, 0, 0, 0, 0, 0, 0, 0, 0]}\n" + ] + } + ], + "source": [ + "## data selection example\n", + "print(train_data[135])" + ] + }, + { + "cell_type": "markdown", + "id": "24c6271b-e359-4aba-a583-f18c40eddba9", + "metadata": {}, + "source": [ + "The output should look something like\n", + "\n", + "{'encoder_inputs': [9514, 265, 3339, 264, 2466, 16930, 1618, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_inputs': [29563, 60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0], 'target_output': [60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0, 0]}" + ] + }, + { + "cell_type": "markdown", + "id": "7a906a05-bd17-4a47-afe0-4422d2ea0f50", + "metadata": {}, + "source": [ + "## Define Transformer components: Encoder, Decoder, Positional Embed\n", + "\n", + "In many ways this is very similar to the original source, with `ops` changing to `jnp` and `keras` or `layers` becoming `nnx`. Certain module-specific arguments come and go, like the rngs attached to most things in the updated version, and decode=False in the MultiHeadAttention call." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a3f8a6fd", + "metadata": {}, + "outputs": [], + "source": [ + "@dataclasses.dataclass\n", + "class TransformerConfig:\n", + " sequence_length: int\n", + " vocab_size: int\n", + " embed_dim: int\n", + " latent_dim: int\n", + " num_heads: int\n", + " dropout_rate: float" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "121bf138-34b3-4be9-a0fc-3bbac81f469a", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerEncoder(nnx.Module):\n", + " def __init__(self, config: TransformerConfig, rngs: nnx.Rngs, **kwargs):\n", + " self.attention = nnx.MultiHeadAttention(\n", + " num_heads=config.num_heads, in_features=config.embed_dim, decode=False, rngs=rngs\n", + " )\n", + " self.dense_proj = nnx.Sequential(\n", + " nnx.Linear(config.embed_dim, config.latent_dim, rngs=rngs),\n", + " nnx.relu,\n", + " nnx.Linear(config.latent_dim, config.embed_dim, rngs=rngs),\n", + " )\n", + "\n", + " self.layernorm_1 = nnx.LayerNorm(config.embed_dim, rngs=rngs)\n", + " self.layernorm_2 = nnx.LayerNorm(config.embed_dim, rngs=rngs)\n", + "\n", + " def __call__(self, inputs, mask=None):\n", + " if mask is not None:\n", + " padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)\n", + " else:\n", + " padding_mask = None\n", + "\n", + " attention_output = self.attention(\n", + " inputs_q=inputs, inputs_k=inputs, inputs_v=inputs, mask=padding_mask, decode=False\n", + " )\n", + " proj_input = self.layernorm_1(inputs + attention_output)\n", + " proj_output = self.dense_proj(proj_input)\n", + " return self.layernorm_2(proj_input + proj_output)\n", + "\n", + "\n", + "class PositionalEmbedding(nnx.Module):\n", + " def __init__(self, config: TransformerConfig, rngs: nnx.Rngs, **kwargs):\n", + " self.token_embeddings = nnx.Embed(num_embeddings=config.vocab_size, features=config.embed_dim, rngs=rngs)\n", + " self.position_embeddings = nnx.Embed(\n", + " num_embeddings=config.sequence_length, features=config.embed_dim, rngs=rngs\n", + " )\n", + "\n", + " def __call__(self, inputs, step=None):\n", + " if step is None:\n", + " length = inputs.shape[1]\n", + " positions = jnp.arange(0, length)[None, :]\n", + " else:\n", + " positions = jnp.array([step])[None, :]\n", + "\n", + " embedded_tokens = self.token_embeddings(inputs)\n", + " embedded_positions = self.position_embeddings(positions)\n", + " return embedded_tokens + embedded_positions\n", + "\n", + "\n", + "class TransformerDecoder(nnx.Module):\n", + " def __init__(self, config: TransformerConfig, rngs: nnx.Rngs, **kwargs):\n", + " self.attention_1 = nnx.MultiHeadAttention(\n", + " num_heads=config.num_heads, in_features=config.embed_dim, decode=True, rngs=rngs\n", + " )\n", + " self.attention_2 = nnx.MultiHeadAttention(\n", + " num_heads=config.num_heads, in_features=config.embed_dim, decode=False, rngs=rngs\n", + " )\n", + "\n", + " self.dense_proj = nnx.Sequential(\n", + " nnx.Linear(config.embed_dim, config.latent_dim, rngs=rngs),\n", + " nnx.relu,\n", + " nnx.Linear(config.latent_dim, config.embed_dim, rngs=rngs),\n", + " )\n", + " self.layernorm_1 = nnx.LayerNorm(config.embed_dim, rngs=rngs)\n", + " self.layernorm_2 = nnx.LayerNorm(config.embed_dim, rngs=rngs)\n", + " self.layernorm_3 = nnx.LayerNorm(config.embed_dim, rngs=rngs)\n", + "\n", + " def init_cache(self, input_shape):\n", + " self.attention_1.init_cache(input_shape=input_shape)\n", + "\n", + " def __call__(self, inputs, encoder_outputs, mask=None, decode=False):\n", + " if decode:\n", + " padding_mask = None\n", + " else:\n", + " causal_mask = self.get_causal_attention_mask(inputs.shape[1])\n", + " if mask is not None:\n", + " padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)\n", + " padding_mask = jnp.minimum(padding_mask, causal_mask)\n", + " else:\n", + " padding_mask = causal_mask\n", + "\n", + " cross_mask = mask\n", + " if decode and mask is not None:\n", + " cross_mask = jnp.expand_dims(mask, axis=(1, 2)).astype(jnp.int32)\n", + "\n", + " elif not decode and mask is not None:\n", + " cross_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)\n", + "\n", + " attention_output_1 = self.attention_1(\n", + " inputs_q=inputs, inputs_k=inputs, inputs_v=inputs, mask=padding_mask, decode=decode\n", + " )\n", + " out_1 = self.layernorm_1(inputs + attention_output_1)\n", + "\n", + " attention_output_2 = self.attention_2(\n", + " inputs_q=out_1, inputs_k=encoder_outputs, inputs_v=encoder_outputs, mask=cross_mask, decode=False\n", + " )\n", + "\n", + " out_2 = self.layernorm_2(out_1 + attention_output_2)\n", + "\n", + " proj_output = self.dense_proj(out_2)\n", + " output = self.layernorm_3(out_2 + proj_output)\n", + " return output\n", + "\n", + " def get_causal_attention_mask(self, sequence_length):\n", + " i = jnp.arange(sequence_length)[:, None]\n", + " j = jnp.arange(sequence_length)\n", + " mask = (i >= j).astype(jnp.int32)\n", + " mask = jnp.reshape(mask, (1, 1, sequence_length, sequence_length))\n", + " return mask" + ] + }, + { + "cell_type": "markdown", + "id": "d033ae31-cc43-4e61-8d7f-cdc6d55b8bf9", + "metadata": {}, + "source": [ + "Here we finally use our earlier encoder, decoder, and positional embed classes to construct the Model that we'll train and later use for inference." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c5dcfaf6-f5cd-40f4-bbf0-2754c0193327", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerModel(nnx.Module):\n", + " def __init__(self, config: TransformerConfig, rngs: nnx.Rngs):\n", + " self.config = config\n", + " self.encoder = TransformerEncoder(config, rngs=rngs)\n", + " self.positional_embedding = PositionalEmbedding(config, rngs=rngs)\n", + " self.decoder = TransformerDecoder(config, rngs=rngs)\n", + " self.dropout = nnx.Dropout(rate=config.dropout_rate, rngs=rngs)\n", + " self.dense = nnx.Linear(config.embed_dim, config.vocab_size, rngs=rngs)\n", + "\n", + " def init_cache(self, input_shape):\n", + " self.decoder.init_cache(input_shape)\n", + "\n", + " def __call__(self, encoder_inputs, decoder_inputs, mask=None, deterministic=False, decode=False):\n", + " x = self.positional_embedding(encoder_inputs)\n", + " encoder_outputs = self.encoder(x, mask=mask)\n", + "\n", + " x = self.positional_embedding(decoder_inputs)\n", + " decoder_outputs = self.decoder(x, encoder_outputs, mask=mask, decode=decode)\n", + "\n", + " decoder_outputs = self.dropout(decoder_outputs, deterministic=deterministic)\n", + " logits = self.dense(decoder_outputs)\n", + " return logits\n", + "\n", + " def decode_step(self, token_input, encoder_outputs, step_index):\n", + " x = self.positional_embedding(token_input, step=step_index)\n", + " decoder_outputs = self.decoder(x, encoder_outputs, decode=True)\n", + " logits = self.dense(decoder_outputs)\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "id": "1744cd95-afcc-4a82-9a00-18fef4f6f7df", + "metadata": {}, + "source": [ + "## Build out Data Loader and Training Definitions\n", + "It can be more computationally efficient to use pygrain for the data load stage, but this way it's abundandtly clear what's happening: data pairs go in and sets of jnp arrays come out, in step with our original dictionaries. 'Encoder_inputs', 'decoder_inputs' and 'target_output'." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "1fb8cb44-9012-4802-9286-1efc19dd2ba1", + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 64 # set here for the loader and model train later on\n", + "\n", + "\n", + "class CustomPreprocessing(grain.MapTransform):\n", + " def __init__(self):\n", + " pass\n", + "\n", + " def map(self, data):\n", + " return {\n", + " \"encoder_inputs\": np.array(data[\"encoder_inputs\"]),\n", + " \"decoder_inputs\": np.array(data[\"decoder_inputs\"]),\n", + " \"target_output\": np.array(data[\"target_output\"]),\n", + " }\n", + "\n", + "\n", + "train_sampler = grain.IndexSampler(\n", + " len(train_data),\n", + " shuffle=True,\n", + " seed=12, # Seed for reproducibility\n", + " shard_options=grain.NoSharding(), # No sharding since it's a single-device setup\n", + " num_epochs=1, # Iterate over the dataset for one epoch\n", + ")\n", + "\n", + "val_sampler = grain.IndexSampler(\n", + " len(val_data),\n", + " shuffle=False,\n", + " seed=12,\n", + " shard_options=grain.NoSharding(),\n", + " num_epochs=1,\n", + ")\n", + "\n", + "train_loader = grain.DataLoader(\n", + " data_source=train_data,\n", + " sampler=train_sampler, # Sampler to determine how to access the data\n", + " worker_count=4, # Number of child processes launched to parallelize the transformations\n", + " worker_buffer_size=2, # Count of output batches to produce in advance per worker\n", + " operations=[\n", + " CustomPreprocessing(),\n", + " grain.Batch(batch_size=batch_size, drop_remainder=True),\n", + " ],\n", + ")\n", + "\n", + "val_loader = grain.DataLoader(\n", + " data_source=val_data,\n", + " sampler=val_sampler,\n", + " worker_count=4,\n", + " worker_buffer_size=2,\n", + " operations=[\n", + " CustomPreprocessing(),\n", + " grain.Batch(batch_size=batch_size),\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "40d9707d-a73c-47f5-8c12-1f336e526e61", + "metadata": {}, + "source": [ + "Optax doesn't have the identical loss function that the source tutorial uses, but this softmax cross entropy works well here - you can one_hot_encode if you don't use the `_with_integer_labels` version of the loss." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d2f8e06f-1126-41cc-b8d8-de6bd7a5255a", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_loss(logits, labels):\n", + " loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels)\n", + " return jnp.mean(loss)" + ] + }, + { + "cell_type": "markdown", + "id": "0a1b625a-d9e7-4028-bc98-521ce1632450", + "metadata": {}, + "source": [ + "While in the original tutorial most of the model and training details happen inside keras, we make them explicit here in our step functions, which are later used in `train_one_epoch` and `eval_model`." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "279d991f-f129-48b3-9b7e-d143019c18a8", + "metadata": {}, + "outputs": [], + "source": [ + "@nnx.jit\n", + "def train_step(model, optimizer, batch):\n", + " def loss_fn(model, train_encoder_input, train_decoder_input, train_target_input):\n", + " logits = model(train_encoder_input, train_decoder_input)\n", + " loss = compute_loss(logits, train_target_input)\n", + " return loss\n", + "\n", + " grad_fn = nnx.value_and_grad(loss_fn)\n", + " loss, grads = grad_fn(\n", + " model, jnp.array(batch[\"encoder_inputs\"]), jnp.array(batch[\"decoder_inputs\"]), jnp.array(batch[\"target_output\"])\n", + " )\n", + " optimizer.update(model, grads)\n", + " return loss\n", + "\n", + "\n", + "@nnx.jit\n", + "def eval_step(model, batch, eval_metrics):\n", + " logits = model(jnp.array(batch[\"encoder_inputs\"]), jnp.array(batch[\"decoder_inputs\"]))\n", + " loss = compute_loss(logits, jnp.array(batch[\"target_output\"]))\n", + " labels = jnp.array(batch[\"target_output\"])\n", + "\n", + " eval_metrics.update(\n", + " loss=loss,\n", + " logits=logits,\n", + " labels=labels,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "04e53ee9-6da1-431c-8b3f-f619d3fee68f", + "metadata": {}, + "source": [ + "Here, `nnx.MultiMetric` helps us keep track of general training statistics, while we make our own dictionaries to hold historical values" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "32a17edc-33d0-41bc-a516-8b8ce45c3ad7", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to find host bounds for accelerator type: WARNING: could not determine TPU accelerator type, please set env var `TPU_ACCELERATOR_TYPE` manually, otherwise libtpu.so may not properly initialize.\n", + "WARNING: Logging before InitGoogle() is written to STDERR\n", + "E0000 00:00:1768499505.463938 40451 common_lib.cc:530] INVALID_ARGUMENT: Error: unexpected worker hostname 'WARNING: could not determine TPU worker hostnames or IP addresses' from env var TPU_WORKER_HOSTNAMES. Expecting a valid hostname or IP address without port number, or hostname:port:address triple. (Full TPU workers' addr string: WARNING: could not determine TPU worker hostnames or IP addresses, please set env var `TPU_WORKER_HOSTNAMES` manually, otherwise libtpu.so may not properly initialize.)\n", + "=== Source Location Trace: === \n", + "learning/45eac/tfrc/runtime/libtpu_init_utils.cc:310\n" + ] + } + ], + "source": [ + "eval_metrics = nnx.MultiMetric(\n", + " loss=nnx.metrics.Average(\"loss\"),\n", + " accuracy=nnx.metrics.Accuracy(),\n", + ")\n", + "\n", + "train_metrics_history = {\n", + " \"train_loss\": [],\n", + "}\n", + "\n", + "eval_metrics_history = {\n", + " \"test_loss\": [],\n", + " \"test_accuracy\": [],\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "1189a6a6-2cc6-4c87-9f87-b4b800a1513d", + "metadata": {}, + "outputs": [], + "source": [ + "## Hyperparameters\n", + "rng = nnx.Rngs(0)\n", + "embed_dim = 256\n", + "latent_dim = 2048\n", + "num_heads = 8\n", + "dropout_rate = 0.5\n", + "vocab_size = tokenizer.n_vocab\n", + "sequence_length = 20\n", + "learning_rate = 1.5e-3\n", + "num_epochs = 10\n", + "\n", + "config = TransformerConfig(\n", + " sequence_length=sequence_length,\n", + " vocab_size=vocab_size,\n", + " embed_dim=embed_dim,\n", + " latent_dim=latent_dim,\n", + " num_heads=num_heads,\n", + " dropout_rate=dropout_rate,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "fbeb6101-be11-4a33-9650-a3efd3656855", + "metadata": {}, + "outputs": [], + "source": [ + "bar_format = \"{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]\"\n", + "train_total_steps = len(train_data) // batch_size\n", + "\n", + "\n", + "def train_one_epoch(epoch):\n", + " model.train() # Set model to the training mode: e.g. update batch statistics\n", + " with tqdm.tqdm(\n", + " desc=f\"[train] epoch: {epoch}/{num_epochs}, \",\n", + " total=train_total_steps,\n", + " bar_format=bar_format,\n", + " leave=True,\n", + " ) as pbar:\n", + " for batch in train_loader:\n", + " loss = train_step(model, optimizer, batch)\n", + " train_metrics_history[\"train_loss\"].append(loss.item())\n", + " pbar.set_postfix({\"loss\": loss.item()})\n", + " pbar.update(1)\n", + "\n", + "\n", + "def evaluate_model(epoch):\n", + " # Compute the metrics on the train and val sets after each training epoch.\n", + " model.eval() # Set model to evaluation model: e.g. use stored batch statistics\n", + "\n", + " eval_metrics.reset() # Reset the eval metrics\n", + " for val_batch in val_loader:\n", + " eval_step(model, val_batch, eval_metrics)\n", + "\n", + " for metric, value in eval_metrics.compute().items():\n", + " eval_metrics_history[f\"test_{metric}\"].append(value)\n", + "\n", + " print(f\"[test] epoch: {epoch + 1}/{num_epochs}\")\n", + " print(f\"- total loss: {eval_metrics_history['test_loss'][-1]:0.4f}\")\n", + " print(f\"- Accuracy: {eval_metrics_history['test_accuracy'][-1]:0.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "49a1d33a-c2e4-4d48-821b-519f5c0192c7", + "metadata": {}, + "outputs": [], + "source": [ + "model = TransformerModel(config, rngs=rng)\n", + "optimizer = nnx.Optimizer(model, optax.adamw(learning_rate), wrt=nnx.Param)" + ] + }, + { + "cell_type": "markdown", + "id": "fa7d5601-60c1-4131-a40c-c670f055ce68", + "metadata": {}, + "source": [ + "## Start the Training!\n", + "With our data loaders in place and the model, optimizer, and training/evaluation loops fully configured, it’s finally time to press go.\n", + "Training on an RTX 4050 (6 GB VRAM), the model fits comfortably within memory, and with a batch size of 64, each epoch completes in approximately 1 minute 30 seconds." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "c764510c-4d98-46ad-b877-8cfc2fa5a9ea", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[train] epoch: 0/10, [1300/1301], loss=1.1 [02:31<00:00] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[test] epoch: 1/10\n", + "- total loss: 1.2200\n", + "- Accuracy: 0.7859\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[train] epoch: 1/10, [1300/1301], loss=0.766 [01:37<00:00]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[test] epoch: 2/10\n", + "- total loss: 0.9918\n", + "- Accuracy: 0.8195\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[train] epoch: 2/10, [1300/1301], loss=0.63 [01:35<00:00] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[test] epoch: 3/10\n", + "- total loss: 0.9035\n", + "- Accuracy: 0.8343\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[train] epoch: 3/10, [1300/1301], loss=0.568 [01:36<00:00]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[test] epoch: 4/10\n", + "- total loss: 0.8744\n", + "- Accuracy: 0.8410\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[train] epoch: 4/10, [1300/1301], loss=0.499 [01:35<00:00]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[test] epoch: 5/10\n", + "- total loss: 0.8573\n", + "- Accuracy: 0.8464\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[train] epoch: 5/10, [1300/1301], loss=0.463 [01:35<00:00]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[test] epoch: 6/10\n", + "- total loss: 0.8596\n", + "- Accuracy: 0.8478\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[train] epoch: 6/10, [1300/1301], loss=0.425 [01:35<00:00]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[test] epoch: 7/10\n", + "- total loss: 0.8618\n", + "- Accuracy: 0.8502\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[train] epoch: 7/10, [1300/1301], loss=0.43 [01:34<00:00] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[test] epoch: 8/10\n", + "- total loss: 0.8556\n", + "- Accuracy: 0.8514\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[train] epoch: 8/10, [1300/1301], loss=0.442 [01:32<00:00]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[test] epoch: 9/10\n", + "- total loss: 0.8708\n", + "- Accuracy: 0.8522\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[train] epoch: 9/10, [1300/1301], loss=0.376 [01:34<00:00]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[test] epoch: 10/10\n", + "- total loss: 0.8773\n", + "- Accuracy: 0.8537\n" + ] + } + ], + "source": [ + "for epoch in range(num_epochs):\n", + " train_one_epoch(epoch)\n", + " evaluate_model(epoch)" + ] + }, + { + "cell_type": "markdown", + "id": "f922eac4-8338-4a0d-bc6d-1f5880079bde", + "metadata": {}, + "source": [ + "We can then plot the loss over training time. That log-plot comes in handy here, or it's hard to appreciate the progress after 1000 steps or so." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "a79ecfa5-d74a-4956-9ee2-cbed86d5a82f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(train_metrics_history[\"train_loss\"], label=\"Loss value during the training\")\n", + "plt.yscale(\"log\")\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "id": "66250bf2-3d88-40ad-87e3-7d2b906fd860", + "metadata": {}, + "source": [ + "And eval set Loss and Accuracy - Accuracy does continue to rise, though it's hard-earned progress after about the 5th epoch. Based on the training statistics, it's fair to say the process starts overfitting after roughly that 5th epoch." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "64d54051-358b-4de8-b5b3-04bebf18018f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(1, 2, figsize=(10, 10))\n", + "axs[0].set_title(\"Loss value on eval set\")\n", + "axs[0].plot(eval_metrics_history[\"test_loss\"])\n", + "axs[1].set_title(\"Accuracy on eval set\")\n", + "axs[1].plot(eval_metrics_history[\"test_accuracy\"])" + ] + }, + { + "cell_type": "markdown", + "id": "a3f7b0ad-ddfa-4ab3-b56f-6ea99385ff6a", + "metadata": {}, + "source": [ + "## Use Model for Inference\n", + "After all that, the product of what we were working for: a trained model we can save and load for inference. For people using LLMs recently, this pattern may look rather familiar: an input sentence tokenized into an array and computed 'next' token-by-token. While many recent LLMs are decoder-only, this was an encoder/decoder architecture with the very specific english-to-spanish pattern baked in.\n", + "\n", + "We've changed a couple things from the source 'use' function, here - because of the tokenizer used, things like `[start]` and `[end]` are no longer single tokens - instead `[start]` is `[29563, 60] = \"[start\" + \"]\"` and `[end]` is `[58308, 60] = \"[end\" + \"]\"` - thus we start with only a single token `[start` and can't only test on `last_token = \"[end\"]`. Otherwise, the main change here is that the input is assumed a single sentence, rather than batch inference." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "e4589706-cfd6-4efb-9975-bfa0df75d4f0", + "metadata": {}, + "outputs": [], + "source": [ + "def decode_sequence(input_sentence):\n", + " input_sentence = custom_standardization(input_sentence)\n", + " tokenized_input_sentence = tokenize_and_pad(input_sentence, tokenizer, sequence_length)\n", + " encoder_input = jnp.array([tokenized_input_sentence])\n", + "\n", + " emb_enc = model.positional_embedding(encoder_input)\n", + " encoder_outputs = model.encoder(emb_enc, mask=None)\n", + "\n", + " dummy_input_shape = (1, 40, model.config.embed_dim)\n", + " model.init_cache(dummy_input_shape)\n", + "\n", + " decoded_sentence = \"[start\"\n", + " current_token_id = tokenizer.encode(\"[start\")\n", + " current_input = jnp.array([current_token_id])\n", + "\n", + " for i in range(sequence_length):\n", + " logits = model.decode_step(current_input, encoder_outputs, step_index=i)\n", + "\n", + " sampled_id = np.argmax(logits[0, 0, :]).item()\n", + " sampled_token = tokenizer.decode([sampled_id])\n", + "\n", + " decoded_sentence += \"\" + sampled_token\n", + "\n", + " clean_token = sampled_token.strip()\n", + "\n", + " if sampled_token == \"[end]\" or clean_token == \"end\":\n", + " decoded_sentence += \"]\"\n", + " break\n", + "\n", + " # Update input for next loop\n", + " current_input = jnp.array([[sampled_id]])\n", + "\n", + " return decoded_sentence" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "554c2f72-0bd3-4ed1-804b-5f1a4cc13851", + "metadata": {}, + "outputs": [], + "source": [ + "test_eng_texts = [pair[0] for pair in test_pairs]" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "c1d6edbb-af89-42c9-90c3-d61612b75da3", + "metadata": {}, + "outputs": [], + "source": [ + "test_result_pairs = []\n", + "for _ in range(10):\n", + " input_sentence = random.choice(test_eng_texts)\n", + " translated = decode_sequence(input_sentence)\n", + "\n", + " test_result_pairs.append(f\"[Input]: {input_sentence} [Translation]: {translated}\")" + ] + }, + { + "cell_type": "markdown", + "id": "258c2172-5a0f-4dee-9b21-f65433183c62", + "metadata": {}, + "source": [ + "## Test Results\n", + "For the model and the data, not too shabby - It's definitely spanish-ish. Though when 'making' friends, please don't confuse 'hacer' (to make) with 'comer' (to eat)." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "4f0ae018-b7cd-4849-b245-c5c647ad1a95", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Input]: This is the same necklace that I lost yesterday. [Translation]: [start] este es el mismo que perdí ayer [end]\n", + "[Input]: Tom has a house with two rooms. [Translation]: [start] tom tiene una casa de dos habitación [end]\n", + "[Input]: Tom is a schemer. [Translation]: [start] tom es un confunduista [end]\n", + "[Input]: You must study much harder. [Translation]: [start] debes estudiar más [end]\n", + "[Input]: Why isn't this working? [Translation]: [start] por qué no funciona esto [end]\n", + "[Input]: Do you like oranges? [Translation]: [start] te gustan las naranjas [end]\n", + "[Input]: I started writing a book. [Translation]: [start] empezé a escribir un libro [end]\n", + "[Input]: He's watching me. [Translation]: [start] él me está mirando [end]\n", + "[Input]: You speak Russian, don't you? [Translation]: [start] hablas ruso verdad [end]\n", + "[Input]: I didn't have any desire to do that. [Translation]: [start] no tenía nada de ganas de hacer eso [end]\n" + ] + } + ], + "source": [ + "for i in test_result_pairs:\n", + " print(i)" + ] + }, + { + "cell_type": "markdown", + "id": "5ca18d4c-b3c0-4abb-b5fc-fc96a2264b53", + "metadata": {}, + "source": [ + "Example output from the above cell:\n", + "\n", + " [Input]: Take this medicine after meals. [Translation]: [start] toma esta medicina después de comer [end]\n", + " [Input]: The English are said to be conservative. [Translation]: [start] el inglés son dijo que sería conservadores [end]\n", + " [Input]: Tom might call Mary tonight. [Translation]: [start] tom quizás podría llamar a mary esta noche [end]\n", + " [Input]: I have not finished lunch. [Translation]: [start] no he finalmente terminado [end]\n", + " [Input]: Are you ready to start? [Translation]: [start] estás listo para empezar [end]\n", + " [Input]: Tom worked as a lifeguard during the summer. [Translation]: [start] tom trabajó como un salvavidas durante el verano [end]\n", + " [Input]: Can I pay later? [Translation]: [start] puedo pagar más tarde [end]\n", + " [Input]: They went hand in hand. [Translation]: [start] ellos se fueron [end]\n", + " [Input]: You look like a baboon. [Translation]: [start] parecés como un papión [end]\n", + " [Input]: A cloud floated across the sky. [Translation]: [start] una sola nubeió en el cielo [end]" + ] + }, + { + "cell_type": "markdown", + "id": "cd25c648", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "env (3.11.14)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}