diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 68adf346fe8c0..358b57d76d32e 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -415,7 +415,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { // - Define module variables and OpenMP/OpenACC declarative constructs so // they are available before lowering any function that may use them. bool hasMainProgram = false; - const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr; + llvm::SmallVector + globalOmpRequiresSymbols; createBuilderOutsideOfFuncOpAndDo([&]() { for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) { Fortran::common::visit( @@ -424,8 +425,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { if (f.isMainProgram()) hasMainProgram = true; declareFunction(f); - if (!globalOmpRequiresSymbol) - globalOmpRequiresSymbol = f.getScope().symbol(); + globalOmpRequiresSymbols.push_back(f.getScope().symbol()); }, [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerModuleDeclScope(m); @@ -433,12 +433,15 @@ class FirConverter : public Fortran::lower::AbstractConverter { m.containedUnitList) if (auto *f = std::get_if( - &unit)) + &unit)) { declareFunction(*f); + globalOmpRequiresSymbols.push_back( + f->getScope().symbol()); + } + globalOmpRequiresSymbols.push_back(m.getScope().symbol()); }, [&](Fortran::lower::pft::BlockDataUnit &b) { - if (!globalOmpRequiresSymbol) - globalOmpRequiresSymbol = b.symTab.symbol(); + globalOmpRequiresSymbols.push_back(b.symTab.symbol()); }, [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {}, [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {}, @@ -481,7 +484,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { Fortran::common::LanguageFeature::Coarray)); }); - finalizeOpenMPLowering(globalOmpRequiresSymbol); + finalizeOpenMPLowering(globalOmpRequiresSymbols); } /// Declare a function. @@ -6681,7 +6684,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { /// Performing OpenMP lowering actions that were deferred to the end of /// lowering. void finalizeOpenMPLowering( - const Fortran::semantics::Symbol *globalOmpRequiresSymbol) { + llvm::SmallVectorImpl + &globalOmpRequiresSymbol) { if (!ompDeferredDeclareTarget.empty()) { bool deferredDeviceFuncFound = Fortran::lower::markOpenMPDeferredDeclareTargetFunctions( @@ -6690,9 +6694,10 @@ class FirConverter : public Fortran::lower::AbstractConverter { } // Set the module attribute related to OpenMP requires directives - if (ompDeviceCodeFound) - Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), - globalOmpRequiresSymbol); + if (ompDeviceCodeFound) { + for (auto sym : globalOmpRequiresSymbol) + Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), sym); + } } /// Record fir.dummy_scope operation for this function. diff --git a/flang/test/Lower/OpenMP/requires-usm.f90 b/flang/test/Lower/OpenMP/requires-usm.f90 new file mode 100644 index 0000000000000..600e8387d6ad4 --- /dev/null +++ b/flang/test/Lower/OpenMP/requires-usm.f90 @@ -0,0 +1,22 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s +! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s +! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck %s + +! Verify that we pick up USM and apply it correctly when it is specified +! outside of the program. + +!CHECK: module attributes { +!CHECK-SAME: omp.requires = #omp +module declare_mod + implicit none +!$omp requires unified_shared_memory + contains +end module + +program main + use declare_mod + implicit none +!$omp target +!$omp end target +end program