Skip to content

Commit ba412e1

Browse files
authored
[nfc] Automate OpTileLowering and pattern registration using .def file (part of #132) (#180)
* generate patterns from .def * add docs * not yet: automate defs in TTLops.td Does not use yaml file as suggested in #132 as that seemed overkill for this change, may still be good to do in the future.
1 parent 7cb4fe5 commit ba412e1

File tree

5 files changed

+161
-77
lines changed

5 files changed

+161
-77
lines changed

docs/sphinx/contributor-guide.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,47 @@
1414
- Add new user-facing pages under `docs/sphinx` and link them in `index.rst`.
1515
- Keep contributor-only instructions in this guide or `guidelines.md`.
1616
- Build docs with `cmake --build build --target ttlang-docs`.
17+
18+
## Adding Elementwise Operations
19+
20+
To add a new elementwise operation (unary or binary), update these files:
21+
22+
### 1. `include/ttlang/Dialect/TTL/TTLElementwiseOps.def`
23+
24+
Add an entry with the TTL op name, tile op name, and TTKernel init/compute op names:
25+
26+
```cpp
27+
// Binary op (3-arg form: DST[odst] = op(DST[src0], DST[src1]))
28+
TTL_BINARY_TILE_OP(NewOp, NewOpTileOp, NewOpBinaryTilesInitOp, NewOpBinaryTilesOp)
29+
30+
// Unary op (in-place form: DST[dst_idx] = op(DST[dst_idx]))
31+
TTL_UNARY_TILE_OP(NewOp, NewOpTileOp, NewOpTileInitOp, NewOpTileOp)
32+
33+
// Special binary op (2-arg in-place form, like Max)
34+
TTL_BINARY_TILE_OP_SPECIAL(NewOp, NewOpTileOp, NewOpTilesInitOp, NewOpTilesOp)
35+
```
36+
37+
This automatically generates:
38+
- C++ lowering patterns (`ConvertTTLToCompute.cpp`, `ConvertTTLTileOpsToTTKernel.cpp`)
39+
- Python bindings (`_generated_elementwise.py`)
40+
41+
### 2. `include/ttlang/Dialect/TTL/IR/TTLOps.td`
42+
43+
Add the TableGen op definitions using the multiclass:
44+
45+
```tablegen
46+
// Binary op
47+
defm TTL_NewOp : TTL_BinaryElementwisePair<"newop", "newop_tiles">;
48+
49+
// Unary op
50+
defm TTL_NewOp : TTL_UnaryElementwisePair<"newop", "newop_tile">;
51+
```
52+
53+
### 3. Verify the TTKernel ops exist in tt-mlir
54+
55+
The TTKernel init and compute ops must exist in `tt-mlir/include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td`. If they don't, they need to be added to tt-mlir first.
56+
57+
### 4. Add tests
58+
59+
- Add a lit test in `test/ttlang/Dialect/TTL/Transforms/` for the lowering
60+
- Add to `test/python/test_elementwise_ops.py` for end-to-end verification

include/ttlang/Dialect/TTL/TTLElementwiseOps.def

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,72 @@
66
// TTL Elementwise Operation Definitions
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file defines mappings for TTL elementwise operations lowering to
10-
// ttl.compute with tile ops.
9+
// This file defines mappings for TTL elementwise operations, supporting:
10+
// 1. TTL tensor ops -> ttl.compute with tile ops (ConvertTTLToCompute)
11+
// 2. TTL tile ops -> TTKernel ops (ConvertTTLTileOpsToTTKernel)
12+
// 3. Python binding generation (gen_elementwise.py)
1113
//
1214
// It uses the X-macro pattern to generate boilerplate code for pattern
1315
// definitions and registration.
1416
//
17+
// Macro signatures:
18+
// TTL_BINARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE)
19+
// TTL_UNARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE)
20+
// TTL_BINARY_TILE_OP_SPECIAL(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE)
21+
//
22+
// Parameters:
23+
// TTL_OP - TTL tensor op name (e.g., Add -> ttl::AddOp)
24+
// TILE_OP - TTL tile op name (e.g., AddTileOp -> ttl::AddTileOp)
25+
// TTK_INIT - TTKernel init op (e.g., AddBinaryTilesInitOp)
26+
// TTK_COMPUTE - TTKernel compute op (e.g., AddBinaryTilesOp)
27+
//
1528
// Usage:
16-
// #define TTL_BINARY_TILE_OP(TTL_OP, TILE_OP) // your code here
17-
// #define TTL_UNARY_TILE_OP(TTL_OP, TILE_OP) // your code here
29+
// #define TTL_BINARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) ...
30+
// #define TTL_UNARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) ...
1831
// #include "ttlang/Dialect/TTL/TTLElementwiseOps.def"
1932
//
2033
//===----------------------------------------------------------------------===//
2134

2235
//===----------------------------------------------------------------------===//
23-
// Elementwise operation mappings (TTL tensor ops -> ttl.compute with tile ops)
36+
// Elementwise operation mappings
2437
//===----------------------------------------------------------------------===//
2538

2639
#ifndef TTL_BINARY_TILE_OP
27-
#define TTL_BINARY_TILE_OP(TTL_OP, TILE_OP)
40+
#define TTL_BINARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE)
2841
#endif
2942

3043
#ifndef TTL_UNARY_TILE_OP
31-
#define TTL_UNARY_TILE_OP(TTL_OP, TILE_OP)
44+
#define TTL_UNARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE)
45+
#endif
46+
47+
// Special macro for binary ops that use a different lowering template.
48+
// Max uses 2-arg in-place form: DST[dst0] = max(DST[dst0], DST[dst1])
49+
#ifndef TTL_BINARY_TILE_OP_SPECIAL
50+
#define TTL_BINARY_TILE_OP_SPECIAL(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE)
3251
#endif
3352

34-
// Binary operations: TTL tensor op -> TTL tile op
35-
TTL_BINARY_TILE_OP(Add, AddTileOp)
36-
TTL_BINARY_TILE_OP(Sub, SubTileOp)
37-
TTL_BINARY_TILE_OP(Mul, MulTileOp)
38-
TTL_BINARY_TILE_OP(Max, MaxTileOp)
39-
40-
// Unary operations: TTL tensor op -> TTL tile op
41-
TTL_UNARY_TILE_OP(Exp, ExpTileOp)
42-
TTL_UNARY_TILE_OP(Log, LogTileOp)
43-
TTL_UNARY_TILE_OP(Sqrt, SqrtTileOp)
44-
TTL_UNARY_TILE_OP(Rsqrt, RsqrtTileOp)
45-
TTL_UNARY_TILE_OP(Tanh, TanhTileOp)
46-
TTL_UNARY_TILE_OP(Abs, AbsTileOp)
47-
TTL_UNARY_TILE_OP(Neg, NegTileOp)
48-
TTL_UNARY_TILE_OP(Relu, ReluTileOp)
49-
TTL_UNARY_TILE_OP(Sigmoid, SigmoidTileOp)
53+
// Binary operations: TTL tensor op -> TTL tile op -> TTKernel init/compute ops
54+
// These use the standard 3-arg binary template: DST[odst] = op(DST[src0], DST[src1])
55+
TTL_BINARY_TILE_OP(Add, AddTileOp, AddBinaryTilesInitOp, AddBinaryTilesOp)
56+
TTL_BINARY_TILE_OP(Sub, SubTileOp, SubBinaryTilesInitOp, SubBinaryTilesOp)
57+
TTL_BINARY_TILE_OP(Mul, MulTileOp, MulBinaryTilesInitOp, MulBinaryTilesOp)
58+
59+
// Special binary ops with non-standard lowering
60+
// Max uses 2-arg in-place form (TTLTileMaxToTTKernel template)
61+
TTL_BINARY_TILE_OP_SPECIAL(Max, MaxTileOp, BinaryMaxTileInitOp, BinaryMaxTileOp)
62+
63+
// Unary operations: TTL tensor op -> TTL tile op -> TTKernel init/compute ops
64+
// These use the standard unary template: DST[dst_idx] = op(DST[dst_idx])
65+
TTL_UNARY_TILE_OP(Exp, ExpTileOp, ExpTileInitOp, ExpTileOp)
66+
TTL_UNARY_TILE_OP(Log, LogTileOp, LogTileInitOp, LogTileOp)
67+
TTL_UNARY_TILE_OP(Sqrt, SqrtTileOp, SqrtTileInitOp, SqrtTileOp)
68+
TTL_UNARY_TILE_OP(Rsqrt, RsqrtTileOp, RsqrtTileInitOp, RsqrtTileOp)
69+
TTL_UNARY_TILE_OP(Tanh, TanhTileOp, TanhTileInitOp, TanhTileOp)
70+
TTL_UNARY_TILE_OP(Abs, AbsTileOp, AbsTileInitOp, AbsTileOp)
71+
TTL_UNARY_TILE_OP(Neg, NegTileOp, NegativeTileInitOp, NegativeTileOp)
72+
TTL_UNARY_TILE_OP(Relu, ReluTileOp, ReluTileInitOp, ReluTileOp)
73+
TTL_UNARY_TILE_OP(Sigmoid, SigmoidTileOp, SigmoidTileInitOp, SigmoidTileOp)
5074

5175
#undef TTL_BINARY_TILE_OP
5276
#undef TTL_UNARY_TILE_OP
77+
#undef TTL_BINARY_TILE_OP_SPECIAL

lib/Dialect/TTL/Transforms/ConvertTTLTileOpsToTTKernel.cpp

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -353,46 +353,26 @@ struct TTLTileCopyToTTKernel : OpConversionPattern<CopyTileOp> {
353353
};
354354

355355
//===----------------------------------------------------------------------===//
356-
// Unary Tile Op Lowerings (LLVM-style type aliases)
356+
// Tile Op Lowerings - Generated from TTLElementwiseOps.def
357357
//===----------------------------------------------------------------------===//
358358

359-
using ExpTileLowering =
360-
TTLTileUnaryToTTKernel<ExpTileOp, ttk::ExpTileInitOp, ttk::ExpTileOp>;
361-
using LogTileLowering =
362-
TTLTileUnaryToTTKernel<LogTileOp, ttk::LogTileInitOp, ttk::LogTileOp>;
363-
using SqrtTileLowering =
364-
TTLTileUnaryToTTKernel<SqrtTileOp, ttk::SqrtTileInitOp, ttk::SqrtTileOp>;
365-
using RsqrtTileLowering =
366-
TTLTileUnaryToTTKernel<RsqrtTileOp, ttk::RsqrtTileInitOp, ttk::RsqrtTileOp>;
367-
using TanhTileLowering =
368-
TTLTileUnaryToTTKernel<TanhTileOp, ttk::TanhTileInitOp, ttk::TanhTileOp>;
369-
using SigmoidTileLowering =
370-
TTLTileUnaryToTTKernel<SigmoidTileOp, ttk::SigmoidTileInitOp,
371-
ttk::SigmoidTileOp>;
372-
using AbsTileLowering =
373-
TTLTileUnaryToTTKernel<AbsTileOp, ttk::AbsTileInitOp, ttk::AbsTileOp>;
374-
using NegTileLowering =
375-
TTLTileUnaryToTTKernel<NegTileOp, ttk::NegativeTileInitOp,
376-
ttk::NegativeTileOp>;
377-
using ReluTileLowering =
378-
TTLTileUnaryToTTKernel<ReluTileOp, ttk::ReluTileInitOp, ttk::ReluTileOp>;
359+
// Generate type aliases for unary tile op lowerings
360+
#define TTL_UNARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) \
361+
using TTL_OP##TileLowering = \
362+
TTLTileUnaryToTTKernel<TILE_OP, ttk::TTK_INIT, ttk::TTK_COMPUTE>;
363+
#include "ttlang/Dialect/TTL/TTLElementwiseOps.def"
379364

380-
//===----------------------------------------------------------------------===//
381-
// Binary Tile Op Lowerings
382-
//===----------------------------------------------------------------------===//
365+
// Generate type aliases for binary tile op lowerings (standard 3-arg form)
366+
#define TTL_BINARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) \
367+
using TTL_OP##TileLowering = \
368+
TTLTileBinaryToTTKernel<TILE_OP, ttk::TTK_INIT, ttk::TTK_COMPUTE>;
369+
#include "ttlang/Dialect/TTL/TTLElementwiseOps.def"
383370

384-
using AddTileLowering =
385-
TTLTileBinaryToTTKernel<AddTileOp, ttk::AddBinaryTilesInitOp,
386-
ttk::AddBinaryTilesOp>;
387-
using SubTileLowering =
388-
TTLTileBinaryToTTKernel<SubTileOp, ttk::SubBinaryTilesInitOp,
389-
ttk::SubBinaryTilesOp>;
390-
using MulTileLowering =
391-
TTLTileBinaryToTTKernel<MulTileOp, ttk::MulBinaryTilesInitOp,
392-
ttk::MulBinaryTilesOp>;
393-
using MaxTileLowering =
394-
TTLTileMaxToTTKernel<MaxTileOp, ttk::BinaryMaxTileInitOp,
395-
ttk::BinaryMaxTileOp>;
371+
// Generate type aliases for special binary tile op lowerings (2-arg in-place)
372+
#define TTL_BINARY_TILE_OP_SPECIAL(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) \
373+
using TTL_OP##TileLowering = \
374+
TTLTileMaxToTTKernel<TILE_OP, ttk::TTK_INIT, ttk::TTK_COMPUTE>;
375+
#include "ttlang/Dialect/TTL/TTLElementwiseOps.def"
396376

397377
} // namespace
398378

@@ -409,14 +389,21 @@ void populateTTLTileOpsToTTKernelPatterns(TypeConverter *typeConverter,
409389
patterns.add<TTLTileRegsAcquireToTTKernel, TTLTileRegsCommitToTTKernel,
410390
TTLTileRegsWaitToTTKernel, TTLTileRegsReleaseToTTKernel>(ctx);
411391

412-
// Tile op lowerings (ttl.tile_* → ttkernel.*_tile)
413-
patterns.add<
414-
// Unary ops
415-
ExpTileLowering, LogTileLowering, SqrtTileLowering, RsqrtTileLowering,
416-
TanhTileLowering, SigmoidTileLowering, AbsTileLowering, NegTileLowering,
417-
ReluTileLowering,
418-
// Binary ops
419-
AddTileLowering, SubTileLowering, MulTileLowering, MaxTileLowering>(ctx);
392+
// Tile op lowerings - generated from TTLElementwiseOps.def
393+
// Unary ops (ttl.tile_* → ttkernel.*_tile)
394+
#define TTL_UNARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) \
395+
patterns.add<TTL_OP##TileLowering>(ctx);
396+
#include "ttlang/Dialect/TTL/TTLElementwiseOps.def"
397+
398+
// Binary ops (ttl.tile_* → ttkernel.*_tiles)
399+
#define TTL_BINARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) \
400+
patterns.add<TTL_OP##TileLowering>(ctx);
401+
#include "ttlang/Dialect/TTL/TTLElementwiseOps.def"
402+
403+
// Special binary ops (non-standard lowering template)
404+
#define TTL_BINARY_TILE_OP_SPECIAL(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) \
405+
patterns.add<TTL_OP##TileLowering>(ctx);
406+
#include "ttlang/Dialect/TTL/TTLElementwiseOps.def"
420407

421408
// Copy op needs the type converter.
422409
patterns.add<TTLTileCopyToTTKernel>(*typeConverter, ctx);

lib/Dialect/TTL/Transforms/ConvertTTLToCompute.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,13 @@ struct LowerUnaryToCompute : OpRewritePattern<TTLOp> {
281281
//===----------------------------------------------------------------------===//
282282

283283
// Generate type aliases for binary operations using tile ops
284-
#define TTL_BINARY_TILE_OP(TTL_OP, TILE_OP) \
284+
// (TTK_INIT and TTK_COMPUTE are unused here, only needed for TTKernel lowering)
285+
#define TTL_BINARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) \
286+
using Lower##TTL_OP = LowerBinaryToCompute<TTL_OP##Op, TILE_OP>;
287+
#define TTL_BINARY_TILE_OP_SPECIAL(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) \
285288
using Lower##TTL_OP = LowerBinaryToCompute<TTL_OP##Op, TILE_OP>;
286289
// Generate type aliases for unary operations using tile ops
287-
#define TTL_UNARY_TILE_OP(TTL_OP, TILE_OP) \
290+
#define TTL_UNARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) \
288291
using Lower##TTL_OP = LowerUnaryToCompute<TTL_OP##Op, TILE_OP>;
289292
#include "ttlang/Dialect/TTL/TTLElementwiseOps.def"
290293

@@ -319,8 +322,14 @@ void populateTTLToComputePatterns(RewritePatternSet &patterns) {
319322

320323
// Register patterns for lowering to ttl.compute with tile ops.
321324
// These are generated from TTLElementwiseOps.def using tile-based mappings.
322-
#define TTL_BINARY_TILE_OP(TTL_OP, TILE_OP) patterns.add<Lower##TTL_OP>(ctx);
323-
#define TTL_UNARY_TILE_OP(TTL_OP, TILE_OP) patterns.add<Lower##TTL_OP>(ctx);
325+
// (TTK_INIT and TTK_COMPUTE are unused here, only needed for TTKernel
326+
// lowering)
327+
#define TTL_BINARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) \
328+
patterns.add<Lower##TTL_OP>(ctx);
329+
#define TTL_BINARY_TILE_OP_SPECIAL(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) \
330+
patterns.add<Lower##TTL_OP>(ctx);
331+
#define TTL_UNARY_TILE_OP(TTL_OP, TILE_OP, TTK_INIT, TTK_COMPUTE) \
332+
patterns.add<Lower##TTL_OP>(ctx);
324333
#include "ttlang/Dialect/TTL/TTLElementwiseOps.def"
325334
}
326335

python/gen_elementwise.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,24 +64,43 @@ def {name}(input: "TensorBlock") -> "TensorBlock":
6464

6565

6666
def parse_def_file(def_path: Path) -> tuple[list[str], list[str]]:
67-
"""Parse TTLElementwiseOps.def and extract operation names."""
67+
"""Parse TTLElementwiseOps.def and extract operation names.
68+
69+
Handles both standard and special binary ops:
70+
- TTL_BINARY_TILE_OP(Name, TileOp, TTKInit, TTKCompute)
71+
- TTL_BINARY_TILE_OP_SPECIAL(Name, TileOp, TTKInit, TTKCompute)
72+
- TTL_UNARY_TILE_OP(Name, TileOp, TTKInit, TTKCompute)
73+
"""
6874
content = def_path.read_text()
6975

7076
binary_ops = []
7177
unary_ops = []
7278

73-
# Match TTL_BINARY_TILE_OP(Name, TileOp) but skip #define lines
74-
for match in re.finditer(r"^TTL_BINARY_TILE_OP\((\w+),", content, re.MULTILINE):
79+
# Match TTL_BINARY_TILE_OP(Name, ...) and TTL_BINARY_TILE_OP_SPECIAL(Name, ...)
80+
# but skip #define lines
81+
for match in re.finditer(
82+
r"^TTL_BINARY_TILE_OP(?:_SPECIAL)?\((\w+),", content, re.MULTILINE
83+
):
7584
name = match.group(1).lower()
7685
# Skip macro parameter names (lowercase indicates it's a parameter)
77-
if name[0].isupper() or name not in ("ttl_op", "tile_op"):
86+
if name[0].isupper() or name not in (
87+
"ttl_op",
88+
"tile_op",
89+
"ttk_init",
90+
"ttk_compute",
91+
):
7892
binary_ops.append(name)
7993

80-
# Match TTL_UNARY_TILE_OP(Name, TileOp) but skip #define lines
94+
# Match TTL_UNARY_TILE_OP(Name, ...) but skip #define lines
8195
for match in re.finditer(r"^TTL_UNARY_TILE_OP\((\w+),", content, re.MULTILINE):
8296
name = match.group(1).lower()
8397
# Skip macro parameter names
84-
if name[0].isupper() or name not in ("ttl_op", "tile_op"):
98+
if name[0].isupper() or name not in (
99+
"ttl_op",
100+
"tile_op",
101+
"ttk_init",
102+
"ttk_compute",
103+
):
85104
unary_ops.append(name)
86105

87106
return binary_ops, unary_ops

0 commit comments

Comments
 (0)