Skip to content

Commit cce7cc0

Browse files
authored
Add an option for JittableModule to dedup parameters. (#8965)
1 parent 7882475 commit cce7cc0

File tree

3 files changed

+98
-21
lines changed

3 files changed

+98
-21
lines changed

torchax/test/test_interop.py

+74-19
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,40 @@
1+
import functools
12
import torch
23
import unittest
34
import torchax
45
from torchax import interop
6+
import torchax
57

6-
class M1(torch.nn.Module):
78

8-
def __init__(self):
9-
super().__init__()
10-
self.x = torch.ones(10, 10)
9+
class InteropTest(unittest.TestCase):
1110

12-
class M(torch.nn.Module):
11+
def setUp(self):
12+
torchax.enable_globally()
1313

14-
def __init__(self):
15-
super().__init__()
16-
self.a = torch.nn.Linear(100, 100)
17-
self.b = torch.nn.Parameter(
18-
torch.ones(10, 10)
19-
)
20-
c = torch.ones(10, 10)
21-
self.register_buffer('c', c)
22-
self.register_buffer('c2', c, persistent=False)
23-
self.d = torch.ones(10, 10)
24-
self.m1 = M1()
2514

15+
def test_mod_attr(self):
2616

27-
class InteropTest(unittest.TestCase):
17+
class Child(torch.nn.Module):
2818

19+
def __init__(self):
20+
super().__init__()
21+
self.x = torch.ones(10, 10)
2922

30-
def test_mod_attr(self):
31-
m = M()
23+
class ModuleWithUnregisteredTensor(torch.nn.Module):
24+
25+
def __init__(self):
26+
super().__init__()
27+
self.a = torch.nn.Linear(100, 100)
28+
self.b = torch.nn.Parameter(
29+
torch.ones(10, 10)
30+
)
31+
c = torch.ones(10, 10)
32+
self.register_buffer('c', c)
33+
self.register_buffer('c2', c, persistent=False)
34+
self.d = torch.ones(10, 10)
35+
self.m1 = Child()
36+
37+
m = ModuleWithUnregisteredTensor()
3238
params, buffers = interop.extract_all_buffers(m)
3339
self.assertEqual(
3440
set(params.keys()), {'a.weight', 'a.bias', 'b'}
@@ -75,6 +81,55 @@ def fn(x):
7581
expected = torch.ones(2, 2) * 2
7682
torch.testing.assert_close(x.grad, expected, check_device=False)
7783

84+
def test_module_with_shared_weights(self):
85+
86+
# arrange
87+
class ModuleWithSharedWeights(torch.nn.Module):
88+
89+
def __init__(self):
90+
super().__init__()
91+
self.a = torch.nn.Linear(10, 10)
92+
self.b = self.a
93+
94+
def forward(self, x):
95+
return self.a(self.b(x))
96+
97+
m = ModuleWithSharedWeights().to('jax')
98+
99+
m_jitted = interop.JittableModule(m, dedup_parameters=True)
100+
101+
# a's weights and bias and b's weights and bias
102+
self.assertEqual(len(m.state_dict()), 4)
103+
104+
# b's weights and bias are deduped
105+
self.assertEqual(len(m_jitted.params), 2)
106+
x = torch.randn(10, 10).to('jax')
107+
expected = m(x)
108+
109+
# act
110+
actual = m_jitted(x)
111+
112+
# assert
113+
torch.testing.assert_allclose(actual, expected)
114+
115+
# arrange
116+
# make sure buffer donation works
117+
functional_forward = interop.jax_jit(
118+
functools.partial(m_jitted.functional_call, 'forward'),
119+
kwargs_for_jax_jit={
120+
'donate_argnums': (0, )
121+
}
122+
)
123+
124+
# act
125+
actual = functional_forward(m_jitted.params, m_jitted.buffers, x)
126+
# assert
127+
torch.testing.assert_allclose(actual, expected)
128+
129+
130+
131+
132+
78133

79134
if __name__ == '__main__':
80135
unittest.main()

torchax/torchax/interop.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import copy
23
import functools
34
import torch
@@ -51,14 +52,27 @@ def set_one(module, prefix):
5152

5253
class JittableModule(torch.nn.Module):
5354

54-
def __init__(self, m: torch.nn.Module, extra_jit_args={}):
55+
def __init__(self, m: torch.nn.Module, extra_jit_args={}, dedup_parameters=True):
5556
super().__init__()
5657
self.params, self.buffers = extract_all_buffers(m)
5758
self._model = m
5859
self._jitted = {}
5960

6061
self._extra_jit_args = extra_jit_args
6162

63+
self._extra_dumped_weights = {}
64+
65+
if dedup_parameters:
66+
temp = collections.defaultdict(list)
67+
for k, v in self.params.items():
68+
temp[id(v)].append(k)
69+
70+
for v in temp.values():
71+
if len(v) > 1:
72+
# duplicated weights with different name
73+
self._extra_dumped_weights[v[0]] = v[1:]
74+
for extra_keys in v[1:]:
75+
del self.params[extra_keys]
6276

6377
def __call__(self, *args, **kwargs):
6478
return self.forward(*args, **kwargs)
@@ -69,6 +83,10 @@ def functional_call(
6983
kwargs = kwargs or {}
7084
params_copy = copy.copy(params)
7185
params_copy.update(buffers)
86+
# reinflate the state dict so there are not any missing keys
87+
for k, v in self._extra_dumped_weights.items():
88+
for new_key in v:
89+
params_copy[new_key] = params_copy[k]
7290
with torch_stateless._reparametrize_module(self._model, params_copy):
7391
res = getattr(self._model, method_name)(*args, **kwargs)
7492
return res
@@ -285,11 +303,13 @@ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
285303
return torch_view(jitted)
286304

287305

288-
def jax_jit(torch_function, kwargs_for_jax_jit=None):
306+
def jax_jit(torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False):
289307
return wrap_jax_jit(torch_function, jax_jit_func=jax.jit,
290308
kwargs_for_jax=kwargs_for_jax_jit)
291309

292310

311+
312+
293313
def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
294314
return wrap_jax_jit(torch_function, jax_jit_func=shard_map,
295315
kwargs_for_jax=kwargs_for_jax_shard_map)

torchax/torchax/ops/jaten.py

+2
Original file line numberDiff line numberDiff line change
@@ -1590,12 +1590,14 @@ def _aten_bitwise_not(self):
15901590

15911591

15921592
# aten.bitwise_left_shift
1593+
@op(torch.ops.aten.__lshift__)
15931594
@op(torch.ops.aten.bitwise_left_shift)
15941595
def _aten_bitwise_left_shift(input, other):
15951596
return jnp.left_shift(input, other)
15961597

15971598

15981599
# aten.bitwise_right_shift
1600+
@op(torch.ops.aten.__rshift__)
15991601
@op(torch.ops.aten.bitwise_right_shift)
16001602
def _aten_bitwise_right_shift(input, other):
16011603
return jnp.right_shift(input, other)

0 commit comments

Comments
 (0)