-
Notifications
You must be signed in to change notification settings - Fork 579
support cuda graph capture offloading module #2435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
support cuda graph capture offloading module #2435
Conversation
Signed-off-by: Hongbin Liu <[email protected]>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR adds support for CUDA graph capture with stream offloading by introducing three new optional hook parameters ( Key Changes:
Critical Issues:
Additional Notes:
Confidence Score: 2/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant Callable
participant Graphed as Graphed.apply
participant MainStream as Main CUDA Stream
participant GraphStream as CUDA Graph Stream
participant Event as CUDA Event
User->>Callable: forward(*args, cuda_graph_stream, cuda_graph_event, **kwargs)
Callable->>Graphed: apply(skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *inputs)
Note over Graphed: Store stream & event in ctx
Graphed->>Graphed: Copy inputs to static tensors
Graphed->>GraphStream: wait_stream(current_stream)
Note over GraphStream: Wait for main stream operations
Graphed->>GraphStream: fwd_graph.replay()
Note over GraphStream: Execute forward graph
GraphStream->>Event: record() [MISSING!]
Note over Event: Should record completion
Graphed->>MainStream: wait_event(cuda_graph_event)
Note over MainStream: Wait for graph completion
Graphed-->>User: static_outputs
Note over User: Backward pass triggered
User->>Graphed: backward(*grads)
Graphed->>Graphed: Copy grads to static tensors
Graphed->>GraphStream: wait_stream(current_stream)
Graphed->>GraphStream: bwd_graph.replay()
GraphStream->>Event: record() [MISSING!]
Graphed->>MainStream: wait_event(cuda_graph_event)
Graphed-->>User: static_grad_inputs
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 3 comments
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | ||
| with cuda_graph_stream: | ||
| fwd_graph.replay() | ||
| torch.cuda.current_stream().wait_event(cuda_graph_event) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Missing event recording after graph replay. The cuda_graph_event should be recorded in the cuda_graph_stream after the forward graph replay completes, otherwise wait_event on line 712 will wait on an unrecorded event which may cause synchronization issues.
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with cuda_graph_stream: | |
| fwd_graph.replay() | |
| torch.cuda.current_stream().wait_event(cuda_graph_event) | |
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with cuda_graph_stream: | |
| fwd_graph.replay() | |
| cuda_graph_event.record() | |
| torch.cuda.current_stream().wait_event(cuda_graph_event) |
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | ||
| with ctx.cuda_graph_stream: | ||
| bwd_graph.replay() | ||
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Missing event recording after backward graph replay. The cuda_graph_event should be recorded in the cuda_graph_stream after the backward graph replay completes.
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with ctx.cuda_graph_stream: | |
| bwd_graph.replay() | |
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) | |
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with ctx.cuda_graph_stream: | |
| bwd_graph.replay() | |
| ctx.cuda_graph_event.record() | |
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) |
| assert "cuda_graph_stream" in user_kwargs | ||
| assert "cuda_graph_event" in user_kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Hard assertions will crash the program when users don't provide these required kwargs. Use a more informative error message that guides users on proper usage.
| assert "cuda_graph_stream" in user_kwargs | |
| assert "cuda_graph_event" in user_kwargs | |
| if "cuda_graph_stream" not in user_kwargs: | |
| raise TypeError( | |
| "Graphed callable requires 'cuda_graph_stream' kwarg. " | |
| "Please provide a torch.cuda.Stream for graph execution." | |
| ) | |
| if "cuda_graph_event" not in user_kwargs: | |
| raise TypeError( | |
| "Graphed callable requires 'cuda_graph_event' kwarg. " | |
| "Please provide a torch.cuda.Event for synchronization." | |
| ) |
Signed-off-by: Hongbin Liu <[email protected]>
Signed-off-by: Hongbin Liu <[email protected]>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: