Skip to content

Conversation

@buptzyb
Copy link
Contributor

@buptzyb buptzyb commented Dec 16, 2025

Description

Support cudagraph recomputation with two changes:

  1. Replace autograd.grad with autograd.backward in cudagraph capturing.
  2. Get default RNG states in a graphsafe manner, if the tracker states are also graphsafe.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 16, 2025

Greptile Overview

Greptile Summary

This PR enables cudagraph recomputation support through two key changes:

Key Changes:

  • Replaced torch.autograd.grad with torch.autograd.backward in graph.py during cudagraph capturing (lines 454-460, 638-643, 729-734). This change is necessary because autograd.grad creates intermediate tensors that cannot be captured in CUDA graphs, while autograd.backward writes directly to .grad attributes
  • Added _none_grad_context_wrapper helper to temporarily clear input gradients during backward graph capture to handle gradient accumulation correctly
  • Introduced is_graph_safe_rng_state helper function in distributed.py to detect graph-safe RNG states
  • Updated checkpoint RNG state management to dynamically determine graph_safe parameter based on whether tracker states are graph-safe, enabling proper RNG state handling during cudagraph recomputation

Implementation Details:
The backward computation now extracts gradients from input.grad after calling torch.autograd.backward, rather than receiving them as return values from autograd.grad. The context wrapper ensures gradients start as None to avoid incorrect accumulation during graph capture.

Potential Issue:
One logic issue identified: _none_grad_context_wrapper assumes all inputs are tensors but doesn't validate this before accessing .grad attribute.

Confidence Score: 4/5

  • Safe to merge with minor risk - the changes are well-isolated to cudagraph functionality
  • The PR makes focused changes to enable cudagraph recomputation. The autograd.grad to autograd.backward conversion is a well-understood pattern for CUDA graphs. However, one potential logic issue was identified in the new _none_grad_context_wrapper that could cause AttributeError if non-tensor inputs are passed. The RNG state changes are sound and properly detect graph-safe states before using them.
  • Pay close attention to transformer_engine/pytorch/graph.py lines 66-77 where the new _none_grad_context_wrapper could fail with non-tensor inputs

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/distributed.py 4/5 Added is_graph_safe_rng_state helper and updated checkpoint RNG state management to dynamically determine graph-safe mode based on tracker state, ensuring compatibility with cudagraph recomputation
transformer_engine/pytorch/graph.py 4/5 Replaced autograd.grad with autograd.backward in cudagraph capturing and added _none_grad_context_wrapper to properly handle gradient accumulation during graph capture

Sequence Diagram

sequenceDiagram
    participant User
    participant make_graphed_callables
    participant _make_graphed_callables
    participant Module
    participant _none_grad_context_wrapper
    participant autograd

    User->>make_graphed_callables: Call with modules & sample_args
    make_graphed_callables->>_make_graphed_callables: Forward callables
    
    Note over _make_graphed_callables: Warmup Phase
    loop num_warmup_iters
        _make_graphed_callables->>Module: Forward pass
        Module-->>_make_graphed_callables: outputs
        _make_graphed_callables->>_none_grad_context_wrapper: Enter context with inputs
        _none_grad_context_wrapper->>_none_grad_context_wrapper: Save original grads
        _none_grad_context_wrapper->>_none_grad_context_wrapper: Set input.grad = None
        _make_graphed_callables->>autograd: backward(outputs, grad_tensors)
        autograd->>autograd: Write gradients to input.grad
        _make_graphed_callables->>_make_graphed_callables: Extract grad_inputs from input.grad
        _none_grad_context_wrapper->>_none_grad_context_wrapper: Restore original grads
    end
    
    Note over _make_graphed_callables: Graph Capture Phase
    _make_graphed_callables->>Module: Capture forward graph
    Module-->>_make_graphed_callables: static_outputs
    _make_graphed_callables->>_none_grad_context_wrapper: Enter context with inputs
    _none_grad_context_wrapper->>_none_grad_context_wrapper: Clear input.grad
    _make_graphed_callables->>autograd: backward(static_outputs, static_grad_outputs)
    autograd->>autograd: Capture backward graph
    _make_graphed_callables->>_make_graphed_callables: Capture grad_inputs from input.grad
    _none_grad_context_wrapper->>_none_grad_context_wrapper: Restore grads
    
    _make_graphed_callables-->>make_graphed_callables: Return graphed callables
    make_graphed_callables-->>User: Return graphed modules
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Robin Zhang <[email protected]>
@buptzyb buptzyb changed the title [PyTorch] Use autograd.backward to capture cudagraph backward [PyTorch] Support cudagraph recomputation Dec 16, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Robin Zhang <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/graph.py, line 66-77 (link)

    logic: potential race condition if inputs tuple contains non-tensor items

    The code assumes all items in inputs are tensors with .grad attributes, but the wrapper doesn't validate this. If non-tensor items slip through, accessing .grad would raise AttributeError

    Are all items in the inputs tuple guaranteed to be tensors at the call sites, or could there be edge cases with non-tensor inputs?

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant