Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 304 additions & 4 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "WebAssemblySubtarget.h"
#include "WebAssemblyTargetMachine.h"
#include "WebAssemblyUtilities.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
Expand Down Expand Up @@ -91,6 +92,19 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setOperationAction(ISD::LOAD, T, Custom);
setOperationAction(ISD::STORE, T, Custom);
}

// Likewise, transform zext/sext/anyext extending loads from address space 1
// (WASM globals)
setLoadExtAction({ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}, MVT::i32,
{MVT::i8, MVT::i16}, Custom);
setLoadExtAction({ISD::EXTLOAD, ISD::ZEXTLOAD, ISD::SEXTLOAD}, MVT::i64,
{MVT::i8, MVT::i16, MVT::i32}, Custom);

// Compensate for the EXTLOADs being custom by reimplementing some combiner
// logic
setTargetDAGCombine(ISD::AND);
setTargetDAGCombine(ISD::SIGN_EXTEND_INREG);

if (Subtarget->hasSIMD128()) {
for (auto T : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v4f32, MVT::v2i64,
MVT::v2f64}) {
Expand Down Expand Up @@ -1707,6 +1721,11 @@ static bool IsWebAssemblyGlobal(SDValue Op) {
if (const GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Op))
return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace());

if (Op->getOpcode() == WebAssemblyISD::Wrapper)
if (const GlobalAddressSDNode *GA =
dyn_cast<GlobalAddressSDNode>(Op->getOperand(0)))
return WebAssembly::isWasmVarAddressSpace(GA->getAddressSpace());

return false;
}

Expand Down Expand Up @@ -1764,16 +1783,115 @@ SDValue WebAssemblyTargetLowering::LowerLoad(SDValue Op,
LoadSDNode *LN = cast<LoadSDNode>(Op.getNode());
const SDValue &Base = LN->getBasePtr();
const SDValue &Offset = LN->getOffset();
ISD::LoadExtType ExtType = LN->getExtensionType();
EVT ResultType = LN->getValueType(0);

if (IsWebAssemblyGlobal(Base)) {
if (!Offset->isUndef())
report_fatal_error(
"unexpected offset when loading from webassembly global", false);

SDVTList Tys = DAG.getVTList(LN->getValueType(0), MVT::Other);
SDValue Ops[] = {LN->getChain(), Base};
return DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops,
LN->getMemoryVT(), LN->getMemOperand());
if (!ResultType.isInteger() && !ResultType.isFloatingPoint()) {
SDVTList Tys = DAG.getVTList(ResultType, MVT::Other);
SDValue Ops[] = {LN->getChain(), Base};
SDValue GlobalGetNode =
DAG.getMemIntrinsicNode(WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops,
LN->getMemoryVT(), LN->getMemOperand());
return GlobalGetNode;
}

EVT GT = MVT::INVALID_SIMPLE_VALUE_TYPE;

if (auto *GA = dyn_cast<GlobalAddressSDNode>(
Base->getOpcode() == WebAssemblyISD::Wrapper ? Base->getOperand(0)
: Base))
GT = EVT::getEVT(GA->getGlobal()->getValueType());

if (GT != MVT::i8 && GT != MVT::i16 && GT != MVT::i32 && GT != MVT::i64 &&
GT != MVT::f32 && GT != MVT::f64)
report_fatal_error("encountered unexpected global type for Base when "
"loading from webassembly global",
false);

EVT PromotedGT = getTypeToTransformTo(*DAG.getContext(), GT);

if (ExtType == ISD::NON_EXTLOAD) {
// A normal, non-extending load may try to load more or less than the
// underlying global, which is invalid. We lower this to a load of the
// global (i32 or i64) then truncate or extend as needed

// Modify the MMO to load the full global
MachineMemOperand *OldMMO = LN->getMemOperand();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this can be made into a local helper since you're reusing it at the bottom as well?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is this the correct approach (and worth making into a helper)? I wasn't sure whether or not it's safe to modify the existing MMO, so I figured I better clone it.

Otherwise, I could call MachineMemOperand::setType on the existing MMO.

Copy link
Contributor

@badumbatish badumbatish Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks safe enough, LLT in each MMO is per SDNode with no references

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But could the MMO as a whole ever be shared? Or is that also created per node at first?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this i actually dont know, it is possible, i'll tag someone in after finishing the review

MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand(
OldMMO->getPointerInfo(), OldMMO->getFlags(),
LLT(PromotedGT.getSimpleVT()), OldMMO->getBaseAlign(),
OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(),
OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering());

SDVTList Tys = DAG.getVTList(PromotedGT, MVT::Other);
SDValue Ops[] = {LN->getChain(), Base};
SDValue GlobalGetNode = DAG.getMemIntrinsicNode(
WebAssemblyISD::GLOBAL_GET, DL, Tys, Ops, PromotedGT, NewMMO);

if (ResultType.bitsEq(PromotedGT)) {
return GlobalGetNode;
}

SDValue ValRes;
if (ResultType.isFloatingPoint())
ValRes = DAG.getFPExtendOrRound(GlobalGetNode, DL, ResultType);
else
ValRes = DAG.getAnyExtOrTrunc(GlobalGetNode, DL, ResultType);

return DAG.getMergeValues({ValRes, GlobalGetNode.getValue(1)}, DL);
}

if (ExtType == ISD::ZEXTLOAD || ExtType == ISD::SEXTLOAD) {
// Turn the unsupported load into an EXTLOAD followed by an
// explicit zero/sign extend inreg. Same as Expand

SDValue Result =
DAG.getExtLoad(ISD::EXTLOAD, DL, ResultType, LN->getChain(), Base,
LN->getMemoryVT(), LN->getMemOperand());
SDValue ValRes;
if (ExtType == ISD::SEXTLOAD)
ValRes = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, Result.getValueType(),
Result, DAG.getValueType(LN->getMemoryVT()));
else
ValRes = DAG.getZeroExtendInReg(Result, DL, LN->getMemoryVT());

return DAG.getMergeValues({ValRes, Result.getValue(1)}, DL);
}

if (ExtType == ISD::EXTLOAD) {
// Expand the EXTLOAD into a regular LOAD of the global, and if
// needed, a zero-extension

EVT OldLoadType = LN->getMemoryVT();
EVT NewLoadType = getTypeToTransformTo(*DAG.getContext(), OldLoadType);

// Modify the MMO to load a whole WASM "register"'s worth
MachineMemOperand *OldMMO = LN->getMemOperand();
MachineMemOperand *NewMMO = DAG.getMachineFunction().getMachineMemOperand(
OldMMO->getPointerInfo(), OldMMO->getFlags(),
LLT(NewLoadType.getSimpleVT()), OldMMO->getBaseAlign(),
OldMMO->getAAInfo(), OldMMO->getRanges(), OldMMO->getSyncScopeID(),
OldMMO->getSuccessOrdering(), OldMMO->getFailureOrdering());

SDValue Result =
DAG.getLoad(NewLoadType, DL, LN->getChain(), Base, NewMMO);

if (NewLoadType != ResultType) {
SDValue ValRes = DAG.getNode(ISD::ANY_EXTEND, DL, ResultType, Result);
return DAG.getMergeValues({ValRes, Result.getValue(1)}, DL);
}

return Result;
}

report_fatal_error(
"encountered unexpected ExtType when loading from webassembly global",
false);
}

if (std::optional<unsigned> Local = IsWebAssemblyLocal(Base, DAG)) {
Expand Down Expand Up @@ -3637,6 +3755,184 @@ static SDValue performMulCombine(SDNode *N,
}
}

static SDValue performANDCombine(SDNode *N,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I found one solution, but I don't like it. Duplicate the combiner logic here.

It works for simd.ll, but instruction cost calculation for the LoopVectorize related tests is still off.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree... this shouldn't be the way to go.

TargetLowering::DAGCombinerInfo &DCI) {
// Copied and modified from DAGCombiner::visitAND(SDNode *N)
// We have to do this because the original combiner doesn't work when ZEXTLOAD
// has custom lowering

SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SDLoc DL(N);

// fold (and (X (load ([non_ext|any_ext|zero_ext] V))), c) ->
// (X (load ([non_ext|zero_ext] V))) if 'and' only clears top bits which must
// already be zero by virtue of the width of the base type of the load.
//
// the 'X' node here can either be nothing or an extract_vector_elt to catch
// more cases.
if ((N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
N0.getValueSizeInBits() == N0.getOperand(0).getScalarValueSizeInBits() &&
N0.getOperand(0).getOpcode() == ISD::LOAD &&
N0.getOperand(0).getResNo() == 0) ||
(N0.getOpcode() == ISD::LOAD && N0.getResNo() == 0)) {
auto *Load =
cast<LoadSDNode>((N0.getOpcode() == ISD::LOAD) ? N0 : N0.getOperand(0));

// Get the constant (if applicable) the zero'th operand is being ANDed with.
// This can be a pure constant or a vector splat, in which case we treat the
// vector as a scalar and use the splat value.
APInt Constant = APInt::getZero(1);
if (const ConstantSDNode *C = isConstOrConstSplat(
N1, /*AllowUndefs=*/false, /*AllowTruncation=*/true)) {
Constant = C->getAPIntValue();
} else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
unsigned EltBitWidth = Vector->getValueType(0).getScalarSizeInBits();
APInt SplatValue, SplatUndef;
unsigned SplatBitSize;
bool HasAnyUndefs;
// Endianness should not matter here. Code below makes sure that we only
// use the result if the SplatBitSize is a multiple of the vector element
// size. And after that we AND all element sized parts of the splat
// together. So the end result should be the same regardless of in which
// order we do those operations.
const bool IsBigEndian = false;
bool IsSplat =
Vector->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
HasAnyUndefs, EltBitWidth, IsBigEndian);

// Make sure that variable 'Constant' is only set if 'SplatBitSize' is a
// multiple of 'BitWidth'. Otherwise, we could propagate a wrong value.
if (IsSplat && (SplatBitSize % EltBitWidth) == 0) {
// Undef bits can contribute to a possible optimisation if set, so
// set them.
SplatValue |= SplatUndef;

// The splat value may be something like "0x00FFFFFF", which means 0 for
// the first vector value and FF for the rest, repeating. We need a mask
// that will apply equally to all members of the vector, so AND all the
// lanes of the constant together.
Constant = APInt::getAllOnes(EltBitWidth);
for (unsigned i = 0, n = (SplatBitSize / EltBitWidth); i < n; ++i)
Constant &= SplatValue.extractBits(EltBitWidth, i * EltBitWidth);
}
}

// If we want to change an EXTLOAD to a ZEXTLOAD, ensure a ZEXTLOAD is
// actually legal and isn't going to get expanded, else this is a false
// optimisation.

/*bool CanZextLoadProfitably = TLI.isLoadExtLegal(ISD::ZEXTLOAD,
Load->getValueType(0),
Load->getMemoryVT());*/
// MODIFIED: this is the one difference in the logic; we allow ZEXT combine
// only in addrspace 0, where it's legal
bool CanZextLoadProfitably = Load->getAddressSpace() == 0;

// Resize the constant to the same size as the original memory access before
// extension. If it is still the AllOnesValue then this AND is completely
// unneeded.
Constant = Constant.zextOrTrunc(Load->getMemoryVT().getScalarSizeInBits());

bool B;
switch (Load->getExtensionType()) {
default:
B = false;
break;
case ISD::EXTLOAD:
B = CanZextLoadProfitably;
break;
case ISD::ZEXTLOAD:
case ISD::NON_EXTLOAD:
B = true;
break;
}

if (B && Constant.isAllOnes()) {
// If the load type was an EXTLOAD, convert to ZEXTLOAD in order to
// preserve semantics once we get rid of the AND.
SDValue NewLoad(Load, 0);

// Fold the AND away. NewLoad may get replaced immediately.
DCI.CombineTo(N, (N0.getNode() == Load) ? NewLoad : N0);

if (Load->getExtensionType() == ISD::EXTLOAD) {
NewLoad = DCI.DAG.getLoad(
Load->getAddressingMode(), ISD::ZEXTLOAD, Load->getValueType(0),
SDLoc(Load), Load->getChain(), Load->getBasePtr(),
Load->getOffset(), Load->getMemoryVT(), Load->getMemOperand());
// Replace uses of the EXTLOAD with the new ZEXTLOAD.
if (Load->getNumValues() == 3) {
// PRE/POST_INC loads have 3 values.
SDValue To[] = {NewLoad.getValue(0), NewLoad.getValue(1),
NewLoad.getValue(2)};
DCI.CombineTo(Load, ArrayRef<SDValue>(To, 3), true);
} else {
DCI.CombineTo(Load, NewLoad.getValue(0), NewLoad.getValue(1));
}
}

return SDValue(N, 0); // Return N so it doesn't get rechecked!
}
}
return SDValue();
}

static SDValue
performSIGN_EXTEND_INREGCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
// Copied and modified from DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N)
// We have to do this because the original combiner doesn't work when SEXTLOAD
// has custom lowering

SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N->getValueType(0);
EVT ExtVT = cast<VTSDNode>(N1)->getVT();
SDLoc DL(N);

// fold (sext_inreg (extload x)) -> (sextload x)
// If sextload is not supported by target, we can only do the combine when
// load has one use. Doing otherwise can block folding the extload with other
// extends that the target does support.

// MODIFIED: replaced TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) with
// cast<LoadSDNode>(N0)->getAddressSpace() == 0)
if (ISD::isEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
((!DCI.isAfterLegalizeDAG() && cast<LoadSDNode>(N0)->isSimple() &&
N0.hasOneUse()) ||
cast<LoadSDNode>(N0)->getAddressSpace() == 0)) {
auto *LN0 = cast<LoadSDNode>(N0);
SDValue ExtLoad =
DCI.DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
LN0->getBasePtr(), ExtVT, LN0->getMemOperand());
DCI.CombineTo(N, ExtLoad);
DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
DCI.AddToWorklist(ExtLoad.getNode());
return SDValue(N, 0); // Return N so it doesn't get rechecked!
}

// fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use

// MODIFIED: replaced TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, ExtVT)) with
// cast<LoadSDNode>(N0)->getAddressSpace() == 0)
if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) &&
N0.hasOneUse() && ExtVT == cast<LoadSDNode>(N0)->getMemoryVT() &&
((!DCI.isAfterLegalizeDAG() && cast<LoadSDNode>(N0)->isSimple()) &&
cast<LoadSDNode>(N0)->getAddressSpace() == 0)) {
auto *LN0 = cast<LoadSDNode>(N0);
SDValue ExtLoad =
DCI.DAG.getExtLoad(ISD::SEXTLOAD, DL, VT, LN0->getChain(),
LN0->getBasePtr(), ExtVT, LN0->getMemOperand());
DCI.CombineTo(N, ExtLoad);
DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1));
return SDValue(N, 0); // Return N so it doesn't get rechecked!
}

return SDValue();
}

SDValue
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
Expand Down Expand Up @@ -3672,5 +3968,9 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
}
case ISD::MUL:
return performMulCombine(N, DCI);
case ISD::AND:
return performANDCombine(N, DCI);
case ISD::SIGN_EXTEND_INREG:
return performSIGN_EXTEND_INREGCombine(N, DCI);
}
}
3 changes: 1 addition & 2 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
bool CanLowerReturn(CallingConv::ID CallConv, MachineFunction &MF,
bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
LLVMContext &Context,
const Type *RetTy) const override;
LLVMContext &Context, const Type *RetTy) const override;
SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,
Expand Down
Loading