Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
pashu123 authored Feb 4, 2025
1 parent d444ab4 commit 2733115
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 12 deletions.
3 changes: 2 additions & 1 deletion compiler/bindings/python/iree/compiler/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
SIMPLE_MUL_ASM = """
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
"""
Expand Down
9 changes: 5 additions & 4 deletions tests/e2e/regression/layernorm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@
// %x = util.unfoldable_constant dense<5.0> : tensor<128x384xf32>
// %c384 = util.unfoldable_constant dense<384.0> : tensor<128x1xf32>
// %sum = tosa.reduce_sum %x {axis = 1 : i64} : (tensor<128x384xf32>) -> tensor<128x1xf32>
// %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// %r384 = tosa.reciprocal %c384 : (tensor<128x1xf32>) -> tensor<128x1xf32>
// %mean = tosa.mul %sum, %r384 {shift = 0 : i8} : (tensor<128x1xf32>, tensor<128x1xf32>) -> tensor<128x1xf32>
// %mean = tosa.mul %sum, %r384, %shift : (tensor<128x1xf32>, tensor<128x1xf32>, tensor<1xi8>) -> tensor<128x1xf32>
// %x_sub_mean = tosa.sub %x, %mean : (tensor<128x384xf32>, tensor<128x1xf32>) -> tensor<128x384xf32>
// %square = tosa.mul %x_sub_mean, %x_sub_mean {shift = 0 : i8} : (tensor<128x384xf32>, tensor<128x384xf32>) -> tensor<128x384xf32>
// %square = tosa.mul %x_sub_mean, %x_sub_mean, %shift : (tensor<128x384xf32>, tensor<128x384xf32>, tensor<1xi8>) -> tensor<128x384xf32>
// %square_sum = tosa.reduce_sum %square {axis = 1 : i64} : (tensor<128x384xf32>) -> tensor<128x1xf32>
// %variance = tosa.mul %square_sum, %r384 {shift = 0 : i8} : (tensor<128x1xf32>, tensor<128x1xf32>) -> tensor<128x1xf32>
// %variance = tosa.mul %square_sum, %r384, %shift : (tensor<128x1xf32>, tensor<128x1xf32>, tensor<1xi8>) -> tensor<128x1xf32>
// %epsilon = util.unfoldable_constant dense<9.99999996E-13> : tensor<128x1xf32>
// %var_eps = tosa.add %variance, %epsilon : (tensor<128x1xf32>, tensor<128x1xf32>) -> tensor<128x1xf32>
// %rsigma = tosa.rsqrt %var_eps : (tensor<128x1xf32>) -> tensor<128x1xf32>
// %norm = tosa.mul %x_sub_mean, %rsigma {shift = 0 : i8} : (tensor<128x384xf32>, tensor<128x1xf32>) -> tensor<128x384xf32>
// %norm = tosa.mul %x_sub_mean, %rsigma, %shift : (tensor<128x384xf32>, tensor<128x1xf32>, tensor<1xi8>) -> tensor<128x384xf32>
// check.expect_almost_eq_const(%norm, dense<0.0> : tensor<128x384xf32>) : tensor<128x384xf32>
// return
// }
Expand Down
3 changes: 2 additions & 1 deletion tests/e2e/regression/softmax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
// %exp = tosa.exp %sub : (tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
// %sum = tosa.reduce_sum %exp {axis = 2 : i64} : (tensor<12x128x128xf32>) -> tensor<12x128x1xf32>
// %rec = tosa.reciprocal %sum : (tensor<12x128x1xf32>) -> tensor<12x128x1xf32>
// %mul = tosa.mul %exp, %rec {shift = 0 : i8} : (tensor<12x128x128xf32>, tensor<12x128x1xf32>) -> tensor<12x128x128xf32>
// %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
// %mul = tosa.mul %exp, %rec, %shift : (tensor<12x128x128xf32>, tensor<12x128x1xf32>, tensor<1xi8>) -> tensor<12x128x128xf32>
// check.expect_almost_eq_const(%mul, dense<0.0078125> : tensor<12x128x128xf32>) : tensor<12x128x128xf32>
// return
// }
Expand Down
6 changes: 4 additions & 2 deletions tests/e2e/tosa_ops/mul.mlir
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
func.func @tensor_float() {
%0 = util.unfoldable_constant dense<[1.0, 0.0, 3.0, 4.0]> : tensor<4xf32>
%1 = util.unfoldable_constant dense<[5.0, 6.0, -3.0, 8.0]> : tensor<4xf32>
%result = tosa.mul %0, %1 {shift = 0 : i8} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%result = tosa.mul %0, %1, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
check.expect_almost_eq_const(%result, dense<[5.0, 0.0, -9.0, 32.0]> : tensor<4xf32>) : tensor<4xf32>
return
}

func.func @tensor_int() {
%0 = util.unfoldable_constant dense<[1, 0, 3, 4]> : tensor<4xi32>
%1 = util.unfoldable_constant dense<[5, 6, -3, 8]> : tensor<4xi32>
%result = tosa.mul %0, %1 {shift = 0 : i8} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%result = tosa.mul %0, %1, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
check.expect_eq_const(%result, dense<[5, 0, -9, 32]> : tensor<4xi32>) : tensor<4xi32>
return
}
3 changes: 2 additions & 1 deletion tests/e2e/tosa_ops/mul_shift.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
func.func @tensor_int_shifted() {
%0 = util.unfoldable_constant dense<[1, 0, 3, 4, 4]> : tensor<5xi32>
%1 = util.unfoldable_constant dense<[5, 6, -3, 8, 8]> : tensor<5xi32>
%result = tosa.mul %0, %1 {shift = 1 : i8} : (tensor<5xi32>, tensor<5xi32>) -> tensor<5xi32>
%shift = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>
%result = tosa.mul %0, %1, %shift : (tensor<5xi32>, tensor<5xi32>, tensor<1xi8>) -> tensor<5xi32>
check.expect_eq_const(%result, dense<[3, 0, -4, 16, 16]> : tensor<5xi32>) : tensor<5xi32>
return
}
2 changes: 1 addition & 1 deletion third_party/llvm-project
Submodule llvm-project updated 715 files
2 changes: 1 addition & 1 deletion third_party/stablehlo

0 comments on commit 2733115

Please sign in to comment.