1
1
'''
2
2
JAX Optimization Algorithms Module
3
3
4
- * Last Updated: 2025-03-09
4
+ * Last Updated: 2025-03-15
5
5
* Author: HugoPhi, [GitHub](https://github.com/HugoPhi)
6
6
7
7
39
39
'''
40
40
41
41
import jax .numpy as jnp
42
- from jax import grad , tree , lax
42
+ from jax import grad , tree , lax , random
43
43
from abc import ABC , abstractmethod
44
44
45
45
@@ -89,19 +89,20 @@ def flash(self):
89
89
'''
90
90
pass
91
91
92
- def open (self , loss_function , x_train : jnp .ndarray , y_train : jnp .ndarray , short_batch = 'drop' ):
92
+ def open (self , loss_function , x_train : jnp .ndarray , y_train : jnp .ndarray , short_batch = 'drop' , key = random . PRNGKey ( 42 ) ):
93
93
'''
94
94
Prepares the optimizer for training by initializing its state and setting up the training data.
95
95
96
96
Args:
97
97
loss_function: A loss function that computes the scalar loss given model parameters,
98
98
input data, and true labels. It must be JIT-compiled.
99
- Signature: `f(params, x, y_true) -> scalar`.
99
+ Signature: `f(params, x, y_true, key=random.PRNGKey(42) ) -> scalar`.
100
100
x_train: Input data for training. Shape: `(num_samples, ...)`.
101
101
y_train: True labels for training. Shape: `(num_samples, ...)`.
102
102
short_batch: The Strategy to handle short batch. including:
103
103
- 'drop': drop short batch, used when: dataset size >> batch size, num_batches = N // B
104
104
- 'pad': append arr[-B:] to trimmed arr, used when: dataset size >~ batch size, num_batches = N // B + 1
105
+ key: A random number generator key used for initialization.
105
106
106
107
Notes:
107
108
- The training data is divided into batches based on the `batch_size` attribute.
@@ -114,6 +115,7 @@ def open(self, loss_function, x_train: jnp.ndarray, y_train: jnp.ndarray, short_
114
115
else :
115
116
self .flash ()
116
117
self ._loss = loss_function
118
+ self .key = key
117
119
118
120
if short_batch == 'drop' :
119
121
self .num_batches = x_train .shape [0 ] // self .batch_size
@@ -247,6 +249,8 @@ def update(self):
247
249
ixs = jnp .arange (self .num_batches )
248
250
bxs = self .x_train .reshape (self .num_batches , self .batch_size , * self .x_train .shape [1 :])
249
251
bys = self .y_train .reshape (self .num_batches , self .batch_size , * self .y_train .shape [1 :])
252
+ subkeys = random .split (self .key , self .num_batches + 1 )
253
+ self .key , subkeys = subkeys [0 ], subkeys [1 :] # update self.key & get subkeys
250
254
251
255
def one_batch (carry , ix ):
252
256
@@ -270,7 +274,8 @@ def adam(d_w, w, v, vv):
270
274
271
275
bx = bxs [ix ]
272
276
by = bys [ix ]
273
- d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by )
277
+ kkey = subkeys [ix ]
278
+ d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by , kkey )
274
279
275
280
pack = tree .map (adam , d_params , carry ['params' ], carry ['V' ], carry ['VV' ]) # use Adam
276
281
carry ['params' ] = tree .map (lambda x : x [0 ], pack )
@@ -346,6 +351,8 @@ def update(self):
346
351
ixs = jnp .arange (self .num_batches )
347
352
bxs = self .x_train .reshape (self .num_batches , self .batch_size , * self .x_train .shape [1 :])
348
353
bys = self .y_train .reshape (self .num_batches , self .batch_size , * self .y_train .shape [1 :])
354
+ subkeys = random .split (self .key , self .num_batches + 1 )
355
+ self .key , subkeys = subkeys [0 ], subkeys [1 :] # update self.key & get subkeys
349
356
350
357
def one_batch (carry , ix ):
351
358
@@ -355,8 +362,9 @@ def gd(d_w, w):
355
362
356
363
bx = bxs [ix ]
357
364
by = bys [ix ]
365
+ kkey = subkeys [ix ]
358
366
359
- d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by )
367
+ d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by , kkey )
360
368
361
369
pack = tree .map (gd , d_params , carry ['params' ])
362
370
carry ['params' ] = pack
@@ -436,6 +444,8 @@ def update(self):
436
444
ixs = jnp .arange (self .num_batches )
437
445
bxs = self .x_train .reshape (self .num_batches , self .batch_size , * self .x_train .shape [1 :])
438
446
bys = self .y_train .reshape (self .num_batches , self .batch_size , * self .y_train .shape [1 :])
447
+ subkeys = random .split (self .key , self .num_batches + 1 )
448
+ self .key , subkeys = subkeys [0 ], subkeys [1 :] # update self.key & get subkeys
439
449
440
450
def one_batch (carry , ix ):
441
451
@@ -446,7 +456,9 @@ def momentum(d_w, w, v):
446
456
447
457
bx = bxs [ix ]
448
458
by = bys [ix ]
449
- d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by )
459
+ kkey = subkeys [ix ]
460
+
461
+ d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by , kkey )
450
462
451
463
pack = tree .map (momentum , d_params , carry ['params' ], carry ['V' ])
452
464
carry ['params' ] = tree .map (lambda x : x [0 ], pack )
@@ -529,6 +541,8 @@ def update(self):
529
541
ixs = jnp .arange (self .num_batches )
530
542
bxs = self .x_train .reshape (self .num_batches , self .batch_size , * self .x_train .shape [1 :])
531
543
bys = self .y_train .reshape (self .num_batches , self .batch_size , * self .y_train .shape [1 :])
544
+ subkeys = random .split (self .key , self .num_batches + 1 )
545
+ self .key , subkeys = subkeys [0 ], subkeys [1 :] # update self.key & get subkeys
532
546
533
547
def one_batch (carry , ix ):
534
548
@@ -541,7 +555,9 @@ def nag(d_w, w, v):
541
555
542
556
bx = bxs [ix ]
543
557
by = bys [ix ]
544
- d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by )
558
+ kkey = subkeys [ix ]
559
+
560
+ d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by , kkey )
545
561
546
562
pack = tree .map (nag , d_params , carry ['params' ], carry ['V' ])
547
563
carry ['params' ] = tree .map (lambda x : x [0 ], pack )
@@ -623,6 +639,8 @@ def update(self):
623
639
ixs = jnp .arange (self .num_batches )
624
640
bxs = self .x_train .reshape (self .num_batches , self .batch_size , * self .x_train .shape [1 :])
625
641
bys = self .y_train .reshape (self .num_batches , self .batch_size , * self .y_train .shape [1 :])
642
+ subkeys = random .split (self .key , self .num_batches + 1 )
643
+ self .key , subkeys = subkeys [0 ], subkeys [1 :] # update self.key & get subkeys
626
644
627
645
def one_batch (carry , ix ):
628
646
@@ -633,7 +651,9 @@ def adagrad(d_w, w, g):
633
651
634
652
bx = bxs [ix ]
635
653
by = bys [ix ]
636
- d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by )
654
+ kkey = subkeys [ix ]
655
+
656
+ d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by , kkey )
637
657
638
658
pack = tree .map (adagrad , d_params , carry ['params' ], carry ['G' ])
639
659
carry ['params' ] = tree .map (lambda x : x [0 ], pack )
@@ -720,6 +740,8 @@ def update(self):
720
740
ixs = jnp .arange (self .num_batches )
721
741
bxs = self .x_train .reshape (self .num_batches , self .batch_size , * self .x_train .shape [1 :])
722
742
bys = self .y_train .reshape (self .num_batches , self .batch_size , * self .y_train .shape [1 :])
743
+ subkeys = random .split (self .key , self .num_batches + 1 )
744
+ self .key , subkeys = subkeys [0 ], subkeys [1 :] # update self.key & get subkeys
723
745
724
746
def one_batch (carry , ix ):
725
747
@@ -730,7 +752,9 @@ def rmsprop(d_w, w, g):
730
752
731
753
bx = bxs [ix ]
732
754
by = bys [ix ]
733
- d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by )
755
+ kkey = subkeys [ix ]
756
+
757
+ d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by , kkey )
734
758
735
759
pack = tree .map (rmsprop , d_params , carry ['params' ], carry ['G' ])
736
760
carry ['params' ] = tree .map (lambda x : x [0 ], pack )
@@ -817,6 +841,8 @@ def update(self):
817
841
ixs = jnp .arange (self .num_batches )
818
842
bxs = self .x_train .reshape (self .num_batches , self .batch_size , * self .x_train .shape [1 :])
819
843
bys = self .y_train .reshape (self .num_batches , self .batch_size , * self .y_train .shape [1 :])
844
+ subkeys = random .split (self .key , self .num_batches + 1 )
845
+ self .key , subkeys = subkeys [0 ], subkeys [1 :] # update self.key & get subkeys
820
846
821
847
def one_batch (carry , ix ):
822
848
@@ -829,7 +855,9 @@ def adadelta(d_w, w, e_g2, e_dx2):
829
855
830
856
bx = bxs [ix ]
831
857
by = bys [ix ]
832
- d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by )
858
+ kkey = subkeys [ix ]
859
+
860
+ d_params = grad (self ._loss , argnums = 0 )(carry ['params' ], bx , by , kkey )
833
861
834
862
pack = tree .map (adadelta , d_params , carry ['params' ], carry ['E_g2' ], carry ['E_dx2' ])
835
863
carry ['params' ] = tree .map (lambda x : x [0 ], pack )
0 commit comments