Skip to content

Commit 4b7e1cc

Browse files
committed
change var mem to heap, to manualy control
Signed-off-by: yunruis <[email protected]>
1 parent b326be2 commit 4b7e1cc

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

cpp/tensorrt_llm/common/opUtils.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,16 +179,24 @@ class PerCudaCtxPerThreadSingletonCreator
179179
PerCudaCtxPerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter)
180180
: mCreator{std::move(creator)}
181181
, mDeleter{std::move(deleter)}
182+
, mObservers{new std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>>()}
182183
{
183184
}
185+
186+
~PerCudaCtxPerThreadSingletonCreator()
187+
{
188+
std::lock_guard<std::mutex> lk{mMutex};
189+
delete mObservers;
190+
mObservers = nullptr;
191+
}
184192

185193
std::shared_ptr<T> operator()()
186194
{
187195
std::lock_guard<std::mutex> lk{mMutex};
188196
CUcontext ctx{getCurrentCudaCtx()};
189197
std::thread::id thread = std::this_thread::get_id();
190198
auto const key = std::make_tuple(ctx, thread);
191-
std::shared_ptr<T> result = mObservers[key].lock();
199+
std::shared_ptr<T> result = (*mObservers)[key].lock();
192200
if (result == nullptr)
193201
{
194202
TLLM_LOG_TRACE("creating singleton instance for CUDA context %lu and thread %lu", ctx, thread);
@@ -202,6 +210,11 @@ class PerCudaCtxPerThreadSingletonCreator
202210
}
203211
mDeleter(obj);
204212

213+
if (mObservers == nullptr)
214+
{
215+
return;
216+
}
217+
205218
// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
206219
// frequently.
207220
std::shared_ptr<T> observedObjHolder; // Delay destroy to avoid dead lock.
@@ -210,17 +223,18 @@ class PerCudaCtxPerThreadSingletonCreator
210223
// thread just before we lock mMutex. We can't infer that the observer is stale from the fact that
211224
// obj is destroyed, because shared_ptr ref-count checking and observer removing are not in one
212225
// atomic operation, and the observer may be changed to observe another instance.
213-
if (mObservers.find(key) == mObservers.end())
226+
auto it = mObservers->find(key);
227+
if (it == mObservers->end())
214228
{
215229
return;
216230
}
217-
observedObjHolder = mObservers.at(key).lock();
231+
observedObjHolder = it->second.lock();
218232
if (observedObjHolder == nullptr)
219233
{
220-
mObservers.erase(key);
234+
mObservers->erase(it);
221235
}
222236
}};
223-
mObservers.at(key) = result;
237+
(*mObservers)[key] = result;
224238
}
225239
else
226240
{
@@ -235,7 +249,7 @@ class PerCudaCtxPerThreadSingletonCreator
235249
mutable std::mutex mMutex;
236250
// CUDA resources are per-context and per-thread.
237251
using CacheKey = std::tuple<CUcontext, std::thread::id>;
238-
std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>> mObservers;
252+
std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>>* mObservers;
239253
};
240254

241255
} // namespace
@@ -253,6 +267,7 @@ std::shared_ptr<cublasHandle_t> getCublasHandle()
253267
{
254268
TLLM_CUDA_CHECK(cublasDestroy(*handle));
255269
delete handle;
270+
handle = nullptr;
256271
});
257272
return creator();
258273
}
@@ -270,6 +285,7 @@ std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
270285
{
271286
TLLM_CUDA_CHECK(cublasLtDestroy(*handle));
272287
delete handle;
288+
handle = nullptr;
273289
});
274290
return creator();
275291
}

0 commit comments

Comments
 (0)