Skip to content

Commit 821b6e4

Browse files
committed
require copy src/dst to be tensor slice not tensor itself; delete a bunch of dead code
1 parent 2ccd218 commit 821b6e4

14 files changed

+111
-451
lines changed

lib/Dialect/TTL/Transforms/ConvertTTLToTTKernel.cpp

Lines changed: 29 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -431,50 +431,16 @@ struct StoreLowering : OpConversionPattern<StoreOp> {
431431
}
432432
};
433433

434-
enum class CopySourceKind {
435-
TensorAccessor,
436-
TensorSlice,
437-
CircularBuffer,
438-
Pipe,
439-
Unknown
440-
};
441-
enum class CopyDestKind {
442-
TensorAccessor,
443-
TensorSlice,
444-
CircularBuffer,
445-
Pipe,
446-
Unknown
447-
};
448-
449-
static bool isTensorAccessorLike(Type t) {
450-
return llvm::isa<ttk::TensorAccessorType>(t) ||
451-
llvm::isa<RankedTensorType>(t);
452-
}
453-
454-
static CopySourceKind classifySrc(Value v) {
455-
if (llvm::isa<CircularBufferType>(v.getType())) {
456-
return CopySourceKind::CircularBuffer;
457-
}
458-
if (llvm::isa<TensorSliceType>(v.getType())) {
459-
return CopySourceKind::TensorSlice;
460-
}
461-
if (isTensorAccessorLike(v.getType())) {
462-
return CopySourceKind::TensorAccessor;
463-
}
464-
return CopySourceKind::Unknown;
465-
}
434+
enum class CopyOperandKind { TensorSlice, CircularBuffer, Unknown };
466435

467-
static CopyDestKind classifyDst(Value v) {
436+
static CopyOperandKind classifyOperand(Value v) {
468437
if (llvm::isa<CircularBufferType>(v.getType())) {
469-
return CopyDestKind::CircularBuffer;
438+
return CopyOperandKind::CircularBuffer;
470439
}
471440
if (llvm::isa<TensorSliceType>(v.getType())) {
472-
return CopyDestKind::TensorSlice;
441+
return CopyOperandKind::TensorSlice;
473442
}
474-
if (isTensorAccessorLike(v.getType())) {
475-
return CopyDestKind::TensorAccessor;
476-
}
477-
return CopyDestKind::Unknown;
443+
return CopyOperandKind::Unknown;
478444
}
479445

480446
static Value makeZeroI32(Location loc, ConversionPatternRewriter &rewriter) {
@@ -612,140 +578,6 @@ static Value linearizeTileIndex(OpBuilder &builder, Location loc, Value row,
612578
return builder.create<arith::AddIOp>(loc, rowOffset, col);
613579
}
614580

615-
/// Lower tensor->CB copy: read from DRAM/L1 tensor into circular buffer.
616-
static LogicalResult lowerTensorToCB(CopyOp op, Value srcTensor, Value dstCB,
617-
ConversionPatternRewriter &rewriter,
618-
const TypeConverter &typeConverter) {
619-
auto loc = op.getLoc();
620-
621-
// Get tensor L1 address from runtime args.
622-
auto bankBase = getBufferAddressFromRuntimeArg(srcTensor, loc, rewriter);
623-
if (failed(bankBase)) {
624-
return rewriter.notifyMatchFailure(
625-
op, "tensor must be a function argument for runtime arg mapping");
626-
}
627-
628-
// Create tensor accessor with actual buffer address.
629-
// This derives page size from TTNNLayoutAttr encoding.
630-
auto srcAccessor =
631-
materializeTensorAccessor(srcTensor, *bankBase, op, rewriter);
632-
if (failed(srcAccessor)) {
633-
return failure(); // Error already emitted by materializeTensorAccessor
634-
}
635-
636-
// Convert CB to TTKernel type and get write pointer.
637-
auto cbConverted = utils::convertTTLCBToTTKernel(dstCB, rewriter, loc);
638-
if (failed(cbConverted)) {
639-
return rewriter.notifyMatchFailure(op, "failed to convert CB operand");
640-
}
641-
auto cbWritePtr = rewriter.create<ttk::GetWritePtrOp>(loc, *cbConverted);
642-
643-
auto tileGridShape = getTileGridShapeFromValue(srcTensor);
644-
int64_t tilesY = tileGridShape.first;
645-
int64_t tilesX = tileGridShape.second;
646-
647-
// Get page size for CB pointer arithmetic from TTNNLayoutAttr.
648-
auto tensorTy = mlir::cast<RankedTensorType>(srcTensor.getType());
649-
auto layoutAttr =
650-
mlir::dyn_cast_or_null<ttnn::TTNNLayoutAttr>(tensorTy.getEncoding());
651-
assert(layoutAttr &&
652-
"lowerTensorToCB: srcTensor must have TTNNLayoutAttr encoding");
653-
int64_t pageSizeBytes = layoutAttr.getElementSizeBytes();
654-
655-
// Cast cbWritePtr to index for address arithmetic.
656-
auto indexTy = rewriter.getIndexType();
657-
auto cbWritePtrIdx =
658-
rewriter.create<arith::IndexCastOp>(loc, indexTy, cbWritePtr);
659-
660-
auto pageSizeIdx = rewriter.create<arith::ConstantIndexOp>(loc, pageSizeBytes);
661-
auto i32Ty = rewriter.getI32Type();
662-
663-
// TODO(#138): Emit single block transfer for contiguous layouts instead of
664-
// tile loop.
665-
emitTileLoop(
666-
rewriter, loc, tilesY, tilesX,
667-
[&, tilesX](OpBuilder &b, Location bodyLoc, Value row, Value col) {
668-
Value tileIdx = linearizeTileIndex(b, bodyLoc, row, col, tilesX);
669-
// Compute CB address: cbWritePtr + tileIdx * pageSize
670-
Value byteOffset = b.create<arith::MulIOp>(bodyLoc, tileIdx, pageSizeIdx);
671-
Value cbAddrIdx = b.create<arith::AddIOp>(bodyLoc, cbWritePtrIdx, byteOffset);
672-
// Cast to i32 for NOC operation.
673-
Value tileIdx32 = b.create<arith::IndexCastOp>(bodyLoc, i32Ty, tileIdx);
674-
Value cbAddr = b.create<arith::IndexCastOp>(bodyLoc, i32Ty, cbAddrIdx);
675-
b.create<ttk::NocAsyncReadTileOp>(bodyLoc, tileIdx32, *srcAccessor, cbAddr);
676-
});
677-
678-
rewriter.replaceOp(op, makeZeroI32(loc, rewriter));
679-
return success();
680-
}
681-
682-
/// Lower CB->tensor copy: write from circular buffer to DRAM/L1 tensor.
683-
static LogicalResult lowerCBToTensor(CopyOp op, Value srcCB, Value dstTensor,
684-
ConversionPatternRewriter &rewriter,
685-
const TypeConverter &typeConverter) {
686-
auto loc = op.getLoc();
687-
688-
// Get tensor L1 address from runtime args.
689-
auto bankBase = getBufferAddressFromRuntimeArg(dstTensor, loc, rewriter);
690-
if (failed(bankBase)) {
691-
return rewriter.notifyMatchFailure(
692-
op, "tensor must be a function argument for runtime arg mapping");
693-
}
694-
695-
// Create tensor accessor with actual buffer address.
696-
// This derives page size from TTNNLayoutAttr encoding.
697-
auto dstAccessor =
698-
materializeTensorAccessor(dstTensor, *bankBase, op, rewriter);
699-
if (failed(dstAccessor)) {
700-
return failure(); // Error already emitted by materializeTensorAccessor
701-
}
702-
703-
// Convert CB to TTKernel type and get read pointer.
704-
auto cbConverted = utils::convertTTLCBToTTKernel(srcCB, rewriter, loc);
705-
if (failed(cbConverted)) {
706-
return rewriter.notifyMatchFailure(op, "failed to convert CB operand");
707-
}
708-
auto cbReadPtr = rewriter.create<ttk::GetReadPtrOp>(loc, *cbConverted);
709-
710-
auto tileGridShape = getTileGridShapeFromValue(dstTensor);
711-
int64_t tilesY = tileGridShape.first;
712-
int64_t tilesX = tileGridShape.second;
713-
714-
// Get page size for CB pointer arithmetic from TTNNLayoutAttr.
715-
auto tensorTy = mlir::cast<RankedTensorType>(dstTensor.getType());
716-
auto layoutAttr =
717-
mlir::dyn_cast_or_null<ttnn::TTNNLayoutAttr>(tensorTy.getEncoding());
718-
assert(layoutAttr &&
719-
"lowerCBToTensor: dstTensor must have TTNNLayoutAttr encoding");
720-
int64_t pageSizeBytes = layoutAttr.getElementSizeBytes();
721-
722-
// Cast cbReadPtr to index for address arithmetic.
723-
auto indexTy = rewriter.getIndexType();
724-
auto cbReadPtrIdx =
725-
rewriter.create<arith::IndexCastOp>(loc, indexTy, cbReadPtr);
726-
727-
auto pageSizeIdx = rewriter.create<arith::ConstantIndexOp>(loc, pageSizeBytes);
728-
auto i32Ty = rewriter.getI32Type();
729-
730-
// TODO(#138): Emit single block transfer for contiguous layouts instead of
731-
// tile loop.
732-
emitTileLoop(
733-
rewriter, loc, tilesY, tilesX,
734-
[&, tilesX](OpBuilder &b, Location bodyLoc, Value row, Value col) {
735-
Value tileIdx = linearizeTileIndex(b, bodyLoc, row, col, tilesX);
736-
// Compute CB address: cbReadPtr + tileIdx * pageSize
737-
Value byteOffset = b.create<arith::MulIOp>(bodyLoc, tileIdx, pageSizeIdx);
738-
Value cbAddrIdx = b.create<arith::AddIOp>(bodyLoc, cbReadPtrIdx, byteOffset);
739-
// Cast to i32 for NOC operation.
740-
Value tileIdx32 = b.create<arith::IndexCastOp>(bodyLoc, i32Ty, tileIdx);
741-
Value cbAddr = b.create<arith::IndexCastOp>(bodyLoc, i32Ty, cbAddrIdx);
742-
b.create<ttk::NocAsyncWriteTileOp>(bodyLoc, tileIdx32, *dstAccessor, cbAddr);
743-
});
744-
745-
rewriter.replaceOp(op, makeZeroI32(loc, rewriter));
746-
return success();
747-
}
748-
749581
/// Lower tensor_slice->CB copy: read tiles from tensor into CB.
750582
/// Loops over CB shape, reading tiles starting at slice offset.
751583
static LogicalResult lowerSliceToCB(CopyOp op, TensorSliceOp sliceOp,
@@ -954,16 +786,26 @@ struct CopyLowering : OpConversionPattern<CopyOp> {
954786
return rewriter.notifyMatchFailure(op, "no type converter");
955787
}
956788

957-
// Use original operands for classification since lowering functions
958-
// handle type conversion internally.
959789
Value src = op.getSrc();
960790
Value dst = op.getDst();
961-
auto srcKind = classifySrc(src);
962-
auto dstKind = classifyDst(dst);
791+
auto srcKind = classifyOperand(src);
792+
auto dstKind = classifyOperand(dst);
963793

964-
// TensorSlice -> CB: read a single tile from tensor into circular buffer.
965-
if (srcKind == CopySourceKind::TensorSlice &&
966-
dstKind == CopyDestKind::CircularBuffer) {
794+
// Validate: copy requires exactly one TensorSlice and one CircularBuffer.
795+
bool srcIsSlice = srcKind == CopyOperandKind::TensorSlice;
796+
bool srcIsCB = srcKind == CopyOperandKind::CircularBuffer;
797+
bool dstIsSlice = dstKind == CopyOperandKind::TensorSlice;
798+
bool dstIsCB = dstKind == CopyOperandKind::CircularBuffer;
799+
800+
if (!((srcIsSlice && dstIsCB) || (srcIsCB && dstIsSlice))) {
801+
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
802+
diag << "ttl.copy requires one tensor_slice and one circular_buffer, "
803+
<< "got src=" << src.getType() << " dst=" << dst.getType();
804+
});
805+
}
806+
807+
// TensorSlice -> CB: read tiles from tensor into circular buffer.
808+
if (srcIsSlice && dstIsCB) {
967809
auto sliceOp = src.getDefiningOp<TensorSliceOp>();
968810
if (!sliceOp) {
969811
return rewriter.notifyMatchFailure(
@@ -973,36 +815,14 @@ struct CopyLowering : OpConversionPattern<CopyOp> {
973815
*typeConverter);
974816
}
975817

976-
// CB -> TensorSlice: write a single tile from circular buffer to tensor.
977-
if (srcKind == CopySourceKind::CircularBuffer &&
978-
dstKind == CopyDestKind::TensorSlice) {
979-
auto sliceOp = dst.getDefiningOp<TensorSliceOp>();
980-
if (!sliceOp) {
981-
return rewriter.notifyMatchFailure(
982-
op, "tensor_slice destination must come from ttl.tensor_slice op");
983-
}
984-
return lowerCBToSlice(op, adaptor.getSrc(), sliceOp, rewriter,
985-
*typeConverter);
986-
}
987-
988-
// Tensor -> CB: read all tiles from tensor into circular buffer (loop).
989-
if (srcKind == CopySourceKind::TensorAccessor &&
990-
dstKind == CopyDestKind::CircularBuffer) {
991-
return lowerTensorToCB(op, src, adaptor.getDst(), rewriter,
992-
*typeConverter);
993-
}
994-
995-
// CB -> Tensor: write all tiles from circular buffer to tensor (loop).
996-
if (srcKind == CopySourceKind::CircularBuffer &&
997-
dstKind == CopyDestKind::TensorAccessor) {
998-
return lowerCBToTensor(op, adaptor.getSrc(), dst, rewriter,
999-
*typeConverter);
818+
// CB -> TensorSlice: write tiles from circular buffer to tensor.
819+
auto sliceOp = dst.getDefiningOp<TensorSliceOp>();
820+
if (!sliceOp) {
821+
return rewriter.notifyMatchFailure(
822+
op, "tensor_slice destination must come from ttl.tensor_slice op");
1000823
}
1001-
1002-
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1003-
diag << "unsupported ttl.copy src/dst combination: src=" << src.getType()
1004-
<< " dst=" << dst.getType();
1005-
});
824+
return lowerCBToSlice(op, adaptor.getSrc(), sliceOp, rewriter,
825+
*typeConverter);
1006826
}
1007827
};
1008828

0 commit comments

Comments
 (0)