@@ -179,16 +179,24 @@ class PerCudaCtxPerThreadSingletonCreator
179179 PerCudaCtxPerThreadSingletonCreator (CreatorFunc creator, DeleterFunc deleter)
180180 : mCreator {std::move (creator)}
181181 , mDeleter {std::move (deleter)}
182+ , mObservers {new std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>>()}
182183 {
183184 }
185+
186+ ~PerCudaCtxPerThreadSingletonCreator ()
187+ {
188+ std::lock_guard<std::mutex> lk{mMutex };
189+ delete mObservers ;
190+ mObservers = nullptr ;
191+ }
184192
185193 std::shared_ptr<T> operator ()()
186194 {
187195 std::lock_guard<std::mutex> lk{mMutex };
188196 CUcontext ctx{getCurrentCudaCtx ()};
189197 std::thread::id thread = std::this_thread::get_id ();
190198 auto const key = std::make_tuple (ctx, thread);
191- std::shared_ptr<T> result = mObservers [key].lock ();
199+ std::shared_ptr<T> result = (* mObservers ) [key].lock ();
192200 if (result == nullptr )
193201 {
194202 TLLM_LOG_TRACE (" creating singleton instance for CUDA context %lu and thread %lu" , ctx, thread);
@@ -202,6 +210,11 @@ class PerCudaCtxPerThreadSingletonCreator
202210 }
203211 mDeleter (obj);
204212
213+ if (mObservers == nullptr )
214+ {
215+ return ;
216+ }
217+
205218 // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
206219 // frequently.
207220 std::shared_ptr<T> observedObjHolder; // Delay destroy to avoid dead lock.
@@ -210,17 +223,18 @@ class PerCudaCtxPerThreadSingletonCreator
210223 // thread just before we lock mMutex. We can't infer that the observer is stale from the fact that
211224 // obj is destroyed, because shared_ptr ref-count checking and observer removing are not in one
212225 // atomic operation, and the observer may be changed to observe another instance.
213- if (mObservers .find (key) == mObservers .end ())
226+ auto it = mObservers ->find (key);
227+ if (it == mObservers ->end ())
214228 {
215229 return ;
216230 }
217- observedObjHolder = mObservers . at (key) .lock ();
231+ observedObjHolder = it-> second .lock ();
218232 if (observedObjHolder == nullptr )
219233 {
220- mObservers . erase (key );
234+ mObservers -> erase (it );
221235 }
222236 }};
223- mObservers . at (key) = result;
237+ (* mObservers )[key] = result;
224238 }
225239 else
226240 {
@@ -235,7 +249,7 @@ class PerCudaCtxPerThreadSingletonCreator
235249 mutable std::mutex mMutex ;
236250 // CUDA resources are per-context and per-thread.
237251 using CacheKey = std::tuple<CUcontext, std::thread::id>;
238- std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>> mObservers ;
252+ std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>>* mObservers ;
239253};
240254
241255} // namespace
@@ -253,6 +267,7 @@ std::shared_ptr<cublasHandle_t> getCublasHandle()
253267 {
254268 TLLM_CUDA_CHECK (cublasDestroy (*handle));
255269 delete handle;
270+ handle = nullptr ;
256271 });
257272 return creator ();
258273}
@@ -270,6 +285,7 @@ std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
270285 {
271286 TLLM_CUDA_CHECK (cublasLtDestroy (*handle));
272287 delete handle;
288+ handle = nullptr ;
273289 });
274290 return creator ();
275291}
0 commit comments