@@ -16,6 +16,8 @@ SPDX-License-Identifier: MIT
1616#include " llvm/IR/IRBuilder.h"
1717#include " common/LLVMWarningsPop.hpp"
1818
19+ #include " GenISAIntrinsics/GenIntrinsicInst.h"
20+
1921// Simple pass which sinks constant add operations in pointer calculations
2022// It changes following pattern:
2123// %addr_part1 = <base1> + const
@@ -41,7 +43,9 @@ class SinkPointerConstAddPass : public llvm::FunctionPass {
4143 static char ID;
4244
4345private:
44- bool getConstantOffset (llvm::Value *value, int &offset);
46+ bool getConstantOffset (llvm::Value *value, std::vector<llvm::Instruction *> &zexts, int &offset);
47+ void zextToSext (std::vector<llvm::Instruction *> &zexts);
48+ bool skipZextToSext (llvm::Instruction *op, llvm::BasicBlock *parentBB);
4549};
4650
4751#define PASS_FLAG " igc-sink-ptr-const-add"
@@ -50,7 +54,8 @@ class SinkPointerConstAddPass : public llvm::FunctionPass {
5054#define PASS_ANALYSIS false
5155IGC_INITIALIZE_PASS (SinkPointerConstAddPass, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
5256
53- bool SinkPointerConstAddPass::getConstantOffset(llvm::Value *value, int &offset) {
57+ bool SinkPointerConstAddPass::getConstantOffset(llvm::Value *value, std::vector<llvm::Instruction *> &zexts,
58+ int &offset) {
5459 // Recursively search for constant add operations - this will stop after the first const add found,
5560 // and should be called repeatedly until no more const adds can be sunk.
5661
@@ -65,7 +70,11 @@ bool SinkPointerConstAddPass::getConstantOffset(llvm::Value *value, int &offset)
6570 llvm::Instruction *op = llvm::dyn_cast<llvm::Instruction>(instr->getOperand (0 ));
6671 // This is a simple pass, only sink within the same basic block
6772 if (op && instr->getParent () == op->getParent ()) {
68- return getConstantOffset (instr->getOperand (0 ), offset);
73+ // Collect zext instructions for later processing
74+ if (instr->getOpcode () == llvm::Instruction::ZExt) {
75+ zexts.push_back (instr);
76+ }
77+ return getConstantOffset (instr->getOperand (0 ), zexts, offset);
6978 } else {
7079 return false ;
7180 }
@@ -89,8 +98,8 @@ bool SinkPointerConstAddPass::getConstantOffset(llvm::Value *value, int &offset)
8998 llvm::Instruction *op0 = llvm::dyn_cast<llvm::Instruction>(instr->getOperand (0 ));
9099 llvm::Instruction *op1 = llvm::dyn_cast<llvm::Instruction>(instr->getOperand (1 ));
91100 // This is a simple pass, only sink within the same basic block
92- return (op0 && instr->getParent () == op0->getParent () && getConstantOffset (op0, offset)) ? true
93- : (op1 && instr->getParent () == op1->getParent ()) ? getConstantOffset (op1, offset)
101+ return (op0 && instr->getParent () == op0->getParent () && getConstantOffset (op0, zexts, offset)) ? true
102+ : (op1 && instr->getParent () == op1->getParent ()) ? getConstantOffset (op1, zexts, offset)
94103 : false ;
95104 }
96105 }
@@ -99,6 +108,51 @@ bool SinkPointerConstAddPass::getConstantOffset(llvm::Value *value, int &offset)
99108 return false ;
100109}
101110
111+ bool SinkPointerConstAddPass::skipZextToSext (llvm::Instruction *op, llvm::BasicBlock *parentBB) {
112+ // This is a simple pass, only sink within the same basic block
113+ if (op && parentBB == op->getParent ()) {
114+ // Do not change zext of pushed constants or loaded values - UMD provides unsigned offsets
115+ if (llvm::GenIntrinsicInst *instr = llvm::dyn_cast<llvm::GenIntrinsicInst>(op)) {
116+ if (instr->getIntrinsicID () == llvm::GenISAIntrinsic::GenISA_RuntimeValue) {
117+ return true ;
118+ }
119+ } else if (llvm::dyn_cast<llvm::Argument>(op) || llvm::dyn_cast<llvm::LoadInst>(op)) {
120+ return true ;
121+ } else if (llvm::BinaryOperator *bo = llvm::dyn_cast<BinaryOperator>(op)) {
122+ llvm::ConstantInt *cOp0 = llvm::dyn_cast<llvm::ConstantInt>(bo->getOperand (0 ));
123+ llvm::ConstantInt *cOp1 = llvm::dyn_cast<llvm::ConstantInt>(bo->getOperand (1 ));
124+ if ((cOp0 && cOp0->isNegative ()) || (cOp1 && cOp1->isNegative ())) {
125+ return false ;
126+ } else if (bo->getOpcode () == llvm::Instruction::Sub) {
127+ return false ;
128+ } else {
129+ return (skipZextToSext (llvm::dyn_cast<llvm::Instruction>(bo->getOperand (0 )), parentBB) &&
130+ skipZextToSext (llvm::dyn_cast<llvm::Instruction>(bo->getOperand (1 )), parentBB));
131+ }
132+ } else {
133+ return false ;
134+ }
135+ }
136+ return true ;
137+ }
138+
139+ void SinkPointerConstAddPass::zextToSext (std::vector<llvm::Instruction *> &zexts) {
140+ // Remove duplicates
141+ std::sort (zexts.begin (), zexts.end ());
142+ zexts.erase (std::unique (zexts.begin (), zexts.end ()), zexts.end ());
143+ // Convert zext instructions to sext instructions
144+ for (auto &zext : zexts) {
145+ llvm::Instruction *op = llvm::dyn_cast<llvm::Instruction>(zext->getOperand (0 ));
146+ if (skipZextToSext (op, zext->getParent ())) {
147+ continue ;
148+ }
149+ llvm::IRBuilder<> builder (zext);
150+ llvm::Value *sext = builder.CreateSExt (zext->getOperand (0 ), zext->getType ());
151+ zext->replaceAllUsesWith (sext);
152+ zext->eraseFromParent ();
153+ }
154+ }
155+
102156bool SinkPointerConstAddPass::runOnFunction (llvm::Function &F) {
103157 bool changed = false ;
104158 std::vector<llvm::IntToPtrInst *> intToPtrInsts;
@@ -115,11 +169,20 @@ bool SinkPointerConstAddPass::runOnFunction(llvm::Function &F) {
115169
116170 for (auto &intrinsic : intToPtrInsts) {
117171 int offset = 0 ;
172+ bool localChanged = false ;
173+ std::vector<llvm::Instruction *> zexts;
118174 // Keep sinking constant adds until no more can be sunk
119- while (getConstantOffset (intrinsic->getOperand (0 ), offset)) {
175+ while (getConstantOffset (intrinsic->getOperand (0 ), zexts, offset)) {
176+ localChanged = true ;
120177 changed = true ;
121178 }
122179
180+ if (localChanged) {
181+ // In some cases, sinking constant add may introduce negative values in pointer calculations.
182+ // Convert affected zext instructions to sext instructions to avoid potential issues.
183+ zextToSext (zexts);
184+ }
185+
123186 // If we found any constant offset, create new pointer calculation
124187 if (offset != 0 ) {
125188 llvm::IRBuilder<> builder (intrinsic);
0 commit comments