Skip to content

feat:suppot infer slice tensor at dim > 0 and optimize memory#106

Merged
chaokunyang merged 1 commit intoinclusionAI:mainfrom
PrometheusComing:xxj_hccl
Apr 13, 2026
Merged

feat:suppot infer slice tensor at dim > 0 and optimize memory#106
chaokunyang merged 1 commit intoinclusionAI:mainfrom
PrometheusComing:xxj_hccl

Conversation

@PrometheusComing
Copy link
Copy Markdown
Contributor

@PrometheusComing PrometheusComing commented Apr 9, 2026

What does this PR do?

support infer slice tensor at dim > 0 which could make p2p tensor not continuous in NCCL or HCCL and optimize memory using

1、support infer tensor slice at dim >0 such as attention.dense.weight when infer's tp < train's tp which could make p2p tensor not continuous in NCCL or HCCL. another scene is the experts in VLLM-ascend has been transposed but shared_experts not.

2、optimize memory by supporting p2p send receive one by one, it's useful for debug or insufficient memory if infer with closed sleep mode and it's also useful for the hardware diff scene which using batch_send_recv will be error,such as 910B2 and 910B1. p2p send receive one by one comparing with batch send receive, can increase time consumption less than 10% but decrease peak memory 25% for qwen3-30B

3、optimize memory by using local part process group and destroy weight exchange process group in NPU HCCL, because the HCCL_BUFFERSIZE actually occupies twice the memory, and it also requires HCCL_BUFFERSIZE keep consistency at infer and train. Using local part process group can optimize this portion of the memory.

Related issues

No

Does this PR introduce any user-facing change?

No

  • Does this PR introduce any public API change? NO
  • Does this PR introduce any binary protocol compatibility change? NO

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the NCCL and HCCL weight transfer logic to support non-contiguous tensors and introduces a one-by-one communication mode for debugging and memory efficiency. It also updates the SGLang converter to handle shared experts in MoE models and adds configuration for HCCL buffer sizes. Review feedback suggests removing redundant del statements for local list references to improve code clarity.

Comment on lines +336 to +337
non_contiguous_tensor_pairs.clear()
del non_contiguous_tensor_pairs
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The del non_contiguous_tensor_pairs statement is redundant. The list is cleared on the previous line, and the local reference will be garbage collected when the function returns. Removing this line will make the code cleaner without affecting functionality.

Suggested change
non_contiguous_tensor_pairs.clear()
del non_contiguous_tensor_pairs
non_contiguous_tensor_pairs.clear()

Comment on lines +182 to +183
non_contiguous_tensor_pairs.clear()
del non_contiguous_tensor_pairs
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The del non_contiguous_tensor_pairs statement is redundant here. The list is cleared on the previous line, and the local reference will be garbage collected automatically. Removing this line would improve code clarity.

Suggested change
non_contiguous_tensor_pairs.clear()
del non_contiguous_tensor_pairs
non_contiguous_tensor_pairs.clear()

Copy link
Copy Markdown
Collaborator

@chaokunyang chaokunyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chaokunyang chaokunyang merged commit 7739878 into inclusionAI:main Apr 13, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants