-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[Flang][OpenMP] Increase detection capability for requires usm (and others) #162971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…thers) Currently, the compiler only picks up some cases where requires is designated such as in the main program. However, it'll gloss over cases such as when it is specified by a user in a module, that's then used elsewhere. This patch attempts to amend that by searching the varying scopes in the current program module more comprehensively.
@llvm/pr-subscribers-flang-fir-hlfir Author: None (agozillon) ChangesCurrently, the compiler only picks up some cases where requires is designated such as in the main program. However, it'll gloss over cases such as when it is specified by a user in a module, that's then used elsewhere. This patch attempts to amend that by searching the varying scopes in the current program module more comprehensively. Full diff: https://github.com/llvm/llvm-project/pull/162971.diff 2 Files Affected:
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 68adf346fe8c0..358b57d76d32e 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -415,7 +415,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// - Define module variables and OpenMP/OpenACC declarative constructs so
// they are available before lowering any function that may use them.
bool hasMainProgram = false;
- const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
+ llvm::SmallVector<const Fortran::semantics::Symbol *>
+ globalOmpRequiresSymbols;
createBuilderOutsideOfFuncOpAndDo([&]() {
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
Fortran::common::visit(
@@ -424,8 +425,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (f.isMainProgram())
hasMainProgram = true;
declareFunction(f);
- if (!globalOmpRequiresSymbol)
- globalOmpRequiresSymbol = f.getScope().symbol();
+ globalOmpRequiresSymbols.push_back(f.getScope().symbol());
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
@@ -433,12 +433,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
m.containedUnitList)
if (auto *f =
std::get_if<Fortran::lower::pft::FunctionLikeUnit>(
- &unit))
+ &unit)) {
declareFunction(*f);
+ globalOmpRequiresSymbols.push_back(
+ f->getScope().symbol());
+ }
+ globalOmpRequiresSymbols.push_back(m.getScope().symbol());
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
- if (!globalOmpRequiresSymbol)
- globalOmpRequiresSymbol = b.symTab.symbol();
+ globalOmpRequiresSymbols.push_back(b.symTab.symbol());
},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
@@ -481,7 +484,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::common::LanguageFeature::Coarray));
});
- finalizeOpenMPLowering(globalOmpRequiresSymbol);
+ finalizeOpenMPLowering(globalOmpRequiresSymbols);
}
/// Declare a function.
@@ -6681,7 +6684,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Performing OpenMP lowering actions that were deferred to the end of
/// lowering.
void finalizeOpenMPLowering(
- const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ &globalOmpRequiresSymbol) {
if (!ompDeferredDeclareTarget.empty()) {
bool deferredDeviceFuncFound =
Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
@@ -6690,9 +6694,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
// Set the module attribute related to OpenMP requires directives
- if (ompDeviceCodeFound)
- Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
- globalOmpRequiresSymbol);
+ if (ompDeviceCodeFound) {
+ for (auto sym : globalOmpRequiresSymbol)
+ Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), sym);
+ }
}
/// Record fir.dummy_scope operation for this function.
diff --git a/flang/test/Lower/OpenMP/requires-usm.f90 b/flang/test/Lower/OpenMP/requires-usm.f90
new file mode 100644
index 0000000000000..600e8387d6ad4
--- /dev/null
+++ b/flang/test/Lower/OpenMP/requires-usm.f90
@@ -0,0 +1,22 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck %s
+
+! Verify that we pick up USM and apply it correctly when it is specified
+! outside of the program.
+
+!CHECK: module attributes {
+!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
+module declare_mod
+ implicit none
+!$omp requires unified_shared_memory
+ contains
+end module
+
+program main
+ use declare_mod
+ implicit none
+!$omp target
+!$omp end target
+end program
|
@llvm/pr-subscribers-flang-openmp Author: None (agozillon) ChangesCurrently, the compiler only picks up some cases where requires is designated such as in the main program. However, it'll gloss over cases such as when it is specified by a user in a module, that's then used elsewhere. This patch attempts to amend that by searching the varying scopes in the current program module more comprehensively. Full diff: https://github.com/llvm/llvm-project/pull/162971.diff 2 Files Affected:
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 68adf346fe8c0..358b57d76d32e 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -415,7 +415,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// - Define module variables and OpenMP/OpenACC declarative constructs so
// they are available before lowering any function that may use them.
bool hasMainProgram = false;
- const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
+ llvm::SmallVector<const Fortran::semantics::Symbol *>
+ globalOmpRequiresSymbols;
createBuilderOutsideOfFuncOpAndDo([&]() {
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
Fortran::common::visit(
@@ -424,8 +425,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (f.isMainProgram())
hasMainProgram = true;
declareFunction(f);
- if (!globalOmpRequiresSymbol)
- globalOmpRequiresSymbol = f.getScope().symbol();
+ globalOmpRequiresSymbols.push_back(f.getScope().symbol());
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
@@ -433,12 +433,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
m.containedUnitList)
if (auto *f =
std::get_if<Fortran::lower::pft::FunctionLikeUnit>(
- &unit))
+ &unit)) {
declareFunction(*f);
+ globalOmpRequiresSymbols.push_back(
+ f->getScope().symbol());
+ }
+ globalOmpRequiresSymbols.push_back(m.getScope().symbol());
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
- if (!globalOmpRequiresSymbol)
- globalOmpRequiresSymbol = b.symTab.symbol();
+ globalOmpRequiresSymbols.push_back(b.symTab.symbol());
},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
@@ -481,7 +484,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::common::LanguageFeature::Coarray));
});
- finalizeOpenMPLowering(globalOmpRequiresSymbol);
+ finalizeOpenMPLowering(globalOmpRequiresSymbols);
}
/// Declare a function.
@@ -6681,7 +6684,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Performing OpenMP lowering actions that were deferred to the end of
/// lowering.
void finalizeOpenMPLowering(
- const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ &globalOmpRequiresSymbol) {
if (!ompDeferredDeclareTarget.empty()) {
bool deferredDeviceFuncFound =
Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
@@ -6690,9 +6694,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
// Set the module attribute related to OpenMP requires directives
- if (ompDeviceCodeFound)
- Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
- globalOmpRequiresSymbol);
+ if (ompDeviceCodeFound) {
+ for (auto sym : globalOmpRequiresSymbol)
+ Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), sym);
+ }
}
/// Record fir.dummy_scope operation for this function.
diff --git a/flang/test/Lower/OpenMP/requires-usm.f90 b/flang/test/Lower/OpenMP/requires-usm.f90
new file mode 100644
index 0000000000000..600e8387d6ad4
--- /dev/null
+++ b/flang/test/Lower/OpenMP/requires-usm.f90
@@ -0,0 +1,22 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck %s
+
+! Verify that we pick up USM and apply it correctly when it is specified
+! outside of the program.
+
+!CHECK: module attributes {
+!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
+module declare_mod
+ implicit none
+!$omp requires unified_shared_memory
+ contains
+end module
+
+program main
+ use declare_mod
+ implicit none
+!$omp target
+!$omp end target
+end program
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Currently, the compiler only picks up some cases where requires is designated such as in the main program. However, it'll gloss over cases such as when it is specified by a user in a module, that's then used elsewhere.
This patch attempts to amend that by searching the varying scopes in the current program module more comprehensively.