Skip to content

Commit

Permalink
[CUDA][shared memory allocation]fix 'ptxas error : Entry function 'fu…
Browse files Browse the repository at this point in the history
…sion_##' uses too much shared data'
  • Loading branch information
AIYoungcino committed Aug 12, 2024
1 parent b3d01c2 commit 6e09af0
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class PointerTypeNode : public TypeNode {
/*!
* \brief The storage scope of the pointer
*/
String storage_scope;
mutable String storage_scope;

void VisitAttrs(AttrVisitor* v) {
v->Visit("element_type", &element_type);
Expand Down
6 changes: 4 additions & 2 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -642,12 +642,14 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
}

void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead";
// ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass "
// "all global arrays as input instead";
if (scope == "shared") {
os << "__shared__ ";
} else if (scope == "shared.dyn") {
os << "extern __shared__ ";
} else if (scope == "global") {
os << "__device__ static ";
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1518,8 +1518,12 @@ class StorageFlattener : public StmtExprMutator {
StorageScope skey = StorageScope::Create(GetPtrStorageScope(op->buffer->data));

// use small alignment for small arrays
auto* ptr_type = op->buffer->data->type_annotation.as<PointerTypeNode>();
auto dtype = op->buffer->dtype;
size_t const_size = AllocateNode::ConstantAllocationSize(op->buffer->shape);
if (const_size > 41984) {
ptr_type->storage_scope = tvm::runtime::String("global");
}
int align = GetTempAllocaAlignment(dtype, const_size);
if (skey.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(skey.to_string());
Expand Down

0 comments on commit 6e09af0

Please sign in to comment.