@@ -691,8 +691,12 @@ module function get_params(self) result(params)
691
691
params = this_layer % get_params()
692
692
type is (embedding_layer)
693
693
params = this_layer % get_params()
694
+
694
695
type is (layernorm_layer)
695
- params = this_layer % get_params()
696
+ call this_layer % get_params_ptr(w_ptr, b_ptr)
697
+ allocate (params(size (w_ptr) + size (b_ptr)))
698
+ params(1 :size (w_ptr)) = w_ptr
699
+ params(size (w_ptr)+ 1 :) = b_ptr
696
700
class default
697
701
error stop ' Unknown layer type.'
698
702
end select
@@ -703,6 +707,8 @@ end function get_params
703
707
module subroutine set_params (self , params )
704
708
class(layer), intent (in out ) :: self
705
709
real , intent (in ) :: params(:)
710
+ real , pointer :: w_ptr(:)
711
+ real , pointer :: b_ptr(:)
706
712
707
713
! Check that the number of parameters is correct.
708
714
! This check will still pass if the size(params) == 0 and the layer is a
@@ -736,37 +742,55 @@ module subroutine set_params(self, params)
736
742
// ' on a zero-parameter layer; nothing to do.'
737
743
738
744
type is (dense_layer)
739
- call this_layer % set_params(params)
745
+ call this_layer % get_params_ptr(w_ptr, b_ptr)
746
+
747
+ w_ptr = params(1 :size (w_ptr))
748
+ b_ptr = params(size (w_ptr)+ 1 :)
740
749
741
750
type is (dropout_layer)
742
751
! No parameters to set.
743
752
write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
744
753
// ' on a zero-parameter layer; nothing to do.'
745
754
746
755
type is (conv1d_layer)
747
- call this_layer % set_params(params)
756
+ call this_layer % get_params_ptr(w_ptr, b_ptr)
757
+
758
+ w_ptr = params(1 :size (w_ptr))
759
+ b_ptr = params(size (w_ptr)+ 1 :)
748
760
749
761
type is (conv2d_layer)
750
- call this_layer % set_params(params)
762
+ call this_layer % get_params_ptr(w_ptr, b_ptr)
763
+
764
+ w_ptr = params(1 :size (w_ptr))
765
+ b_ptr = params(size (w_ptr)+ 1 :)
751
766
752
767
type is (locally_connected2d_layer)
753
- call this_layer % set_params(params)
768
+ call this_layer % get_params_ptr(w_ptr, b_ptr)
769
+
770
+ w_ptr = params(1 :size (w_ptr))
771
+ b_ptr = params(size (w_ptr)+ 1 :)
754
772
755
773
type is (maxpool1d_layer)
756
774
! No parameters to set.
757
775
write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
758
776
// ' on a zero-parameter layer; nothing to do.'
759
777
760
778
type is (linear2d_layer)
761
- call this_layer % set_params(params)
779
+ call this_layer % get_params_ptr(w_ptr, b_ptr)
780
+
781
+ w_ptr = params(1 :size (w_ptr))
782
+ b_ptr = params(size (w_ptr)+ 1 :)
762
783
763
784
type is (self_attention_layer)
764
785
call this_layer % set_params(params)
765
786
type is (embedding_layer)
766
787
call this_layer % set_params(params)
767
788
768
789
type is (layernorm_layer)
769
- call this_layer % set_params(params)
790
+ call this_layer % get_params_ptr(w_ptr, b_ptr)
791
+
792
+ w_ptr = params(1 :size (w_ptr))
793
+ b_ptr = params(size (w_ptr)+ 1 :)
770
794
771
795
type is (maxpool2d_layer)
772
796
! No parameters to set.
0 commit comments