@@ -155,6 +155,32 @@ pure subroutine minimize_sgd_1d(self, param, gradient)
155155 end subroutine minimize_sgd_1d
156156
157157
158+ pure subroutine minimize_sgd_2d (self , param , gradient )
159+ ! ! Concrete implementation of a stochastic gradient descent optimizer
160+ ! ! update rule for 2D arrays.
161+ class(sgd), intent (inout ) :: self
162+ real , intent (inout ) :: param(:,:)
163+ real , intent (in ) :: gradient(:,:)
164+
165+ if (self % momentum > 0 ) then
166+ ! Apply momentum update
167+ self % velocity = self % momentum * self % velocity &
168+ - self % learning_rate * reshape (gradient, [size (gradient)])
169+ if (self % nesterov) then
170+ ! Apply Nesterov update
171+ param = param + reshape (self % momentum * self % velocity &
172+ - self % learning_rate * reshape (gradient, [size (gradient)]), shape (param))
173+ else
174+ param = param + reshape (self % velocity, shape (param))
175+ end if
176+ else
177+ ! Apply regular update
178+ param = param - self % learning_rate * gradient
179+ end if
180+
181+ end subroutine minimize_sgd_2d
182+
183+
158184 impure elemental subroutine init_rmsprop(self, num_params)
159185 class(rmsprop), intent (inout ) :: self
160186 integer , intent (in ) :: num_params
@@ -182,6 +208,23 @@ pure subroutine minimize_rmsprop_1d(self, param, gradient)
182208 end subroutine minimize_rmsprop_1d
183209
184210
211+ pure subroutine minimize_rmsprop_2d (self , param , gradient )
212+ ! ! Concrete implementation of a RMSProp optimizer update rule for 2D arrays.
213+ class(rmsprop), intent (inout ) :: self
214+ real , intent (inout ) :: param(:,:)
215+ real , intent (in ) :: gradient(:,:)
216+
217+ ! Compute the RMS of the gradient using the RMSProp rule
218+ self % rms_gradient = self % decay_rate * self % rms_gradient &
219+ + (1 - self % decay_rate) * reshape (gradient, [size (gradient)])** 2
220+
221+ ! Update the network parameters based on the new RMS of the gradient
222+ param = param - self % learning_rate &
223+ / sqrt (reshape (self % rms_gradient, shape (param)) + self % epsilon) * gradient
224+
225+ end subroutine minimize_rmsprop_2d
226+
227+
185228 impure elemental subroutine init_adam(self, num_params)
186229 class(adam), intent (inout ) :: self
187230 integer , intent (in ) :: num_params
@@ -224,6 +267,37 @@ pure subroutine minimize_adam_1d(self, param, gradient)
224267 end subroutine minimize_adam_1d
225268
226269
270+ pure subroutine minimize_adam_2d (self , param , gradient )
271+ ! ! Concrete implementation of an Adam optimizer update rule for 2D arrays.
272+ class(adam), intent (inout ) :: self
273+ real , intent (inout ) :: param(:,:)
274+ real , intent (in ) :: gradient(:,:)
275+
276+ self % t = self % t + 1
277+
278+ ! If weight_decay_l2 > 0, use L2 regularization;
279+ ! otherwise, default to regular Adam.
280+ associate(g = > reshape (gradient, [size (gradient)]) + self % weight_decay_l2 * reshape (param, [size (param)]))
281+ self % m = self % beta1 * self % m + (1 - self % beta1) * g
282+ self % v = self % beta2 * self % v + (1 - self % beta2) * g** 2
283+ end associate
284+
285+ ! Compute bias-corrected first and second moment estimates.
286+ associate( &
287+ m_hat = > self % m / (1 - self % beta1** self % t), &
288+ v_hat = > self % v / (1 - self % beta2** self % t) &
289+ )
290+
291+ ! Update parameters.
292+ param = param &
293+ - self % learning_rate * reshape (m_hat / (sqrt (v_hat) + self % epsilon), shape (param)) &
294+ - self % learning_rate * self % weight_decay_decoupled * param
295+
296+ end associate
297+
298+ end subroutine minimize_adam_2d
299+
300+
227301 impure elemental subroutine init_adagrad(self, num_params)
228302 class(adagrad), intent (inout ) :: self
229303 integer , intent (in ) :: num_params
@@ -262,80 +336,6 @@ pure subroutine minimize_adagrad_1d(self, param, gradient)
262336 end subroutine minimize_adagrad_1d
263337
264338
265- pure subroutine minimize_sgd_2d (self , param , gradient )
266- ! ! Concrete implementation of a stochastic gradient descent optimizer
267- ! ! update rule for 2D arrays.
268- class(sgd), intent (inout ) :: self
269- real , intent (inout ) :: param(:,:)
270- real , intent (in ) :: gradient(:,:)
271-
272- if (self % momentum > 0 ) then
273- ! Apply momentum update
274- self % velocity = self % momentum * self % velocity &
275- - self % learning_rate * reshape (gradient, [size (gradient)])
276- if (self % nesterov) then
277- ! Apply Nesterov update
278- param = param + reshape (self % momentum * self % velocity &
279- - self % learning_rate * reshape (gradient, [size (gradient)]), shape (param))
280- else
281- param = param + reshape (self % velocity, shape (param))
282- end if
283- else
284- ! Apply regular update
285- param = param - self % learning_rate * gradient
286- end if
287-
288- end subroutine minimize_sgd_2d
289-
290-
291- pure subroutine minimize_rmsprop_2d (self , param , gradient )
292- ! ! Concrete implementation of a RMSProp optimizer update rule for 2D arrays.
293- class(rmsprop), intent (inout ) :: self
294- real , intent (inout ) :: param(:,:)
295- real , intent (in ) :: gradient(:,:)
296-
297- ! Compute the RMS of the gradient using the RMSProp rule
298- self % rms_gradient = self % decay_rate * self % rms_gradient &
299- + (1 - self % decay_rate) * reshape (gradient, [size (gradient)])** 2
300-
301- ! Update the network parameters based on the new RMS of the gradient
302- param = param - self % learning_rate &
303- / sqrt (reshape (self % rms_gradient, shape (param)) + self % epsilon) * gradient
304-
305- end subroutine minimize_rmsprop_2d
306-
307-
308- pure subroutine minimize_adam_2d (self , param , gradient )
309- ! ! Concrete implementation of an Adam optimizer update rule for 2D arrays.
310- class(adam), intent (inout ) :: self
311- real , intent (inout ) :: param(:,:)
312- real , intent (in ) :: gradient(:,:)
313-
314- self % t = self % t + 1
315-
316- ! If weight_decay_l2 > 0, use L2 regularization;
317- ! otherwise, default to regular Adam.
318- associate(g = > reshape (gradient, [size (gradient)]) + self % weight_decay_l2 * reshape (param, [size (param)]))
319- self % m = self % beta1 * self % m + (1 - self % beta1) * g
320- self % v = self % beta2 * self % v + (1 - self % beta2) * g** 2
321- end associate
322-
323- ! Compute bias-corrected first and second moment estimates.
324- associate( &
325- m_hat = > self % m / (1 - self % beta1** self % t), &
326- v_hat = > self % v / (1 - self % beta2** self % t) &
327- )
328-
329- ! Update parameters.
330- param = param &
331- - self % learning_rate * reshape (m_hat / (sqrt (v_hat) + self % epsilon), shape (param)) &
332- - self % learning_rate * self % weight_decay_decoupled * param
333-
334- end associate
335-
336- end subroutine minimize_adam_2d
337-
338-
339339 pure subroutine minimize_adagrad_2d (self , param , gradient )
340340 ! ! Concrete implementation of an Adagrad optimizer update rule for 2D arrays.
341341 class(adagrad), intent (inout ) :: self
@@ -363,4 +363,4 @@ pure subroutine minimize_adagrad_2d(self, param, gradient)
363363
364364 end subroutine minimize_adagrad_2d
365365
366- end module nf_optimizers
366+ end module nf_optimizers
0 commit comments