11program test_dense_network
22 use iso_fortran_env, only: stderr = > error_unit
3- use nf, only: dense, input, network
4- use nf_optimizers , only: sgd
3+ use nf, only: dense, input, network, sgd
4+ use tuff , only: test, test_result
55 implicit none
66 type (network) :: net
7- logical :: ok = .true.
7+ type (test_result) :: tests
88
99 ! Minimal 2-layer network
1010 net = network([ &
1111 input(1 ), &
1212 dense(1 ) &
1313 ])
1414
15- if (.not. size (net % layers) == 2 ) then
16- write (stderr, ' (a)' ) ' dense network should have 2 layers.. failed'
17- ok = .false.
18- end if
15+ tests = test(" test_dense_network" , [ &
16+ test(" network has 2 layers" , size (net % layers) == 2 ), &
17+ test(" network predicts 0.5 for input 0" , all (net % predict([0 .]) == 0.5 )), &
18+ test(simple_training), &
19+ test(larger_network_size) &
20+ ])
1921
20- if (.not. all (net % predict([0 .]) == 0.5 )) then
21- write (stderr, ' (a)' ) &
22- ' dense network should output exactly 0.5 for input 0.. failed'
23- ok = .false.
24- end if
22+ contains
2523
26- training: block
24+ type (test_result) function simple_training() result(res)
2725 real :: x(1 ), y(1 )
2826 real :: tolerance = 1e-3
2927 integer :: n
30- integer , parameter :: num_iterations = 1000
28+ integer , parameter :: num_iterations = 1000
29+ type (network) :: net
30+
31+ res % name = ' simple training'
32+
33+ ! Minimal 2-layer network
34+ net = network([ &
35+ input(1 ), &
36+ dense(1 ) &
37+ ])
3138
3239 x = [0.123 ]
3340 y = [0.765 ]
@@ -39,32 +46,25 @@ program test_dense_network
3946 if (all (abs (net % predict(x) - y) < tolerance)) exit
4047 end do
4148
42- if (.not. n <= num_iterations) then
43- write (stderr, ' (a)' ) &
44- ' dense network should converge in simple training.. failed'
45- ok = .false.
46- end if
49+ res % ok = n <= num_iterations
4750
48- end block training
51+ end function simple_training
4952
50- ! A bit larger multi-layer network
51- net = network([ &
52- input(784 ), &
53- dense(30 ), &
54- dense(20 ), &
55- dense(10 ) &
56- ])
53+ type (test_result) function larger_network_size() result(res)
54+ type (network) :: net
55+
56+ res % name = ' larger network training'
57+
58+ ! A bit larger multi-layer network
59+ net = network([ &
60+ input(784 ), &
61+ dense(30 ), &
62+ dense(20 ), &
63+ dense(10 ) &
64+ ])
5765
58- if (.not. size (net % layers) == 4 ) then
59- write (stderr, ' (a)' ) ' dense network should have 4 layers.. failed'
60- ok = .false.
61- end if
66+ res % ok = size (net % layers) == 4
6267
63- if (ok) then
64- print ' (a)' , ' test_dense_network: All tests passed.'
65- else
66- write (stderr, ' (a)' ) ' test_dense_network: One or more tests failed.'
67- stop 1
68- end if
68+ end function larger_network_size
6969
70- end program test_dense_network
70+ end program test_dense_network
0 commit comments