@@ -19,7 +19,9 @@ module nf_optimizers
1919 real :: learning_rate = 0.01
2020 contains
2121 procedure (init), deferred :: init
22- procedure (minimize), deferred :: minimize
22+ procedure (minimize_1d), deferred :: minimize_1d
23+ procedure (minimize_2d), deferred :: minimize_2d
24+ generic :: minimize = > minimize_1d, minimize_2d
2325 end type optimizer_base_type
2426
2527 abstract interface
@@ -30,13 +32,19 @@ impure elemental subroutine init(self, num_params)
3032 integer , intent (in ) :: num_params
3133 end subroutine init
3234
33- pure subroutine minimize (self , weights , biases , gradient )
35+ pure subroutine minimize_1d (self , param , gradient )
3436 import :: optimizer_base_type
3537 class(optimizer_base_type), intent (inout ) :: self
36- real , intent (inout ), pointer :: weights(:)
37- real , intent (inout ), pointer :: biases(:)
38- real , intent (in ), pointer :: gradient(:)
39- end subroutine minimize
38+ real , intent (inout ) :: param(:)
39+ real , intent (in ) :: gradient(:)
40+ end subroutine minimize_1d
41+
42+ pure subroutine minimize_2d (self , param , gradient )
43+ import :: optimizer_base_type
44+ class(optimizer_base_type), intent (inout ) :: self
45+ real , intent (inout ) :: param(:,:)
46+ real , intent (in ) :: gradient(:,:)
47+ end subroutine minimize_2d
4048
4149 end interface
4250
@@ -47,7 +55,8 @@ end subroutine minimize
4755 real , allocatable , private :: velocity(:)
4856 contains
4957 procedure :: init = > init_sgd
50- procedure :: minimize = > minimize_sgd
58+ procedure :: minimize_1d = > minimize_sgd_1d
59+ procedure :: minimize_2d = > minimize_sgd_2d
5160 end type sgd
5261
5362 type, extends(optimizer_base_type) :: rmsprop
@@ -62,7 +71,8 @@ end subroutine minimize
6271 real , allocatable , private :: rms_gradient(:)
6372 contains
6473 procedure :: init = > init_rmsprop
65- procedure :: minimize = > minimize_rmsprop
74+ procedure :: minimize_1d = > minimize_rmsprop_1d
75+ procedure :: minimize_2d = > minimize_rmsprop_2d
6676 end type rmsprop
6777
6878 type, extends(optimizer_base_type) :: adam
@@ -85,7 +95,8 @@ end subroutine minimize
8595 integer , private :: t = 0
8696 contains
8797 procedure :: init = > init_adam
88- procedure :: minimize = > minimize_adam
98+ procedure :: minimize_1d = > minimize_adam_1d
99+ procedure :: minimize_2d = > minimize_adam_2d
89100 end type adam
90101
91102 type, extends(optimizer_base_type) :: adagrad
@@ -102,7 +113,8 @@ end subroutine minimize
102113 integer , private :: t = 0
103114 contains
104115 procedure :: init = > init_adagrad
105- procedure :: minimize = > minimize_adagrad
116+ procedure :: minimize_1d = > minimize_adagrad_1d
117+ procedure :: minimize_2d = > minimize_adagrad_2d
106118 end type adagrad
107119
108120contains
@@ -117,35 +129,30 @@ impure elemental subroutine init_sgd(self, num_params)
117129 end subroutine init_sgd
118130
119131
120- pure subroutine minimize_sgd (self , weights , biases , gradient )
132+ pure subroutine minimize_sgd_1d (self , param , gradient )
121133 ! ! Concrete implementation of a stochastic gradient descent optimizer
122134 ! ! update rule.
123135 class(sgd), intent (inout ) :: self
124- real , intent (inout ), pointer :: weights(:)
125- real , intent (inout ), pointer :: biases(:)
126- real , intent (in ), pointer :: gradient(:)
136+ real , intent (inout ) :: param(:)
137+ real , intent (in ) :: gradient(:)
127138
128139 if (self % momentum > 0 ) then
129140 ! Apply momentum update
130141 self % velocity = self % momentum * self % velocity &
131142 - self % learning_rate * gradient
132143 if (self % nesterov) then
133144 ! Apply Nesterov update
134- weights = weights + self % momentum * self % velocity &
135- - self % learning_rate * gradient
136- biases = biases + self % momentum * self % velocity &
145+ param = param + self % momentum * self % velocity &
137146 - self % learning_rate * gradient
138147 else
139- weights = weights + self % velocity
140- biases = biases + self % velocity
148+ param = param + self % velocity
141149 end if
142150 else
143151 ! Apply regular update
144- weights = weights - self % learning_rate * gradient
145- biases = biases - self % learning_rate * gradient
152+ param = param - self % learning_rate * gradient
146153 end if
147154
148- end subroutine minimize_sgd
155+ end subroutine minimize_sgd_1d
149156
150157
151158 impure elemental subroutine init_rmsprop(self, num_params)
@@ -158,24 +165,21 @@ impure elemental subroutine init_rmsprop(self, num_params)
158165 end subroutine init_rmsprop
159166
160167
161- pure subroutine minimize_rmsprop (self , weights , biases , gradient )
168+ pure subroutine minimize_rmsprop_1d (self , param , gradient )
162169 ! ! Concrete implementation of a RMSProp optimizer update rule.
163170 class(rmsprop), intent (inout ) :: self
164- real , intent (inout ), pointer :: weights(:)
165- real , intent (inout ), pointer :: biases(:)
166- real , intent (in ), pointer :: gradient(:)
171+ real , intent (inout ) :: param(:)
172+ real , intent (in ) :: gradient(:)
167173
168174 ! Compute the RMS of the gradient using the RMSProp rule
169175 self % rms_gradient = self % decay_rate * self % rms_gradient &
170176 + (1 - self % decay_rate) * gradient** 2
171177
172178 ! Update the network parameters based on the new RMS of the gradient
173- weights = weights - self % learning_rate &
174- / sqrt (self % rms_gradient + self % epsilon) * gradient
175- biases = biases - self % learning_rate &
179+ param = param - self % learning_rate &
176180 / sqrt (self % rms_gradient + self % epsilon) * gradient
177181
178- end subroutine minimize_rmsprop
182+ end subroutine minimize_rmsprop_1d
179183
180184
181185 impure elemental subroutine init_adam(self, num_params)
@@ -189,18 +193,17 @@ impure elemental subroutine init_adam(self, num_params)
189193 end subroutine init_adam
190194
191195
192- pure subroutine minimize_adam (self , weights , biases , gradient )
196+ pure subroutine minimize_adam_1d (self , param , gradient )
193197 ! ! Concrete implementation of an Adam optimizer update rule.
194198 class(adam), intent (inout ) :: self
195- real , intent (inout ), pointer :: weights(:)
196- real , intent (inout ), pointer :: biases(:)
197- real , intent (in ), pointer :: gradient(:)
199+ real , intent (inout ) :: param(:)
200+ real , intent (in ) :: gradient(:)
198201
199202 self % t = self % t + 1
200203
201204 ! If weight_decay_l2 > 0, use L2 regularization;
202205 ! otherwise, default to regular Adam.
203- associate(g = > gradient + self % weight_decay_l2 * weights )
206+ associate(g = > gradient + self % weight_decay_l2 * param )
204207 self % m = self % beta1 * self % m + (1 - self % beta1) * g
205208 self % v = self % beta2 * self % v + (1 - self % beta2) * g** 2
206209 end associate
@@ -212,19 +215,13 @@ pure subroutine minimize_adam(self, weights, biases, gradient)
212215 )
213216
214217 ! Update parameters.
215- weights = weights &
218+ param = param &
216219 - self % learning_rate * (m_hat / (sqrt (v_hat) + self % epsilon) &
217- + self % weight_decay_decoupled * weights)
218-
219- ! Update biases (without weight decay for biases)
220- associate(g = > gradient)
221- biases = biases &
222- - self % learning_rate * (m_hat / (sqrt (v_hat) + self % epsilon))
223- end associate
220+ + self % weight_decay_decoupled * param)
224221
225222 end associate
226223
227- end subroutine minimize_adam
224+ end subroutine minimize_adam_1d
228225
229226
230227 impure elemental subroutine init_adagrad(self, num_params)
@@ -237,43 +234,133 @@ impure elemental subroutine init_adagrad(self, num_params)
237234 end subroutine init_adagrad
238235
239236
240- pure subroutine minimize_adagrad (self , weights , biases , gradient )
237+ pure subroutine minimize_adagrad_1d (self , param , gradient )
241238 ! ! Concrete implementation of an Adagrad optimizer update rule.
242239 class(adagrad), intent (inout ) :: self
243- real , intent (inout ), pointer :: weights(:)
244- real , intent (inout ), pointer :: biases(:)
245- real , intent (in ), pointer :: gradient(:)
240+ real , intent (inout ) :: param(:)
241+ real , intent (in ) :: gradient(:)
246242
247243 ! Update the current time step
248244 self % t = self % t + 1
249245
250- ! For weights
251246 associate( &
252247 ! If weight_decay_l2 > 0, use L2 regularization;
253248 ! otherwise, default to regular Adagrad.
254- g = > gradient + self % weight_decay_l2 * weights , &
249+ g = > gradient + self % weight_decay_l2 * param , &
255250 ! Amortize the learning rate as function of the current time step.
256251 learning_rate = > self % learning_rate &
257252 / (1 + (self % t - 1 ) * self % learning_rate_decay) &
258253 )
259254
260255 self % sum_squared_gradient = self % sum_squared_gradient + g** 2
261256
262- weights = weights - learning_rate * g / (sqrt (self % sum_squared_gradient) &
257+ param = param - learning_rate * g / (sqrt (self % sum_squared_gradient) &
263258 + self % epsilon)
264259
265260 end associate
266-
267- ! For biases (without weight decay)
261+
262+ end subroutine minimize_adagrad_1d
263+
264+
265+ pure subroutine minimize_sgd_2d (self , param , gradient )
266+ ! ! Concrete implementation of a stochastic gradient descent optimizer
267+ ! ! update rule for 2D arrays.
268+ class(sgd), intent (inout ) :: self
269+ real , intent (inout ) :: param(:,:)
270+ real , intent (in ) :: gradient(:,:)
271+
272+ if (self % momentum > 0 ) then
273+ ! Apply momentum update
274+ self % velocity = self % momentum * self % velocity &
275+ - self % learning_rate * reshape (gradient, [size (gradient)])
276+ if (self % nesterov) then
277+ ! Apply Nesterov update
278+ param = param + reshape (self % momentum * self % velocity &
279+ - self % learning_rate * reshape (gradient, [size (gradient)]), shape (param))
280+ else
281+ param = param + reshape (self % velocity, shape (param))
282+ end if
283+ else
284+ ! Apply regular update
285+ param = param - self % learning_rate * gradient
286+ end if
287+
288+ end subroutine minimize_sgd_2d
289+
290+
291+ pure subroutine minimize_rmsprop_2d (self , param , gradient )
292+ ! ! Concrete implementation of a RMSProp optimizer update rule for 2D arrays.
293+ class(rmsprop), intent (inout ) :: self
294+ real , intent (inout ) :: param(:,:)
295+ real , intent (in ) :: gradient(:,:)
296+
297+ ! Compute the RMS of the gradient using the RMSProp rule
298+ self % rms_gradient = self % decay_rate * self % rms_gradient &
299+ + (1 - self % decay_rate) * reshape (gradient, [size (gradient)])** 2
300+
301+ ! Update the network parameters based on the new RMS of the gradient
302+ param = param - self % learning_rate &
303+ / sqrt (reshape (self % rms_gradient, shape (param)) + self % epsilon) * gradient
304+
305+ end subroutine minimize_rmsprop_2d
306+
307+
308+ pure subroutine minimize_adam_2d (self , param , gradient )
309+ ! ! Concrete implementation of an Adam optimizer update rule for 2D arrays.
310+ class(adam), intent (inout ) :: self
311+ real , intent (inout ) :: param(:,:)
312+ real , intent (in ) :: gradient(:,:)
313+
314+ self % t = self % t + 1
315+
316+ ! If weight_decay_l2 > 0, use L2 regularization;
317+ ! otherwise, default to regular Adam.
318+ associate(g = > reshape (gradient, [size (gradient)]) + self % weight_decay_l2 * reshape (param, [size (param)]))
319+ self % m = self % beta1 * self % m + (1 - self % beta1) * g
320+ self % v = self % beta2 * self % v + (1 - self % beta2) * g** 2
321+ end associate
322+
323+ ! Compute bias-corrected first and second moment estimates.
324+ associate( &
325+ m_hat = > self % m / (1 - self % beta1** self % t), &
326+ v_hat = > self % v / (1 - self % beta2** self % t) &
327+ )
328+
329+ ! Update parameters.
330+ param = param &
331+ - self % learning_rate * reshape (m_hat / (sqrt (v_hat) + self % epsilon), shape (param)) &
332+ - self % learning_rate * self % weight_decay_decoupled * param
333+
334+ end associate
335+
336+ end subroutine minimize_adam_2d
337+
338+
339+ pure subroutine minimize_adagrad_2d (self , param , gradient )
340+ ! ! Concrete implementation of an Adagrad optimizer update rule for 2D arrays.
341+ class(adagrad), intent (inout ) :: self
342+ real , intent (inout ) :: param(:,:)
343+ real , intent (in ) :: gradient(:,:)
344+
345+ ! Update the current time step
346+ self % t = self % t + 1
347+
268348 associate( &
269- g = > gradient, &
349+ ! If weight_decay_l2 > 0, use L2 regularization;
350+ ! otherwise, default to regular Adagrad.
351+ g = > reshape (gradient, [size (gradient)]) + self % weight_decay_l2 * reshape (param, [size (param)]), &
352+ ! Amortize the learning rate as function of the current time step.
270353 learning_rate = > self % learning_rate &
271354 / (1 + (self % t - 1 ) * self % learning_rate_decay) &
272355 )
273- biases = biases - learning_rate * g / (sqrt (self % sum_squared_gradient) &
274- + self % epsilon)
356+
357+ self % sum_squared_gradient = self % sum_squared_gradient + g** 2
358+
359+ param = param - learning_rate * reshape (g / (sqrt (self % sum_squared_gradient) &
360+ + self % epsilon), shape (param))
361+
275362 end associate
276363
277- end subroutine minimize_adagrad
364+ end subroutine minimize_adagrad_2d
278365
279366end module nf_optimizers
0 commit comments