Skip to content

Commit 6daf2b9

Browse files
authored
[MLIR][Linalg] Remove elemwise_unary and elemwise_binary (#147082)
RFC: https://discourse.llvm.org/t/rfc-deprecate-linalg-elemwise-unary-and-elemwise-binary/87144 Remove the two operations and fix the tests by: * Cleaning simple operation tests of the old ops * Changing `linalg.elemwise_{u|bi}nary` with `linalg.{exp|add}` on transform tests * Changing some of the tests with `linalg.elementwise` instead, to broaden test coverage * Surgically removing the `elemwise_*` part in the Python tests * Update MLIR transform examples (text and tests) with `linalg.elementwise` instead Nothing else changed.
1 parent 0aab8e4 commit 6daf2b9

File tree

26 files changed

+168
-667
lines changed

26 files changed

+168
-667
lines changed

mlir/docs/Tutorials/transform/Ch1.md

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
1919
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
2020
2121
// Elementwise addition.
22-
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
22+
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
2323
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
2424
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
2525
2626
// Elementwise max with 0 (ReLU).
2727
%c0f = arith.constant 0.0 : f32
28-
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
28+
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
2929
ins(%biased, %c0f : tensor<512x512xf32>, f32)
3030
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
3131
func.return %relued : tensor<512x512xf32>
@@ -41,7 +41,7 @@ module attributes {transform.with_named_sequence} {
4141
transform.named_sequence @__transform_main(
4242
%arg0: !transform.any_op,
4343
%arg1: !transform.op<"linalg.matmul">,
44-
%arg2: !transform.op<"linalg.elemwise_binary">):
44+
%arg2: !transform.op<"linalg.elementwise">):
4545
transform.yield
4646
}
4747
}
@@ -72,11 +72,11 @@ To check or debug a transform sequence, it is possible to print various entities
7272
transform.sequence failures(propagate) {
7373
^bb0(%arg0: !transform.any_op,
7474
%arg1: !transform.op<"linalg.matmul">,
75-
%arg2: !transform.op<"linalg.elemwise_binary">):
75+
%arg2: !transform.op<"linalg.elementwise">):
7676
transform.debug.emit_remark_at %arg1, "matmul"
7777
: !transform.op<"linalg.matmul">
7878
transform.debug.emit_remark_at %arg2, "elemwise_binaries"
79-
: !transform.op<"linalg.elemwise_binary">
79+
: !transform.op<"linalg.elementwise">
8080
transform.yield
8181
}
8282
```
@@ -89,24 +89,24 @@ Since we don’t want to recompile the compiler every time we change a transform
8989
```sh
9090
$ mlir-opt sequence.mlir --pass-pipeline="
9191
builtin.module(transform-interpreter{
92-
debug-bind-trailing-args=linalg.matmul,linalg.elemwise_binary})"
92+
debug-bind-trailing-args=linalg.matmul,linalg.elementwise})"
9393
```
9494

95-
The `sequence.mlir` file contains _both_ the payload IR function _and_ the transform IR sequence nested in the same module. The transform interpreter pass will apply the `@__transform_main` named sequence to the anchor operation of the pass. In our case, we also asked the interpreter pass to associate the two extra arguments of the top-level sequence with all `linalg.matmul` and `linalg.elemwise_binary` payload operations through the respective pass options. Running this pass results in the expected remarks:
95+
The `sequence.mlir` file contains _both_ the payload IR function _and_ the transform IR sequence nested in the same module. The transform interpreter pass will apply the `@__transform_main` named sequence to the anchor operation of the pass. In our case, we also asked the interpreter pass to associate the two extra arguments of the top-level sequence with all `linalg.matmul` and `linalg.elementwise` payload operations through the respective pass options. Running this pass results in the expected remarks:
9696

9797
```sh
9898
sequence.mlir:7:13: remark: matmul
9999
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
100100
^
101101
sequence.mlir:7:13: note: see current operation: %0 = linalg.matmul ins(%arg0, %arg1 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
102102
sequence.mlir:10:13: remark: elemwise_binaries
103-
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
103+
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
104104
^
105-
sequence.mlir:10:13: note: see current operation: %1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
105+
sequence.mlir:10:13: note: see current operation: %1 = linalg.elementwise kind=#linalg.elementwise_kind<add>> ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
106106
sequence.mlir:14:13: remark: elemwise_binaries
107-
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
107+
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
108108
^
109-
sequence.mlir:14:13: note: see current operation: %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<max_signed>} ins(%1, %cst : tensor<512x512xf32>, f32) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
109+
sequence.mlir:14:13: note: see current operation: %2 = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>> ins(%1, %cst : tensor<512x512xf32>, f32) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
110110
```
111111

112112
Note that `%arg2` is associated with both elementwise payload operations. Any handle is associated with a list of entities. Individual transformations may or may not care about the order of elements in that list.
@@ -121,7 +121,7 @@ module attributes {transform.with_named_sequence} {
121121
transform.named_sequence @__transform_main(
122122
%arg0: !transform.any_op,
123123
%arg1: !transform.op<"linalg.matmul">,
124-
%arg2: !transform.op<"linalg.elemwise_binary">) {
124+
%arg2: !transform.op<"linalg.elementwise">) {
125125
// The actual tiling transformation takes tile sizes as attributes.
126126
%loop, %tiled = transform.structured.tile_using_forall %arg1
127127
tile_sizes [4, 32]
@@ -163,10 +163,10 @@ func.func @fc_relu(%arg0: tensor<512x512xf32>,
163163
: tensor<4x32xf32> into tensor<512x512xf32>
164164
}
165165
}
166-
%1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
166+
%1 = linalg.elementwise kind=#linalg.elementwise_kind<add>>
167167
ins(%0, %arg2 : tensor<512x512xf32>, tensor<512x512xf32>)
168168
outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
169-
%2 = linalg.elemwise_binary {fun = #linalg.binary_fn<max_signed>}
169+
%2 = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>>
170170
ins(%1, %cst : tensor<512x512xf32>, f32)
171171
outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32>
172172
return %2 : tensor<512x512xf32>
@@ -185,7 +185,7 @@ module attributes {transform.with_named_sequence} {
185185
transform.named_sequence @__transform_main(
186186
%arg0: !transform.any_op,
187187
%arg1: !transform.op<"linalg.matmul">,
188-
%arg2: !transform.op<"linalg.elemwise_binary">) {
188+
%arg2: !transform.op<"linalg.elementwise">) {
189189
// The actual tiling transformation takes tile sizes as attributes.
190190
%loop, %tiled = transform.structured.tile_using_forall %arg1 tile_sizes [4, 32]
191191
: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op)
@@ -219,7 +219,7 @@ module attributes {transform.with_named_sequence} {
219219
transform.named_sequence @__transform_main
220220
%arg0: !transform.any_op,
221221
%arg1: !transform.op<"linalg.matmul">,
222-
%arg2: !transform.op<"linalg.elemwise_binary">) {
222+
%arg2: !transform.op<"linalg.elementwise">) {
223223
// We can cast one type to another as long as operations are compatible
224224
// with both types. This creates "aliasing" handles.
225225
%casted = transform.cast %arg1 : !transform.op<"linalg.matmul">
@@ -248,7 +248,7 @@ sequence.mlir:28:3: error: op uses a handle invalidated by a previously executed
248248
transform.debug.emit_remark_at %matmul, "elemwise_binaries" : !transform.op<"linalg.matmul">
249249
^
250250
sequence.mlir:21:29: note: handle to invalidated ops
251-
^bb0(%root: !transform.any_op, %matmul: !transform.op<"linalg.matmul">, %elemwise: !transform.op<"linalg.elemwise_binary">):
251+
^bb0(%root: !transform.any_op, %matmul: !transform.op<"linalg.matmul">, %elemwise: !transform.op<"linalg.elementwise">):
252252
^
253253
sequence.mlir:27:19: note: invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them
254254
%loop, %tiled = transform.structured.tile_using_forall %mm tile_sizes [4, 32]
@@ -263,12 +263,12 @@ module attributes {transform.with_named_sequence} {
263263
transform.named_sequence @__transform_main(
264264
%arg0: !transform.any_op,
265265
%arg1: !transform.op<"linalg.matmul">,
266-
%arg2: !transform.op<"linalg.elemwise_binary">) {
266+
%arg2: !transform.op<"linalg.elementwise">) {
267267
// Since the %arg2 handle is associated with both elementwise operations,
268268
// we need to split it into two handles so we can target only the second
269269
// elementwise operation.
270270
%add, %max = transform.split_handle %arg2
271-
: (!transform.op<"linalg.elemwise_binary">)
271+
: (!transform.op<"linalg.elementwise">)
272272
-> (!transform.any_op, !transform.any_op)
273273
274274
// The actual tiling transformation takes tile sizes as attributes. It
@@ -308,12 +308,12 @@ module attributes {transform.with_named_sequence} {
308308
transform.named_sequence @__transform_main(
309309
%arg0: !transform.any_op,
310310
%arg1: !transform.op<"linalg.matmul">,
311-
%arg2: !transform.op<"linalg.elemwise_binary">) {
311+
%arg2: !transform.op<"linalg.elementwise">) {
312312
// Since the %arg2 handle is associated with both elementwise operations,
313313
// we need to split it into two handles so we can target only the second
314314
// elementwise operation.
315315
%add, %max = transform.split_handle %arg2
316-
: (!transform.op<"linalg.elemwise_binary">)
316+
: (!transform.op<"linalg.elementwise">)
317317
-> (!transform.any_op, !transform.any_op)
318318
319319
// The actual tiling transformation takes tile sizes as attributes. It
@@ -384,7 +384,7 @@ test/Examples/transform/Ch1/invalidation-2.mlir:106:18: note: invalidated by thi
384384
%func, %call = transform.loop.outline %outline_target {func_name = "outlined"}
385385
^
386386
test/Examples/transform/Ch1/invalidation-2.mlir:24:13: note: ancestor payload op
387-
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
387+
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
388388
^
389389
test/Examples/transform/Ch1/invalidation-2.mlir:24:13: note: nested payload op
390390
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)

mlir/docs/Tutorials/transform/Ch2.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,12 @@ module attributes {transform.with_named_sequence} {
290290
transform.named_sequence @__transform_main(
291291
%arg0: !transform.any_op,
292292
%arg1: !transform.op<"linalg.matmul">,
293-
%arg2: !transform.op<"linalg.elemwise_binary">) {
293+
%arg2: !transform.op<"linalg.elementwise">) {
294294
// Since the %arg2 handle is associated with both elementwise operations,
295295
// we need to split it into two handles so we can target only the second
296296
// elementwise operation.
297297
%add, %max = transform.split_handle %arg2
298-
: (!transform.op<"linalg.elemwise_binary">)
298+
: (!transform.op<"linalg.elementwise">)
299299
-> (!transform.any_op, !transform.any_op)
300300
301301
// The actual tiling transformation takes tile sizes as attributes. It

mlir/docs/Tutorials/transform/Ch4.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
4242
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
4343
4444
// Elementwise addition.
45-
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
45+
%biased = linalg.elementwise kind=#linalg.elementwise_kind<add>
4646
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
4747
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
4848
4949
// Elementwise max with 0 (ReLU).
5050
%c0f = arith.constant 0.0 : f32
51-
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
51+
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed>
5252
ins(%biased, %c0f : tensor<512x512xf32>, f32)
5353
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
5454
func.return %relued : tensor<512x512xf32>
@@ -59,7 +59,7 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
5959

6060
In Chapter 1, we were calling the test transform interpreter pass with
6161
additional arguments, `bind-first-extra-to-ops=linalg.matmul
62-
bind-second-extra-to-ops=linalg.elemwise_binary`, to provide initial
62+
bind-second-extra-to-ops=linalg.elementwise`, to provide initial
6363
associations for operation handles. Instead, we can use match operations to
6464
discover relevant operations in the payload IR. Match operations can be combined
6565
with “regular” transform operations using, e.g., the
@@ -97,7 +97,7 @@ module @transforms attributes { transform.with_named_sequence } {
9797
// rewriter sequence on success.
9898
transform.named_sequence @match_elemwise(
9999
%entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
100-
transform.match.operation_name %entry ["linalg.elemwise_binary"]
100+
transform.match.operation_name %entry ["linalg.elementwise"]
101101
: !transform.any_op
102102
transform.yield %entry : !transform.any_op
103103
}
@@ -127,7 +127,7 @@ module @transforms attributes { transform.with_named_sequence } {
127127
This script can be executed using the non-test interpreter pass running on the
128128
root operation of the translation unit without additional flags: `mlir-opt
129129
--transform-interpreter`. It will emit corresponding remarks at
130-
`linalg.elemwise_binary` and `linalg.matmul` operations. In debug builds, the
130+
`linalg.elementwise` and `linalg.matmul` operations. In debug builds, the
131131
infrastructure provides a convenient method to understand the matching process
132132
by passing `-debug-only=transform-matcher` to `mlir-opt` or a derived tool. It
133133
will print the silenceable failure messages produced by the match operations
@@ -169,7 +169,7 @@ transform.named_sequence @match_matmul_elemwise(
169169
%last: !transform.any_op {transform.readonly})
170170
-> (!transform.any_op, !transform.any_op, !transform.any_op) {
171171
// The last operation must be an elementwise binary.
172-
transform.match.operation_name %last ["linalg.elemwise_binary"]
172+
transform.match.operation_name %last ["linalg.elementwise"]
173173
: !transform.any_op
174174
// Its first operand must be defined by another operation, to which we
175175
// will get a handle here. We are guaranteed that the first operand exists
@@ -179,7 +179,7 @@ transform.named_sequence @match_matmul_elemwise(
179179
%middle = transform.get_producer_of_operand %last[0]
180180
: (!transform.any_op) -> !transform.any_op
181181
// The defining operation must itself be an elementwise binary.
182-
transform.match.operation_name %middle ["linalg.elemwise_binary"]
182+
transform.match.operation_name %middle ["linalg.elementwise"]
183183
: !transform.any_op
184184
// And the first operand of that operation must be defined by yet another
185185
// operation.
@@ -399,7 +399,7 @@ transform.named_sequence @match_matmul_elemwise(
399399
-> (!transform.any_op, !transform.any_op, !transform.any_op,
400400
!transform.param<i32>) {
401401
// The last operation must be an elementwise binary.
402-
transform.match.operation_name %last ["linalg.elemwise_binary"]
402+
transform.match.operation_name %last ["linalg.elementwise"]
403403
: !transform.any_op
404404
405405
// One of its operands must be defined by another operation, to which we
@@ -413,7 +413,7 @@ transform.named_sequence @match_matmul_elemwise(
413413
%def = transform.get_defining_op %operand
414414
: (!transform.any_value) -> !transform.any_op
415415
// The defining operation must itself be an elementwise binary.
416-
transform.match.operation_name %def ["linalg.elemwise_binary"]
416+
transform.match.operation_name %def ["linalg.elementwise"]
417417
: !transform.any_op
418418
transform.yield %def : !transform.any_op
419419
}

mlir/docs/Tutorials/transform/ChH.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ scf.forall (%co) in (2) {
290290
scf.forall (%n, %y, %xo) in (5, 80, 20) {
291291
tensor.extract_slice
292292
// Implicit dimensions [ni=0:1, y=0:1, xi=0:5, ci=0:64]
293-
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } // ...
293+
%relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed> // ...
294294
scf.forall.in_parallel {
295295
tensor.parallel_insert_slice // ...
296296
}

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 0 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -44,56 +44,6 @@ structured_op: !LinalgStructuredOpConfig
4444
- !ScalarExpression
4545
scalar_arg: I
4646
--- !LinalgOpConfig
47-
metadata: !LinalgOpMetadata
48-
name: elemwise_unary
49-
cpp_class_name: ElemwiseUnaryOp
50-
doc: |-
51-
Applies the unary function fun elementwise.
52-
53-
Numeric casting is performed on the input operand, promoting it to the same
54-
data type as the accumulator/output.
55-
structured_op: !LinalgStructuredOpConfig
56-
args:
57-
- !LinalgOperandDefConfig
58-
name: I
59-
kind: input_tensor
60-
type_var: T1
61-
shape_map: affine_map<() -> ()>
62-
- !LinalgOperandDefConfig
63-
name: O
64-
kind: output_tensor
65-
type_var: U
66-
shape_map: affine_map<() -> ()>
67-
- !LinalgOperandDefConfig
68-
name: fun
69-
kind: unary_fn_attr
70-
default_fn: exp
71-
- !LinalgOperandDefConfig
72-
name: cast
73-
kind: type_fn_attr
74-
default_fn: cast_signed
75-
indexing_maps: !LinalgIndexingMapsConfig
76-
static_indexing_maps:
77-
- affine_map<() -> ()>
78-
- affine_map<() -> ()>
79-
iterator_types: []
80-
assignments:
81-
- !ScalarAssign
82-
arg: O
83-
value: !ScalarExpression
84-
scalar_fn:
85-
kind: unary
86-
attr_name: fun
87-
operands:
88-
- !ScalarExpression
89-
scalar_fn:
90-
kind: type
91-
attr_name: cast
92-
type_var: U
93-
operands:
94-
- !ScalarExpression
95-
scalar_arg: I
96-
--- !LinalgOpConfig
9747
metadata: !LinalgOpMetadata
9848
name: exp
9949
cpp_class_name: ExpOp
@@ -549,70 +499,6 @@ structured_op: !LinalgStructuredOpConfig
549499
- !ScalarExpression
550500
scalar_arg: I
551501
--- !LinalgOpConfig
552-
metadata: !LinalgOpMetadata
553-
name: elemwise_binary
554-
cpp_class_name: ElemwiseBinaryOp
555-
doc: |-
556-
Applies the binary function fun elementwise.
557-
558-
Numeric casting is performed on the input operand, promoting it to the same
559-
data type as the accumulator/output.
560-
structured_op: !LinalgStructuredOpConfig
561-
args:
562-
- !LinalgOperandDefConfig
563-
name: lhs
564-
kind: input_tensor
565-
type_var: T1
566-
shape_map: affine_map<() -> ()>
567-
- !LinalgOperandDefConfig
568-
name: rhs
569-
kind: input_tensor
570-
type_var: T2
571-
shape_map: affine_map<() -> ()>
572-
- !LinalgOperandDefConfig
573-
name: O
574-
kind: output_tensor
575-
type_var: U
576-
shape_map: affine_map<() -> ()>
577-
- !LinalgOperandDefConfig
578-
name: fun
579-
kind: binary_fn_attr
580-
default_fn: add
581-
- !LinalgOperandDefConfig
582-
name: cast
583-
kind: type_fn_attr
584-
default_fn: cast_signed
585-
indexing_maps: !LinalgIndexingMapsConfig
586-
static_indexing_maps:
587-
- affine_map<() -> ()>
588-
- affine_map<() -> ()>
589-
- affine_map<() -> ()>
590-
iterator_types: []
591-
assignments:
592-
- !ScalarAssign
593-
arg: O
594-
value: !ScalarExpression
595-
scalar_fn:
596-
kind: binary
597-
attr_name: fun
598-
operands:
599-
- !ScalarExpression
600-
scalar_fn:
601-
kind: type
602-
attr_name: cast
603-
type_var: U
604-
operands:
605-
- !ScalarExpression
606-
scalar_arg: lhs
607-
- !ScalarExpression
608-
scalar_fn:
609-
kind: type
610-
attr_name: cast
611-
type_var: U
612-
operands:
613-
- !ScalarExpression
614-
scalar_arg: rhs
615-
--- !LinalgOpConfig
616502
metadata: !LinalgOpMetadata
617503
name: add
618504
cpp_class_name: AddOp

0 commit comments

Comments
 (0)