-
Notifications
You must be signed in to change notification settings - Fork 15k
[WebAssembly] Fix lowering of (extending) loads from addrspace(1) globals #155937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
#include "WebAssemblySubtarget.h" | ||
#include "WebAssemblyTargetMachine.h" | ||
#include "WebAssemblyUtilities.h" | ||
#include "llvm/ADT/ArrayRef.h" | ||
#include "llvm/CodeGen/CallingConvLower.h" | ||
#include "llvm/CodeGen/MachineFrameInfo.h" | ||
#include "llvm/CodeGen/MachineInstrBuilder.h" | ||
|
@@ -91,6 +92,19 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( | |
setOperationAction(ISD::LOAD, T, Custom); | ||
setOperationAction(ISD::STORE, T, Custom); | ||
} | ||
|
||
// Likewise, transform zext/sext/anyext extending loads from address space 1 | ||
// (WASM globals) | ||
setLoadExtAction({ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}, MVT::i32, | ||
{MVT::i8, MVT::i16}, Custom); | ||
setLoadExtAction({ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}, MVT::i64, | ||
{MVT::i8, MVT::i16, MVT::i32}, Custom); | ||
|
||
// Compensate for the EXTLOADs being custom by reimplementing some combiner | ||
// logic | ||
setTargetDAGCombine(ISD::AND); | ||
setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); | ||
|
||
if (Subtarget->hasSIMD128()) { | ||
for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64, | ||
MVT::v2f64}) { | ||
|
@@ -1707,6 +1721,11 @@ static bool IsWebAssemblyGlobal(SDValue Op) { | |
if (const GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Op)) | ||
return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace()); | ||
|
||
if (Op->getOpcode() == WebAssemblyISD::Wrapper) | ||
if (const GlobalAddressSDNode *GA = | ||
dyn_cast<GlobalAddressSDNode>(Op->getOperand(0))) | ||
return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace()); | ||
|
||
return false; | ||
} | ||
|
||
|
@@ -1764,16 +1783,115 @@ SDValue WebAssemblyTargetLowering::LowerLoad(SDValue Op, | |
LoadSDNode *LN = cast<LoadSDNode>(Op.getNode()); | ||
const SDValue &Base = LN->getBasePtr(); | ||
const SDValue &Offset = LN->getOffset(); | ||
ISD::LoadExtType ExtType = LN->getExtensionType(); | ||
EVT ResultType = LN->getValueType(0); | ||
|
||
if (IsWebAssemblyGlobal(Base)) { | ||
if (!Offset->isUndef()) | ||
report_fatal_error( | ||
"unexpected offset when loading from webassembly global", false); | ||
|
||
SDVTList Tys = DAG.getVTList(LN->getValueType(0), MVT::Other); | ||
SDValue Ops[] = {LN->getChain(), Base}; | ||
return DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, | ||
LN->getMemoryVT(), LN->getMemOperand()); | ||
if (!ResultType.isInteger() && !ResultType.isFloatingPoint()) { | ||
SDVTList Tys = DAG.getVTList(ResultType, MVT::Other); | ||
SDValue Ops[] = {LN->getChain(), Base}; | ||
SDValue GlobalGetNode = | ||
DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, | ||
LN->getMemoryVT(), LN->getMemOperand()); | ||
return GlobalGetNode; | ||
} | ||
|
||
EVT GT = MVT::INVALID_SIMPLE_VALUE_TYPE; | ||
|
||
if (auto *GA = dyn_cast<GlobalAddressSDNode>( | ||
Base->getOpcode() == WebAssemblyISD::Wrapper ? Base->getOperand(0) | ||
: Base)) | ||
GT = EVT::getEVT(GA->getGlobal()->getValueType()); | ||
|
||
if (GT != MVT::i8 && GT != MVT::i16 && GT != MVT::i32 && GT != MVT::i64 && | ||
GT != MVT::f32 && GT != MVT::f64) | ||
report_fatal_error("encountered unexpected global type for Base when " | ||
"loading from webassembly global", | ||
false); | ||
|
||
EVT PromotedGT = getTypeToTransformTo(*DAG.getContext(), GT); | ||
|
||
if (ExtType == ISD::NON_EXTLOAD) { | ||
// A normal, non-extending load may try to load more or less than the | ||
// underlying global, which is invalid. We lower this to a load of the | ||
// global (i32 or i64) then truncate or extend as needed | ||
|
||
// Modify the MMO to load the full global | ||
MachineMemOperand *OldMMO = LN->getMemOperand(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems like this can be made into a local helper since you're reusing it at the bottom as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So is this the correct approach (and worth making into a helper)? I wasn't sure whether or not it's safe to modify the existing MMO, so I figured I better clone it. Otherwise, I could call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks safe enough, LLT in each MMO is per SDNode with no references There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But could the MMO as a whole ever be shared? Or is that also created per node at first? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm this i actually dont know, it is possible, i'll tag someone in after finishing the review |
||
MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand( | ||
OldMMO->getPointerInfo(), OldMMO->getFlags(), | ||
LLT(PromotedGT.getSimpleVT()), OldMMO->getBaseAlign(), | ||
OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(), | ||
OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering()); | ||
|
||
SDVTList Tys = DAG.getVTList(PromotedGT, MVT::Other); | ||
SDValue Ops[] = {LN->getChain(), Base}; | ||
SDValue GlobalGetNode = DAG.getMemIntrinsicNode( | ||
WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, PromotedGT, NewMMO); | ||
|
||
if (ResultType.bitsEq(PromotedGT)) { | ||
return GlobalGetNode; | ||
} | ||
|
||
SDValue ValRes; | ||
if (ResultType.isFloatingPoint()) | ||
ValRes = DAG.getFPExtendOrRound(GlobalGetNode, DL, ResultType); | ||
else | ||
ValRes = DAG.getAnyExtOrTrunc(GlobalGetNode, DL, ResultType); | ||
|
||
return DAG.getMergeValues({ValRes, GlobalGetNode.getValue(1)}, DL); | ||
} | ||
|
||
if (ExtType == ISD::ZEXTLOAD || ExtType == ISD::SEXTLOAD) { | ||
// Turn the unsupported load into an EXTLOAD followed by an | ||
// explicit zero/sign extend inreg. Same as Expand | ||
|
||
SDValue Result = | ||
DAG.getExtLoad(ISD::EXTLOAD, DL, ResultType, LN->getChain(), Base, | ||
LN->getMemoryVT(), LN->getMemOperand()); | ||
SDValue ValRes; | ||
if (ExtType == ISD::SEXTLOAD) | ||
ValRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, Result.getValueType(), | ||
Result, DAG.getValueType(LN->getMemoryVT())); | ||
else | ||
ValRes = DAG.getZeroExtendInReg(Result, DL, LN->getMemoryVT()); | ||
|
||
return DAG.getMergeValues({ValRes, Result.getValue(1)}, DL); | ||
} | ||
|
||
if (ExtType == ISD::EXTLOAD) { | ||
// Expand the EXTLOAD into a regular LOAD of the global, and if | ||
// needed, a zero-extension | ||
|
||
EVT OldLoadType = LN->getMemoryVT(); | ||
EVT NewLoadType = getTypeToTransformTo(*DAG.getContext(), OldLoadType); | ||
|
||
// Modify the MMO to load a whole WASM "register"'s worth | ||
MachineMemOperand *OldMMO = LN->getMemOperand(); | ||
MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand( | ||
OldMMO->getPointerInfo(), OldMMO->getFlags(), | ||
LLT(NewLoadType.getSimpleVT()), OldMMO->getBaseAlign(), | ||
OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(), | ||
OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering()); | ||
|
||
SDValue Result = | ||
DAG.getLoad(NewLoadType, DL, LN->getChain(), Base, NewMMO); | ||
|
||
if (NewLoadType != ResultType) { | ||
SDValue ValRes = DAG.getNode(ISD::ANY_EXTEND, DL, ResultType, Result); | ||
return DAG.getMergeValues({ValRes, Result.getValue(1)}, DL); | ||
} | ||
|
||
return Result; | ||
} | ||
|
||
report_fatal_error( | ||
"encountered unexpected ExtType when loading from webassembly global", | ||
false); | ||
} | ||
|
||
if (std::optional<unsigned> Local = IsWebAssemblyLocal(Base, DAG)) { | ||
|
@@ -3637,6 +3755,184 @@ static SDValue performMulCombine(SDNode *N, | |
} | ||
} | ||
|
||
static SDValue performANDCombine(SDNode *N, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, I found one solution, but I don't like it. Duplicate the combiner logic here. It works for simd.ll, but instruction cost calculation for the LoopVectorize related tests is still off. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree... this shouldn't be the way to go. |
||
TargetLowering::DAGCombinerInfo &DCI) { | ||
// Copied and modified from DAGCombiner::visitAND(SDNode *N) | ||
// We have to do this because the original combiner doesn't work when ZEXTLOAD | ||
// has custom lowering | ||
|
||
SDValue N0 = N->getOperand(0); | ||
SDValue N1 = N->getOperand(1); | ||
SDLoc DL(N); | ||
|
||
// fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) -> | ||
// (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must | ||
// already be zero by virtue of the width of the base type of the load. | ||
// | ||
// the 'X' node here can either be nothing or an extract_vector_elt to catch | ||
// more cases. | ||
if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT && | ||
N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() && | ||
N0.getOperand(0).getOpcode() == ISD::LOAD && | ||
N0.getOperand(0).getResNo() == 0) || | ||
(N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) { | ||
auto *Load = | ||
cast<LoadSDNode>((N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(0)); | ||
|
||
// Get the constant (if applicable) the zero'th operand is being ANDed with. | ||
// This can be a pure constant or a vector splat, in which case we treat the | ||
// vector as a scalar and use the splat value. | ||
APInt Constant = APInt::getZero(1); | ||
if (const ConstantSDNode *C = isConstOrConstSplat( | ||
N1, /*AllowUndefs=*/false, /*AllowTruncation=*/true)) { | ||
Constant = C->getAPIntValue(); | ||
} else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) { | ||
unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits(); | ||
APInt SplatValue, SplatUndef; | ||
unsigned SplatBitSize; | ||
bool HasAnyUndefs; | ||
// Endianness should not matter here. Code below makes sure that we only | ||
// use the result if the SplatBitSize is a multiple of the vector element | ||
// size. And after that we AND all element sized parts of the splat | ||
// together. So the end result should be the same regardless of in which | ||
// order we do those operations. | ||
const bool IsBigEndian = false; | ||
bool IsSplat = | ||
Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, | ||
HasAnyUndefs, EltBitWidth, IsBigEndian); | ||
|
||
// Make sure that variable 'Constant' is only set if 'SplatBitSize' is a | ||
// multiple of 'BitWidth'. Otherwise, we could propagate a wrong value. | ||
if (IsSplat && (SplatBitSize % EltBitWidth) == 0) { | ||
// Undef bits can contribute to a possible optimisation if set, so | ||
// set them. | ||
SplatValue |= SplatUndef; | ||
|
||
// The splat value may be something like "0x00FFFFFF", which means 0 for | ||
// the first vector value and FF for the rest, repeating. We need a mask | ||
// that will apply equally to all members of the vector, so AND all the | ||
// lanes of the constant together. | ||
Constant = APInt::getAllOnes(EltBitWidth); | ||
for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i) | ||
Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth); | ||
} | ||
} | ||
|
||
// If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is | ||
// actually legal and isn't going to get expanded, else this is a false | ||
// optimisation. | ||
|
||
/*bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD, | ||
Load->getValueType(0), | ||
Load->getMemoryVT());*/ | ||
// MODIFIED: this is the one difference in the logic; we allow ZEXT combine | ||
// only in addrspace 0, where it's legal | ||
bool CanZextLoadProfitably = Load->getAddressSpace() == 0; | ||
|
||
// Resize the constant to the same size as the original memory access before | ||
// extension. If it is still the AllOnesValue then this AND is completely | ||
// unneeded. | ||
Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits()); | ||
|
||
bool B; | ||
switch (Load->getExtensionType()) { | ||
default: | ||
B = false; | ||
break; | ||
case ISD::EXTLOAD: | ||
B = CanZextLoadProfitably; | ||
break; | ||
case ISD::ZEXTLOAD: | ||
case ISD::NON_EXTLOAD: | ||
B = true; | ||
break; | ||
} | ||
|
||
if (B && Constant.isAllOnes()) { | ||
// If the load type was an EXTLOAD, convert to ZEXTLOAD in order to | ||
// preserve semantics once we get rid of the AND. | ||
SDValue NewLoad(Load, 0); | ||
|
||
// Fold the AND away. NewLoad may get replaced immediately. | ||
DCI.CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0); | ||
|
||
if (Load->getExtensionType() == ISD::EXTLOAD) { | ||
NewLoad = DCI.DAG.getLoad( | ||
Load->getAddressingMode(), ISD::ZEXTLOAD, Load->getValueType(0), | ||
SDLoc(Load), Load->getChain(), Load->getBasePtr(), | ||
Load->getOffset(), Load->getMemoryVT(), Load->getMemOperand()); | ||
// Replace uses of the EXTLOAD with the new ZEXTLOAD. | ||
if (Load->getNumValues() == 3) { | ||
// PRE/POST_INC loads have 3 values. | ||
SDValue To[] = {NewLoad.getValue(0), NewLoad.getValue(1), | ||
NewLoad.getValue(2)}; | ||
DCI.CombineTo(Load, ArrayRef<SDValue>(To, 3), true); | ||
} else { | ||
DCI.CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1)); | ||
} | ||
} | ||
|
||
return SDValue(N, 0); // Return N so it doesn't get rechecked! | ||
} | ||
} | ||
return SDValue(); | ||
} | ||
|
||
static SDValue | ||
performSIGN_EXTEND_INREGCombine(SDNode *N, | ||
TargetLowering::DAGCombinerInfo &DCI) { | ||
// Copied and modified from DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) | ||
// We have to do this because the original combiner doesn't work when SEXTLOAD | ||
// has custom lowering | ||
|
||
SDValue N0 = N->getOperand(0); | ||
SDValue N1 = N->getOperand(1); | ||
EVT VT = N->getValueType(0); | ||
EVT ExtVT = cast<VTSDNode>(N1)->getVT(); | ||
SDLoc DL(N); | ||
|
||
// fold (sext_inreg (extload x)) -> (sextload x) | ||
// If sextload is not supported by target, we can only do the combine when | ||
// load has one use. Doing otherwise can block folding the extload with other | ||
// extends that the target does support. | ||
|
||
// MODIFIED: replaced TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) with | ||
// cast<LoadSDNode>(N0)->getAddressSpace() == 0) | ||
if (ISD::isEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && | ||
ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() && | ||
((!DCI.isAfterLegalizeDAG() && cast<LoadSDNode>(N0)->isSimple() && | ||
N0.hasOneUse()) || | ||
cast<LoadSDNode>(N0)->getAddressSpace() == 0)) { | ||
auto *LN0 = cast<LoadSDNode>(N0); | ||
SDValue ExtLoad = | ||
DCI.DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(), | ||
LN0->getBasePtr(), ExtVT, LN0->getMemOperand()); | ||
DCI.CombineTo(N, ExtLoad); | ||
DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); | ||
DCI.AddToWorklist(ExtLoad.getNode()); | ||
return SDValue(N, 0); // Return N so it doesn't get rechecked! | ||
} | ||
|
||
// fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use | ||
|
||
// MODIFIED: replaced TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) with | ||
// cast<LoadSDNode>(N0)->getAddressSpace() == 0) | ||
if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && | ||
N0.hasOneUse() && ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() && | ||
((!DCI.isAfterLegalizeDAG() && cast<LoadSDNode>(N0)->isSimple()) && | ||
cast<LoadSDNode>(N0)->getAddressSpace() == 0)) { | ||
auto *LN0 = cast<LoadSDNode>(N0); | ||
SDValue ExtLoad = | ||
DCI.DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(), | ||
LN0->getBasePtr(), ExtVT, LN0->getMemOperand()); | ||
DCI.CombineTo(N, ExtLoad); | ||
DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); | ||
return SDValue(N, 0); // Return N so it doesn't get rechecked! | ||
} | ||
|
||
return SDValue(); | ||
} | ||
|
||
SDValue | ||
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, | ||
DAGCombinerInfo &DCI) const { | ||
|
@@ -3672,5 +3968,9 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, | |
} | ||
case ISD::MUL: | ||
return performMulCombine(N, DCI); | ||
case ISD::AND: | ||
return performANDCombine(N, DCI); | ||
case ISD::SIGN_EXTEND_INREG: | ||
return performSIGN_EXTEND_INREGCombine(N, DCI); | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.