Skip to content

Commit cf6810b

Browse files
committed
add VAE on MNIST
1 parent 4abf3f1 commit cf6810b

9 files changed

+529
-221
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
</p>
99

1010
<p align="center">
11-
On MNIST: (a) acc[96.80%] & loss vs. epochs for mlp; (b) acc[97.86%] & loss vs. epochs for LeNet
11+
On MNIST: (a) acc[96.80%] & loss vs. epochs for mlp; (b) acc[98.24%] & loss vs. epochs for LeNet
1212
</p>
1313

1414

@@ -35,7 +35,7 @@ On MNIST: (a) acc[96.80%] & loss vs. epochs for mlp; (b) acc[97.86%] & loss vs.
3535
- Nerual ODE[[5](#reference)]
3636
- MNIST. <mark>TODO</mark>
3737
- VAE[[7](#reference)]
38-
- MNIST. <mark>TODO</mark>
38+
- MNIST.
3939

4040
## # NoteBook Docs
4141

cnn_cifar10.ipynb

+17-19
Large diffs are not rendered by default.

cnn_mnist.ipynb

+23-20
Large diffs are not rendered by default.

gru_ucihar.ipynb

+44-94
Large diffs are not rendered by default.

knn_cifar10.ipynb

+36-8
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@
183183
},
184184
{
185185
"cell_type": "code",
186-
"execution_count": 7,
186+
"execution_count": 6,
187187
"id": "79ce60c8-8923-400e-b7fa-4753442b85ae",
188188
"metadata": {},
189189
"outputs": [
@@ -192,19 +192,47 @@
192192
"output_type": "stream",
193193
"text": [
194194
"0.30600002\n",
195-
"\n",
196-
"time: 5.3901941776275635 s\n"
195+
"0.30600002\n",
196+
"0.30600002\n",
197+
"0.30600002\n",
198+
"0.30600002\n",
199+
"0.30600002\n",
200+
"0.30600002\n",
201+
"0.30600002\n",
202+
"0.30600002\n",
203+
"0.30600002\n",
204+
"0.30600002\n",
205+
"0.30600002\n",
206+
"0.30600002\n",
207+
"0.30600002\n",
208+
"0.30600002\n",
209+
"0.30600002\n",
210+
"0.30600002\n",
211+
"0.30600002\n",
212+
"0.30600002\n",
213+
"0.30600002\n",
214+
"0.30600002\n",
215+
"0.30600002\n",
216+
"0.30600002\n",
217+
"0.30600002\n",
218+
"0.30600002\n",
219+
"0.30600002\n",
220+
"0.30600002\n",
221+
"0.30600002\n",
222+
"0.30600002\n",
223+
"0.30600002\n",
224+
"5.64 s ± 74.3 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)\n"
197225
]
198226
}
199227
],
200228
"source": [
201229
"import time\n",
202230
"\n",
203-
"s = time.time()\n",
204-
"y_pred = knn.predict(x_test)\n",
205-
"print(jnp.mean(y_test==y_pred))\n",
206-
"print()\n",
207-
"print(f'time: {time.time() - s} s')"
231+
"def fast_run():\n",
232+
" y_pred = knn.predict(x_test)\n",
233+
" print(jnp.mean(y_test==y_pred))\n",
234+
"\n",
235+
"%timeit -n10 -r3 fast_run()"
208236
]
209237
}
210238
],

lstm_ucihar.ipynb

+55-58
Large diffs are not rendered by default.

notebook_docs/jax_tips.ipynb

+3-9
Original file line numberDiff line numberDiff line change
@@ -679,16 +679,10 @@
679679
"id": "b7375ed6-dee1-416e-a6d9-a81247f4ba2f",
680680
"metadata": {},
681681
"source": [
682-
"## # jax.lax.scan: Iter Functool"
682+
"## # jax.lax.scan: Iter Functool\n",
683+
"\n",
684+
"see [knn_on_cifar10](https://github.com/HugoPhi/jaxdls/blob/main/knn_cifar10.ipynb) & [lstm cell](https://github.com/HugoPhi/jaxdls/blob/main/plugins/minitorch/nn/JaxOptimized/rnncell.py)."
683685
]
684-
},
685-
{
686-
"cell_type": "code",
687-
"execution_count": null,
688-
"id": "d3e6e1c8-4ad3-4943-929f-c9f749b09ffd",
689-
"metadata": {},
690-
"outputs": [],
691-
"source": []
692686
}
693687
],
694688
"metadata": {

plugins/minitorch/optimizer.py

+39-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
'''
22
JAX Optimization Algorithms Module
33
4-
* Last Updated: 2025-03-09
4+
* Last Updated: 2025-03-15
55
* Author: HugoPhi, [GitHub](https://github.com/HugoPhi)
66
* Maintainer: [email protected]
77
@@ -39,7 +39,7 @@
3939
'''
4040

4141
import jax.numpy as jnp
42-
from jax import grad, tree, lax
42+
from jax import grad, tree, lax, random
4343
from abc import ABC, abstractmethod
4444

4545

@@ -89,19 +89,20 @@ def flash(self):
8989
'''
9090
pass
9191

92-
def open(self, loss_function, x_train: jnp.ndarray, y_train: jnp.ndarray, short_batch='drop'):
92+
def open(self, loss_function, x_train: jnp.ndarray, y_train: jnp.ndarray, short_batch='drop', key=random.PRNGKey(42)):
9393
'''
9494
Prepares the optimizer for training by initializing its state and setting up the training data.
9595
9696
Args:
9797
loss_function: A loss function that computes the scalar loss given model parameters,
9898
input data, and true labels. It must be JIT-compiled.
99-
Signature: `f(params, x, y_true) -> scalar`.
99+
Signature: `f(params, x, y_true, key=random.PRNGKey(42)) -> scalar`.
100100
x_train: Input data for training. Shape: `(num_samples, ...)`.
101101
y_train: True labels for training. Shape: `(num_samples, ...)`.
102102
short_batch: The Strategy to handle short batch. including:
103103
- 'drop': drop short batch, used when: dataset size >> batch size, num_batches = N // B
104104
- 'pad': append arr[-B:] to trimmed arr, used when: dataset size >~ batch size, num_batches = N // B + 1
105+
key: A random number generator key used for initialization.
105106
106107
Notes:
107108
- The training data is divided into batches based on the `batch_size` attribute.
@@ -114,6 +115,7 @@ def open(self, loss_function, x_train: jnp.ndarray, y_train: jnp.ndarray, short_
114115
else:
115116
self.flash()
116117
self._loss = loss_function
118+
self.key = key
117119

118120
if short_batch == 'drop':
119121
self.num_batches = x_train.shape[0] // self.batch_size
@@ -247,6 +249,8 @@ def update(self):
247249
ixs = jnp.arange(self.num_batches)
248250
bxs = self.x_train.reshape(self.num_batches, self.batch_size, *self.x_train.shape[1:])
249251
bys = self.y_train.reshape(self.num_batches, self.batch_size, *self.y_train.shape[1:])
252+
subkeys = random.split(self.key, self.num_batches + 1)
253+
self.key, subkeys = subkeys[0], subkeys[1:] # update self.key & get subkeys
250254

251255
def one_batch(carry, ix):
252256

@@ -270,7 +274,8 @@ def adam(d_w, w, v, vv):
270274

271275
bx = bxs[ix]
272276
by = bys[ix]
273-
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by)
277+
kkey = subkeys[ix]
278+
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by, kkey)
274279

