Skip to content

Conversation

agozillon
Copy link
Contributor

Currently, the compiler only picks up some cases where requires is designated such as in the main program. However, it'll gloss over cases such as when it is specified by a user in a module, that's then used elsewhere.

This patch attempts to amend that by searching the varying scopes in the current program module more comprehensively.

…thers)

Currently, the compiler only picks up some cases where requires is designated
such as in the main program. However, it'll gloss over cases such as when it
is specified by a user in a module, that's then used elsewhere.

This patch attempts to amend that by searching the varying scopes in the
current program module more comprehensively.
@llvmbot
Copy link
Member

llvmbot commented Oct 11, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (agozillon)

Changes

Currently, the compiler only picks up some cases where requires is designated such as in the main program. However, it'll gloss over cases such as when it is specified by a user in a module, that's then used elsewhere.

This patch attempts to amend that by searching the varying scopes in the current program module more comprehensively.


Full diff: https://github.com/llvm/llvm-project/pull/162971.diff

2 Files Affected:

  • (modified) flang/lib/Lower/Bridge.cpp (+16-11)
  • (added) flang/test/Lower/OpenMP/requires-usm.f90 (+22)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 68adf346fe8c0..358b57d76d32e 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -415,7 +415,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     // - Define module variables and OpenMP/OpenACC declarative constructs so
     //   they are available before lowering any function that may use them.
     bool hasMainProgram = false;
-    const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
+    llvm::SmallVector<const Fortran::semantics::Symbol *>
+        globalOmpRequiresSymbols;
     createBuilderOutsideOfFuncOpAndDo([&]() {
       for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
         Fortran::common::visit(
@@ -424,8 +425,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                   if (f.isMainProgram())
                     hasMainProgram = true;
                   declareFunction(f);
-                  if (!globalOmpRequiresSymbol)
-                    globalOmpRequiresSymbol = f.getScope().symbol();
+                  globalOmpRequiresSymbols.push_back(f.getScope().symbol());
                 },
                 [&](Fortran::lower::pft::ModuleLikeUnit &m) {
                   lowerModuleDeclScope(m);
@@ -433,12 +433,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                        m.containedUnitList)
                     if (auto *f =
                             std::get_if<Fortran::lower::pft::FunctionLikeUnit>(
-                                &unit))
+                                &unit)) {
                       declareFunction(*f);
+                      globalOmpRequiresSymbols.push_back(
+                          f->getScope().symbol());
+                    }
+                  globalOmpRequiresSymbols.push_back(m.getScope().symbol());
                 },
                 [&](Fortran::lower::pft::BlockDataUnit &b) {
-                  if (!globalOmpRequiresSymbol)
-                    globalOmpRequiresSymbol = b.symTab.symbol();
+                  globalOmpRequiresSymbols.push_back(b.symTab.symbol());
                 },
                 [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
                 [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
@@ -481,7 +484,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                                   Fortran::common::LanguageFeature::Coarray));
       });
 
-    finalizeOpenMPLowering(globalOmpRequiresSymbol);
+    finalizeOpenMPLowering(globalOmpRequiresSymbols);
   }
 
   /// Declare a function.
@@ -6681,7 +6684,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Performing OpenMP lowering actions that were deferred to the end of
   /// lowering.
   void finalizeOpenMPLowering(
-      const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
+      llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+          &globalOmpRequiresSymbol) {
     if (!ompDeferredDeclareTarget.empty()) {
       bool deferredDeviceFuncFound =
           Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
@@ -6690,9 +6694,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     }
 
     // Set the module attribute related to OpenMP requires directives
-    if (ompDeviceCodeFound)
-      Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
-                                        globalOmpRequiresSymbol);
+    if (ompDeviceCodeFound) {
+      for (auto sym : globalOmpRequiresSymbol)
+        Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), sym);
+    }
   }
 
   /// Record fir.dummy_scope operation for this function.
diff --git a/flang/test/Lower/OpenMP/requires-usm.f90 b/flang/test/Lower/OpenMP/requires-usm.f90
new file mode 100644
index 0000000000000..600e8387d6ad4
--- /dev/null
+++ b/flang/test/Lower/OpenMP/requires-usm.f90
@@ -0,0 +1,22 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck %s
+
+! Verify that we pick up USM and apply it correctly when it is specified
+! outside of the program.
+
+!CHECK:      module attributes {
+!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
+module declare_mod
+    implicit none
+!$omp requires unified_shared_memory
+ contains
+end module
+
+program main
+ use declare_mod
+ implicit none
+!$omp target
+!$omp end target
+end program

@llvmbot
Copy link
Member

llvmbot commented Oct 11, 2025

@llvm/pr-subscribers-flang-openmp

Author: None (agozillon)

Changes

Currently, the compiler only picks up some cases where requires is designated such as in the main program. However, it'll gloss over cases such as when it is specified by a user in a module, that's then used elsewhere.

This patch attempts to amend that by searching the varying scopes in the current program module more comprehensively.


Full diff: https://github.com/llvm/llvm-project/pull/162971.diff

2 Files Affected:

  • (modified) flang/lib/Lower/Bridge.cpp (+16-11)
  • (added) flang/test/Lower/OpenMP/requires-usm.f90 (+22)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 68adf346fe8c0..358b57d76d32e 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -415,7 +415,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     // - Define module variables and OpenMP/OpenACC declarative constructs so
     //   they are available before lowering any function that may use them.
     bool hasMainProgram = false;
-    const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
+    llvm::SmallVector<const Fortran::semantics::Symbol *>
+        globalOmpRequiresSymbols;
     createBuilderOutsideOfFuncOpAndDo([&]() {
       for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
         Fortran::common::visit(
@@ -424,8 +425,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                   if (f.isMainProgram())
                     hasMainProgram = true;
                   declareFunction(f);
-                  if (!globalOmpRequiresSymbol)
-                    globalOmpRequiresSymbol = f.getScope().symbol();
+                  globalOmpRequiresSymbols.push_back(f.getScope().symbol());
                 },
                 [&](Fortran::lower::pft::ModuleLikeUnit &m) {
                   lowerModuleDeclScope(m);
@@ -433,12 +433,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                        m.containedUnitList)
                     if (auto *f =
                             std::get_if<Fortran::lower::pft::FunctionLikeUnit>(
-                                &unit))
+                                &unit)) {
                       declareFunction(*f);
+                      globalOmpRequiresSymbols.push_back(
+                          f->getScope().symbol());
+                    }
+                  globalOmpRequiresSymbols.push_back(m.getScope().symbol());
                 },
                 [&](Fortran::lower::pft::BlockDataUnit &b) {
-                  if (!globalOmpRequiresSymbol)
-                    globalOmpRequiresSymbol = b.symTab.symbol();
+                  globalOmpRequiresSymbols.push_back(b.symTab.symbol());
                 },
                 [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
                 [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
@@ -481,7 +484,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                                   Fortran::common::LanguageFeature::Coarray));
       });
 
-    finalizeOpenMPLowering(globalOmpRequiresSymbol);
+    finalizeOpenMPLowering(globalOmpRequiresSymbols);
   }
 
   /// Declare a function.
@@ -6681,7 +6684,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Performing OpenMP lowering actions that were deferred to the end of
   /// lowering.
   void finalizeOpenMPLowering(
-      const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
+      llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+          &globalOmpRequiresSymbol) {
     if (!ompDeferredDeclareTarget.empty()) {
       bool deferredDeviceFuncFound =
           Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
@@ -6690,9 +6694,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     }
 
     // Set the module attribute related to OpenMP requires directives
-    if (ompDeviceCodeFound)
-      Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
-                                        globalOmpRequiresSymbol);
+    if (ompDeviceCodeFound) {
+      for (auto sym : globalOmpRequiresSymbol)
+        Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), sym);
+    }
   }
 
   /// Record fir.dummy_scope operation for this function.
diff --git a/flang/test/Lower/OpenMP/requires-usm.f90 b/flang/test/Lower/OpenMP/requires-usm.f90
new file mode 100644
index 0000000000000..600e8387d6ad4
--- /dev/null
+++ b/flang/test/Lower/OpenMP/requires-usm.f90
@@ -0,0 +1,22 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck %s
+
+! Verify that we pick up USM and apply it correctly when it is specified
+! outside of the program.
+
+!CHECK:      module attributes {
+!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
+module declare_mod
+    implicit none
+!$omp requires unified_shared_memory
+ contains
+end module
+
+program main
+ use declare_mod
+ implicit none
+!$omp target
+!$omp end target
+end program

Copy link
Contributor

@mjklemm mjklemm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants