-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[WebAssembly] Fix lowering of (extending) loads from addrspace(1) globals #155937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WebAssembly] Fix lowering of (extending) loads from addrspace(1) globals #155937
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-backend-webassembly Author: None (QuantumSegfault) ChangesImplements custom EXTLOAD lowering logic to enable correct lowering of such from WASM (address space 1) globals, similarily to the existing LOAD logic. These fixes also include making sure that the global.load is done for the type of the underlying global (accounting for globals declared smaller than i32), and the loaded value is extended/truncated to the desired result width (preventing invalid instruction sequences). Loads from integral or floating-point globals of widths other than 8, 16, 32, or 64 bits are explicitly disallowed. Full diff: https://github.com/llvm/llvm-project/pull/155937.diff 2 Files Affected:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 5a45134692865..343f85e475016 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -19,6 +19,7 @@
#include "WebAssemblyTargetMachine.h"
#include "WebAssemblyUtilities.h"
#include "llvm/CodeGen/CallingConvLower.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
@@ -111,6 +112,17 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
}
}
+ // Likewise, transform extending loads for address space 1
+ for (auto T : {MVT::i32, MVT::i64}) {
+ for (auto S : {MVT::i8, MVT::i16, MVT::i32}) {
+ if (T != S) {
+ setLoadExtAction(ISD::EXTLOAD, T, S, Custom);
+ setLoadExtAction(ISD::ZEXTLOAD, T, S, Custom);
+ setLoadExtAction(ISD::SEXTLOAD, T, S, Custom);
+ }
+ }
+ }
+
setOperationAction(ISD::GlobalAddress, MVTPtr, Custom);
setOperationAction(ISD::GlobalTLSAddress, MVTPtr, Custom);
setOperationAction(ISD::ExternalSymbol, MVTPtr, Custom);
@@ -1707,6 +1719,11 @@ static bool IsWebAssemblyGlobal(SDValue Op) {
if (const GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Op))
return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace());
+ if (Op->getOpcode() == WebAssemblyISD::Wrapper)
+ if (const GlobalAddressSDNode *GA =
+ dyn_cast<GlobalAddressSDNode>(Op->getOperand(0)))
+ return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace());
+
return false;
}
@@ -1764,16 +1781,110 @@ SDValue WebAssemblyTargetLowering::LowerLoad(SDValue Op,
LoadSDNode *LN = cast<LoadSDNode>(Op.getNode());
const SDValue &Base = LN->getBasePtr();
const SDValue &Offset = LN->getOffset();
+ ISD::LoadExtType ExtType = LN->getExtensionType();
+ EVT ResultType = LN->getValueType(0);
if (IsWebAssemblyGlobal(Base)) {
if (!Offset->isUndef())
report_fatal_error(
"unexpected offset when loading from webassembly global", false);
- SDVTList Tys = DAG.getVTList(LN->getValueType(0), MVT::Other);
- SDValue Ops[] = {LN->getChain(), Base};
- return DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops,
- LN->getMemoryVT(), LN->getMemOperand());
+ EVT GT = MVT::INVALID_SIMPLE_VALUE_TYPE;
+
+ if (const GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Base))
+ GT = EVT::getEVT(GA->getGlobal()->getValueType());
+ if (Base->getOpcode() == WebAssemblyISD::Wrapper)
+ if (const GlobalAddressSDNode *GA =
+ dyn_cast<GlobalAddressSDNode>(Base->getOperand(0)))
+ GT = EVT::getEVT(GA->getGlobal()->getValueType());
+
+ if (GT != MVT::i8 && GT != MVT::i16 && GT != MVT::i32 && GT != MVT::i64 &&
+ GT != MVT::f32 && GT != MVT::f64)
+ report_fatal_error("encountered unexpected global type for Base when "
+ "loading from webassembly global",
+ false);
+
+ EVT PromotedGT = (GT == MVT::i8 || GT == MVT::i16) ? MVT::i32 : GT;
+
+ if (ExtType == ISD::NON_EXTLOAD) {
+ // A normal, non-extending load may try to load more or less than the
+ // underlying global, which is invalid. We lower this to a load of the
+ // global (i32 or i64) then truncate or extend as needed
+
+ // Modify the MMO to load the full global
+ MachineMemOperand *OldMMO = LN->getMemOperand();
+ MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand(
+ OldMMO->getPointerInfo(), OldMMO->getFlags(),
+ LLT(PromotedGT.getSimpleVT()), OldMMO->getBaseAlign(),
+ OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(),
+ OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering());
+
+ SDVTList Tys = DAG.getVTList(PromotedGT, MVT::Other);
+ SDValue Ops[] = {LN->getChain(), Base};
+ SDValue GlobalGetNode = DAG.getMemIntrinsicNode(
+ WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, PromotedGT, NewMMO);
+
+ if (ResultType.bitsEq(PromotedGT)) {
+ return GlobalGetNode;
+ }
+
+ SDValue ValRes;
+ if (ResultType.isFloatingPoint())
+ ValRes = DAG.getFPExtendOrRound(GlobalGetNode, DL, ResultType);
+ else
+ ValRes = DAG.getAnyExtOrTrunc(GlobalGetNode, DL, ResultType);
+
+ return DAG.getMergeValues({ValRes, LN->getChain()}, DL);
+ }
+
+ if (ExtType == ISD::ZEXTLOAD || ExtType == ISD::SEXTLOAD) {
+ // Turn the unsupported load into an EXTLOAD followed by an
+ // explicit zero/sign extend inreg. Same as Expand
+
+ SDValue Result =
+ DAG.getExtLoad(ISD::EXTLOAD, DL, ResultType, LN->getChain(), Base,
+ LN->getMemoryVT(), LN->getMemOperand());
+ SDValue ValRes;
+ if (ExtType == ISD::SEXTLOAD)
+ ValRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, Result.getValueType(),
+ Result, DAG.getValueType(LN->getMemoryVT()));
+ else
+ ValRes = DAG.getZeroExtendInReg(Result, DL, LN->getMemoryVT());
+
+ return DAG.getMergeValues({ValRes, LN->getChain()}, DL);
+ }
+
+ if (ExtType == ISD::EXTLOAD) {
+ // Expand the EXTLOAD into a regular LOAD of the global, and if
+ // needed, a zero-extension
+
+ EVT OldLoadType = LN->getMemoryVT();
+ EVT NewLoadType = (OldLoadType == MVT::i8 || OldLoadType == MVT::i16)
+ ? MVT::i32
+ : OldLoadType;
+
+ // Modify the MMO to load a whole WASM "register"'s worth
+ MachineMemOperand *OldMMO = LN->getMemOperand();
+ MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand(
+ OldMMO->getPointerInfo(), OldMMO->getFlags(),
+ LLT(NewLoadType.getSimpleVT()), OldMMO->getBaseAlign(),
+ OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(),
+ OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering());
+
+ SDValue Result =
+ DAG.getLoad(NewLoadType, DL, LN->getChain(), Base, NewMMO);
+
+ if (NewLoadType != ResultType) {
+ SDValue ValRes = DAG.getNode(ISD::ANY_EXTEND, DL, ResultType, Result);
+ return DAG.getMergeValues({ValRes, LN->getChain()}, DL);
+ }
+
+ return Result;
+ }
+
+ report_fatal_error(
+ "encountered unexpected ExtType when loading from webassembly global",
+ false);
}
if (std::optional<unsigned> Local = IsWebAssemblyLocal(Base, DAG)) {
diff --git a/llvm/test/CodeGen/WebAssembly/lower-load-wasm-global.ll b/llvm/test/CodeGen/WebAssembly/lower-load-wasm-global.ll
new file mode 100644
index 0000000000000..0112296df1aa8
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/lower-load-wasm-global.ll
@@ -0,0 +1,185 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+
+; Test various loads from WASM (address space 1) globals lower as intended
+
+target triple = "wasm32-unknown-unknown"
+
+
+@globalI8 = local_unnamed_addr addrspace(1) global i8 undef
+@globalI32 = local_unnamed_addr addrspace(1) global i32 undef
+@globalI64 = local_unnamed_addr addrspace(1) global i64 undef
+
+
+define i32 @zext_i8_i32() {
+; CHECK-LABEL: zext_i8_i32:
+; CHECK: .functype zext_i8_i32 () -> (i32)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI32
+; CHECK-NEXT: i32.const 255
+; CHECK-NEXT: i32.and
+; CHECK-NEXT: # fallthrough-return
+ %v = load i8, ptr addrspace(1) @globalI32
+ %e = zext i8 %v to i32
+ ret i32 %e
+}
+
+define i32 @sext_i8_i32() {
+; CHECK-LABEL: sext_i8_i32:
+; CHECK: .functype sext_i8_i32 () -> (i32)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI32
+; CHECK-NEXT: i32.extend8_s
+; CHECK-NEXT: # fallthrough-return
+ %v = load i8, ptr addrspace(1) @globalI32
+ %e = sext i8 %v to i32
+ ret i32 %e
+}
+
+define i32 @zext_i16_i32() {
+; CHECK-LABEL: zext_i16_i32:
+; CHECK: .functype zext_i16_i32 () -> (i32)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI32
+; CHECK-NEXT: i32.const 65535
+; CHECK-NEXT: i32.and
+; CHECK-NEXT: # fallthrough-return
+ %v = load i16, ptr addrspace(1) @globalI32
+ %e = zext i16 %v to i32
+ ret i32 %e
+}
+
+define i32 @sext_i16_i32() {
+; CHECK-LABEL: sext_i16_i32:
+; CHECK: .functype sext_i16_i32 () -> (i32)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI32
+; CHECK-NEXT: i32.extend16_s
+; CHECK-NEXT: # fallthrough-return
+ %v = load i16, ptr addrspace(1) @globalI32
+ %e = sext i16 %v to i32
+ ret i32 %e
+}
+
+
+define i64 @zext_i8_i64() {
+; CHECK-LABEL: zext_i8_i64:
+; CHECK: .functype zext_i8_i64 () -> (i64)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI64
+; CHECK-NEXT: i64.const 255
+; CHECK-NEXT: i64.and
+; CHECK-NEXT: # fallthrough-return
+ %v = load i8, ptr addrspace(1) @globalI64
+ %e = zext i8 %v to i64
+ ret i64 %e
+}
+
+define i64 @sext_i8_i64() {
+; CHECK-LABEL: sext_i8_i64:
+; CHECK: .functype sext_i8_i64 () -> (i64)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI64
+; CHECK-NEXT: i64.extend8_s
+; CHECK-NEXT: # fallthrough-return
+ %v = load i8, ptr addrspace(1) @globalI64
+ %e = sext i8 %v to i64
+ ret i64 %e
+}
+
+define i64 @zext_i16_i64() {
+; CHECK-LABEL: zext_i16_i64:
+; CHECK: .functype zext_i16_i64 () -> (i64)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI64
+; CHECK-NEXT: i64.const 65535
+; CHECK-NEXT: i64.and
+; CHECK-NEXT: # fallthrough-return
+ %v = load i16, ptr addrspace(1) @globalI64
+ %e = zext i16 %v to i64
+ ret i64 %e
+}
+
+define i64 @sext_i16_i64() {
+; CHECK-LABEL: sext_i16_i64:
+; CHECK: .functype sext_i16_i64 () -> (i64)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI64
+; CHECK-NEXT: i64.extend16_s
+; CHECK-NEXT: # fallthrough-return
+ %v = load i16, ptr addrspace(1) @globalI64
+ %e = sext i16 %v to i64
+ ret i64 %e
+}
+
+define i64 @zext_i32_i64() {
+; CHECK-LABEL: zext_i32_i64:
+; CHECK: .functype zext_i32_i64 () -> (i64)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI64
+; CHECK-NEXT: i64.const 4294967295
+; CHECK-NEXT: i64.and
+; CHECK-NEXT: # fallthrough-return
+ %v = load i32, ptr addrspace(1) @globalI64
+ %e = zext i32 %v to i64
+ ret i64 %e
+}
+
+define i64 @sext_i32_i64() {
+; CHECK-LABEL: sext_i32_i64:
+; CHECK: .functype sext_i32_i64 () -> (i64)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI64
+; CHECK-NEXT: i64.extend32_s
+; CHECK-NEXT: # fallthrough-return
+ %v = load i32, ptr addrspace(1) @globalI64
+ %e = sext i32 %v to i64
+ ret i64 %e
+}
+
+
+define i64 @load_i64_from_i32() {
+; CHECK-LABEL: load_i64_from_i32:
+; CHECK: .functype load_i64_from_i32 () -> (i64)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI32
+; CHECK-NEXT: i64.extend_i32_u
+; CHECK-NEXT: # fallthrough-return
+ %v = load i64, ptr addrspace(1) @globalI32
+ ret i64 %v
+}
+
+define i32 @load_i32_from_i64() {
+; CHECK-LABEL: load_i32_from_i64:
+; CHECK: .functype load_i32_from_i64 () -> (i32)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI64
+; CHECK-NEXT: i32.wrap_i64
+; CHECK-NEXT: # fallthrough-return
+ %v = load i32, ptr addrspace(1) @globalI64
+ ret i32 %v
+}
+
+define i8 @load_i8() {
+; CHECK-LABEL: load_i8:
+; CHECK: .functype load_i8 () -> (i32)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI8
+; CHECK-NEXT: # fallthrough-return
+ %v = load i8, ptr addrspace(1) @globalI8
+ ret i8 %v
+}
+
+define i64 @load_i16_from_i8_zext_to_i64() {
+; CHECK-LABEL: load_i16_from_i8_zext_to_i64:
+; CHECK: .functype load_i16_from_i8_zext_to_i64 () -> (i64)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: global.get globalI8
+; CHECK-NEXT: i64.extend_i32_u
+; CHECK-NEXT: i64.const 65535
+; CHECK-NEXT: i64.and
+; CHECK-NEXT: # fallthrough-return
+ %v = load i16, ptr addrspace(1) @globalI8
+ %e = zext i16 %v to i64
+ ret i64 %e
+}
|
I've never worked with the LLVM code before, so please let me know if there was a simpler or cleaner way to do any of that. |
I ran the test-suite, and my changes introduced regressions in three tests.
The SIMD test fails now because there are various ANDs that are no longer being folded into the zext loads. I'm guessing that's because I marked extending loads "Custom" and it can no longer reason about it? The other two tests seem to be failing for similar reasons. What should I do about it? Is my code flawed, or do the tests need to be regenerated? |
06c51ab
to
6e3d317
Compare
Implements custom EXTLOAD lowering logic to enable correct lowering of such from WASM (address space 1) globals, similarily to the existing LOAD logic.
These fixes also include making sure that the global.load is done for the type of the underlying global (accounting for globals declared smaller than i32), and the loaded value is extended/truncated to the desired result width (preventing invalid instruction sequences).
Loads from integral or floating-point globals of widths other than 8, 16, 32, or 64 bits are explicitly disallowed.
Fixes: #155880