@@ -33,11 +33,12 @@ def forward(self, x):
33
33
class LowRankRotateLayer (torch .nn .Module ):
34
34
"""A linear transformation with orthogonal initialization."""
35
35
36
- def __init__ (self , n , m ):
36
+ def __init__ (self , n , m , init_orth = True ):
37
37
super ().__init__ ()
38
38
# n > m
39
39
self .weight = torch .nn .Parameter (torch .empty (n , m ), requires_grad = True )
40
- torch .nn .init .orthogonal_ (self .weight )
40
+ if init_orth :
41
+ torch .nn .init .orthogonal_ (self .weight )
41
42
42
43
def forward (self , x ):
43
44
return torch .matmul (x .to (self .weight .dtype ), self .weight )
@@ -46,11 +47,12 @@ def forward(self, x):
46
47
class SubspaceLowRankRotateLayer (torch .nn .Module ):
47
48
"""A linear transformation with orthogonal initialization with subspace."""
48
49
49
- def __init__ (self , n , m ):
50
+ def __init__ (self , n , m , init_orth = True ):
50
51
super ().__init__ ()
51
52
# n > m
52
53
self .weight = torch .nn .Parameter (torch .empty (n , m ), requires_grad = True )
53
- torch .nn .init .orthogonal_ (self .weight )
54
+ if init_orth :
55
+ torch .nn .init .orthogonal_ (self .weight )
54
56
55
57
def forward (self , x , l , r ):
56
58
return torch .matmul (x .to (self .weight .dtype ), self .weight [:, l :r ])
0 commit comments