Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,6 @@ void CudaGraphInstruction::Run() {
cuda_graph_ = platform::EndCUDAGraphCapture();
VLOG(4) << "Finish capturing cuda graph @" << cuda_graph_.get();

// compute the right result
cuda_graph_->Replay();
} else {
VLOG(4) << "Run interpreter without cuda graph";
interpreter_->Run({}, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,16 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
->GetDevContext());
return dev_ctx;
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// If the current OP is inside a CUDAGraphOp,
// we must use the same device context as the parent CUDAGraphOp,
// mainly to ensure that cuda_graph_allocator_ is not nullptr.
// This is necessary for correct CUDA Graph capture and memory allocation.
if (op->GetParentOp()->isa<paddle::dialect::CudaGraphOp>()) {
VLOG(4) << "CudaGraphOp detected, using original device context";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

说明是 CUDAGraphOp 内的 OP,并且说明要确保是同一个 devcie context

return origin_dev_ctx;
}
#endif
// handle comm op
if (op_attributes.count("ring_id") != 0) {
int ring_id =
Expand Down
Loading