@@ -33,8 +33,10 @@ module nf_conv2d_layer
3333 procedure :: forward
3434 procedure :: backward
3535 procedure :: get_gradients
36+ procedure :: get_gradients_ptr
3637 procedure :: get_num_params
3738 procedure :: get_params
39+ procedure :: get_params_ptr
3840 procedure :: init
3941 procedure :: set_params
4042
@@ -98,6 +100,16 @@ module function get_params(self) result(params)
98100 ! ! Parameters to get
99101 end function get_params
100102
103+ module subroutine get_params_ptr (self , w_ptr , b_ptr )
104+ ! ! Return pointers to the parameters (weights and biases) of this layer.
105+ class(conv2d_layer), intent (in ), target :: self
106+ ! ! A `conv2d_layer` instance
107+ real , pointer , intent (out ) :: w_ptr(:)
108+ ! ! Pointer to the kernel weights (flattened)
109+ real , pointer , intent (out ) :: b_ptr(:)
110+ ! ! Pointer to the biases
111+ end subroutine get_params_ptr
112+
101113 module function get_gradients (self ) result(gradients)
102114 ! ! Return the gradients of this layer.
103115 ! ! The gradients are ordered as weights first, biases second.
@@ -107,6 +119,16 @@ module function get_gradients(self) result(gradients)
107119 ! ! Gradients to get
108120 end function get_gradients
109121
122+ module subroutine get_gradients_ptr (self , dw_ptr , db_ptr )
123+ ! ! Return pointers to the gradients of this layer.
124+ class(conv2d_layer), intent (in ), target :: self
125+ ! ! A `conv2d_layer` instance
126+ real , pointer , intent (out ) :: dw_ptr(:)
127+ ! ! Pointer to the kernel weight gradients (flattened)
128+ real , pointer , intent (out ) :: db_ptr(:)
129+ ! ! Pointer to the bias gradients
130+ end subroutine get_gradients_ptr
131+
110132 module subroutine set_params (self , params )
111133 ! ! Set the parameters of the layer.
112134 class(conv2d_layer), intent (in out ) :: self
0 commit comments