275280
pack = tree.map(adam, d_params, carry['params'], carry['V'], carry['VV']) # use Adam
276281
carry['params'] = tree.map(lambda x: x[0], pack)
@@ -346,6 +351,8 @@ def update(self):
346351
ixs = jnp.arange(self.num_batches)
347352
bxs = self.x_train.reshape(self.num_batches, self.batch_size, *self.x_train.shape[1:])
348353
bys = self.y_train.reshape(self.num_batches, self.batch_size, *self.y_train.shape[1:])
354+
subkeys = random.split(self.key, self.num_batches + 1)
355+
self.key, subkeys = subkeys[0], subkeys[1:] # update self.key & get subkeys
349356

350357
def one_batch(carry, ix):
351358

@@ -355,8 +362,9 @@ def gd(d_w, w):
355362

356363
bx = bxs[ix]
357364
by = bys[ix]
365+
kkey = subkeys[ix]
358366

359-
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by)
367+
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by, kkey)
360368

361369
pack = tree.map(gd, d_params, carry['params'])
362370
carry['params'] = pack
@@ -436,6 +444,8 @@ def update(self):
436444
ixs = jnp.arange(self.num_batches)
437445
bxs = self.x_train.reshape(self.num_batches, self.batch_size, *self.x_train.shape[1:])
438446
bys = self.y_train.reshape(self.num_batches, self.batch_size, *self.y_train.shape[1:])
447+
subkeys = random.split(self.key, self.num_batches + 1)
448+
self.key, subkeys = subkeys[0], subkeys[1:] # update self.key & get subkeys
439449

440450
def one_batch(carry, ix):
441451

@@ -446,7 +456,9 @@ def momentum(d_w, w, v):
446456

447457
bx = bxs[ix]
448458
by = bys[ix]
449-
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by)
459+
kkey = subkeys[ix]
460+
461+
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by, kkey)
450462

451463
pack = tree.map(momentum, d_params, carry['params'], carry['V'])
452464
carry['params'] = tree.map(lambda x: x[0], pack)
@@ -529,6 +541,8 @@ def update(self):
529541
ixs = jnp.arange(self.num_batches)
530542
bxs = self.x_train.reshape(self.num_batches, self.batch_size, *self.x_train.shape[1:])
531543
bys = self.y_train.reshape(self.num_batches, self.batch_size, *self.y_train.shape[1:])
544+
subkeys = random.split(self.key, self.num_batches + 1)
545+
self.key, subkeys = subkeys[0], subkeys[1:] # update self.key & get subkeys
532546

533547
def one_batch(carry, ix):
534548

@@ -541,7 +555,9 @@ def nag(d_w, w, v):
541555

542556
bx = bxs[ix]
543557
by = bys[ix]
544-
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by)
558+
kkey = subkeys[ix]
559+
560+
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by, kkey)
545561

546562
pack = tree.map(nag, d_params, carry['params'], carry['V'])
547563
carry['params'] = tree.map(lambda x: x[0], pack)
@@ -623,6 +639,8 @@ def update(self):
623639
ixs = jnp.arange(self.num_batches)
624640
bxs = self.x_train.reshape(self.num_batches, self.batch_size, *self.x_train.shape[1:])
625641
bys = self.y_train.reshape(self.num_batches, self.batch_size, *self.y_train.shape[1:])
642+
subkeys = random.split(self.key, self.num_batches + 1)
643+
self.key, subkeys = subkeys[0], subkeys[1:] # update self.key & get subkeys
626644

627645
def one_batch(carry, ix):
628646

@@ -633,7 +651,9 @@ def adagrad(d_w, w, g):
633651

634652
bx = bxs[ix]
635653
by = bys[ix]
636-
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by)
654+
kkey = subkeys[ix]
655+
656+
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by, kkey)
637657

638658
pack = tree.map(adagrad, d_params, carry['params'], carry['G'])
639659
carry['params'] = tree.map(lambda x: x[0], pack)
@@ -720,6 +740,8 @@ def update(self):
720740
ixs = jnp.arange(self.num_batches)
721741
bxs = self.x_train.reshape(self.num_batches, self.batch_size, *self.x_train.shape[1:])
722742
bys = self.y_train.reshape(self.num_batches, self.batch_size, *self.y_train.shape[1:])
743+
subkeys = random.split(self.key, self.num_batches + 1)
744+
self.key, subkeys = subkeys[0], subkeys[1:] # update self.key & get subkeys
723745

724746
def one_batch(carry, ix):
725747

@@ -730,7 +752,9 @@ def rmsprop(d_w, w, g):
730752

731753
bx = bxs[ix]
732754
by = bys[ix]
733-
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by)
755+
kkey = subkeys[ix]
756+
757+
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by, kkey)
734758

735759
pack = tree.map(rmsprop, d_params, carry['params'], carry['G'])
736760
carry['params'] = tree.map(lambda x: x[0], pack)
@@ -817,6 +841,8 @@ def update(self):
817841
ixs = jnp.arange(self.num_batches)
818842
bxs = self.x_train.reshape(self.num_batches, self.batch_size, *self.x_train.shape[1:])
819843
bys = self.y_train.reshape(self.num_batches, self.batch_size, *self.y_train.shape[1:])
844+
subkeys = random.split(self.key, self.num_batches + 1)
845+
self.key, subkeys = subkeys[0], subkeys[1:] # update self.key & get subkeys
820846

821847
def one_batch(carry, ix):
822848

@@ -829,7 +855,9 @@ def adadelta(d_w, w, e_g2, e_dx2):
829855

830856
bx = bxs[ix]
831857
by = bys[ix]
832-
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by)
858+
kkey = subkeys[ix]
859+
860+
d_params = grad(self._loss, argnums=0)(carry['params'], bx, by, kkey)
833861

834862
pack = tree.map(adadelta, d_params, carry['params'], carry['E_g2'], carry['E_dx2'])
835863
carry['params'] = tree.map(lambda x: x[0], pack)

0 commit comments

Comments
 (0)