Skip to content

FSDP2 Not Compatible w/ HSDP (DeviceMesh Ignored By State) #3916

@schopra8

Description

@schopra8

** Environment **

---------------------------------                                                                                                                                                                                                                                            
System Environment Report                                                                                                                                                                                                                                                    
Created: 2025-07-28 20:34:47 UTC                                                                                                                                                                                                                                             
---------------------------------                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                             
PyTorch information                                                                                                                                                                                                                                                          
-------------------                                                                                                                                                                                                                                                          
PyTorch version: 2.7.1+cu128                                                                                                                                                                                                                                                 
Is debug build: False                                                                                                                                                                                                                                                        
CUDA used to build PyTorch: 12.8                                                                                                                                                                                                                                             
ROCM used to build PyTorch: N/A                                                                                                                                                                                                                                              
                                                                                                                                                                                                                                                                             
OS: Ubuntu 22.04.4 LTS (x86_64)                                                                                                                                                                                                                                              
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0                                                                                                                                                                                                                           
Clang version: Could not collect                                                                                                                                                                                                                                             
CMake version: Could not collect                                                                                                                                                                                                                                             
Libc version: glibc-2.35                                                                                                                                                                                                                                                     
                                                                                                                                                                                                                                                                             
Python version: 3.10.12 (main, Mar 22 2024, 16:50:05) [GCC 11.4.0] (64-bit runtime)                                                                                                                                                                                          
Python platform: Linux-6.5.0-1023-aws-x86_64-with-glibc2.35                                                                                                                                                                                                                  
Is CUDA available: True                                                                                                                                                                                                                                                      
CUDA runtime version: Could not collect                                                                                                                                                                                                                                      
CUDA_MODULE_LOADING set to: LAZY                                                                                                                                                                                                                                             
GPU models and configuration:                                                                                                                                                                                                                                                
GPU 0: NVIDIA H100 80GB HBM3                                                                                                                                                                                                                                                 
GPU 1: NVIDIA H100 80GB HBM3                                                                                                                                                                                                                                                 
GPU 2: NVIDIA H100 80GB HBM3                                                                                                                                                                                                                                                 
GPU 3: NVIDIA H100 80GB HBM3                                                                                                                                                                                                                                                 
GPU 4: NVIDIA H100 80GB HBM3                                                                                                                                                                                                                                                 
GPU 5: NVIDIA H100 80GB HBM3                                                                                                                                                                                                                                                 
GPU 6: NVIDIA H100 80GB HBM3                                                                                                                                                                                                                                                 
GPU 7: NVIDIA H100 80GB HBM3 

Nvidia driver version: 535.183.01                                                                                                                                                                                                                                            
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             104
On-line CPU(s) list:                0-103
Vendor ID:                          GenuineIntel
Model name:                         Intel Xeon Processor (SapphireRapids)
CPU family:                         6
Model:                              143
Thread(s) per core:                 1
Core(s) per socket:                 104
Socket(s):                          1
Stepping:                           4
BogoMIPS:                           5600.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4
_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid av
x512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopc
ntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization:                     VT-x
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          3.3 MiB (104 instances)
L1i cache:                          3.3 MiB (104 instances)
L2 cache:                           416 MiB (104 instances)
L3 cache:                           16 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-103
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Unknown: No mitigations
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI Syscall hardening, KVM SW loop
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] flake8==7.3.0
[pip3] mypy==1.16.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.1.1
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.7.1.26
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.8.61

pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.7.1.26
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] pytorch-lightning==2.5.2
[pip3] pytorch-ranger==0.1.1
[pip3] torch==2.7.1+cu128
[pip3] torch-optimizer==0.3.0
[pip3] torchaudio==2.7.1+cu128
[pip3] torchmetrics==1.7.1
[pip3] torchvision==0.22.1+cu128
[pip3] triton==3.3.1
[conda] Could not collect


Composer information
--------------------
Composer Version: 0.32.0
Composer Commit Hash: None
CPU Model: Intel Xeon Processor (SapphireRapids)
CPU Count: 104
Number of Nodes: 1
GPU Model: NVIDIA H100 80GB HBM3
GPUs per Node: 1
GPU Count: 1
CUDA Device Count: 8

** To reproduce

Steps to reproduce the behavior:

  1. Train with this parallelism_config
  parallelism_config:
    fsdp: # FSPD2
      device_mesh:
        - [0, 1, 2, 3]
        - [4, 5, 6, 7]
      reshard_after_forward: False # Zero-2 (Gradients + Optimizer Sharding)
      activation_checkpointing: True # Activation Checkpointing
      load_monolith_rank0_only: True  # only rank0 touches the file
      state_dict_type: 'full'         # we are loading/saving full ckpts
  1. Put print statements in
    self.device_mesh: Optional[DeviceMesh] = _create_device_mesh(self.device, self.fsdp_config, self.tp_config)
    to log the create device mesh. The device mesh will be flat (1-D) not (2-D).

Expected behavior

Per the documentation for FSDP2 in Composer, we should be able to provide a DeviceMesh in the FSDP2Config. When I provide a 2D configuration (as seen above), this is ignored.

If you read the code in state.py, the device_mesh property of FSDP2Config is never used. Instead, the HSDP (2D) config is built if fsdp_config.data_parallel_replicate_degree is set. In FSDP this should work. But, in FSDP2 these properties are hardcoded and are not settable.

As a result, it is not possible to use FSDP2 with HSDP (2D) configs.

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions