Skip to content

Conversation

@lhb8125
Copy link
Contributor

@lhb8125 lhb8125 commented Dec 1, 2025

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@lhb8125 lhb8125 marked this pull request as draft December 1, 2025 05:34
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 1, 2025

Greptile Overview

Greptile Summary

This PR adds support for CUDA graph capture with stream offloading by introducing three new optional hook parameters (pre_warmup_hook, post_warmup_hook, init_chunk_handler) and making CUDA graph execution offloadable to a separate stream via required cuda_graph_stream and cuda_graph_event kwargs.

Key Changes:

  • Added three optional hook parameters to enable custom logic during warmup and chunk initialization phases
  • Modified the Graphed autograd function to accept and use cuda_graph_stream and cuda_graph_event for offloading graph replay to a separate CUDA stream
  • Updated stream synchronization logic in both forward and backward passes to properly coordinate between the main stream and the graph stream
  • Extended function signatures throughout the call chain to support the new parameters

Critical Issues:

  • Missing event recording after graph replays in both forward (line 709-712) and backward (line 729-732) passes - the cuda_graph_event is never recorded after the graph executes, which means wait_event calls wait on unrecorded events, potentially causing race conditions or hangs
  • Hard assertions for required kwargs (lines 755-756) will crash with uninformative error messages instead of guiding users on proper usage

Additional Notes:

  • The PR description is incomplete - missing actual description, issue reference, type of change selection, and checklist completion
  • No tests added to verify the new functionality works correctly with stream offloading

Confidence Score: 2/5

  • This PR has critical synchronization bugs that could cause race conditions or hangs in production
  • Score reflects two critical logic errors in CUDA event/stream synchronization that will likely cause runtime failures, plus incomplete PR documentation and missing tests for the new functionality
  • Pay close attention to transformer_engine/pytorch/graph.py lines 709-712 and 729-732 where event recording is missing after graph replays

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/graph.py 3/5 Adds CUDA graph stream offloading support with new hooks and stream synchronization parameters, but has critical bugs in event recording and required parameter handling

Sequence Diagram

sequenceDiagram
    participant User
    participant Callable
    participant Graphed as Graphed.apply
    participant MainStream as Main CUDA Stream
    participant GraphStream as CUDA Graph Stream
    participant Event as CUDA Event

    User->>Callable: forward(*args, cuda_graph_stream, cuda_graph_event, **kwargs)
    Callable->>Graphed: apply(skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *inputs)
    
    Note over Graphed: Store stream & event in ctx
    Graphed->>Graphed: Copy inputs to static tensors
    
    Graphed->>GraphStream: wait_stream(current_stream)
    Note over GraphStream: Wait for main stream operations
    
    Graphed->>GraphStream: fwd_graph.replay()
    Note over GraphStream: Execute forward graph
    
    GraphStream->>Event: record() [MISSING!]
    Note over Event: Should record completion
    
    Graphed->>MainStream: wait_event(cuda_graph_event)
    Note over MainStream: Wait for graph completion
    
    Graphed-->>User: static_outputs
    
    Note over User: Backward pass triggered
    
    User->>Graphed: backward(*grads)
    Graphed->>Graphed: Copy grads to static tensors
    
    Graphed->>GraphStream: wait_stream(current_stream)
    Graphed->>GraphStream: bwd_graph.replay()
    
    GraphStream->>Event: record() [MISSING!]
    
    Graphed->>MainStream: wait_event(cuda_graph_event)
    
    Graphed-->>User: static_grad_inputs
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +709 to +712
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
torch.cuda.current_stream().wait_event(cuda_graph_event)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Missing event recording after graph replay. The cuda_graph_event should be recorded in the cuda_graph_stream after the forward graph replay completes, otherwise wait_event on line 712 will wait on an unrecorded event which may cause synchronization issues.

Suggested change
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
torch.cuda.current_stream().wait_event(cuda_graph_event)
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
cuda_graph_event.record()
torch.cuda.current_stream().wait_event(cuda_graph_event)

Comment on lines +729 to +732
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Missing event recording after backward graph replay. The cuda_graph_event should be recorded in the cuda_graph_stream after the backward graph replay completes.

Suggested change
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
ctx.cuda_graph_event.record()
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)

Comment on lines +755 to +756
assert "cuda_graph_stream" in user_kwargs
assert "cuda_graph_event" in user_kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Hard assertions will crash the program when users don't provide these required kwargs. Use a more informative error message that guides users on proper usage.

Suggested change
assert "cuda_graph_stream" in user_kwargs
assert "cuda_graph_event" in user_kwargs
if "cuda_graph_stream" not in user_kwargs:
raise TypeError(
"Graphed callable requires 'cuda_graph_stream' kwarg. "
"Please provide a torch.cuda.Stream for graph execution."
)
if "cuda_graph_event" not in user_kwargs:
raise TypeError(
"Graphed callable requires 'cuda_graph_event' kwarg. "
"Please provide a torch.cuda.Event for synchronization."
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant