@@ -21,7 +21,9 @@ SPDX-License-Identifier: MIT
2121#include  < llvm/Analysis/ScalarEvolution.h> 
2222#include  < llvm/Analysis/ScalarEvolutionExpressions.h> 
2323#include  < llvm/Analysis/TargetFolder.h> 
24+ #include  < llvm/Analysis/ValueTracking.h> 
2425#include  < llvm/IR/GetElementPtrTypeIterator.h> 
26+ #include  < llvm/Support/KnownBits.h> 
2527#include  < llvm/Transforms/Utils/ScalarEvolutionExpander.h> 
2628#include  < llvm/Transforms/Utils/Local.h> 
2729#include  " llvmWrapper/IR/Intrinsics.h" 
@@ -63,6 +65,7 @@ class GenIRLowering : public FunctionPass {
6365
6466  bool  combineFMaxFMin (CallInst *GII, BasicBlock::iterator &BBI) const ;
6567  bool  combineSelectInst (SelectInst *Sel, BasicBlock::iterator &BBI) const ;
68+   bool  combinePack4i8Or2i16 (Instruction *inst, uint64_t  numBits) const ;
6669
6770  bool  constantFoldFMaxFMin (CallInst *GII, BasicBlock::iterator &BBI) const ;
6871};
@@ -362,6 +365,15 @@ bool GenIRLowering::runOnFunction(Function &F) {
362365          Changed |= combineSelectInst (cast<SelectInst>(Inst), BI);
363366        }
364367        break ;
368+       case  Instruction::Or:
369+         if  (Inst->getType ()->isIntegerTy (32 )) {
370+           //  Detect packing of 4 i8 values and convert to a pattern that is
371+           //  matched CodeGenPatternMatch::MatchPack4i8
372+           Changed |= combinePack4i8Or2i16 (Inst, 8  /* numBits*/ 
373+           //  TODO: also detect <2 x i16> packing once PatternMatch is updated
374+           //  to packing of 16-bit values.
375+         }
376+         break ;
365377      }
366378    }
367379  }
@@ -1000,6 +1012,162 @@ bool GenIRLowering::combineSelectInst(SelectInst *Sel, BasicBlock::iterator &BBI
10001012  return  false ;
10011013}
10021014
1015+ // //////////////////////////////////////////////////////////////////////////////
1016+ //  Detect complex patterns that pack 2 16-bit or 4 8-bit integers into a 32-bit
1017+ //  value. Generate equivalent sequence of instructions that is later matched in
1018+ //  the CodeGenPatternMatch::MatchPack4i8().
1019+ //  Pattern example for <4 x i8> packing:
1020+ //    %x1 = and i32 %x, 127
1021+ //    %x2 = lshr i32 %x, 24
1022+ //    %x3 = and i32 %x2, 128
1023+ //    %x4 = or i32 %x3, %x1
1024+ //    %y1 = and i32 %y, 127
1025+ //    %y2 = lshr i32 %y, 24
1026+ //    %y3 = and i32 %y2, 128
1027+ //    %y4 = or i32 %y3, %y1
1028+ //    %y5 = shl nuw nsw i32 %y4, 8
1029+ //    %xy = or i32 %x4, %y5
1030+ //    %z1 = and i32 %z, 127
1031+ //    %z2 = lshr i32 %z, 24
1032+ //    %z3 = and i32 %z2, 128
1033+ //    %z4 = or i32 %z3, %z1
1034+ //    %z5 = shl nuw nsw i32 %z4, 16
1035+ //    %xyz = or i32 %xy, %z5
1036+ //    %w1 = shl nsw i32 %w, 24
1037+ //    %w2 = and i32 %w1, 2130706432
1038+ //    %w3 = and i32 %w, -2147483648
1039+ //    %w4 = or i32 %w2, %w3
1040+ //    %xyzw = or i32 %xyz, %w4
1041+ //  and generate:
1042+ //    %0 = trunc i32 %x to i8
1043+ //    %1 = insertelement <4 x i8> poison, i8 %0, i32 0
1044+ //    %2 = trunc i32 %y to i8
1045+ //    %3 = insertelement <4 x i8> %1, i8 %2, i32 1
1046+ //    %4 = trunc i32 %z to i8
1047+ //    %5 = insertelement <4 x i8> %3, i8 %4, i32 2
1048+ //    %6 = trunc i32 %w to i8
1049+ //    %7 = insertelement <4 x i8> %5, i8 %6, i32 3
1050+ //    %8 = bitcast <4 x i8> %7 to i32
1051+ bool  GenIRLowering::combinePack4i8Or2i16 (Instruction *inst, uint64_t  numBits) const  {
1052+   using  namespace  llvm ::PatternMatch; 
1053+ 
1054+   const  DataLayout &DL = inst->getModule ()->getDataLayout ();
1055+   //  Vector of 4 or 2 values that will be packed into a single 32-bit value.
1056+   SmallVector<Value *, 4 > toPack;
1057+   IGC_ASSERT (numBits == 8  || numBits == 16 );
1058+   uint64_t  packedVecSize = 32  / numBits;
1059+   toPack.resize (packedVecSize);
1060+   uint64_t  cSignMask = QWBIT (numBits - 1 );
1061+   uint64_t  cMagnMask = BITMASK (numBits - 1 );
1062+   SmallVector<std::pair<Value *, uint64_t >, 4 > args;
1063+   args.push_back ({isa<BitCastInst>(inst) ? inst->getOperand (0 ) : inst, 0 });
1064+   //  In the first step traverse the chain of `or` and `shl` instructions
1065+   //  and find all elements of the packed vector.
1066+   while  (!args.empty ()) {
1067+     auto  [v, prevShlBits] = args.pop_back_val ();
1068+     Value *lOp = nullptr ;
1069+     Value *rOp = nullptr ;
1070+ 
1071+     //  Detect left shift by multiple of `numBits`. The `shl` operation sets the
1072+     //  `index` argument in the corresponding InsertElement instruction in the
1073+     //  final packing sequence. This operation can also be viewed as repacking
1074+     //  of already packed vector into another packed vector.
1075+     uint64_t  shlBits = 0 ;
1076+     if  (match (v, m_Shl (m_Value (lOp), m_ConstantInt (shlBits))) && (shlBits % numBits) == 0 ) {
1077+       args.push_back ({lOp, shlBits + prevShlBits});
1078+       continue ;
1079+     }
1080+     //  Detect values that fit into `numBits` bits - a single element of
1081+     //  the packed vector.
1082+     KnownBits kb = computeKnownBits (v, DL);
1083+     uint32_t  nonZeroBits = ~(static_cast <uint32_t >(kb.Zero .getZExtValue ()));
1084+     uint32_t  lsb = findFirstSet (nonZeroBits);
1085+     uint32_t  msb = findLastSet (nonZeroBits);
1086+     if  (msb != lsb && (msb / numBits) == (lsb / numBits)) {
1087+       uint32_t  idx = (prevShlBits / numBits) + (lsb / numBits);
1088+       if  (idx < packedVecSize && toPack[idx] == nullptr ) {
1089+         toPack[idx] = v;
1090+         continue ;
1091+       }
1092+     }
1093+     //  Detect packing of two disjoint values. This `or` operation corresponds
1094+     //  to an InsertElement instruction in the final packing sequence.
1095+     if  (match (v, m_Or (m_Value (lOp), m_Value (rOp)))) {
1096+       KnownBits kbL = computeKnownBits (lOp, DL);
1097+       KnownBits kbR = computeKnownBits (rOp, DL);
1098+       uint32_t  nonZeroBitsL = ~(static_cast <uint32_t >(kbL.Zero .getZExtValue ()));
1099+       uint32_t  nonZeroBitsR = ~(static_cast <uint32_t >(kbR.Zero .getZExtValue ()));
1100+       if  ((nonZeroBitsL & nonZeroBitsR) == 0 ) {
1101+         args.push_back ({lOp, prevShlBits});
1102+         args.push_back ({rOp, prevShlBits});
1103+       }
1104+       continue ;
1105+     }
1106+     if  (std::all_of (toPack.begin (), toPack.end (), [](const  Value *c) { return  c != nullptr ; })) {
1107+       break ;
1108+     }
1109+     //  Unsupported pattern.
1110+     return  false ;
1111+   }
1112+   if  (std::any_of (toPack.begin (), toPack.end (), [](const  Value *c) { return  c == nullptr ; })) {
1113+     return  false ;
1114+   }
1115+   //  In the second step match the pattern that packs sign and magnitude parts
1116+   //  and simple masking with `and` instruction.
1117+   for  (uint32_t  i = 0 ; i < packedVecSize; ++i) {
1118+     Value *v = toPack[i];
1119+     Value *lOp = nullptr ;
1120+     Value *rOp = nullptr ;
1121+     uint64_t  lMask = 0 ;
1122+     uint64_t  rMask = 0 ;
1123+     //  Match patterns that pack the sign and magnitude parts.
1124+     if  (match (v, m_Or (m_And (m_Value (lOp), m_ConstantInt (lMask)), m_And (m_Value (rOp), m_ConstantInt (rMask)))) &&
1125+       (countPopulation (rMask) == 1  || countPopulation (lMask) == 1 )) {
1126+       Value *signOp = countPopulation (rMask) == 1  ? rOp : lOp;
1127+       Value *magnOp = countPopulation (rMask) == 1  ? lOp : rOp;
1128+       uint64_t  signMask = countPopulation (rMask) == 1  ? rMask : lMask;
1129+       uint64_t  magnMask = countPopulation (rMask) == 1  ? lMask : rMask;
1130+       uint64_t  shlBits = 0 ;
1131+       uint64_t  shrBits = 0 ;
1132+       //  %b = shl nsw i32 %a, 24
1133+       //  %c = and i32 %b, 2130706432
1134+       //  %sign = and i32 %a, -2147483648
1135+       //  %e = or i32 %sign, %c
1136+       if  (match (magnOp, m_Shl (m_Value (v), m_ConstantInt (shlBits))) && v == signOp && (shlBits % numBits) == 0  &&
1137+           shlBits == (i * numBits) && (cSignMask << shlBits) == signMask && (cMagnMask << shlBits) == magnMask) {
1138+         toPack[i] = v;
1139+         continue ;
1140+       }
1141+       //  %b = and i32 %a, 127
1142+       //  %c = lshr i32 %a, 24
1143+       //  %sign = and i32 %c, 128
1144+       //  %e = or i32 %sign, %b
1145+       if  (match (signOp, m_LShr (m_Value (v), m_ConstantInt (shrBits))) && v == magnOp && shrBits == (32  - numBits) &&
1146+           cSignMask == signMask && cMagnMask == magnMask) {
1147+         toPack[i] = v;
1148+         continue ;
1149+       }
1150+     }
1151+     uint64_t  andMask = 0 ;
1152+     if  (match (v, m_And (m_Value (lOp), m_ConstantInt (andMask))) && (andMask & BITMASK (numBits)) == andMask) {
1153+       toPack[i] = lOp;
1154+       continue ;
1155+     }
1156+   }
1157+ 
1158+   //  Create the packing sequence that is matched in the PatternMatch later.
1159+   Type *elemTy = Builder->getIntNTy (numBits);
1160+   Value *packed = PoisonValue::get (IGCLLVM::FixedVectorType::get (elemTy, packedVecSize));
1161+   for  (uint32_t  i = 0 ; i < packedVecSize; ++i) {
1162+     Value *elem = Builder->CreateTrunc (toPack[i], elemTy);
1163+     packed = Builder->CreateInsertElement (packed, elem, Builder->getInt32 (i));
1164+   }
1165+   packed = Builder->CreateBitCast (packed, inst->getType ());
1166+   inst->replaceAllUsesWith (packed);
1167+   inst->eraseFromParent ();
1168+   return  true ;
1169+ }
1170+ 
10031171FunctionPass *IGC::createGenIRLowerPass () { return  new  GenIRLowering (); }
10041172
10051173//  Register pass to igc-opt
0 commit comments