-
Notifications
You must be signed in to change notification settings - Fork 579
[PyTorch] Support cudagraph recomputation #2518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Robin Zhang <[email protected]>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR enables cudagraph recomputation support through two key changes: Key Changes:
Implementation Details: Potential Issue: Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant make_graphed_callables
participant _make_graphed_callables
participant Module
participant _none_grad_context_wrapper
participant autograd
User->>make_graphed_callables: Call with modules & sample_args
make_graphed_callables->>_make_graphed_callables: Forward callables
Note over _make_graphed_callables: Warmup Phase
loop num_warmup_iters
_make_graphed_callables->>Module: Forward pass
Module-->>_make_graphed_callables: outputs
_make_graphed_callables->>_none_grad_context_wrapper: Enter context with inputs
_none_grad_context_wrapper->>_none_grad_context_wrapper: Save original grads
_none_grad_context_wrapper->>_none_grad_context_wrapper: Set input.grad = None
_make_graphed_callables->>autograd: backward(outputs, grad_tensors)
autograd->>autograd: Write gradients to input.grad
_make_graphed_callables->>_make_graphed_callables: Extract grad_inputs from input.grad
_none_grad_context_wrapper->>_none_grad_context_wrapper: Restore original grads
end
Note over _make_graphed_callables: Graph Capture Phase
_make_graphed_callables->>Module: Capture forward graph
Module-->>_make_graphed_callables: static_outputs
_make_graphed_callables->>_none_grad_context_wrapper: Enter context with inputs
_none_grad_context_wrapper->>_none_grad_context_wrapper: Clear input.grad
_make_graphed_callables->>autograd: backward(static_outputs, static_grad_outputs)
autograd->>autograd: Capture backward graph
_make_graphed_callables->>_make_graphed_callables: Capture grad_inputs from input.grad
_none_grad_context_wrapper->>_none_grad_context_wrapper: Restore grads
_make_graphed_callables-->>make_graphed_callables: Return graphed callables
make_graphed_callables-->>User: Return graphed modules
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
Signed-off-by: Robin Zhang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
Signed-off-by: Robin Zhang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/graph.py, line 66-77 (link)logic: potential race condition if
inputstuple contains non-tensor itemsThe code assumes all items in
inputsare tensors with.gradattributes, but the wrapper doesn't validate this. If non-tensor items slip through, accessing.gradwould raiseAttributeErrorAre all items in the inputs tuple guaranteed to be tensors at the call sites, or could there be edge cases with non-tensor inputs?
2 files reviewed, 1 comment
Description
Support cudagraph recomputation with two changes:
autograd.gradwithautograd.backwardin cudagraph capturing.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: