Skip to content

Commit 0a3dfc6

Browse files
committed
Success reporing consistent with existing tests
1 parent dbf6e83 commit 0a3dfc6

File tree

3 files changed

+43
-45
lines changed

3 files changed

+43
-45
lines changed

test/test_dense_layer.f90

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
program test_dense_layer
2-
use iso_fortran_env, only: stderr => error_unit
32
use nf, only: dense, layer, relu
43
use tuff, only: test, test_result
54
implicit none
@@ -11,7 +10,7 @@ program test_dense_layer
1110
layer3 = dense(20)
1211
call layer3 % init(layer1)
1312

14-
tests = test("Dense layer", [ &
13+
tests = test("test_dense_layer", [ &
1514
test("layer name is set", layer1 % name == 'dense'), &
1615
test("layer shape is correct", all(layer1 % layer_shape == [10])), &
1716
test("layer is initialized", layer3 % initialized), &

test/test_dense_network.f90

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,40 @@
11
program test_dense_network
22
use iso_fortran_env, only: stderr => error_unit
3-
use nf, only: dense, input, network
4-
use nf_optimizers, only: sgd
3+
use nf, only: dense, input, network, sgd
4+
use tuff, only: test, test_result
55
implicit none
66
type(network) :: net
7-
logical :: ok = .true.
7+
type(test_result) :: tests
88

99
! Minimal 2-layer network
1010
net = network([ &
1111
input(1), &
1212
dense(1) &
1313
])
1414

15-
if (.not. size(net % layers) == 2) then
16-
write(stderr, '(a)') 'dense network should have 2 layers.. failed'
17-
ok = .false.
18-
end if
15+
tests = test("test_dense_network", [ &
16+
test("network has 2 layers", size(net % layers) == 2), &
17+
test("network predicts 0.5 for input 0", all(net % predict([0.]) == 0.5)), &
18+
test(simple_training), &
19+
test(larger_network_size) &
20+
])
1921

20-
if (.not. all(net % predict([0.]) == 0.5)) then
21-
write(stderr, '(a)') &
22-
'dense network should output exactly 0.5 for input 0.. failed'
23-
ok = .false.
24-
end if
22+
contains
2523

26-
training: block
24+
type(test_result) function simple_training() result(res)
2725
real :: x(1), y(1)
2826
real :: tolerance = 1e-3
2927
integer :: n
30-
integer, parameter :: num_iterations = 1000
28+
integer, parameter :: num_iterations = 1000
29+
type(network) :: net
30+
31+
res % name = 'simple training'
32+
33+
! Minimal 2-layer network
34+
net = network([ &
35+
input(1), &
36+
dense(1) &
37+
])
3138

3239
x = [0.123]
3340
y = [0.765]
@@ -39,32 +46,25 @@ program test_dense_network
3946
if (all(abs(net % predict(x) - y) < tolerance)) exit
4047
end do
4148

42-
if (.not. n <= num_iterations) then
43-
write(stderr, '(a)') &
44-
'dense network should converge in simple training.. failed'
45-
ok = .false.
46-
end if
49+
res % ok = n <= num_iterations
4750

48-
end block training
51+
end function simple_training
4952

50-
! A bit larger multi-layer network
51-
net = network([ &
52-
input(784), &
53-
dense(30), &
54-
dense(20), &
55-
dense(10) &
56-
])
53+
type(test_result) function larger_network_size() result(res)
54+
type(network) :: net
55+
56+
res % name = 'larger network training'
57+
58+
! A bit larger multi-layer network
59+
net = network([ &
60+
input(784), &
61+
dense(30), &
62+
dense(20), &
63+
dense(10) &
64+
])
5765

58-
if (.not. size(net % layers) == 4) then
59-
write(stderr, '(a)') 'dense network should have 4 layers.. failed'
60-
ok = .false.
61-
end if
66+
res % ok = size(net % layers) == 4
6267

63-
if (ok) then
64-
print '(a)', 'test_dense_network: All tests passed.'
65-
else
66-
write(stderr, '(a)') 'test_dense_network: One or more tests failed.'
67-
stop 1
68-
end if
68+
end function larger_network_size
6969

70-
end program test_dense_network
70+
end program test_dense_network

test/tuff.f90

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ module tuff
1313
end type test_result
1414

1515
interface test
16-
module procedure test_logical
17-
module procedure test_func
18-
module procedure test_array
16+
module procedure test_logical, test_func, test_array
1917
end interface test
2018

2119
abstract interface
@@ -64,8 +62,9 @@ type(test_result) function test_array(name, tests) result(suite)
6462
type(test_result), intent(in) :: tests(:)
6563
suite % ok = all(tests % ok)
6664
suite % elapsed = sum(tests % elapsed)
67-
if (.not. suite % ok) then
68-
! Report to stderr only on failure.
65+
if (suite % ok) then
66+
write(stdout, '(a)') trim(name) // ": All tests passed."
67+
else
6968
write(stderr, '(i0,a,i0,a)') count(.not. tests % ok), '/', size(tests), &
7069
" tests failed in suite: " // trim(name)
7170
end if

0 commit comments

Comments
 (0)