@@ -11,9 +11,9 @@ train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...)
11
11
function train! (loss, backend, model, ps, st, data; epochs= 10 )
12
12
l1 = loss (model, ps, st, first (data))
13
13
14
- tstate = Lux . Experimental . TrainState (model, ps, st, Adam (0.01f0 ))
14
+ tstate = Training . TrainState (model, ps, st, Adam (0.01f0 ))
15
15
for _ in 1 : epochs, (x, y) in data
16
- _, _, _, tstate = Lux . Experimental . single_train_step! (backend, loss, (x, y), tstate)
16
+ _, _, _, tstate = Training . single_train_step! (backend, loss, (x, y), tstate)
17
17
end
18
18
19
19
l2 = loss (model, ps, st, first (data))
25
25
n_points = 128
26
26
batch_size = 64
27
27
28
- x = rand (Float32, 1 , n_points , batch_size);
29
- y = rand (Float32, 1 , n_points , batch_size);
28
+ x = rand (Float32, n_points, 1 , batch_size);
29
+ y = rand (Float32, n_points, 1 , batch_size);
30
30
data = [(x, y)];
31
31
t_fwd = zeros (5 )
32
32
t_train = zeros (5 )
33
33
for i in 1 : 5
34
34
chs = (1 , 128 , fill (64 , i)... , 128 , 1 )
35
- model = FourierNeuralOperator (gelu; chs= chs , modes= (16 ,))
35
+ model = FourierNeuralOperator (gelu; chs, modes= (16 ,), permuted = Val ( true ))
36
36
ps, st = Lux. setup (rng, model)
37
37
model (x, ps, st) # TTFX
38
38
0 commit comments