diff --git a/paddle/fluid/framework/new_executor/instruction/cuda_graph_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cuda_graph_instruction.cc index ad63e8c363683f..5622b9a1e9676a 100644 --- a/paddle/fluid/framework/new_executor/instruction/cuda_graph_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cuda_graph_instruction.cc @@ -232,8 +232,6 @@ void CudaGraphInstruction::Run() { cuda_graph_ = platform::EndCUDAGraphCapture(); VLOG(4) << "Finish capturing cuda graph @" << cuda_graph_.get(); - // compute the right result - cuda_graph_->Replay(); } else { VLOG(4) << "Run interpreter without cuda graph"; interpreter_->Run({}, false); diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc index 3aa492ceff87c1..0d7fdb9a9d52df 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -161,7 +161,16 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op, ->GetDevContext()); return dev_ctx; } - +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + // If the current OP is inside a CUDAGraphOp, + // we must use the same device context as the parent CUDAGraphOp, + // mainly to ensure that cuda_graph_allocator_ is not nullptr. + // This is necessary for correct CUDA Graph capture and memory allocation. + if (op->GetParentOp()->isa()) { + VLOG(4) << "CudaGraphOp detected, using original device context"; + return origin_dev_ctx; + } +#endif // handle comm op if (op_attributes.count("ring_id") != 0) { int ring_id =