Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def graph_safe_rng_available() -> bool:
)


def is_graph_safe_rng_state(state: Union[torch.Tensor, torch.Generator]) -> bool:
"""Returns whether the rng state is a graph safe version."""
return graph_safe_rng_available() and isinstance(state, torch.Generator)


def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda",
clone: bool = False,
Expand Down Expand Up @@ -340,9 +345,16 @@ def forward(

# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None:
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
ctx.graph_safe_rng_state = (
is_graph_safe_rng_state(next(iter(ctx.fwd_cuda_rng_state_tracker.values())))
if ctx.fwd_cuda_rng_state_tracker
else False
)
else:
ctx.graph_safe_rng_state = False
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state)

if context_fn is not None:
forward_ctx, recompute_ctx = context_fn()
Expand Down Expand Up @@ -406,13 +418,13 @@ def backward(

# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None:
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()

# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

Expand All @@ -427,7 +439,7 @@ def backward(

# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False)
_set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)

Expand Down Expand Up @@ -470,12 +482,21 @@ def __init__(self, recompute_fn: Callable, get_rng_state_tracker: Callable):

def cache_rng_states(self, forward=True):
"""Cache fwd/bwd RNG states in the frame to restore later."""
rng_states = (
torch.get_rng_state(),
_get_cuda_rng_state(graph_safe=False),
)
rng_states = (torch.get_rng_state(),)
if self.get_rng_state_tracker is not None:
rng_states += (self.get_rng_state_tracker().get_states(),)
tracker_states = self.get_rng_state_tracker().get_states()
self.graph_safe_rng_state = (
is_graph_safe_rng_state(next(iter(tracker_states.values())))
if tracker_states
else False
)
rng_states += (
_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),
tracker_states,
)
else:
self.graph_safe_rng_state = False
rng_states += (_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),)

if forward:
self.fwd_rng_states = rng_states
Expand All @@ -490,7 +511,7 @@ def restore_rng_states(self, forward=True):
rng_states = self.bwd_rng_states

torch.set_rng_state(rng_states[0])
_set_cuda_rng_state(rng_states[1], graph_safe=False)
_set_cuda_rng_state(rng_states[1], graph_safe=self.graph_safe_rng_state)
if self.get_rng_state_tracker is not None:
self.get_rng_state_tracker().set_states(rng_states[2])

Expand Down
71 changes: 50 additions & 21 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ def graph_pool_handle():
return _graph_pool_handle()


@contextlib.contextmanager
def _none_grad_context_wrapper(inputs):
"""
Wrapper to set the gradients of the inputs to None,
in case the backward pass makes grad accumulations.
"""
original_input_grads = []
for input_tensor in inputs:
original_input_grads.append(input_tensor.grad)
input_tensor.grad = None
yield
for input_tensor, original_grad in zip(inputs, original_input_grads):
input_tensor.grad = original_grad


@contextlib.contextmanager
def _graph_context_wrapper(*args, **kwargs):
"""Wrapper around `torch.cuda.graph`.
Expand Down Expand Up @@ -434,13 +449,15 @@ def hook_fn(
for hook in hooks:
hook.remove()
if is_training:
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad),
only_inputs=True,
allow_unused=allow_unused_input,
)
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs):
torch.autograd.backward(
tuple(o for o in outputs if o.requires_grad),
grad_tensors=tuple(
torch.empty_like(o) for o in outputs if o.requires_grad
),
)
grad_inputs = tuple(input.grad for input in inputs)

# Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks
Expand All @@ -455,6 +472,14 @@ def hook_fn(
module_params_with_grad = []
for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx):
if (
grad_inputs[grad_inputs_idx] is None
and grad_inputs_idx < num_required_grad_sample_args
):
assert allow_unused_input, (
"The input tensor requires grad, but the grad is None after"
" backward pass."
)
elif (
grad_inputs[grad_inputs_idx] is not None
and grad_inputs_idx >= num_required_grad_sample_args
):
Expand Down Expand Up @@ -606,15 +631,17 @@ def hook_fn(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward,
)
grad_inputs = tuple(input.grad for input in inputs)

# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
Expand Down Expand Up @@ -695,15 +722,17 @@ def hook_fn(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward,
)
grad_inputs = tuple(input.grad for input in inputs)

if need_bwd_dw_graph[bwd_idx]:
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[bwd_idx]:
Expand Down