Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// - Define module variables and OpenMP/OpenACC declarative constructs so
// they are available before lowering any function that may use them.
bool hasMainProgram = false;
const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
llvm::SmallVector<const Fortran::semantics::Symbol *>
globalOmpRequiresSymbols;
createBuilderOutsideOfFuncOpAndDo([&]() {
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
Fortran::common::visit(
Expand All @@ -424,21 +425,23 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (f.isMainProgram())
hasMainProgram = true;
declareFunction(f);
if (!globalOmpRequiresSymbol)
globalOmpRequiresSymbol = f.getScope().symbol();
globalOmpRequiresSymbols.push_back(f.getScope().symbol());
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
for (Fortran::lower::pft::ContainedUnit &unit :
m.containedUnitList)
if (auto *f =
std::get_if<Fortran::lower::pft::FunctionLikeUnit>(
&unit))
&unit)) {
declareFunction(*f);
globalOmpRequiresSymbols.push_back(
f->getScope().symbol());
}
globalOmpRequiresSymbols.push_back(m.getScope().symbol());
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
if (!globalOmpRequiresSymbol)
globalOmpRequiresSymbol = b.symTab.symbol();
globalOmpRequiresSymbols.push_back(b.symTab.symbol());
},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
Expand Down Expand Up @@ -481,7 +484,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::common::LanguageFeature::Coarray));
});

finalizeOpenMPLowering(globalOmpRequiresSymbol);
finalizeOpenMPLowering(globalOmpRequiresSymbols);
}

/// Declare a function.
Expand Down Expand Up @@ -6681,7 +6684,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Performing OpenMP lowering actions that were deferred to the end of
/// lowering.
void finalizeOpenMPLowering(
const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
&globalOmpRequiresSymbol) {
if (!ompDeferredDeclareTarget.empty()) {
bool deferredDeviceFuncFound =
Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
Expand All @@ -6690,9 +6694,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}

// Set the module attribute related to OpenMP requires directives
if (ompDeviceCodeFound)
Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
globalOmpRequiresSymbol);
if (ompDeviceCodeFound) {
for (auto sym : globalOmpRequiresSymbol)
Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), sym);
}
}

/// Record fir.dummy_scope operation for this function.
Expand Down
22 changes: 22 additions & 0 deletions flang/test/Lower/OpenMP/requires-usm.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s
! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-hlfir %s -o - | FileCheck %s

! Verify that we pick up USM and apply it correctly when it is specified
! outside of the program.

!CHECK: module attributes {
!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
module declare_mod
implicit none
!$omp requires unified_shared_memory
contains
end module

program main
use declare_mod
implicit none
!$omp target
!$omp end target
end program