Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deeplabv3 Conv2d Shapes #559

Open
yzhang93 opened this issue Jul 16, 2024 · 4 comments
Open

Deeplabv3 Conv2d Shapes #559

yzhang93 opened this issue Jul 16, 2024 · 4 comments

Comments

@yzhang93
Copy link
Contributor

yzhang93 commented Jul 16, 2024

  1. Stride 2 conv2d:
    %8 = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%3, %4, %c0_i32, %c0_i32 : tensor<1x515x515x3xi8>, tensor<3x3x3x32xi8>, i32, i32) outs(%7 : tensor<1x257x257x32xi32>) -> tensor<1x257x257x32xi32>

  2. Stride 1 conv2d filter 1x1:
    %8 = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4, %c0_i32, %c0_i32 : tensor<1x257x257x32xi8>, tensor<1x1x32x16xi8>, i32, i32) outs(%7 : tensor<1x257x257x16xi32>) -> tensor<1x257x257x16xi32>
    which can be converted to matmul_transpose_b:
    %8 = linalg.matmul_transpose_b ins(%3, %4 : tensor<66049x32xi8>, tensor<16x32xi8>) outs(%7 : tensor<66049x16xi32>) -> tensor<66049x16xi32>

Input Weight Output Stride Convert to matmul
1x515x515x3xi8 3x3x3x32xi8 1x257x257x32xi32 2 NA
1x257x257x32xi8 1x1x32x16xi8 1x257x257x16xi32 1 66049x16x32
1x257x257x16xi8 1x1x16x96xi8 1x257x257x96xi32 1 66049x96x16
1x129x129x96xi8 1x1x96x24xi8 1x129x129x24xi32 1 16641x24x96
1x129x129x24xi8 1x1x24x144xi8 1x129x129x144xi32 1 16641x144x24
1x129x129x144xi8 1x1x144x24xi8 1x129x129x24xi32 1 16641x24x144
1x65x65x144xi8 1x1x144x32xi8 1x65x65x32xi32 1 4225x32x144
1x65x65x32xi8 1x1x32x192xi8 1x65x65x192xi32 1 4225x192x32
1x65x65x192xi8 1x1x192x32xi8 1x65x65x32xi32 1 4225x32x192
1x65x65x192xi8 1x1x192x64xi8 1x65x65x64xi32 1 4225x64x192
1x65x65x64xi8 1x1x64x384xi8 1x65x65x384xi32 1 4225x384x64
1x65x65x384xi8 1x1x384x64xi8 1x65x65x64xi32 1 4225x64x384
1x65x65x384xi8 1x1x384x96xi8 1x65x65x96xi32 1 4225x96x384
1x65x65x96xi8 1x1x96x576xi8 1x65x65x576xi32 1 4225x576x96
1x65x65x576xi8 1x1x576x96xi8 1x65x65x96xi32 1 4225x96x576
1x65x65x576xi8 1x1x576x160xi8 1x65x65x160xi32 1 4225x160x576
1x65x65x160xi8 1x1x160x960xi8 1x65x65x960xi32 1 4225x960x160
1x65x65x960xi8 1x1x960x160xi8 1x65x65x160xi32 1 4225x160x960
1x65x65x960xi8 1x1x960x320xi8 1x65x65x320xi32 1 4225x320x960
1x65x65x320xi8 1x1x320x256xi8 1x65x65x256xi32 1 4225x256x320
1x1x1x320xi8 1x1x320x256xi8 1x1x1x256xi32 1 1x256x320
1x65x65x512xi8 1x1x512x256xi8 1x65x65x256xi32 1 4225x256x512
1x65x65x256xi8 1x1x256x21xi8 1x65x65x21xi32 1 4225x21x256
  1. Depthwise Conv2d:

%7 = linalg.conv_2d_ngchw_gfchw_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4, %c0_i32, %c0_i32 : tensor<1x32x1x259x259xi8>, tensor<32x1x1x3x3xi8>, i32, i32) outs(%6 : tensor<1x32x1x257x257xi32>) -> tensor<1x32x1x257x257xi32>

Input Weight Output Stride
1x32x1x259x259xi8 32x1x1x3x3xi8 1x32x1x257x257xi32 1
1x96x1x259x259xi8 96x1x1x3x3xi8 1x96x1x129x129xi32 2
1x144x1x131x131xi8 144x1x1x3x3xi8 1x144x1x129x129xi32 1
1x144x1x131x131xi8 144x1x1x3x3xi8 1x144x1x65x65xi32 2
1x192x1x67x67xi8 192x1x1x3x3xi8 1x192x1x65x65xi32 1
1x384x1x69x69xi8 384x1x1x3x3xi8 1x384x1x65x65xi32 1
1x576x1x69x69xi8 576x1x1x3x3xi8 1x576x1x65x65xi32 1
1x960x1x73x73xi8 960x1x1x3x3xi8 1x960x1x65x65xi32 1
@yzhang93
Copy link
Contributor Author

@newling @erwei-xilinx The above is a list of all the original conv shapes in the model without padding.

@erwei-xilinx
Copy link
Contributor

Do they all have stride = 1?

@yzhang93
Copy link
Contributor Author

Do they all have stride = 1?

Good point. I've updated the table to include stride.

@yzhang93
Copy link
Contributor Author

yzhang93 commented Jul 17, 2024

@newling The depthwise ops didn't get transposed to channel last, because the pass only support linalg::Conv2DNchwFchwOp conversion. https://github.com/iree-org/iree/blob/4de493af31e370ca2eb1bb590469ebbf76fc8d5b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp#L452

We have to extend the pass if we need to work on channel last version, otherwise we can directly try lowering for linalg.depthwise_conv_2d_nchw_chw or linalg.conv_2d_ngchw_gfchw_q.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants