Skip to content

arith constant in stable HLO graph #3803

@AleksKnezevic

Description

@AleksKnezevic

When converting a simple linear module with bias into a stable HLO graph, a constant comes out in arith dialect. Is this expected?

Here is the code to produce it:

def test_linear_with_bias():
  class Basic(nn.Module):
    def __init__(self):
      super().__init__()
      self.linear_a = nn.Linear(32, 32)

    def forward(self, x):
      x = self.linear_a(x)
      return x

  mod = fx.export_and_import(Basic(), torch.randint(0, 100, (1, 32)), output_type=OutputType.STABLEHLO)
  mod.dump()

and here is the result:

test/test_basic.py::test_linear_with_bias module {
  func.func @main(%arg0: tensor<1x32xi64>) -> tensor<1x32xf32> {
    %cst = stablehlo.constant dense_resource<torch_tensor_32_torch.float32> : tensor<32xf32>
    %cst_0 = stablehlo.constant dense_resource<torch_tensor_32_32_torch.float32> : tensor<32x32xf32>
    %cst_1 = arith.constant dense<1> : tensor<1xi64>
    %0 = stablehlo.transpose %cst_0, dims = [1, 0] : (tensor<32x32xf32>) -> tensor<32x32xf32>
    %1 = stablehlo.convert %arg0 : (tensor<1x32xi64>) -> tensor<1x32xf32>
    %2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0] : (tensor<1x32xf32>, tensor<32x32xf32>) -> tensor<1x32xf32>
    %3 = stablehlo.convert %cst_1 : (tensor<1xi64>) -> tensor<1xf32>
    %4 = stablehlo.reshape %3 : (tensor<1xf32>) -> tensor<f32>
    %5 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<1x32xf32>
    %6 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<f32>) -> tensor<1x32xf32>
    %7 = stablehlo.multiply %5, %6 : tensor<1x32xf32>
    %8 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<f32>) -> tensor<32xf32>
    %9 = stablehlo.broadcast_in_dim %cst, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %10 = stablehlo.multiply %8, %9 : tensor<32xf32>
    %11 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<1x32xf32>
    %12 = stablehlo.broadcast_in_dim %10, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32>
    %13 = stablehlo.add %11, %12 : tensor<1x32xf32>
    return %13 : tensor<1x32xf32>
  }
}

I'm using torch-mlir to interface with a compiler that expects only stable HLO dialect input. I'm happy to either add support for arith constants in my compiler, or change how it's emitted from torch-mlir and submit a PR, just wanted to know what the expected behaviour is.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions