-
Notifications
You must be signed in to change notification settings - Fork 651
Closed
Description
When converting a simple linear module with bias into a stable HLO graph, a constant comes out in arith dialect. Is this expected?
Here is the code to produce it:
def test_linear_with_bias():
class Basic(nn.Module):
def __init__(self):
super().__init__()
self.linear_a = nn.Linear(32, 32)
def forward(self, x):
x = self.linear_a(x)
return x
mod = fx.export_and_import(Basic(), torch.randint(0, 100, (1, 32)), output_type=OutputType.STABLEHLO)
mod.dump()and here is the result:
test/test_basic.py::test_linear_with_bias module {
func.func @main(%arg0: tensor<1x32xi64>) -> tensor<1x32xf32> {
%cst = stablehlo.constant dense_resource<torch_tensor_32_torch.float32> : tensor<32xf32>
%cst_0 = stablehlo.constant dense_resource<torch_tensor_32_32_torch.float32> : tensor<32x32xf32>
%cst_1 = arith.constant dense<1> : tensor<1xi64>
%0 = stablehlo.transpose %cst_0, dims = [1, 0] : (tensor<32x32xf32>) -> tensor<32x32xf32>
%1 = stablehlo.convert %arg0 : (tensor<1x32xi64>) -> tensor<1x32xf32>
%2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0] : (tensor<1x32xf32>, tensor<32x32xf32>) -> tensor<1x32xf32>
%3 = stablehlo.convert %cst_1 : (tensor<1xi64>) -> tensor<1xf32>
%4 = stablehlo.reshape %3 : (tensor<1xf32>) -> tensor<f32>
%5 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<1x32xf32>
%6 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<f32>) -> tensor<1x32xf32>
%7 = stablehlo.multiply %5, %6 : tensor<1x32xf32>
%8 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<f32>) -> tensor<32xf32>
%9 = stablehlo.broadcast_in_dim %cst, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
%10 = stablehlo.multiply %8, %9 : tensor<32xf32>
%11 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<1x32xf32>
%12 = stablehlo.broadcast_in_dim %10, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32>
%13 = stablehlo.add %11, %12 : tensor<1x32xf32>
return %13 : tensor<1x32xf32>
}
}
I'm using torch-mlir to interface with a compiler that expects only stable HLO dialect input. I'm happy to either add support for arith constants in my compiler, or change how it's emitted from torch-mlir and submit a PR, just wanted to know what the expected behaviour is.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels