Skip to content

Conversation

@Xing-lil
Copy link
Contributor

@Xing-lil Xing-lil commented Dec 30, 2025

PR Category

Auto Parallel

PR Types

Bug fixes

Description

支持FSDP+auto_dp的场景,在多模态模型上进行验证。

  • 不使用auto_dpFSDP,将根据切分推导自动在前反向插入reshardall_gatherslice
  • 由于auto_dp将dp维标记为fake_replicate,导致原本切分推导改变,不会自动插入all_gatherslice
  • 故这里手动在pre_forwardpost_forwardpre_backwardpost_backward插入reshard
  • 进一步,在共享参数场景下,如lm_headembedding将会使用同一份参数,这里采用反向的层级hook插入通信。(若使用参数级的hook,如param._register_backward_hook,则只会在embedding层添加hook,不符合预期)

card-92763

@paddle-bot
Copy link

paddle-bot bot commented Dec 30, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@codecov-commenter
Copy link

codecov-commenter commented Dec 31, 2025

Codecov Report

❌ Patch coverage is 96.82540% with 2 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@7a7d953). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...on/paddle/distributed/auto_parallel/fully_shard.py 96.82% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #77147   +/-   ##
==========================================
  Coverage           ?   96.82%           
==========================================
  Files              ?        1           
  Lines              ?       63           
  Branches           ?        0           
==========================================
  Hits               ?       61           
  Misses             ?        2           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

self._shard_fn._register_hook_for_param_grad(param)
if not in_auto_dp_mode():
self._shard_fn._register_hook_for_param_grad(param)
if in_auto_dp_mode():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L78 是不是改成 else 就行?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的_register_comm_hook在for循环外面无法改成else

if not param.trainable:
continue
new_placements = [dist.Replicate() for _ in param.placements]
replicte_param = dist.reshard(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1、typo: replicte_param -> replicate_param
2、这里对所有参数都做了 reshard,是否有不必要的开销?要提前判断目标状态是否已经符合预期?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加判断param已符合目标时跳过冗余操作,感谢!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改typo,感谢

def shard_comm(*_):
for key, param in sublayers._parameters.items():
if param.trainable:
new_placements = get_placement_with_sharding(param, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 0 是指 第0维吗?建议设置变量名,提高可读性

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已将fully_shard.py中所有相关维度修改为dp维,感谢!

self._register_comm_hook(model)
os.environ["skip_sharding3_output_reshard"] = "1"

def _register_comm_hook(self, model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的通信逻辑,对每个参数分别处理了,后期考虑提高通信效率?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加判断param已符合目标时跳过冗余操作,感谢!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后面希望通过类似tensor_fusion的方案来提高通信效率

replicte_param = dist.reshard(
param, param.process_mesh, new_placements
)
param.get_tensor()._share_data_with(replicte_param.get_tensor())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 inplace 操作,是否会导致反向时找不到正确的数据地址

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不会导致反向找不到正确地址。

Copy link
Contributor

@liym27 liym27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Xing-lil Xing-lil merged commit 76b2761 into PaddlePaddle:develop Jan 7, 2026
101 of 110 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants