Skip to content

Commit 7d92406

Browse files
authored
[SYCL][Graph] Avoid command graph shared_ptr return (#17473)
Introduce functions checking if the command graph exists, which can be used instead of getting the command graph shared ptr by value.
1 parent 27c5aff commit 7d92406

File tree

7 files changed

+16
-13
lines changed

7 files changed

+16
-13
lines changed

sycl/source/detail/event_impl.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ class event_impl {
300300
return MGraph.lock();
301301
}
302302

303+
bool hasCommandGraph() const { return !MGraph.expired(); }
304+
303305
void setEventFromSubmittedExecCommandBuffer(bool value) {
304306
MEventFromSubmittedExecCommandBuffer = value;
305307
}

sycl/source/detail/graph_impl.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
739739
void graph_impl::beginRecording(
740740
std::shared_ptr<sycl::detail::queue_impl> Queue) {
741741
graph_impl::WriteLock Lock(MMutex);
742-
if (Queue->getCommandGraph() == nullptr) {
742+
if (!Queue->hasCommandGraph()) {
743743
Queue->setCommandGraph(shared_from_this());
744744
addQueue(Queue);
745745
}
@@ -1875,8 +1875,7 @@ void modifiable_command_graph::begin_recording(
18751875
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
18761876
assert(QueueImpl);
18771877

1878-
auto QueueGraph = QueueImpl->getCommandGraph();
1879-
if (QueueGraph != nullptr) {
1878+
if (QueueImpl->hasCommandGraph()) {
18801879
throw sycl::exception(sycl::make_error_code(errc::invalid),
18811880
"begin_recording cannot be called for a queue which "
18821881
"is already in the recording state.");
@@ -1918,7 +1917,7 @@ void modifiable_command_graph::end_recording(queue &RecordingQueue) {
19181917
graph_impl::WriteLock Lock(impl->MMutex);
19191918
impl->removeQueue(QueueImpl);
19201919
}
1921-
if (QueueImpl->getCommandGraph() != nullptr)
1920+
if (QueueImpl->hasCommandGraph())
19221921
throw sycl::exception(sycl::make_error_code(errc::invalid),
19231922
"end_recording called for a queue which is recording "
19241923
"to a different graph.");

sycl/source/detail/graph_impl.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
870870
/// @param NodeImpl Node to associate with event in map.
871871
void addEventForNode(std::shared_ptr<sycl::detail::event_impl> EventImpl,
872872
std::shared_ptr<node_impl> NodeImpl) {
873-
if (!(EventImpl->getCommandGraph()))
873+
if (!(EventImpl->hasCommandGraph()))
874874
EventImpl->setCommandGraph(shared_from_this());
875875
MEventsMap[EventImpl] = NodeImpl;
876876
}

sycl/source/detail/queue_impl.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,8 @@ class queue_impl {
703703
return MGraph.lock();
704704
}
705705

706+
bool hasCommandGraph() const { return !MGraph.expired(); }
707+
706708
unsigned long long getQueueID() { return MQueueID; }
707709

708710
void *getTraceEvent() { return MTraceEvent; }

sycl/source/event.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ event::get_backend_info() const {
8181
template <typename Param>
8282
typename detail::is_event_profiling_info_desc<Param>::return_type
8383
event::get_profiling_info() const {
84-
if (impl->getCommandGraph()) {
84+
if (impl->hasCommandGraph()) {
8585
throw sycl::exception(make_error_code(errc::invalid),
8686
"Profiling information is unavailable for events "
8787
"returned from a submission to a queue in the "

sycl/source/handler.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ event handler::finalize() {
466466
}
467467

468468
if (MQueue && !impl->MGraph && !impl->MSubgraphNode &&
469-
!MQueue->getCommandGraph() && !impl->CGData.MRequirements.size() &&
469+
!MQueue->hasCommandGraph() && !impl->CGData.MRequirements.size() &&
470470
!MStreamStorage.size() &&
471471
(!impl->CGData.MEvents.size() ||
472472
(MQueue->isInOrder() &&
@@ -2036,7 +2036,7 @@ void handler::setNDRangeUsed(bool Value) { (void)Value; }
20362036
void handler::registerDynamicParameter(
20372037
ext::oneapi::experimental::detail::dynamic_parameter_base &DynamicParamBase,
20382038
int ArgIndex) {
2039-
if (MQueue && MQueue->getCommandGraph()) {
2039+
if (MQueue && MQueue->hasCommandGraph()) {
20402040
throw sycl::exception(
20412041
make_error_code(errc::invalid),
20422042
"Dynamic Parameters cannot be used with Graph Queue recording.");

sycl/source/queue.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ context queue::get_context() const { return impl->get_context(); }
107107
device queue::get_device() const { return impl->get_device(); }
108108

109109
ext::oneapi::experimental::queue_state queue::ext_oneapi_get_state() const {
110-
return impl->getCommandGraph()
110+
return impl->hasCommandGraph()
111111
? ext::oneapi::experimental::queue_state::recording
112112
: ext::oneapi::experimental::queue_state::executing;
113113
}
@@ -299,7 +299,7 @@ getBarrierEventForInorderQueueHelper(const detail::QueueImplPtr QueueImpl) {
299299
// as a graph can record from multiple queues and we cannot guarantee the
300300
// last node added by an in-order queue will be the last node added to the
301301
// graph.
302-
assert(!QueueImpl->getCommandGraph() &&
302+
assert(!QueueImpl->hasCommandGraph() &&
303303
"Should not be called in on graph recording.");
304304

305305
sycl::detail::optional<event> LastEvent = QueueImpl->getLastEvent();
@@ -319,7 +319,7 @@ getBarrierEventForInorderQueueHelper(const detail::QueueImplPtr QueueImpl) {
319319
/// \return a SYCL event object, which corresponds to the queue the command
320320
/// group is being enqueued on.
321321
event queue::ext_oneapi_submit_barrier(const detail::code_location &CodeLoc) {
322-
if (is_in_order() && !impl->getCommandGraph() && !impl->MDiscardEvents &&
322+
if (is_in_order() && !impl->hasCommandGraph() && !impl->MDiscardEvents &&
323323
!impl->MIsProfilingEnabled) {
324324
event InOrderLastEvent = getBarrierEventForInorderQueueHelper(impl);
325325
// If the last event was discarded, fall back to enqueuing a barrier.
@@ -345,9 +345,9 @@ event queue::ext_oneapi_submit_barrier(const std::vector<event> &WaitList,
345345
begin(WaitList), end(WaitList), [&](const event &Event) -> bool {
346346
auto EventImpl = detail::getSyclObjImpl(Event);
347347
return (EventImpl->isDefaultConstructed() || EventImpl->isNOP()) &&
348-
!EventImpl->getCommandGraph();
348+
!EventImpl->hasCommandGraph();
349349
});
350-
if (is_in_order() && !impl->getCommandGraph() && !impl->MDiscardEvents &&
350+
if (is_in_order() && !impl->hasCommandGraph() && !impl->MDiscardEvents &&
351351
!impl->MIsProfilingEnabled && AllEventsEmptyOrNop) {
352352
event InOrderLastEvent = getBarrierEventForInorderQueueHelper(impl);
353353
// If the last event was discarded, fall back to enqueuing a barrier.

0 commit comments

Comments
 (0)