Skip to content

Conversation

DrRyanHuang
Copy link
Contributor

@DrRyanHuang DrRyanHuang commented Oct 14, 2025

本 PR 的主要操作是将SOT warmup的过程延后,在CUDAGraph的Capture阶段进行warmup


本PR的前置PR


这个PR终于跑通 SOT + CUDAGraph + 开启子图切分的整个流程,在这个PR中梳理一下整个过程:

单卡模型

以下是 @zyfncg 的前置PR,SOT下实现CudaGraph子图捕获功能,我们快速跑通了 ERNIE45T 0.3B✅

但是依旧存在显存拷贝的问题,如果输入的位置发生改变,则将新输入Copy到Capture的位置

        // https://github.com/PaddlePaddle/Paddle/blob/89f4bd92f49e15a9e1803a9e582526b2b8e4557d/paddle/fluid/framework/new_executor/instruction/cuda_graph_instruction.cc#L179-L187
        if (tensor->data() != input_tensors_.at(i).data()) {
          LOG(WARNING) << "The input [" << i << "] tensor addr for "
                       << "cuda graph is changed. Pay attention to this!";
          if (phi::is_gpu_place(tensor->place())) {
            const auto* dev_ctx =
                phi::DeviceContextPool::Instance().Get(place_);
            phi::Copy(*dev_ctx, *tensor, place_, false, &input_tensors_.at(i)); // <----- 这一行
          }
        }

为了解决这个问题,#3302 添加了 append_attention_with_output,在运行 append_attention 之前,优先创建一个 empty Tensor 作为 append_attention 的输出——在外部管理 append_attention 的显存

#3694 修复了 #3302 导致的打断,#4340(是 #3694 的一部分)移除了一些无用的输出,避免动静不统一导致的BUG

#3694 依赖 Paddle 主框架的两个PR:

到此为止,单卡的 ERNIE45T 21B和0.3B 都能跑通✅,且不存在 Copy,但是多卡运行会遇到CUDA700的问题


多卡模型

遇到到第一个CUDA700问题,不开CUDAGraph,只开SOT就能复现

cuda-gdb 分析这个CUDA700的问题,定位到是 Custom Allreduce 的问题,可暂时通过 --max-num-batched-tokens 2000 规避

@zhink 沟通后,可通过调大 Custom Allreduce 的 Buffer 大小来规避,即:

class CustomAllreduce:

    _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]

    # max_size: max supported allreduce size
-   def __init__(self, group: Group, max_size: int = 8192 * 1024) -> None:
+   def __init__(self, group: Group, max_size: int = 8192 * 1024 * 32 * 2) -> None:
        # This is a buffer for storing the tuples of pointers pointing to
        # IPC buffers from all ranks. Each registered tuple has size of
        # 8*world_size bytes where world_size is at most 8. Allocating 8MB
        # is enough for 131072 such tuples. The largest model I've seen only
        # needs less than 10000 of registered tuples.
-      self.rank_data = paddle.empty([8 * 1024 * 1024], dtype=paddle.uint8)
+      self.rank_data = paddle.empty([8192 * 1024 * 32 * 2], dtype=paddle.uint8)

此时SOT + Custom Allreduce 可以跑通推理流程 ✅

但开启 CUDAGraph 遇到 mp_allreduce_sum 对应的 DeviceContext 存在 cudagraph allocator 为空指针的问题
我把所有 Instruction 对应的 DeviceContext 指针打印了出来,统计了一下,共有1个CPUContext+2个GPUContext
出现次数分别是几十次、几千次和几万次,而动态图+CUDAGraph只有1个CPUContext+1个GPUContext

定位到是 phi::DeviceContext* ParseDeviceContext 这个函数将 GPUContext (有cudagraph allocator)转化成了另一个 GPUContext(无cudagraph allocator),这里为了先跑通,就先直接 return origin_dev_ctx; 了,(后续PR: TBC)但依旧会有 CUDA700 的问题

中间其实也尝试了很多其他方法,和老代码 battle 了好久,遇到了多 CUDA Stream 的问题,CUDA90X之类的,这时,@zyfncg 说:

你先别管多流的问题,就按之前的跑,DeviceContext改对之后,暴露的新CUDA700的问题,才是目前需要解的

好,那就先解这个CUDA700,悲催的是,用之前的方法:cuda-gdb 分析,会出现连环 core dump 的问题,在生成 coredump 文件中,又报一个 CUDA700 🤦‍♂️

@zyfncg 又问了一个关键的问题:目前的问题是 Capture 阶段还是 Replay 阶段?我们把 Capture 过程中的 Replay 全都注释掉看看

  // paddle/fluid/framework/new_executor/instruction/cuda_graph_instruction.cc
  if (*cuda_graph_state_ref_ == 2 && cuda_graph_ == nullptr) {
	......
	// 以下是 Capture 阶段
    platform::BeginCUDAGraphCapture(
        place_, cudaStreamCaptureModeRelaxed, cuda_graph_capture_pool_id_);

    auto RecordTensorsForReplay = [&](const std::vector<Variable*>& vars) {
      ......
      return record_tensors;
    };

    // record the input tensors for replay
    input_tensors_ = RecordTensorsForReplay(input_vars_);
    interpreter_->Run({}, false);

    // record the output tensors for replay
    output_tensors_ = RecordTensorsForReplay(output_vars_);
    cuda_graph_ = platform::EndCUDAGraphCapture();
	
    // 以下是 Capture 阶段中的 Replay
    cuda_graph_->Replay(); // 这里注释就好了
  }

把 Capture 过程中的 Replay 全都注释掉后,服务就能启动了,这其实说明是Replay阶段报的错,Capture阶段没问题
TBC

Copy link

paddle-bot bot commented Oct 14, 2025

Thanks for your contribution!

entry.captured = True
with self.cuda_graph_manager.run_impl_guard():
entry.runnable(**kwargs)
with capture_custom_allreduce():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

静态图也用custom all reduce 对吧

self.proposer.update_task_chunk_prefill(task)
task.chunk_idx += 1

@sot_warmup_guard(True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SOT的 Warm Up 延后是为了避免 custom all reduce 的什么问题呢

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants