Skip to content

JIT: Always create FIELD_LIST for struct args in physical promotion #118778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 64 additions & 164 deletions src/coreclr/jit/promotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,40 +733,10 @@ class LocalUses

if (otherAccess.CountRegCallArgs > 0)
{
auto willPassFieldInRegister = [=, &access, &otherAccess]() {
if (access.Offset < otherAccess.Offset)
{
return false;
}

unsigned layoutOffset = access.Offset - otherAccess.Offset;
if ((layoutOffset % TARGET_POINTER_SIZE) != 0)
{
return false;
}

unsigned accessSize = genTypeSize(access.AccessType);
if (accessSize == TARGET_POINTER_SIZE)
{
return true;
}

const SegmentList& significantSegments = otherAccess.Layout->GetNonPadding(comp);
if (!significantSegments.Intersects(
SegmentList::Segment(layoutOffset + accessSize, layoutOffset + TARGET_POINTER_SIZE)))
{
return true;
}

return false;
};
// We may be able to decompose the call argument to require no
// The call argument will be decomposed and will not require a
// write-back.
if (willPassFieldInRegister())
{
countOverlappedCallArg -= otherAccess.CountRegCallArgs;
countOverlappedCallArgWtd -= otherAccess.CountRegCallArgsWtd;
}
countOverlappedCallArg -= otherAccess.CountRegCallArgs;
countOverlappedCallArgWtd -= otherAccess.CountRegCallArgsWtd;
}
}

Expand Down Expand Up @@ -2244,16 +2214,16 @@ GenTree** ReplaceVisitor::InsertMidTreeReadBacks(GenTree** use)
// Usually this amounts to replacing the struct local by a FIELD_LIST with
// the promoted fields, but merged returns require more complicated handling.
//
bool ReplaceVisitor::ReplaceStructLocal(GenTree* user, GenTreeLclVarCommon* value)
bool ReplaceVisitor::ReplaceStructLocal(GenTree* user, GenTree** use, GenTreeLclVarCommon* value)
{
if (user->IsCall())
{
return ReplaceCallArgWithFieldList(user->AsCall(), value);
return ReplaceCallArgWithFieldList(user->AsCall(), use, value);
}
else
{
assert(user->OperIs(GT_RETURN, GT_SWIFT_ERROR_RET));
return ReplaceReturnedStructLocal(user->AsOp(), value);
return ReplaceReturnedStructLocal(user->AsOp(), use, value);
}
}

Expand All @@ -2263,6 +2233,7 @@ bool ReplaceVisitor::ReplaceStructLocal(GenTree* user, GenTreeLclVarCommon* valu
//
// Parameters:
// ret - The return node
// use - The edge pointing to 'value'
// value - The struct local
//
// Returns:
Expand All @@ -2276,7 +2247,7 @@ bool ReplaceVisitor::ReplaceStructLocal(GenTree* user, GenTreeLclVarCommon* valu
// being merged. Due to that, and for CQ, we instead decompose a store to the
// return local for that case.
//
bool ReplaceVisitor::ReplaceReturnedStructLocal(GenTreeOp* ret, GenTreeLclVarCommon* value)
bool ReplaceVisitor::ReplaceReturnedStructLocal(GenTreeOp* ret, GenTree** use, GenTreeLclVarCommon* value)
{
if (m_compiler->genReturnLocal != BAD_VAR_NUM)
{
Expand All @@ -2302,16 +2273,43 @@ bool ReplaceVisitor::ReplaceReturnedStructLocal(GenTreeOp* ret, GenTreeLclVarCom
return true;
}

AggregateInfo* agg = m_aggregates.Lookup(value->GetLclNum());
ClassLayout* layout = value->GetLayout(m_compiler);
GenTreeFieldList* fieldList = CreateFieldListForStructLocal(value);

if (fieldList == nullptr)
{
return false;
}

*use = fieldList;

m_madeChanges = true;
return true;
}

//------------------------------------------------------------------------
// CreateFieldListForStructLocal:
// Create a FIELD_LIST node that corresponds to a struct local that has uses
// of promoted fields.
//
// Parameters:
// lcl - The local
//
// Returns:
// A field list node, or null pointer if a FIELD_LIST cannot be created for
// this use of the local.
//
GenTreeFieldList* ReplaceVisitor::CreateFieldListForStructLocal(GenTreeLclVarCommon* lcl)
{
AggregateInfo* agg = m_aggregates.Lookup(lcl->GetLclNum());
ClassLayout* layout = lcl->GetLayout(m_compiler);
assert(layout != nullptr);

unsigned startOffset = value->GetLclOffs();
unsigned startOffset = lcl->GetLclOffs();
unsigned returnValueSize = layout->GetSize();
if (agg->Unpromoted.Intersects(SegmentList::Segment(startOffset, startOffset + returnValueSize)))
{
// TODO-CQ: We could handle cases where the intersected remainder is simple
return false;
return nullptr;
}

auto checkPartialOverlap = [=](Replacement& rep) {
Expand All @@ -2328,12 +2326,12 @@ bool ReplaceVisitor::ReplaceReturnedStructLocal(GenTreeOp* ret, GenTreeLclVarCom
return false;
};

if (!VisitOverlappingReplacements(value->GetLclNum(), startOffset, returnValueSize, checkPartialOverlap))
if (!VisitOverlappingReplacements(lcl->GetLclNum(), startOffset, returnValueSize, checkPartialOverlap))
{
return false;
return nullptr;
}

StructDeaths deaths = m_liveness->GetDeathsForStructLocal(value);
StructDeaths deaths = m_liveness->GetDeathsForStructLocal(lcl);
GenTreeFieldList* fieldList = m_compiler->gtNewFieldList();

auto addField = [=](Replacement& rep) {
Expand All @@ -2342,18 +2340,20 @@ bool ReplaceVisitor::ReplaceReturnedStructLocal(GenTreeOp* ret, GenTreeLclVarCom
{
fieldValue = m_compiler->gtNewLclvNode(rep.LclNum, rep.AccessType);

assert(deaths.IsReplacementDying(static_cast<unsigned>(&rep - agg->Replacements.data())));
fieldValue->gtFlags |= GTF_VAR_DEATH;
CheckForwardSubForLastUse(rep.LclNum);
if (deaths.IsReplacementDying(static_cast<unsigned>(&rep - agg->Replacements.data())))
{
fieldValue->gtFlags |= GTF_VAR_DEATH;
CheckForwardSubForLastUse(rep.LclNum);
}
}
else
{
// Replacement local is not up to date.
fieldValue = m_compiler->gtNewLclFldNode(value->GetLclNum(), rep.AccessType, rep.Offset);
fieldValue = m_compiler->gtNewLclFldNode(lcl->GetLclNum(), rep.AccessType, rep.Offset);

if (!m_compiler->lvaGetDesc(value->GetLclNum())->lvDoNotEnregister)
if (!m_compiler->lvaGetDesc(lcl->GetLclNum())->lvDoNotEnregister)
{
m_compiler->lvaSetVarDoNotEnregister(value->GetLclNum() DEBUGARG(DoNotEnregisterReason::LocalField));
m_compiler->lvaSetVarDoNotEnregister(lcl->GetLclNum() DEBUGARG(DoNotEnregisterReason::LocalField));
}
}

Expand All @@ -2362,12 +2362,9 @@ bool ReplaceVisitor::ReplaceReturnedStructLocal(GenTreeOp* ret, GenTreeLclVarCom
return true;
};

VisitOverlappingReplacements(value->GetLclNum(), startOffset, returnValueSize, addField);

ret->SetReturnValue(fieldList);
VisitOverlappingReplacements(lcl->GetLclNum(), startOffset, returnValueSize, addField);

m_madeChanges = true;
return true;
return fieldList;
}

//------------------------------------------------------------------------
Expand All @@ -2377,9 +2374,14 @@ bool ReplaceVisitor::ReplaceReturnedStructLocal(GenTreeOp* ret, GenTreeLclVarCom
//
// Parameters:
// call - The call
// use - The edge pointing to argNode
// argNode - The argument node
//
bool ReplaceVisitor::ReplaceCallArgWithFieldList(GenTreeCall* call, GenTreeLclVarCommon* argNode)
// Returns:
// True if the call argument was replaced with a FIELD_LIST; false if the
// argument could not be represented as a FIELD_LIST.
//
bool ReplaceVisitor::ReplaceCallArgWithFieldList(GenTreeCall* call, GenTree** use, GenTreeLclVarCommon* argNode)
{
CallArg* callArg = call->gtArgs.FindByNode(argNode);
if (callArg == nullptr)
Expand All @@ -2393,65 +2395,13 @@ bool ReplaceVisitor::ReplaceCallArgWithFieldList(GenTreeCall* call, GenTreeLclVa
return false;
}

AggregateInfo* agg = m_aggregates.Lookup(argNode->GetLclNum());
ClassLayout* layout = argNode->GetLayout(m_compiler);
assert(layout != nullptr);
StructDeaths deaths = m_liveness->GetDeathsForStructLocal(argNode);
GenTreeFieldList* fieldList = m_compiler->gtNewFieldList();
for (const ABIPassingSegment& seg : callArg->AbiInfo.Segments())
GenTreeFieldList* fieldList = CreateFieldListForStructLocal(argNode);
if (fieldList == nullptr)
{
Replacement* rep = nullptr;
if (agg->OverlappingReplacements(argNode->GetLclOffs() + seg.Offset, seg.Size, &rep, nullptr) &&
!rep->NeedsReadBack)
{
GenTreeLclVar* fieldValue = m_compiler->gtNewLclvNode(rep->LclNum, rep->AccessType);

if (deaths.IsReplacementDying(static_cast<unsigned>(rep - agg->Replacements.data())))
{
fieldValue->gtFlags |= GTF_VAR_DEATH;
CheckForwardSubForLastUse(rep->LclNum);
}

fieldList->AddField(m_compiler, fieldValue, seg.Offset, rep->AccessType);
}
else
{
// Unpromoted part, or replacement local is not up to date.
var_types type;
if (rep != nullptr)
{
type = rep->AccessType;
}
else if (genIsValidFloatReg(seg.GetRegister()))
{
type = seg.GetRegisterType();
}
else
{
if ((seg.Offset % TARGET_POINTER_SIZE) == 0 && (seg.Size == TARGET_POINTER_SIZE))
{
type = layout->GetGCPtrType(seg.Offset / TARGET_POINTER_SIZE);
}
else
{
type = seg.GetRegisterType();
}
}

GenTree* fieldValue =
m_compiler->gtNewLclFldNode(argNode->GetLclNum(), type, argNode->GetLclOffs() + seg.Offset);
fieldList->AddField(m_compiler, fieldValue, seg.Offset, type);

if (!m_compiler->lvaGetDesc(argNode->GetLclNum())->lvDoNotEnregister)
{
m_compiler->lvaSetVarDoNotEnregister(argNode->GetLclNum() DEBUGARG(DoNotEnregisterReason::LocalField));
}
}
return false;
}

assert(callArg->GetEarlyNode() == argNode);
callArg->SetEarlyNode(fieldList);

*use = fieldList;
m_madeChanges = true;
return true;
}
Expand All @@ -2476,57 +2426,7 @@ bool ReplaceVisitor::CanReplaceCallArgWithFieldListOfReplacements(GenTreeCall*
{
// We should have computed ABI information during the costing phase.
assert(call->gtArgs.IsAbiInformationDetermined());

if (callArg->AbiInfo.HasAnyStackSegment() || callArg->AbiInfo.IsPassedByReference())
{
return false;
}

AggregateInfo* agg = m_aggregates.Lookup(lcl->GetLclNum());
assert(agg != nullptr);

bool anyReplacements = false;
for (const ABIPassingSegment& seg : callArg->AbiInfo.Segments())
{
assert(seg.IsPassedInRegister());

auto callback = [=, &anyReplacements, &seg](Replacement& rep) {
anyReplacements = true;

// Replacement must start at the right offset...
if (rep.Offset != lcl->GetLclOffs() + seg.Offset)
{
return false;
}

// It must not be too long..
unsigned repSize = genTypeSize(rep.AccessType);
if (repSize > seg.Size)
{
return false;
}

// If it is too short, the remainder that would be passed in the
// register should be padding. We can check that by only checking
// whether the remainder intersects anything unpromoted, since if
// the remainder is a different promotion we will return false when
// the replacement is visited in this callback.
if ((repSize < seg.Size) &&
agg->Unpromoted.Intersects(SegmentList::Segment(rep.Offset + repSize, rep.Offset + seg.Size)))
{
return false;
}

return true;
};

if (!VisitOverlappingReplacements(lcl->GetLclNum(), lcl->GetLclOffs() + seg.Offset, seg.Size, callback))
{
return false;
}
}

return anyReplacements;
return !callArg->AbiInfo.HasAnyStackSegment() && !callArg->AbiInfo.IsPassedByReference();
}

//------------------------------------------------------------------------
Expand Down Expand Up @@ -2696,7 +2596,7 @@ void ReplaceVisitor::ReplaceLocal(GenTree** use, GenTree* user)

assert(effectiveUser->OperIs(GT_CALL, GT_RETURN, GT_SWIFT_ERROR_RET));

if (!ReplaceStructLocal(effectiveUser, lcl))
if (!ReplaceStructLocal(effectiveUser, use, lcl))
{
unsigned size = lcl->GetLayout(m_compiler)->GetSize();
WriteBackBeforeUse(use, lclNum, lcl->GetLclOffs(), size);
Expand Down
21 changes: 11 additions & 10 deletions src/coreclr/jit/promotion.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,18 @@ class ReplaceVisitor : public GenTreeVisitor<ReplaceVisitor>
void InsertPreStatementWriteBacks();
GenTree** InsertMidTreeReadBacks(GenTree** use);

bool ReplaceStructLocal(GenTree* user, GenTreeLclVarCommon* value);
bool ReplaceReturnedStructLocal(GenTreeOp* ret, GenTreeLclVarCommon* value);
bool ReplaceCallArgWithFieldList(GenTreeCall* call, GenTreeLclVarCommon* callArg);
bool ReplaceStructLocal(GenTree* user, GenTree** use, GenTreeLclVarCommon* value);
bool ReplaceReturnedStructLocal(GenTreeOp* ret, GenTree** use, GenTreeLclVarCommon* value);
bool ReplaceCallArgWithFieldList(GenTreeCall* call, GenTree** use, GenTreeLclVarCommon* callArg);
bool CanReplaceCallArgWithFieldListOfReplacements(GenTreeCall* call, CallArg* callArg, GenTreeLclVarCommon* lcl);
void ReadBackAfterCall(GenTreeCall* call, GenTree* user);
bool IsPromotedStructLocalDying(GenTreeLclVarCommon* structLcl);
void ReplaceLocal(GenTree** use, GenTree* user);
void CheckForwardSubForLastUse(unsigned lclNum);
void WriteBackBeforeCurrentStatement(unsigned lcl, unsigned offs, unsigned size);
void WriteBackBeforeUse(GenTree** use, unsigned lcl, unsigned offs, unsigned size);
void MarkForReadBack(GenTreeLclVarCommon* lcl, unsigned size DEBUGARG(const char* reason));
GenTreeFieldList* CreateFieldListForStructLocal(GenTreeLclVarCommon* value);
void ReadBackAfterCall(GenTreeCall* call, GenTree* user);
bool IsPromotedStructLocalDying(GenTreeLclVarCommon* structLcl);
void ReplaceLocal(GenTree** use, GenTree* user);
void CheckForwardSubForLastUse(unsigned lclNum);
void WriteBackBeforeCurrentStatement(unsigned lcl, unsigned offs, unsigned size);
void WriteBackBeforeUse(GenTree** use, unsigned lcl, unsigned offs, unsigned size);
void MarkForReadBack(GenTreeLclVarCommon* lcl, unsigned size DEBUGARG(const char* reason));

void HandleStructStore(GenTree** use, GenTree* user);
bool OverlappingReplacements(GenTreeLclVarCommon* lcl,
Expand Down
Loading