Skip to content

Commit dab4da4

Browse files
Aatman09coder0143
authored andcommitted
Refactor tutorial to use dataclass for configuration (jax-ml#119)
* Refactor tutorial to use dataclass for configuration * Imeplemented KV caching (WIP) * final changes
1 parent 206c991 commit dab4da4

File tree

3 files changed

+1255
-2
lines changed

3 files changed

+1255
-2
lines changed

bonsai/models/unet/tests/UNet_segmentation_example.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@
629629
" return state, loss\n",
630630
"\n",
631631
"\n",
632-
"print(\"🚀 Starting training from checkpoint...\")\n",
632+
"print(\"Starting training from checkpoint...\")\n",
633633
"train_loader, vis_loader = load_dataset()\n",
634634
"num_epochs = 100\n",
635635
"state = train_state\n",

bonsai/models/unet/tests/UNet_segmentation_example.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def train_step(state: TrainState, other_vars: nnx.State, batch: tuple[jax.Array,
256256
return state, loss
257257
258258
259-
print("🚀 Starting training from checkpoint...")
259+
print("Starting training from checkpoint...")
260260
train_loader, vis_loader = load_dataset()
261261
num_epochs = 100
262262
state = train_state

0 commit comments

Comments
 (0)