Skip to content

Commit c9a00a7

Browse files
authored
Merge pull request #160 from stanfordnlp/zen/remove_orth_init
[Minor] Remove ortho init for DAS
2 parents f92a379 + 0983f40 commit c9a00a7

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

pyvene/models/layers.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@ def forward(self, x):
3333
class LowRankRotateLayer(torch.nn.Module):
3434
"""A linear transformation with orthogonal initialization."""
3535

36-
def __init__(self, n, m):
36+
def __init__(self, n, m, init_orth=True):
3737
super().__init__()
3838
# n > m
3939
self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)
40-
torch.nn.init.orthogonal_(self.weight)
40+
if init_orth:
41+
torch.nn.init.orthogonal_(self.weight)
4142

4243
def forward(self, x):
4344
return torch.matmul(x.to(self.weight.dtype), self.weight)
@@ -46,11 +47,12 @@ def forward(self, x):
4647
class SubspaceLowRankRotateLayer(torch.nn.Module):
4748
"""A linear transformation with orthogonal initialization with subspace."""
4849

49-
def __init__(self, n, m):
50+
def __init__(self, n, m, init_orth=True):
5051
super().__init__()
5152
# n > m
5253
self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)
53-
torch.nn.init.orthogonal_(self.weight)
54+
if init_orth:
55+
torch.nn.init.orthogonal_(self.weight)
5456

5557
def forward(self, x, l, r):
5658
return torch.matmul(x.to(self.weight.dtype), self.weight[:, l:r])

0 commit comments

Comments
 (0)