Skip to content

Commit 29fe90e

Browse files
eellisoneellisonmalfet
authored
[release/1.6] [JIT] Dont include view ops in autodiff graphs (pytorch#42029)
* Dont include view ops in autodiff graphs * skip view ops in autodiff testing * two more tests * appease calng format * Pacify clang-format Co-authored-by: eellison <[email protected]> Co-authored-by: Nikita Shulga <[email protected]>
1 parent 35ad2d8 commit 29fe90e

File tree

3 files changed

+45
-23
lines changed

3 files changed

+45
-23
lines changed

test/test_jit.py

+3
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ def LSTMCellF(input, hx, cx, *params):
103103

104104

105105
def doAutodiffCheck(testname):
106+
# TODO: setting false on test itself is not working
107+
if "test_t_" in testname or testname == "test_t":
108+
return False
106109

107110
if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
108111
return False

torch/csrc/jit/passes/create_autodiff_subgraphs.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,20 @@ class SubgraphSlicer {
110110
return result;
111111
}
112112

113+
bool isViewOp(Node* n) {
114+
switch (n->kind()) {
115+
case aten::view:
116+
case aten::view_as:
117+
case aten::reshape:
118+
case aten::reshape_as:
119+
case aten::transpose:
120+
case aten::expand:
121+
case aten::expand_as:
122+
return true;
123+
}
124+
return false;
125+
}
126+
113127
bool shouldConsiderForMerge(Node* node) {
114128
// if we're already in the process of merging
115129
if (node->kind() == prim::DifferentiableGraph) {
@@ -118,6 +132,11 @@ class SubgraphSlicer {
118132
if (node->kind() == prim::Constant) {
119133
return false;
120134
}
135+
// view ops as outputs of differentiable subgraphs can cause incorrect
136+
// differentiation for now, do not include them in the subgraph
137+
if (isViewOp(node)) {
138+
return false;
139+
}
121140
return isDifferentiable(node);
122141
}
123142

torch/testing/_internal/common_methods_invocations.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -179,22 +179,22 @@ def method_tests():
179179
('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True, 'aten::pow')),
180180
('pow', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True,)),
181181
('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')),
182-
('transpose', (1, 2, 3), (1, 2), 'dim', (True,), [0, 1]),
183-
('transpose', (), (0, 0), 'scalar', (True,)),
184-
('transpose', (1,), (0, 0), '1d', (True,)),
185-
('transpose', (L, L), (0, 1), '2d', (True,)),
186-
('transpose', (S, S, S), (2, 0), '3d', (True,)),
187-
('t', (1, 2), NO_ARGS, '', (True,)),
188-
('view', (S, S, S), (S * S, S), '', (True,)),
189-
('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (True,)),
190-
('view', (S,), (S,), '1d', (True,)),
191-
('view', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
192-
('view', (), (1,), 'scalar_to_1d', (True,)),
193-
('reshape', (S, S, S), (S * S, S), '', (True,)),
194-
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (True,)),
195-
('reshape', (S,), (S,), '1d', (True,)),
196-
('reshape', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
197-
('reshape', (), (1,), 'scalar_to_1d', (True,)),
182+
('transpose', (1, 2, 3), (1, 2), 'dim', (False,), [0, 1]),
183+
('transpose', (), (0, 0), 'scalar', (False,)),
184+
('transpose', (1,), (0, 0), '1d', (False,)),
185+
('transpose', (L, L), (0, 1), '2d', (False,)),
186+
('transpose', (S, S, S), (2, 0), '3d', (False,)),
187+
('t', (1, 2), NO_ARGS, '', (False,)),
188+
('view', (S, S, S), (S * S, S), '', (False,)),
189+
('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)),
190+
('view', (S,), (S,), '1d', (False,)),
191+
('view', (), (dont_convert(()),), 'scalar_to_scalar', (False,)),
192+
('view', (), (1,), 'scalar_to_1d', (False,)),
193+
('reshape', (S, S, S), (S * S, S), '', (False,)),
194+
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (False,)),
195+
('reshape', (S,), (S,), '1d', (False,)),
196+
('reshape', (), (dont_convert(()),), 'scalar_to_scalar', (False,)),
197+
('reshape', (), (1,), 'scalar_to_1d', (False,)),
198198
('reshape_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
199199
('reshape_as', (), (non_differentiable(torch.tensor(42.)),), 'scalar'),
200200
('reshape_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
@@ -220,14 +220,14 @@ def method_tests():
220220
('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
221221
('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'),
222222
('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
223-
('expand', (S, 1, 1), (S, S, S), '', (True,)),
224-
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size', (True,)),
225-
('expand', (S, 1), (S, S, S), 'new_dim', (True,)),
226-
('expand', (1,), (S, S, S), '1_element', (True,)),
227-
('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1', (True,)),
223+
('expand', (S, 1, 1), (S, S, S), '', (False,)),
224+
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size', (False,)),
225+
('expand', (S, 1), (S, S, S), 'new_dim', (False,)),
226+
('expand', (1,), (S, S, S), '1_element', (False,)),
227+
('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1', (False,)),
228228
('expand', (), (dont_convert(()),), 'scalar_to_scalar'),
229-
('expand', (), (1, 3, 2), 'scalar_to_dims', (True,)),
230-
('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (True,)),
229+
('expand', (), (1, 3, 2), 'scalar_to_dims', (False,)),
230+
('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (False,)),
231231
('exp', (S, S, S), NO_ARGS, '', (True,)),
232232
('exp', (), NO_ARGS, 'scalar', (True,)),
233233
('expm1', (S, S, S), NO_ARGS, '', (True,)),

0 commit comments

Comments
 (0)