|
| 1 | +import functools |
1 | 2 | import torch
|
2 | 3 | import unittest
|
3 | 4 | import torchax
|
4 | 5 | from torchax import interop
|
| 6 | +import torchax |
5 | 7 |
|
6 |
| -class M1(torch.nn.Module): |
7 | 8 |
|
8 |
| - def __init__(self): |
9 |
| - super().__init__() |
10 |
| - self.x = torch.ones(10, 10) |
| 9 | +class InteropTest(unittest.TestCase): |
11 | 10 |
|
12 |
| -class M(torch.nn.Module): |
| 11 | + def setUp(self): |
| 12 | + torchax.enable_globally() |
13 | 13 |
|
14 |
| - def __init__(self): |
15 |
| - super().__init__() |
16 |
| - self.a = torch.nn.Linear(100, 100) |
17 |
| - self.b = torch.nn.Parameter( |
18 |
| - torch.ones(10, 10) |
19 |
| - ) |
20 |
| - c = torch.ones(10, 10) |
21 |
| - self.register_buffer('c', c) |
22 |
| - self.register_buffer('c2', c, persistent=False) |
23 |
| - self.d = torch.ones(10, 10) |
24 |
| - self.m1 = M1() |
25 | 14 |
|
| 15 | + def test_mod_attr(self): |
26 | 16 |
|
27 |
| -class InteropTest(unittest.TestCase): |
| 17 | + class Child(torch.nn.Module): |
28 | 18 |
|
| 19 | + def __init__(self): |
| 20 | + super().__init__() |
| 21 | + self.x = torch.ones(10, 10) |
29 | 22 |
|
30 |
| - def test_mod_attr(self): |
31 |
| - m = M() |
| 23 | + class ModuleWithUnregisteredTensor(torch.nn.Module): |
| 24 | + |
| 25 | + def __init__(self): |
| 26 | + super().__init__() |
| 27 | + self.a = torch.nn.Linear(100, 100) |
| 28 | + self.b = torch.nn.Parameter( |
| 29 | + torch.ones(10, 10) |
| 30 | + ) |
| 31 | + c = torch.ones(10, 10) |
| 32 | + self.register_buffer('c', c) |
| 33 | + self.register_buffer('c2', c, persistent=False) |
| 34 | + self.d = torch.ones(10, 10) |
| 35 | + self.m1 = Child() |
| 36 | + |
| 37 | + m = ModuleWithUnregisteredTensor() |
32 | 38 | params, buffers = interop.extract_all_buffers(m)
|
33 | 39 | self.assertEqual(
|
34 | 40 | set(params.keys()), {'a.weight', 'a.bias', 'b'}
|
@@ -75,6 +81,55 @@ def fn(x):
|
75 | 81 | expected = torch.ones(2, 2) * 2
|
76 | 82 | torch.testing.assert_close(x.grad, expected, check_device=False)
|
77 | 83 |
|
| 84 | + def test_module_with_shared_weights(self): |
| 85 | + |
| 86 | + # arrange |
| 87 | + class ModuleWithSharedWeights(torch.nn.Module): |
| 88 | + |
| 89 | + def __init__(self): |
| 90 | + super().__init__() |
| 91 | + self.a = torch.nn.Linear(10, 10) |
| 92 | + self.b = self.a |
| 93 | + |
| 94 | + def forward(self, x): |
| 95 | + return self.a(self.b(x)) |
| 96 | + |
| 97 | + m = ModuleWithSharedWeights().to('jax') |
| 98 | + |
| 99 | + m_jitted = interop.JittableModule(m, dedup_parameters=True) |
| 100 | + |
| 101 | + # a's weights and bias and b's weights and bias |
| 102 | + self.assertEqual(len(m.state_dict()), 4) |
| 103 | + |
| 104 | + # b's weights and bias are deduped |
| 105 | + self.assertEqual(len(m_jitted.params), 2) |
| 106 | + x = torch.randn(10, 10).to('jax') |
| 107 | + expected = m(x) |
| 108 | + |
| 109 | + # act |
| 110 | + actual = m_jitted(x) |
| 111 | + |
| 112 | + # assert |
| 113 | + torch.testing.assert_allclose(actual, expected) |
| 114 | + |
| 115 | + # arrange |
| 116 | + # make sure buffer donation works |
| 117 | + functional_forward = interop.jax_jit( |
| 118 | + functools.partial(m_jitted.functional_call, 'forward'), |
| 119 | + kwargs_for_jax_jit={ |
| 120 | + 'donate_argnums': (0, ) |
| 121 | + } |
| 122 | + ) |
| 123 | + |
| 124 | + # act |
| 125 | + actual = functional_forward(m_jitted.params, m_jitted.buffers, x) |
| 126 | + # assert |
| 127 | + torch.testing.assert_allclose(actual, expected) |
| 128 | + |
| 129 | + |
| 130 | + |
| 131 | + |
| 132 | + |
78 | 133 |
|
79 | 134 | if __name__ == '__main__':
|
80 | 135 | unittest.main()
|
0 commit comments