diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/low_precision.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/low_precision.py new file mode 100644 index 0000000000..0485db83c9 --- /dev/null +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/low_precision.py @@ -0,0 +1,311 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from importlib.metadata import version +from typing import List, Tuple, Optional + +import torch +from packaging.version import Version as PkgVersion + +logger = logging.getLogger(__name__) + +# Detect if Transformer Engine is installed +try: + import transformer_engine # pylint: disable=W0611 + from transformer_engine.pytorch.module.base import TransformerEngineBaseModule + + HAVE_TE = True +except (ImportError, ModuleNotFoundError): + HAVE_TE = False + logger.info("Using Megatron-FSDP without Transformer Engine.") + +# Detect the Transformer Engine version +try: + import transformer_engine as te + + if hasattr(te, "__version__"): + TE_VERSION = PkgVersion(str(te.__version__)) + else: + TE_VERSION = PkgVersion(version("transformer-engine")) +except: + TE_VERSION = None + +# Detect the FP8 tensor class +try: + from transformer_engine.pytorch.tensor import QuantizedTensor + + HAVE_TE_FP8_TENSOR_CLASS = True + FP8_TENSOR_CLASS = QuantizedTensor +except: + try: + from transformer_engine.pytorch.float8_tensor import Float8Tensor + + HAVE_TE_FP8_TENSOR_CLASS = True + FP8_TENSOR_CLASS = Float8Tensor + except: + HAVE_TE_FP8_TENSOR_CLASS = False + +# Detect the MXFP8 tensor class +try: + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor + + HAVE_TE_MXFP8TENSOR = True +except: + HAVE_TE_MXFP8TENSOR = False + +# Detect the "cast_master_weights_to_fp8" function of Transformer Engine +try: + from transformer_engine.pytorch.tensor.utils import cast_master_weights_to_fp8 + + HAVE_TE_CAST_MASTER_WEIGHTS_TO_FP8 = True +except: + HAVE_TE_CAST_MASTER_WEIGHTS_TO_FP8 = False + + # Try to import multi_tensor_apply, used in the fallback of fp8 quantization. + try: + from transformer_engine.pytorch.optimizers import multi_tensor_applier, multi_tensor_scale + + multi_tensor_scale_impl = multi_tensor_scale + except ImportError: + try: + import amp_C + from apex.multi_tensor_apply import multi_tensor_applier + + multi_tensor_scale_impl = amp_C.multi_tensor_scale + except ImportError: + import warnings + + warnings.warn( + "Transformer Engine and Apex are not installed. " + "Falling back to local implementations of " + "multi_tensor_applier and multi_tensor_scale" + ) + + def local_multi_tensor_applier(op, noop_flag_buffer, tensor_lists, *args): + """Multi tensor op applier""" + return op(2048 * 32, noop_flag_buffer, tensor_lists, *args) + + def local_multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale): + """Works as a drop-in replacement for amp_C.multi_tensor_scale.""" + for src, dst in zip(tensor_lists[0], tensor_lists[1]): + dst.copy_(src * scale) + + multi_tensor_applier = local_multi_tensor_applier + multi_tensor_scale_impl = local_multi_tensor_scale + + def _multi_tensor_copy_this_to_that( + this: List[torch.Tensor], + that: List[torch.Tensor], + overflow_buf: Optional[torch.Tensor] = None, + ): + """ + Use multi-tensor-applier to copy values from one list to another. + We don't have a bfloat16 implementation so for now if the overflow_buf + is not provided, we default back to simple loop copy to be compatible + with bfloat16. + """ + if overflow_buf is not None: + overflow_buf.fill_(0) + # Scaling with factor `1.0` is equivalent to copy. + multi_tensor_applier(multi_tensor_scale_impl, overflow_buf, [this, that], 1.0) + else: + for this_, that_ in zip(this, that): + that_.copy_(this_) + +# Detect the "post_all_gather_processing" function of Transformer Engine +try: + from transformer_engine.pytorch.tensor.utils import post_all_gather_processing + + HAVE_TE_POST_ALL_GATHER_PROCESSING = True +except: + HAVE_TE_POST_ALL_GATHER_PROCESSING = False + + +def is_te_min_version(vers, check_equality=True): + """Check if minimum version of `transformer-engine` is installed.""" + if not isinstance(TE_VERSION, PkgVersion): + return False + + if check_equality: + return TE_VERSION >= PkgVersion(vers) + else: + return TE_VERSION > PkgVersion(vers) + + +def is_float8tensor(tensor: torch.Tensor) -> bool: + """Check if a tensor is a FP8 tensor.""" + return HAVE_TE and isinstance(tensor, FP8_TENSOR_CLASS) + + +def fp8_need_transpose_data(tensor: torch.Tensor) -> bool: + """Check if a FP8 tensor needs transpose data.""" + return HAVE_TE_MXFP8TENSOR and isinstance(tensor, MXFP8Tensor) + + +def fp8_need_transpose_data_for_meta_device_init(module: TransformerEngineBaseModule) -> bool: + """Check if a FP8 tensor needs transpose data, for meta device init scenario.""" + return HAVE_TE_MXFP8TENSOR and module.fp8_meta["recipe"].mxfp8() + + +def fp8_discard_transpose_cache(tensor: torch.Tensor) -> None: + """Discard the transpose cache of a FP8 tensor.""" + assert is_float8tensor(tensor), f"Type {type(tensor)} is not a FP8 tensor" + + if hasattr(tensor, "_transpose_invalid"): + tensor._transpose_invalid = True + tensor._transpose = None + elif not fp8_need_transpose_data(tensor): + tensor.update_usage(rowwise_usage=True, columnwise_usage=False) + + +def fp8_create_transpose_cache(tensors: List[torch.Tensor]) -> None: + """Create the transpose cache of a FP8 tensor.""" + if HAVE_TE_POST_ALL_GATHER_PROCESSING: + post_all_gather_processing(tensors) + else: + _fp8_create_transpose_cache_fallback(tensors) + + +def _fp8_create_transpose_cache_fallback(tensors: List[torch.Tensor]) -> None: + if not isinstance(tensors, list): + tensors = [tensors] + for tensor in tensors: + assert is_float8tensor(tensor), f"Type {type(tensor)} is not a FP8 tensor" + if hasattr(tensor, "_create_transpose"): + tensor._create_transpose() + else: + tensor._create_columnwise() + + +def fp8_set_raw_data(tensor: torch.Tensor, data: torch.Tensor, set_transpose: bool = False) -> None: + """Set the raw data of a Transformer Engine Float8Tensor.""" + assert is_float8tensor(tensor), f"Type {type(tensor)} is not a FP8 tensor" + + if set_transpose: + assert fp8_need_transpose_data(tensor), f"Type {type(tensor)} does not need transpose data" + data_attr = "_columnwise_data" + else: + data_attr = "_rowwise_data" if hasattr(tensor, "_rowwise_data") else "_data" + + old_data = getattr(tensor, data_attr) + assert old_data.dtype == data.dtype, "The data types of raw data don't match" + assert old_data.shape == data.shape, \ + f"Shape {old_data.shape} of old_data doesn't match {data.shape} of new_data" + setattr(tensor, data_attr, data) + + +def fp8_get_raw_data(tensor: torch.Tensor, get_transpose: bool = False) -> torch.Tensor: + assert is_float8tensor(tensor), f"Type {type(tensor)} is not a FP8 tensor" + + if get_transpose: + assert fp8_need_transpose_data(tensor), f"Type {type(tensor)} does not need transpose data" + data_attr = "_columnwise_data" + else: + data_attr = "_rowwise_data" if hasattr(tensor, "_rowwise_data") else "_data" + + return getattr(tensor, data_attr) + + +def fp8_dequantize(tensor: torch.Tensor) -> torch.Tensor: + assert is_float8tensor(tensor), f"Type {type(tensor)} is not a FP8 tensor" + assert is_te_min_version("2.0"), \ + "Transformer Engine >= 2.0 is required for dequantizing parameters." + return tensor.dequantize() + + +def fp8_quantize( + model_params: List[torch.Tensor], + main_params: List[torch.Tensor], + start_offsets: List[int], + data_parallel_group: torch.distributed.ProcessGroup, + fsdp_shard_model_params: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] +) -> None: + if len(model_params) == 0: + return + fsdp_shard_model_params = [x[0] if x[1] is None else x for x in fsdp_shard_model_params] + + if HAVE_TE_CAST_MASTER_WEIGHTS_TO_FP8: + cast_master_weights_to_fp8( + model_params, main_params, start_offsets, data_parallel_group, fsdp_shard_model_params + ) + else: + _fp8_quantize_fallback( + model_params, main_params, start_offsets, data_parallel_group, fsdp_shard_model_params + ) + + +def _fp8_quantize_fallback( + model_params: List[torch.Tensor], + main_params: List[torch.Tensor], + start_offsets: List[int], + data_parallel_group: torch.distributed.ProcessGroup, + fsdp_shard_model_params: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] +) -> None: + for model_param, main_param, start_offset, fsdp_shard_model_param in zip( + model_params, main_params, start_offsets, fsdp_shard_model_params + ): + if main_param is None: + continue + + if fsdp_shard_model_param is not None: + shard_model_param = fsdp_shard_model_param + else: + shard_model_param = model_param._data.view(-1)[ + start_offset : start_offset + main_param.numel() + ] + + quantizer = model_param._quantizer + # When not using fp8 params, the main_param (fp32) is first cast to bf16/fp16, and then + # cast to fp8 during forward. This logic keeps numerical consistency with bf16 params. + main_param = main_param.to(model_param.dtype) + out = Float8Tensor( + shape=main_param.size(), + dtype=model_param.dtype, + requires_grad=False, + data=shard_model_param, + fp8_scale_inv=model_param._scale_inv, + fp8_dtype=model_param._fp8_dtype, + quantizer=quantizer, + ) + quantizer.update_quantized(main_param, out) + + amaxes = [] + scales = [] + scale_invs = [] + for model_param in model_params: + quantizer = model_param._quantizer + amaxes.append(quantizer.amax.view(1)) + scales.append(quantizer.scale.view(1)) + scale_invs.append(model_param._scale_inv.view(1)) + model_param._reset_caches() + + dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") + + # Update scaling factors. + packed_scales = torch.empty(len(scales), dtype=torch.float32, device=scales[0].device) + packed_scale_views = [packed_scales[i].view(1) for i in range(len(scales))] + _multi_tensor_copy_this_to_that(scales, packed_scale_views, dummy_overflow_buf) + torch.reciprocal(packed_scales, out=packed_scales) + _multi_tensor_copy_this_to_that(packed_scale_views, scale_invs, dummy_overflow_buf) + + # Reduce amaxes. + # Note: Assume each param has a separate amax. + packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device) + packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))] + _multi_tensor_copy_this_to_that(amaxes, packed_amax_views, dummy_overflow_buf) + torch.distributed.all_reduce( + packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=data_parallel_group + ) + _multi_tensor_copy_this_to_that(packed_amax_views, amaxes, dummy_overflow_buf) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index 2f02792146..cc5c716e9d 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -24,7 +24,16 @@ from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from .utils import FSDPDistributedIndex - +from .low_precision import is_float8tensor, fp8_discard_transpose_cache, fp8_create_transpose_cache +from .param_and_grad_buffer import ( + AllGatherPipeline, + BucketingPolicy, + GradReducePipeline, + ParamAndGradBuffer, + PrefetchOrder, + override_sharded_param_methods_with_safety_checks, + to_local_if_dtensor, +) logger = logging.getLogger(__name__) @@ -34,23 +43,12 @@ from megatron.core.distributed.distributed_data_parallel_config import ( DistributedDataParallelConfig, ) - from megatron.core.fp8_utils import is_float8tensor from megatron.core.utils import is_submodule except ImportError: # Megatron-LM is not installed, use Megatron-FSDP as a standalone module. logger.info("Megatron Core is not installed, Megatron-FSDP will run without Megatron Core.") from .distributed_data_parallel_config import DistributedDataParallelConfig - from .utils import is_float8tensor, is_submodule - -from .param_and_grad_buffer import ( - AllGatherPipeline, - BucketingPolicy, - GradReducePipeline, - ParamAndGradBuffer, - PrefetchOrder, - override_sharded_param_methods_with_safety_checks, - to_local_if_dtensor, -) + from .utils import is_submodule class TrainingState(Enum): @@ -397,6 +395,7 @@ def all_gather_and_wait_parameters_ready( prefetch=True, prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER, wait_bucket_ready=True, + bwd=False, ): """ All-gather parameters across the data parallel group and wait for @@ -423,11 +422,14 @@ def all_gather_and_wait_parameters_ready( and self.ddp_config.outer_dp_sharding_strategy != "no_shard" and (self.microbatch_count == 0 or self.model_auto_sync) ), + bwd=bwd, ) if wait_bucket_ready: for param in params: bucket_id = self.param_and_grad_buffer.param_to_param_group[param] - ag_pipeline.wait_bucket_ready(bucket_id) + ag_pipeline.wait_bucket_ready(bucket_id, bwd) + if bwd and is_float8tensor(param): + fp8_create_transpose_cache(param) for param in params: # This setting is needed to make FSDP store the weight object when used @@ -481,10 +483,10 @@ def _register_fsdp_hooks(self, root_module): """ fsdp_unit_modules = self.fsdp_unit_modules - def release_module_parameters(module, *unused): + def release_module_parameters(module, bwd, *unused): for param in module.parameters(): bucket_id = self.param_and_grad_buffer.param_to_param_group[param] - self.all_gather_pipeline.release_bucket(bucket_id) + self.all_gather_pipeline.release_bucket(bucket_id, bwd) if not self.ddp_config.keep_fp8_transpose_cache: release_params_fp8_transpose_cache(module.parameters()) @@ -492,8 +494,7 @@ def release_module_parameters(module, *unused): def release_params_fp8_transpose_cache(params): for param in params: if is_float8tensor(param): - param._transpose_invalid = True - param._transpose = None + fp8_discard_transpose_cache(param) def _grad_acc(param): """ @@ -550,7 +551,7 @@ def _post_backward(module, *unused): if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params": # Deallocate the module parameters after the backward pass, # because we have our data-parallel gradients computed. - release_module_parameters(module) + release_module_parameters(module, bwd=True) module._training_state = TrainingState.IDLE param_list = list(module.parameters()) else: @@ -606,6 +607,12 @@ def _pre_forward_param_unshard( # that are not FSDP units. Do not recurse unless absolutely necessary, # to allocate as little memory as possible for this forward pass. param_list = list(module.parameters(recurse=False)) + # TODO(mxfp8): Do we really need this? + self.all_gather_and_wait_parameters_ready( + params=param_list, + prefetch=False, + bwd=True, + ) # All-gather the parameters before the forward pass. self.all_gather_and_wait_parameters_ready( @@ -718,7 +725,9 @@ def _pre_backward(module: nn.Module, *unused): if isinstance(module, tuple(fsdp_unit_modules)): # All-gather / unshard the module parameters before the backward pass. self.all_gather_and_wait_parameters_ready( - list(module.parameters()), prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER + list(module.parameters()), + prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER, + bwd=True ) self._root_pre_backward_hook_issued = False @@ -746,7 +755,9 @@ def _root_pre_backward(module: nn.Module, *unused): for bucket_id in range(ag_pipeline.num_buckets): group = self.param_and_grad_buffer.parameter_groups[bucket_id] if group.fsdp_unit_id is not None: - ag_pipeline.bucket_can_be_released[bucket_id] = True + ag_pipeline.bucket_can_be_released[ + ag_pipeline.get_bucket_key(bucket_id, bwd=False) + ] = True # Track parameters that require gradient reduction and optimization. self._params_require_handle_grad = set() for param_group in self.param_and_grad_buffer.parameter_groups: @@ -769,7 +780,7 @@ def _post_forward(module: nn.Module, input: Any, output: Any): return output # Release the module parameters after the forward pass to save memory. - release_module_parameters(module) + release_module_parameters(module, bwd=False) module._training_state = TrainingState.IDLE return output @@ -977,17 +988,30 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo else: self.synchronize_param_gather() for bucket_id in range(self.all_gather_pipeline.num_buckets): - self.all_gather_pipeline.async_bucket_gather(bucket_id=bucket_id) group = self.param_and_grad_buffer.parameter_groups[bucket_id] + + self.all_gather_pipeline.async_bucket_gather(bucket_id=bucket_id, bwd=False) + # TODO(mxfp8): Is this correct? + if group.transpose_weight_buffer is not None: + self.all_gather_pipeline.async_bucket_gather(bucket_id=bucket_id, bwd=True) + if group.model_weight_buffer is None: continue if group.model_weight_buffer.is_data_distributed: # If model weight is sharded, we wait for the all-gather to complete and # then release the bucket immediately to save memory usage. - self.all_gather_pipeline.wait_bucket_ready(bucket_id) + self.all_gather_pipeline.wait_bucket_ready(bucket_id, False) + # TODO(mxfp8): Is this correct? + if group.transpose_weight_buffer is not None: + self.all_gather_pipeline.wait_bucket_ready(bucket_id, True) + for bucket_id in range(self.all_gather_pipeline.num_buckets): - self.all_gather_pipeline.wait_bucket_ready(bucket_id) + group = self.param_and_grad_buffer.parameter_groups[bucket_id] + self.all_gather_pipeline.wait_bucket_ready(bucket_id, False) + # TODO(mxfp8): Is this correct? + if group.transpose_weight_buffer is not None: + self.all_gather_pipeline.wait_bucket_ready(bucket_id, True) def start_grad_sync(self, *unused): """ diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py index a987ec2cec..0c57e835d4 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py @@ -35,6 +35,17 @@ from .uneven_dtensor import update_uneven_dtensor_chunk_metadata, validate_uneven_dtensor from .utils import _MODEL_PARALLEL_RNG_TRACKER_NAME, FSDPDistributedIndex, get_global_memory_buffer +from .low_precision import ( + is_te_min_version, + is_float8tensor, + fp8_need_transpose_data, + fp8_need_transpose_data_for_meta_device_init, + fp8_discard_transpose_cache, + fp8_set_raw_data, + fp8_get_raw_data, + fp8_dequantize, + fp8_quantize, +) logger = logging.getLogger(__name__) @@ -44,27 +55,15 @@ from megatron.core.distributed.distributed_data_parallel_config import ( DistributedDataParallelConfig, ) - from megatron.core.fp8_utils import ( - is_float8tensor, - modify_underlying_storage, - quantize_param_shard, - ) from megatron.core.tensor_parallel import get_cuda_rng_tracker - from megatron.core.utils import is_submodule, is_te_min_version + from megatron.core.utils import is_submodule logger.info("Detected Megatron Core, using Megatron-FSDP with Megatron.") except ImportError: # Megatron-LM is not installed, use Megatron-FSDP as a standalone module. from .distributed_data_parallel_config import DistributedDataParallelConfig - from .utils import ( - get_cuda_rng_tracker, - is_float8tensor, - is_submodule, - is_te_min_version, - modify_underlying_storage, - quantize_param_shard, - ) + from .utils import get_cuda_rng_tracker, is_submodule logger.info("Megatron Core is not installed, Megatron-FSDP will run without Megatron Core.") @@ -804,7 +803,7 @@ def __init__( data_parallel_group: Optional[torch.distributed.ProcessGroup] = None, dp_rank: Optional[int] = None, temporary_bucket_allocator: Optional[TemporaryBucketAllocator] = None, - is_dtype_float8: bool = False, + is_transpose_buffer: bool = False, gradient_scaling_factor: Optional[float] = None, chunk_size_factor: int = 1, mem_alloc_context: Optional[Callable] = None, @@ -837,7 +836,7 @@ def __init__( self.temporary_bucket_allocator = ( temporary_bucket_allocator if temporary_bucket_allocator else TemporaryBucketAllocator() ) - self.is_dtype_float8 = is_dtype_float8 + self.is_transpose_buffer = is_transpose_buffer self.gradient_scaling_factor = gradient_scaling_factor self.mem_alloc_context = mem_alloc_context if mem_alloc_context else nullcontext @@ -933,11 +932,11 @@ def fetch_bucket( for p in self.params: item_id = self.param_idx[p] p = to_local_if_dtensor(p) + data = self.get_item_from_bucket(bucket, item_id).view(p.shape) if is_float8tensor(p): - p._data = self.get_item_from_bucket(bucket, item_id).view(p.shape) + fp8_set_raw_data(p, data, self.is_transpose_buffer) else: - p.data = self.get_item_from_bucket(bucket, item_id).view(p.shape) - + p.data = data return bucket def free_bucket_storage(self): @@ -1106,6 +1105,9 @@ def set_item(self, item_id: int, item_data: torch.Tensor) -> None: # When fully sharded, we need to get the slice of the item to be stored in this shard. # Otherwise, we can just flatten the entire item since this buffer contains # the entire bucket. + if is_float8tensor(item_data): + item_data = fp8_get_raw_data(item_data, self.is_transpose_buffer) + if self.is_data_distributed: # Get the coordinates of the slice of the item that is contained in this shard. slice_start, slice_end = self._get_item_slice_in_shard(item_id) @@ -1212,6 +1214,8 @@ class ParameterGroup: Factor determining chunk size for grouped parameter processing. model_weight_buffer (Optional[DataParallelBuffer]): Buffer used to store model weights for data-parallel operations. + transpose_weight_buffer (Optional[DataParallelBuffer]): + Buffer used to store transpose weights for data-parallel operations. main_weight_buffer (Optional[DataParallelBuffer]): Buffer used to store main model weights for data-parallel operations. main_grad_buffer (Optional[DataParallelBuffer]): @@ -1231,6 +1235,7 @@ class ParameterGroup: fsdp_unit_id: Optional[int] = None chunk_size_factor: int = 1 model_weight_buffer: Optional[DataParallelBuffer] = None + transpose_weight_buffer: Optional[DataParallelBuffer] = None main_weight_buffer: Optional[DataParallelBuffer] = None main_grad_buffer: Optional[DataParallelBuffer] = None hsdp_wbuf: Optional[DataParallelBuffer] = None @@ -1301,12 +1306,10 @@ def _does_param_require_new_bucket(param): parameter_groups = [] for name, param in module.named_parameters(): # We need this information to correctly dynamically allocate Tensors! + is_fp8 = is_float8tensor(param) + is_fp8_meta_device_init = meta_device_init_fp8_params.get(name, (False, False))[0] param_attrs = dict( - dtype=( - "float8" - if is_float8tensor(param) or meta_device_init_fp8_params.get(name, False) - else param.dtype - ), + dtype="float8" if (is_fp8 or is_fp8_meta_device_init) else param.dtype, is_expert_param=is_expert_parameter(param), requires_grad=param.requires_grad, fsdp_unit_id=None, @@ -1626,7 +1629,9 @@ def __init__( # to determine whether this parameter is fp8 or not. fp8_meta_index = m.param_init_meta[name].fp8_meta_index if m.primary_weights_in_fp8 and fp8_meta_index is not None: - meta_device_init_fp8_params[self.param_to_name[param]] = True + meta_device_init_fp8_params[self.param_to_name[param]] = ( + True, fp8_need_transpose_data_for_meta_device_init(m) + ) # Get the parameter groups. (self.parameter_groups, self.param_to_param_group, self.bucket_to_bucket_group) = ( @@ -1689,6 +1694,7 @@ def _bytes_to_mb(bytes_val: int) -> str: numel = sum(to_local_if_dtensor(p).shape.numel() for p in group.params) buffers = { "weight": group.model_weight_buffer, + "transpose_weight": group.transpose_weight_buffer, "main_weight": group.main_weight_buffer, "grad": group.main_grad_buffer, } @@ -1758,12 +1764,14 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): self.weight_alloc = FixedPoolAllocator( name="fsdp_params", fsdp_param_groups=self.parameter_groups, size=UB_BUFFER_NUM ) + # TODO(mxfp8): Do we need separate alloc for transpose buffer? self.main_grad_alloc = FixedPoolAllocator( name="fsdp_grads", fsdp_param_groups=self.parameter_groups, size=UB_BUFFER_NUM ) self.double_buf_units = self.weight_alloc.fsdp_double_buffer_units else: self.weight_alloc = StorageResizeBasedBucketAllocator() + # TODO(mxfp8): Do we need separate alloc for transpose buffer? self.main_grad_alloc = None self.double_buf_units = [] @@ -1804,8 +1812,8 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): # Check if the parameter group is FP8. one_param = group.params[0] is_dtype_float8 = is_float8tensor(one_param) or meta_device_init_fp8_params.get( - self.param_to_name[one_param], False - ) + self.param_to_name[one_param], (False, False) + )[0] if is_dtype_float8: param_dtype = torch.uint8 grad_dtype = torch.bfloat16 @@ -1813,6 +1821,15 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): param_dtype = group.params[0].dtype grad_dtype = param_dtype + # Check if the parameter group needs a transpose buffer for model weights. + # Currently, only mxfp8 needs it. + need_transpose_data = is_float8tensor(one_param) and fp8_need_transpose_data(one_param) + need_transpose_data_for_meta_device_init = meta_device_init_fp8_params.get( + self.param_to_name[one_param], (False, False) + )[1] + should_create_transpose_weight_buffer = \ + need_transpose_data or need_transpose_data_for_meta_device_init + # Check if the parameter group requires a grad buffer or main weight buffer. should_create_grad_buffer_or_main_weight_buffer = ( not self.only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad @@ -1829,13 +1846,29 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): dtype=param_dtype, device=self.device, data_parallel_group=main_buf_dp_group, - is_dtype_float8=is_dtype_float8, + is_transpose_buffer=False, temporary_bucket_allocator=self.weight_alloc, bucket_id=group_id, chunk_size_factor=group.chunk_size_factor, mem_alloc_context=self.mem_alloc_context, **main_buf_extra_kwargs, ) + if should_create_transpose_weight_buffer: + group.transpose_weight_buffer = DataParallelBuffer( + self.ddp_config, + group.params, + is_data_distributed=is_model_weight_buffer_distributed + and main_buf_dp_group.size() > 1, + dtype=param_dtype, + device=self.device, + data_parallel_group=main_buf_dp_group, + is_transpose_buffer=True, + temporary_bucket_allocator=self.weight_alloc, # TODO(mxfp8): Do we need separate alloc for transpose buffer? + bucket_id=group_id, + chunk_size_factor=group.chunk_size_factor, + mem_alloc_context=self.mem_alloc_context, + **main_buf_extra_kwargs, + ) # Initialize the main weight buffer. if should_create_grad_buffer_or_main_weight_buffer and preserve_fp32_weights: @@ -1867,7 +1900,7 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): dtype=torch.float32 if grad_reduce_in_fp32 else grad_dtype, device=self.device, data_parallel_group=main_buf_dp_group, - is_dtype_float8=False, + is_transpose_buffer=False, temporary_bucket_allocator=self.main_grad_alloc, gradient_scaling_factor=gradient_scaling_factor, bucket_id=group_id, @@ -1891,7 +1924,7 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): dtype=wbuf.dtype, device=wbuf.device, data_parallel_group=hsdp_buf_dp_group, - is_dtype_float8=wbuf.is_dtype_float8, + is_transpose_buffer=False, temporary_bucket_allocator=self.weight_alloc, bucket_id=group_id, chunk_size_factor=group.chunk_size_factor, @@ -1907,6 +1940,9 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): ), ) + if group.transpose_weight_buffer is not None: + raise NotImplementedError("HSDP for transpose buffer is not implemented yet") + if should_create_grad_buffer_or_main_weight_buffer: # Initialize the HSDP grad buffer. gbuf = group.main_grad_buffer @@ -1918,7 +1954,7 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): dtype=gbuf.dtype, device=gbuf.device, data_parallel_group=hsdp_buf_dp_group, - is_dtype_float8=gbuf.is_dtype_float8, + is_transpose_buffer=False, temporary_bucket_allocator=self.main_grad_alloc, gradient_scaling_factor=gradient_scaling_factor, bucket_id=group_id, @@ -2001,6 +2037,20 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): torch.empty(wbuf.data_size, dtype=wbuf.dtype, device=self.device) ) bucket = wbuf.fetch_bucket() + + tbuf = group.transpose_weight_buffer + if tbuf: + with self.mem_alloc_context(): + if group.hsdp_wbuf: + raise NotImplementedError( + "HSDP for transpose buffer is not implemented yet" + ) + else: + tbuf.init_data( + torch.empty(tbuf.data_size, dtype=tbuf.dtype, device=self.device) + ) + transpose_bucket = tbuf.fetch_bucket() + mbuf = group.main_weight_buffer if mbuf: # Manually instantiate an empty tensor into the main weight buffer. @@ -2054,25 +2104,41 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): if not self.ddp_config.keep_fp8_transpose_cache: for _param in m.parameters(recurse=False): if is_float8tensor(_param): - _param._transpose_invalid = True - _param._transpose = None + fp8_discard_transpose_cache(_param) # Raise error if a meta parameter still exists after initialization. assert not p.is_meta, (self.param_to_name[p], module_reset_flag) + p_local = to_local_if_dtensor(p) + # Copy the model weight parameter tensor into the buffer. # When distributed, this shards and preserves the data across all ranks. - wbuf.set_item(item_id, to_local_if_dtensor(p)) + wbuf.set_item(item_id, p_local) + if tbuf: + tbuf.set_item(item_id, p_local) # Retrieve the newly allocated parameter data from the global bucket. # Attach the bucket-allocated parameter data to the module parameter, # to use the bucket-allocated data for autograd and NCCL. - new_param_data = wbuf.get_item_from_bucket(bucket, item_id).view( - to_local_if_dtensor(p).shape - ) - if is_float8tensor(p): - # Needed to instantiate FP8 parameters. Requires installing - # TransformerEngine. - modify_underlying_storage(p, new_param_data) + new_param_data = wbuf.get_item_from_bucket(bucket, item_id).view(p_local.shape) + if tbuf: + new_transpose_data = tbuf.get_item_from_bucket( + transpose_bucket, item_id + ).view(p_local.shape) + else: + new_transpose_data = None + + if is_float8tensor(p_local): + old_param_data = fp8_get_raw_data(p_local) + assert old_param_data._base is None + new_param_data.detach().copy_(old_param_data) + fp8_set_raw_data(p_local, new_param_data) + del old_param_data + if new_transpose_data is not None: + old_transpose_data = fp8_get_raw_data(p_local, True) + assert old_transpose_data._base is None + new_transpose_data.detach().copy_(old_transpose_data) + fp8_set_raw_data(p_local, new_transpose_data, True) + del old_transpose_data elif isinstance(p, DTensor): old_param_data = p._local_tensor.data p._local_tensor.data = new_param_data @@ -2110,7 +2176,11 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): # the (high-precision) main weight buffer. # Nothing else needs to be done, because the main weights # do not require autograd operations, only possibly sharding. - mbuf.set_item(item_id, to_local_if_dtensor(p)) + p_local = to_local_if_dtensor(p) + if is_float8tensor(p_local): + mbuf.set_item(item_id, fp8_dequantize(p_local)) + else: + mbuf.set_item(item_id, p_local) if wbuf and wbuf.is_data_distributed: # Free the memory backing the temporarily-allocated bucket associated @@ -2241,6 +2311,7 @@ def _reset_parameters(self, old_params, new_params): group.params[item_id] = new_p for buf in [ group.model_weight_buffer, + group.transpose_weight_buffer, group.main_weight_buffer, group.main_grad_buffer, group.hsdp_wbuf, @@ -2288,6 +2359,7 @@ def _init_distributed_params(self): dist_main_weight = {} for pg in self.parameter_groups: wbuf = pg.model_weight_buffer + tbuf = pg.transpose_weight_buffer mbuf = pg.main_weight_buffer for item_id, orig_param in enumerate(pg.params): param_name = self.param_to_name[orig_param] @@ -2314,6 +2386,8 @@ def _init_distributed_params(self): ) dist_main_weight[param_name] = dist_param elif wbuf: + assert tbuf is None, \ + "Transpose buffer should only exist when main params exist" dist_param = make_fsdp_dtensor( local_tensor=wbuf.get_item(item_id, only_shard=sharded_optimizer_state), param=orig_param, @@ -2388,6 +2462,7 @@ def set_param_attribute(): # MCore. mbuf = pg.model_weight_buffer if mbuf: + # TODO(mxfp8): Do we need to consider transpose buffer? _start, _end = mbuf._get_item_slice_in_shard(item_id) setattr(dist_param, "megatron_fsdp_slice", slice(_start, _end)) @@ -2480,6 +2555,7 @@ def copy_main_weights_to_model_weights(self): for pg in self.parameter_groups: mbuf = pg.main_weight_buffer wbuf = pg.model_weight_buffer + tbuf = pg.transpose_weight_buffer if mbuf is None: continue @@ -2500,9 +2576,17 @@ def copy_main_weights_to_model_weights(self): if wbuf: if wbuf.is_data_distributed or mbuf.is_data_distributed: model_param = wbuf.get_item(item_id, only_shard=True) + if tbuf: + transpose_param = tbuf.get_item(item_id, only_shard=True) + else: + transpose_param = None main_weight = mbuf.get_item(item_id, only_shard=True) else: model_param = wbuf.get_item(item_id) + if tbuf: + transpose_param = tbuf.get_item(item_id) + else: + transpose_param = None main_weight = mbuf.get_item(item_id) else: assert not mbuf.is_data_distributed @@ -2514,11 +2598,11 @@ def copy_main_weights_to_model_weights(self): if model_param.numel() == 0: shard_fp32_from_fp8.append(None) shard_offsets_in_fp8.append(None) - shard_model_params.append(None) + shard_model_params.append([None, None]) else: shard_fp32_from_fp8.append(main_weight) shard_offsets_in_fp8.append(wbuf.locate_item_in_global_item(item_id)[0]) - shard_model_params.append(model_param) + shard_model_params.append([model_param, transpose_param]) continue if model_param.numel() > 0: @@ -2527,12 +2611,12 @@ def copy_main_weights_to_model_weights(self): if len(dense_param_quantize_kwargs["model_params"]) > 0: # If we have FP8 parameters, we need to quantize them. dense_param_quantize_kwargs["data_parallel_group"] = data_parallel_group - quantize_param_shard(**dense_param_quantize_kwargs) + fp8_quantize(**dense_param_quantize_kwargs) if len(expert_param_quantize_kwargs["model_params"]) > 0: # If we have FP8 expert parameters, we need to quantize them. expert_param_quantize_kwargs["data_parallel_group"] = expert_data_parallel_group - quantize_param_shard(**expert_param_quantize_kwargs) + fp8_quantize(**expert_param_quantize_kwargs) @torch.no_grad() def copy_model_weights_to_main_weights(self): @@ -2550,6 +2634,7 @@ def copy_model_weights_to_main_weights(self): f"Master weight buffer size {mbuf.data.numel()} does not match " f"model weight buffer size {copyin_data.numel()}" ) + # TODO(mxfp8): Make sure it's not a fp8 buf? mbuf.data.copy_(copyin_data.data) def all_gather_parameters(self, async_op: bool = True): @@ -2567,15 +2652,18 @@ def all_gather_parameters(self, async_op: bool = True): all_gather_ops = [] for g in self.parameter_groups: - shard = g.model_weight_buffer.get_shard_from_local_buffer() - all_gather_handler = torch.distributed.all_gather_into_tensor( - output_tensor=g.model_weight_buffer.data, - input_tensor=shard, - group=g.model_weight_buffer.data_parallel_group, - async_op=async_op, - ) - if async_op: - all_gather_ops.append(all_gather_handler) + for buf in [g.model_weight_buffer, g.transpose_weight_buffer]: + if buf is None: + continue + shard = buf.get_shard_from_local_buffer() + all_gather_handler = torch.distributed.all_gather_into_tensor( + output_tensor=buf.data, + input_tensor=shard, + group=buf.data_parallel_group, + async_op=async_op, + ) + if async_op: + all_gather_ops.append(all_gather_handler) for op in all_gather_ops: op.wait() @@ -2597,7 +2685,7 @@ def reduce_scatter_gradients(self, async_op: bool = True): for g in self.parameter_groups: gbuf = g.main_grad_buffer if gbuf is not None: - continue + continue # TODO(mxfp8): This is an error? scaling_factor = gbuf.gradient_scaling_factor reduce_op = gradient_reduce_preprocessing(gbuf.data, scaling_factor, self.ddp_config) reduce_scatter_handler = torch.distributed.reduce_scatter_tensor( @@ -3043,9 +3131,16 @@ def __init__( # Track the status of all-gather operations for each bucket. self.param_gather_event_map = {} # All buckets are initially deallocated / empty after initialization of ParamAndGradBuffer. - self.bucket_status = {i: BucketStatus.EMPTY for i in range(self.buffer.num_buckets)} + self.bucket_status = {} + for i in range(self.buffer.num_buckets): + for bwd in [False, True]: + self.bucket_status[self.get_bucket_key(i, bwd)] = BucketStatus.EMPTY + # Track whether each bucket can be deallocated. - self.bucket_can_be_released = {i: False for i in range(self.buffer.num_buckets)} + self.bucket_can_be_released = {} + for i in range(self.buffer.num_buckets): + for bwd in [False, True]: + self.bucket_can_be_released[self.get_bucket_key(i, bwd)] = False # Map each bucket to the bucket group it belongs to by enumerated ID. # Made to collect a subset of buckets in the same bucket group. @@ -3070,6 +3165,11 @@ def __init__( # all-gather parameters across groups. self.outer_fsdp_group_param_gather_stream = torch.cuda.Stream() + def get_bucket_key(self, bucket_id, bwd): + has_transpose_buffer = \ + self.buffer.parameter_groups[bucket_id].transpose_weight_buffer is not None + return (bucket_id, has_transpose_buffer and bwd) + @property def num_buckets(self): """Return the number of buckets.""" @@ -3086,10 +3186,11 @@ def reset(self): UserWarning, ) while len(self.param_gather_event_map) > 0: - bucket_id = next(iter(self.param_gather_event_map)) - self.wait_bucket_ready(bucket_id) + (bucket_id, bwd) = next(iter(self.param_gather_event_map)) + self.wait_bucket_ready(bucket_id, bwd) for bucket_id in range(self.num_buckets): - self.bucket_can_be_released[bucket_id] = True + for bwd in [False, True]: + self.bucket_can_be_released[self.get_bucket_key(bucket_id, bwd)] = True self.recycle_unused_buckets() assert all([status is BucketStatus.EMPTY for status in self.bucket_status.values()]), ( @@ -3111,6 +3212,7 @@ def all_gather_params( suggested_AG_prefetch_size: Optional[int] = None, async_param_gather: bool = True, outer_fsdp_group_param_gather: bool = False, + bwd: bool = False, ): """All-gather the params. If prefetch is enabled, prefetch next buckets in the order of `prefetch_order`. @@ -3145,7 +3247,7 @@ def all_gather_params( # Do not release the buckets that are being all-gathered. for bucket_id in ag_buckets: - self.bucket_can_be_released[bucket_id] = False + self.bucket_can_be_released[self.get_bucket_key(bucket_id, bwd)] = False # If prefetch is enabled, we will add prefetch buckets to ag_buckets. if prefetch: @@ -3217,7 +3319,11 @@ def need_skip_prefetch(bucket_id): bucket_id = next_bucket_id(ag_buckets) # Only all-gather on buckets that have not been allocated yet. - ag_buckets = [i for i in ag_buckets if self.bucket_status[i] == BucketStatus.EMPTY] + ag_buckets = [ + bucket_id + for bucket_id in ag_buckets + if self.bucket_status[self.get_bucket_key(bucket_id, bwd)] == BucketStatus.EMPTY + ] if len(ag_buckets) == 0: return @@ -3236,6 +3342,7 @@ def need_skip_prefetch(bucket_id): self.ag_stream if self.ag_stream is not None else torch.cuda.current_stream() ) if outer_fsdp_group_param_gather: + # TODO(mxfp8): Support hsdp self.outer_fsdp_group_param_gather_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.outer_fsdp_group_param_gather_stream): outer_fsdp_group = self.buffer.dist_index.get_outer_fsdp_group() @@ -3263,12 +3370,13 @@ def need_skip_prefetch(bucket_id): for bucket_id in buckets: # All-gather the module weights from each FSDP buffer shard # into an allocated bucket containing unsharded weights. - self.async_bucket_gather(bucket_id) + self.async_bucket_gather(bucket_id, bwd) # Replace the parameter all-gather event with coalescing event. for bucket_id in buckets: - _, mark_bucket_ready_to_use = self.param_gather_event_map[bucket_id] - self.param_gather_event_map[bucket_id] = ( + bucket_key = self.get_bucket_key(bucket_id, bwd) + _, mark_bucket_ready_to_use = self.param_gather_event_map[bucket_key] + self.param_gather_event_map[bucket_key] = ( coalescing_event, mark_bucket_ready_to_use, ) @@ -3276,14 +3384,16 @@ def need_skip_prefetch(bucket_id): # Wait for all-gather to finish if not async_param_gather: for bucket_id in buckets: - self.wait_bucket_ready(bucket_id) + self.wait_bucket_ready(bucket_id, bwd) - def wait_bucket_ready(self, bucket_id, empty_ok=False): + def wait_bucket_ready(self, bucket_id, bwd, empty_ok=False): """Wait for the bucket to be ready.""" - if self.bucket_status[bucket_id] == BucketStatus.READY_TO_USE: + bucket_key = self.get_bucket_key(bucket_id, bwd) + + if self.bucket_status[bucket_key] == BucketStatus.READY_TO_USE: # Already ready to use. return - if self.bucket_status[bucket_id] == BucketStatus.EMPTY: + if self.bucket_status[bucket_key] == BucketStatus.EMPTY: if empty_ok: return # Bucket shouldn't be empty, this implies that the bucket @@ -3291,48 +3401,67 @@ def wait_bucket_ready(self, bucket_id, empty_ok=False): raise ValueError(f"Bucket {bucket_id} is empty.") # Wait for asynchronous / overlapped NCCL operations to complete. - param_gather_event, mark_bucket_ready_to_use = self.param_gather_event_map.pop(bucket_id) + param_gather_event, mark_bucket_ready_to_use = self.param_gather_event_map.pop(bucket_key) param_gather_event.wait() mark_bucket_ready_to_use() @torch.no_grad() - def release_bucket(self, bucket_id: int): + def release_bucket(self, bucket_id, bwd): """Release the bucket.""" - if self.bucket_status[bucket_id] == BucketStatus.EMPTY: + # TODO(mxfp8): In some cases, there won't be ag before bwd? + bucket_key = self.get_bucket_key(bucket_id, bwd) + + if self.bucket_status[bucket_key] == BucketStatus.EMPTY: return - self.wait_bucket_ready(bucket_id, empty_ok=True) - if self.bucket_status[bucket_id] == BucketStatus.COMMUNICATING: + self.wait_bucket_ready(bucket_id, bwd, empty_ok=True) + if self.bucket_status[bucket_key] == BucketStatus.COMMUNICATING: raise ValueError(f"Bucket {bucket_id} is communicating.") - wbuf = self.buffer.parameter_groups[bucket_id].model_weight_buffer - wbuf.free_bucket_storage() - self.bucket_status[bucket_id] = BucketStatus.EMPTY + if ( + bwd + and + self.buffer.parameter_groups[bucket_id].transpose_weight_buffer is not None + ): + buf = self.buffer.parameter_groups[bucket_id].transpose_weight_buffer + else: + buf = self.buffer.parameter_groups[bucket_id].model_weight_buffer + buf.free_bucket_storage() + + self.bucket_status[bucket_key] = BucketStatus.EMPTY def recycle_unused_buckets(self): """Recycle the unused buckets.""" - for bucket_id, can_be_released in self.bucket_can_be_released.items(): + for (bucket_id, bwd), can_be_released in self.bucket_can_be_released.items(): if can_be_released: - self.release_bucket(bucket_id) - self.bucket_can_be_released[bucket_id] = False + self.release_bucket(bucket_id, bwd) + self.bucket_can_be_released[(bucket_id, bwd)] = False - def get_fsdp_buffer(self, bucket_id: int) -> DataParallelBuffer: + def get_fsdp_buffer(self, bucket_id: int, bwd = False) -> DataParallelBuffer: """Get the FSDP buffer with the given bucket ID.""" param_group = self.buffer.parameter_groups[bucket_id] if self.buffer.ddp_config.outer_dp_sharding_strategy != "no_shard": - return param_group.hsdp_wbuf - return param_group.model_weight_buffer + if bwd and param_group.transpose_weight_buffer is not None: + raise RuntimeError("Transpose buffer is not supported for HSDP") + else: + return param_group.model_weight_buffer + if bwd and param_group.transpose_weight_buffer is not None: + return param_group.transpose_weight_buffer + else: + return param_group.model_weight_buffer @torch.no_grad() - def async_bucket_gather(self, bucket_id: int) -> None: + def async_bucket_gather(self, bucket_id, bwd) -> None: """All-gather the bucket and set the items.""" - self.bucket_can_be_released[bucket_id] = False - if self.bucket_status[bucket_id] != BucketStatus.EMPTY: + bucket_key = self.get_bucket_key(bucket_id, bwd) + + self.bucket_can_be_released[bucket_key] = False + if self.bucket_status[bucket_key] != BucketStatus.EMPTY: return - self.bucket_status[bucket_id] = BucketStatus.COMMUNICATING + self.bucket_status[bucket_key] = BucketStatus.COMMUNICATING - wbuf = self.get_fsdp_buffer(bucket_id) + wbuf = self.get_fsdp_buffer(bucket_id, bwd) # Lazy release the unused buckets. self.recycle_unused_buckets() @@ -3347,18 +3476,20 @@ def async_bucket_gather(self, bucket_id: int) -> None: async_op=True, ) - def get_closure(bucket_id): + def get_closure(bucket_id, bwd): @torch.no_grad() def mark_bucket_ready_to_use(): # Mark the bucket as ready to use - all NCCL operations are complete. - self.bucket_status[bucket_id] = BucketStatus.READY_TO_USE + self.bucket_status[self.get_bucket_key(bucket_id, bwd)] = BucketStatus.READY_TO_USE return mark_bucket_ready_to_use - mark_bucket_ready_to_use = get_closure(bucket_id) + mark_bucket_ready_to_use = get_closure(bucket_id, bwd) # Track the async all-gather operation for the bucket. - self.param_gather_event_map[bucket_id] = (param_gather_event, mark_bucket_ready_to_use) + self.param_gather_event_map[self.get_bucket_key(bucket_id, bwd)] = ( + param_gather_event, mark_bucket_ready_to_use + ) @torch.no_grad() @@ -3452,6 +3583,7 @@ def override_sharded_param_methods_with_safety_checks(params, all_gather_pipelin def override_sharded_param_to_function_closure(p, to_function): def override_sharded_param_to_function(*args, **kwargs): bucket_id = all_gather_pipeline.buffer.param_to_param_group[p] + # TODO(mxfp8): Fix key of bucket_status. status = all_gather_pipeline.bucket_status[bucket_id] if status == BucketStatus.READY_TO_USE: return to_function(*args, **kwargs) @@ -3468,6 +3600,7 @@ def override_sharded_param_to_function(*args, **kwargs): def override_sharded_param_cpu_function_closure(p, cpu_function): def override_sharded_param_cpu_function(*args, **kwargs): bucket_id = all_gather_pipeline.buffer.param_to_param_group[p] + # TODO(mxfp8): Fix key of bucket_status. status = all_gather_pipeline.bucket_status[bucket_id] if status == BucketStatus.READY_TO_USE: return cpu_function(*args, **kwargs) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py index 1dfe08b90f..1a16e517dc 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py @@ -79,52 +79,6 @@ def is_te_min_version(vers, check_equality=True): return te_version > PkgVersion(vers) -# Check if Transformer Engine has class for fp8 tensors. -try: - if is_te_min_version("2.0"): - # In TE2.x, QuantizedTensor is the base class for all different type of fp8 tensors, - # including fp8 tensor for delayed scaling, current scaling and mxfp8, etc. - from transformer_engine.pytorch.tensor import QuantizedTensor as FP8_TENSOR_CLASS - else: - from transformer_engine.pytorch.float8_tensor import Float8Tensor as FP8_TENSOR_CLASS - - HAVE_TE_FP8_TENSOR_CLASS = True -except (ImportError, ModuleNotFoundError): - # FP8 tensor class not found - HAVE_TE_FP8_TENSOR_CLASS = False - -try: - from transformer_engine.pytorch.optimizers import multi_tensor_applier, multi_tensor_scale - - multi_tensor_scale_impl = multi_tensor_scale -except ImportError: - try: - import amp_C - from apex.multi_tensor_apply import multi_tensor_applier - - multi_tensor_scale_impl = amp_C.multi_tensor_scale - except ImportError: - import warnings - - warnings.warn( - "Transformer Engine and Apex are not installed. " - "Falling back to local implementations of " - "multi_tensor_applier and multi_tensor_scale" - ) - - def local_multi_tensor_applier(op, noop_flag_buffer, tensor_lists, *args): - """Multi tensor op applier""" - return op(2048 * 32, noop_flag_buffer, tensor_lists, *args) - - def local_multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale): - """Works as a drop-in replacement for amp_C.multi_tensor_scale.""" - for src, dst in zip(tensor_lists[0], tensor_lists[1]): - dst.copy_(src * scale) - - multi_tensor_applier = local_multi_tensor_applier - multi_tensor_scale_impl = local_multi_tensor_scale - - def is_submodule(module, parent_module, strict=True): """ Check if a module is a submodule of another module. @@ -138,18 +92,6 @@ def is_submodule(module, parent_module, strict=True): return False -def is_float8tensor(tensor: torch.Tensor) -> bool: - """Check if a tensor is a Transformer Engine Float8Tensor. - - Note that in TE2.x, in order to support more recipes, the design of the fp8 tensor class has - changed. Now Float8Tensor is only used for current scaling and delayed scaling. And mxfp8 - and blockwise scaling have their own fp8 tensor classes. These different fp8 tensor classes - are both inherited from QuantizedTensor. So, for TE1.x, FP8_TENSOR_CLASS is Float8Tensor, - and for TE2.x, FP8_TENSOR_CLASS is QuantizedTensor. - """ - return HAVE_TE_FP8_TENSOR_CLASS and isinstance(tensor, FP8_TENSOR_CLASS) - - def get_mesh_names(device_mesh: Optional[DeviceMesh] = None) -> list[str]: """ Get all the sub-mesh names in the DeviceMesh. @@ -191,198 +133,6 @@ def contains_submesh( return all(submesh_name in device_mesh_names for submesh_name in submesh_names) -def _multi_tensor_copy_this_to_that( - this: List[torch.Tensor], that: List[torch.Tensor], overflow_buf: Optional[torch.Tensor] = None -): - """ - Use multi-tensor-applier to copy values from one list to another. - We don't have a bfloat16 implementation so for now if the overflow_buf - is not provided, we default back to simple loop copy to be compatible - with bfloat16. - """ - if overflow_buf is not None: - overflow_buf.fill_(0) - # Scaling with factor `1.0` is equivalent to copy. - multi_tensor_applier(multi_tensor_scale_impl, overflow_buf, [this, that], 1.0) - else: - for this_, that_ in zip(this, that): - that_.copy_(this_) - - -""" -The code below abstracts the functionalities needed for implementing "--fp8-param-gather" into -several functions. It provides different implementations for each function based on different -versions of TE, ensuring compatibility across various TE versions. - -Currently, there are three functions: - - modify_underlying_storage - This function is used in DDP to place all parameters into a contiguous buffer. For - non-fp8 tensors, replacing their data is simple, just using code like - "tensor.data = new_data". However, for fp8 tensors, their raw data is not stored in the - ".data" attribute, and it varies with different TE versions and different recipes. This - function provides a unified interface to replace the underlying storage of a fp8 tensor. - - quantize_param_shard - This function is used in dist-opt to cast fp32 main params to fp8 params. For non-fp8 - params, this casting is as simple as "bf16_params.copy_(fp32_main_params)"; but for fp8 - params, the casting logic varies with different TE versions and different recipes. This - function provides a unified interface to cast fp32 main params to fp8 params, and also - updates the necessary attributes (like amax, scale, scale_inv or transpose cache) of the - fp8 model params. - - correct_amax_history_if_needed - This function is used to correct the amax history of fp8 tensors. In TE1.x, some inplace - copy operations will write unwanted values to the amax_history of fp8 tensors. This function - corrects the amax_history back. For TE2.x, it's an empty function. - Only useful for delayed scaling. -""" -if HAVE_TE and is_te_min_version("2.2"): - # Supported TE versions: 2.2+ - from transformer_engine.pytorch.tensor import QuantizedTensor - - def _modify_underlying_storage_impl( - fp8_tensor: QuantizedTensor, new_raw_data: torch.Tensor - ) -> None: - from transformer_engine.pytorch.tensor.utils import replace_raw_data - - replace_raw_data(fp8_tensor, new_raw_data) - - def _quantize_param_shard_impl( - model_params: List[QuantizedTensor], - main_params: List[torch.Tensor], - start_offsets: List[int], - data_parallel_group: ProcessGroup, - fsdp_shard_model_params: Optional[List[torch.Tensor]] = None, - ) -> None: - if len(model_params) == 0: - return - - from transformer_engine.pytorch.tensor.utils import cast_master_weights_to_fp8 - - args = [model_params, main_params, start_offsets, data_parallel_group] - if fsdp_shard_model_params is not None: - if get_te_version() == PkgVersion("2.3.0.dev0+5fdd7bb") or is_te_min_version("2.3.0"): - args.append(fsdp_shard_model_params) - else: - raise NotImplementedError( - f"FSDP with --fp8-param-gather is not supported in TE v{get_te_version()}" - ) - cast_master_weights_to_fp8(*args) - -elif HAVE_TE and is_te_min_version("2.0"): - # Supported TE versions: 2.0 - from transformer_engine.pytorch.tensor import QuantizedTensor - from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor - - def _modify_underlying_storage_impl( - fp8_tensor: QuantizedTensor, new_raw_data: torch.Tensor - ) -> None: - old_raw_data = fp8_tensor._data - assert old_raw_data.dtype == new_raw_data.dtype - new_raw_data.detach().copy_(old_raw_data) - fp8_tensor._data = new_raw_data - del old_raw_data - - def _quantize_param_shard_impl( - model_params: List[QuantizedTensor], - main_params: List[torch.Tensor], - start_offsets: List[int], - data_parallel_group: ProcessGroup, - fsdp_shard_model_params: Optional[List[torch.Tensor]] = None, - ) -> None: - if len(model_params) == 0: - return - - if fsdp_shard_model_params is None: - fsdp_shard_model_params = [None] * len(model_params) - - for model_param, main_param, start_offset, fsdp_shard_model_param in zip( - model_params, main_params, start_offsets, fsdp_shard_model_params - ): - if main_param is None: - continue - - if fsdp_shard_model_param is not None: - shard_model_param = fsdp_shard_model_param - else: - shard_model_param = model_param._data.view(-1)[ - start_offset : start_offset + main_param.numel() - ] - - quantizer = model_param._quantizer - # When not using --fp8-param-gather, the main_param (fp32) is first cast to bf16/fp16, - # and then cast to fp8 during forward. - # Although it's not necessary when --fp8-param-gather is enabled, we still keep this - # logic to keep numerical consistency. So here cast the main_param to model_param.dtype. - main_param = main_param.to(model_param.dtype) - out = Float8Tensor( - shape=main_param.size(), - dtype=model_param.dtype, - requires_grad=False, - data=shard_model_param, - fp8_scale_inv=model_param._scale_inv, - fp8_dtype=model_param._fp8_dtype, - quantizer=quantizer, - ) - quantizer.update_quantized(main_param, out) - - amaxes = [] - scales = [] - scale_invs = [] - for model_param in model_params: - quantizer = model_param._quantizer - amaxes.append(quantizer.amax.view(1)) - scales.append(quantizer.scale.view(1)) - scale_invs.append(model_param._scale_inv.view(1)) - model_param._reset_caches() - - dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") - - # Update scaling factors. - packed_scales = torch.empty(len(scales), dtype=torch.float32, device=scales[0].device) - packed_scale_views = [packed_scales[i].view(1) for i in range(len(scales))] - _multi_tensor_copy_this_to_that(scales, packed_scale_views, dummy_overflow_buf) - torch.reciprocal(packed_scales, out=packed_scales) - _multi_tensor_copy_this_to_that(packed_scale_views, scale_invs, dummy_overflow_buf) - - # Reduce amaxes. - # Note: Assume each param has a separate amax. - packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device) - packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))] - _multi_tensor_copy_this_to_that(amaxes, packed_amax_views, dummy_overflow_buf) - torch.distributed.all_reduce( - packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=data_parallel_group - ) - _multi_tensor_copy_this_to_that(packed_amax_views, amaxes, dummy_overflow_buf) - -else: - # Fallback impl if TE version is invalid or TE is not installed. - def _modify_underlying_storage_impl(*args, **kwargs): - raise RuntimeError( - "Invalid Transformer Engine version for FP8 distributed optimizer, " - "please install Transformer Engine 2.0+ or install Megatron-Core" - ) - - def _quantize_param_shard_impl(*args, **kwargs): - raise RuntimeError( - "Invalid Transformer Engine version for FP8 distributed optimizer, " - "please install Transformer Engine 2.0+ or install Megatron-Core" - ) - - -def modify_underlying_storage(tensor: torch.Tensor, new_raw_data: torch.Tensor): - """Replace the underlying raw data of a tensor with new data.""" - _modify_underlying_storage_impl(tensor, new_raw_data) - - -def quantize_param_shard( - model_params, main_params, start_offsets, data_parallel_group, fsdp_shard_model_params=None -): - """Cast shard fp32 main params to fp8 model params.""" - assert HAVE_TE, "Transformer Engine is required for quantizing parameters." - _quantize_param_shard_impl( - model_params, main_params, start_offsets, data_parallel_group, fsdp_shard_model_params - ) - - def _get_cuda_rng_state( device: Union[int, str, torch.device] = "cuda", clone: bool = False, graph_safe: bool = False ) -> torch.Tensor: diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index fa9a0f6d75..eefe71c2c9 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -719,6 +719,9 @@ def validate_args(args, defaults={}): assert args.ckpt_format == "fsdp_dtensor", \ "Megatron FSDP only supports fsdp_dtensor checkpoint format" + if args.use_megatron_fsdp: + args.reuse_grad_buf_for_mxfp8_param_ag = False + # Parameters dtype. args.params_dtype = torch.float if args.fp16: