diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index aa946f2c967d..7fe9f4d2725c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1048,6 +1048,46 @@ StringRef getAMDArch(Operation *module) { return ref.drop_front(4); // drop the "hip:" } +// Rough utility for obtaining a SharedEnc for a LinearEncoding, +// as we've replaced DotOpEnc with Linear in some cases +// (specifically, fp4ToFp and similar unpack-upcast thru join) +std::optional +getSharedForLinear(ttg::LinearEncodingAttr enc, + ArrayRef globalOrder, ArrayRef shape, + unsigned elemBitWidth, ttg::CTALayoutAttr ctaLayout) { + auto ctx = enc.getContext(); + auto ll = enc.getLinearLayout(); + auto rank = shape.size(); + + if (rank != 2) + return std::nullopt; + + auto order = enc.getOrder(); + assert(globalOrder.size() == rank); + // TODO add memdesc_trans support for dot(trans(cvt(src) #linear) #dot_op) + if (order != globalOrder) + return std::nullopt; + + auto innerDim = order[0]; + auto outerDim = order[1]; + auto contigPerWarp = enc.getContigPerWarp(); + + constexpr unsigned BANK_SIZE{128}; + auto elemBytes = elemBitWidth / 8; + + auto vec = contigPerWarp[innerDim]; + auto rowSize = elemBytes * (unsigned)shape[innerDim]; + auto perPhase = std::max(BANK_SIZE / rowSize, 1u); + auto maxPhase = std::max(contigPerWarp[outerDim] / perPhase, 1u); + + // cp.async does not support transfer size < 4B + if (vec * elemBytes < 4 && perPhase < maxPhase) + return std::nullopt; + + return ttg::SwizzledSharedEncodingAttr::get(ctx, vec, perPhase, maxPhase, + order, ctaLayout); +} + // If all the transitive uses of the given value have are used by a convert to // the same dot operand encoding, return the shared encoding that needs to be // used to be compatible with users' layouts. If there are incompatible shared @@ -1074,18 +1114,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { } else { if (!isa(user)) return std::nullopt; - auto dotOpEnc = dyn_cast( + auto enc = cast(user->getResult(0).getType()) - .getEncoding()); - if (!dotOpEnc) - return std::nullopt; + .getEncoding(); auto srcTy = cast(val.getType()); auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); auto order = getOrderForMemory(srcTy); unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); - tempAttr = ttg::SwizzledSharedEncodingAttr::get( - val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, - bitWidth, /*needTrans=*/false); + + if (auto dotOpEnc = dyn_cast(enc)) { + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, + bitWidth, /*needTrans=*/false); + } else if (auto linearEnc = dyn_cast(enc)) { + + auto attrOpt = getSharedForLinear(linearEnc, order, srcTy.getShape(), + bitWidth, CTALayout); + if (!attrOpt) + return std::nullopt; + tempAttr = *attrOpt; + } else { + return std::nullopt; + } } // Check that the shared encodings needed by the users are compatible. if (attr != nullptr && attr != tempAttr) { diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 222ff3f8f9fc..f1d976ed5425 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -59,7 +59,7 @@ createTargetMachine(llvm::Module *module, std::string proc, opt.MCOptions.AsmVerbose = true; opt.MCOptions.PreserveAsmComments = true; std::unique_ptr machine{target->createTargetMachine( - module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + module->getTargetTriple().str(), proc, features, opt, llvm::Reloc::PIC_, std::nullopt, disableLLVMOpt ? llvm::CodeGenOptLevel::None : llvm::CodeGenOptLevel::Aggressive)};