Skip to content

Commit b44a151

Browse files
committed
Support of the argument stride for locally_connected2d_layer
1 parent c406b42 commit b44a151

File tree

3 files changed

+38
-19
lines changed

3 files changed

+38
-19
lines changed

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,14 @@ module function conv2d(filters, kernel_width, kernel_height, activation) result(
8080

8181
end function conv2d
8282

83-
module function locally_connected2d(filters, kernel_size, activation) result(res)
83+
module function locally_connected2d(filters, kernel_size, activation, stride) result(res)
8484
integer, intent(in) :: filters
8585
integer, intent(in) :: kernel_size
8686
class(activation_function), intent(in), optional :: activation
87+
integer, intent(in), optional :: stride
8788
type(layer) :: res
8889

90+
integer :: stride_tmp
8991
class(activation_function), allocatable :: activation_tmp
9092

9193
res % name = 'locally_connected2d'
@@ -98,9 +100,18 @@ module function locally_connected2d(filters, kernel_size, activation) result(res
98100

99101
res % activation = activation_tmp % get_name()
100102

103+
if (present(stride)) then
104+
stride_tmp = stride
105+
else
106+
stride_tmp = 1
107+
endif
108+
109+
if (stride_tmp < 1) &
110+
error stop 'stride must be >= 1 in a conv1d layer'
111+
101112
allocate( &
102113
res % p, &
103-
source=locally_connected2d_layer(filters, kernel_size, activation_tmp) &
114+
source=locally_connected2d_layer(filters, kernel_size, activation_tmp, stride_tmp) &
104115
)
105116

106117
end function locally_connected2d

src/nf/nf_locally_connected2d_layer.f90

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ module nf_locally_connected2d_layer
1515
integer :: channels
1616
integer :: kernel_size
1717
integer :: filters
18+
integer :: stride
1819

1920
real, allocatable :: biases(:,:) ! size(filters)
2021
real, allocatable :: kernel(:,:,:,:) ! filters x channels x window x window
@@ -40,12 +41,13 @@ module nf_locally_connected2d_layer
4041
end type locally_connected2d_layer
4142

4243
interface locally_connected2d_layer
43-
module function locally_connected2d_layer_cons(filters, kernel_size, activation) &
44+
module function locally_connected2d_layer_cons(filters, kernel_size, activation, stride) &
4445
result(res)
4546
!! `locally_connected2d_layer` constructor function
4647
integer, intent(in) :: filters
4748
integer, intent(in) :: kernel_size
4849
class(activation_function), intent(in) :: activation
50+
integer, intent(in) :: stride
4951
type(locally_connected2d_layer) :: res
5052
end function locally_connected2d_layer_cons
5153
end interface locally_connected2d_layer
@@ -91,7 +93,9 @@ end function get_num_params
9193
module subroutine get_params_ptr(self, w_ptr, b_ptr)
9294
class(locally_connected2d_layer), intent(in), target :: self
9395
real, pointer, intent(out) :: w_ptr(:)
96+
!! Pointer to the kernel weights (flattened)
9497
real, pointer, intent(out) :: b_ptr(:)
98+
!! Pointer to the biases
9599
end subroutine get_params_ptr
96100

97101
module function get_gradients(self) result(gradients)
@@ -106,7 +110,9 @@ end function get_gradients
106110
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
107111
class(locally_connected2d_layer), intent(in), target :: self
108112
real, pointer, intent(out) :: dw_ptr(:)
113+
!! Pointer to the kernel weight gradients (flattened)
109114
real, pointer, intent(out) :: db_ptr(:)
115+
!! Pointer to the bias gradients
110116
end subroutine get_gradients_ptr
111117

112118
end interface

src/nf/nf_locally_connected2d_layer_submodule.f90

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,17 @@
77

88
contains
99

10-
module function locally_connected2d_layer_cons(filters, kernel_size, activation) result(res)
11-
implicit none
10+
module function locally_connected2d_layer_cons(filters, kernel_size, activation, stride) result(res)
1211
integer, intent(in) :: filters
1312
integer, intent(in) :: kernel_size
1413
class(activation_function), intent(in) :: activation
14+
integer, intent(in) :: stride
1515
type(locally_connected2d_layer) :: res
1616

1717
res % kernel_size = kernel_size
1818
res % filters = filters
1919
res % activation_name = activation % get_name()
20+
res % stride = stride
2021
allocate(res % activation, source = activation)
2122
end function locally_connected2d_layer_cons
2223

@@ -26,8 +27,11 @@ module subroutine init(self, input_shape)
2627
integer, intent(in) :: input_shape(:)
2728

2829
self % channels = input_shape(1)
29-
self % width = input_shape(2) - self % kernel_size + 1
30+
self % width = (input_shape(2) - self % kernel_size) / self % stride +1
31+
32+
if (mod(input_shape(2) - self % kernel_size , self % stride) /= 0) self % width = self % width + 1
3033

34+
! Output of shape: filters x width
3135
allocate(self % output(self % filters, self % width))
3236
self % output = 0
3337

@@ -63,10 +67,10 @@ pure module subroutine forward(self, input)
6367
input_width = size(input, dim=2)
6468

6569
do j = 1, self % width
66-
iws = j
67-
iwe = j + self % kernel_size - 1
70+
iws = self % stride * (j-1) + 1
71+
iwe = min(iws + self % kernel_size - 1, input_width)
6872
do n = 1, self % filters
69-
self % z(n, j) = sum(self % kernel(n, j, :, :) * input(:, iws:iwe)) + self % biases(n, j)
73+
self % z(n, j) = sum(self % kernel(n, j, :, 1:iwe-iws+1) * input(:, iws:iwe)) + self % biases(n, j)
7074
end do
7175
end do
7276
self % output = self % activation % eval(self % z)
@@ -77,7 +81,7 @@ pure module subroutine backward(self, input, gradient)
7781
class(locally_connected2d_layer), intent(in out) :: self
7882
real, intent(in) :: input(:,:)
7983
real, intent(in) :: gradient(:,:)
80-
integer :: input_channels, input_width, output_width
84+
integer :: input_channels, input_width
8185
integer :: j, n, k
8286
integer :: iws, iwe
8387
real :: gdz(self % filters, self % width)
@@ -86,14 +90,13 @@ pure module subroutine backward(self, input, gradient)
8690

8791
input_channels = size(input, dim=1)
8892
input_width = size(input, dim=2)
89-
output_width = self % width
9093

91-
do j = 1, output_width
94+
do j = 1, self % width
9295
gdz(:, j) = gradient(:, j) * self % activation % eval_prime(self % z(:, j))
9396
end do
9497

9598
do n = 1, self % filters
96-
do j = 1, output_width
99+
do j = 1, self % width
97100
db_local(n, j) = gdz(n, j)
98101
end do
99102
end do
@@ -102,12 +105,12 @@ pure module subroutine backward(self, input, gradient)
102105
self % gradient = 0.0
103106

104107
do n = 1, self % filters
105-
do j = 1, output_width
106-
iws = j
107-
iwe = j + self % kernel_size - 1
108+
do j = 1, self % width
109+
iws = self % stride * (j-1) + 1
110+
iwe = min(iws + self % kernel_size - 1, input_width)
108111
do k = 1, self % channels
109-
dw_local(n, j, k, :) = dw_local(n, j, k, :) + input(k, iws:iwe) * gdz(n, j)
110-
self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, j, k, :) * gdz(n, j)
112+
dw_local(n, j, k, 1:iwe-iws+1) = dw_local(n, j, k, 1:iwe-iws+1) + input(k, iws:iwe) * gdz(n, j)
113+
self % gradient(k, iws:iwe) = self % gradient(k, iws:iwe) + self % kernel(n, j, k, 1:iwe-iws+1) * gdz(n, j)
111114
end do
112115
end do
113116
end do
@@ -144,5 +147,4 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
144147
db_ptr(1:size(self % db)) => self % db
145148
end subroutine get_gradients_ptr
146149

147-
148150
end submodule nf_locally_connected2d_layer_submodule

0 commit comments

Comments
 (0)