Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,11 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
self._set_and_check_sharding_prop_from_param()
self._shard_fn._set_sharding_axis(self._sharding_axis)

# Invoke register hook for sharding stage 2 strategy
if isinstance(self._shard_fn, ShardingStage2) and not in_auto_dp_mode():
for param in self._inner_opt._parameter_list:
self._shard_fn._register_hook_for_param_grad(param)

# Invoke shard_parameter in sharding stage 3 strategy
if isinstance(self._shard_fn, ShardingStage3):
for param in self._inner_opt._parameter_list:
Expand Down Expand Up @@ -2147,7 +2152,7 @@ def __call__(self, key: str, param: Tensor, tensor: Tensor) -> Tensor:
return self._apply_placement(tensor, param, placements)


class ShardingStage2(ShardingStage1):
class ShardingStage2(_ShardingStageBase):
"""
A builtin shard_fn for shard_optimizer interface, users can pass it to shard_optimizer to implement sharding optimization with stage 2.

Expand Down Expand Up @@ -2186,9 +2191,56 @@ class ShardingStage2(ShardingStage1):
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
"""

# Note(luchang): Due to reshard optimizations in Paddle where all-reduce + slicing is fused into reduce_scatter,
# the current behavior of ShardingStage2 is effectively the same as ShardingStage1.
pass
def __init__(
self,
sharding_mesh_dim: int | str,
mesh: ProcessMesh | None = None,
) -> None:
super().__init__(mesh, sharding_mesh_dim)

def __call__(self, key: str, param: Tensor, tensor: Tensor) -> Tensor:
if param.is_dist():
# Only deal with momentum in optimizer, beta should be replicated cross param's mesh
if 'beta' not in key:
placements = get_placement_with_sharding(
param, self._sharding_axis
)
else:
placements = [
dist.Replicate()
for _ in range(len(param.process_mesh.shape))
]
return shard_tensor(
tensor,
mesh=param.process_mesh,
placements=placements,
)
return tensor

@staticmethod
def _grad_hook(grad):
# do reshard only if the grad is dist tensor and in partial status
if grad.is_dist():
partial_mesh_axis = None
for mesh_axis, placement in enumerate(grad.placements):
if isinstance(placement, dist.Partial):
partial_mesh_axis = mesh_axis
if partial_mesh_axis is not None:
new_placements = get_placement_with_sharding(
grad, partial_mesh_axis
)
return reshard(grad, grad.process_mesh, new_placements)

return grad

def _register_hook_for_param_grad(self, param):
if param.is_dense() and self._mesh is not None:
placements = []
for _ in range(len(self._mesh.shape)):
placements.append(dist.Replicate())
param._to_dist_(placements, self._mesh)
if param.is_dist():
param.register_hook(ShardingStage2._grad_hook)


class ShardingStage3(_ShardingStageBase):
Expand Down
Loading