Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCM] Support dynamic key-value cache sequence length. #19015

Open
JamesMBartlett opened this issue Nov 4, 2024 · 0 comments
Open

[ROCM] Support dynamic key-value cache sequence length. #19015

JamesMBartlett opened this issue Nov 4, 2024 · 0 comments
Labels
enhancement ➕ New feature or request

Comments

@JamesMBartlett
Copy link
Contributor

JamesMBartlett commented Nov 4, 2024

Request description

I would like to be able to run Llama3 attention layers in a decode style where the KV cache is a dynamic length and the queries are sequence length 1. I've attached an example IR showing essentially what I'm trying to achieve. (The example IR has batch_size == 1, but ideally I would be able to run a dynamic batch size as well).

example.mlir
module attributes {hal.device.targets = [#hal.device.target<"hip", {legacy_sync}, [#hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<WMMA_F32_16x16x16_F16>, <WMMA_F16_16x16x16_F16>, <WMMA_I32_16x16x16_I8>], subgroup_size_choices = [32, 64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>
, ukernels = "none"}>]> : !hal.device]} {                                                                                                                                                                                                                                                                                    util.func public @main$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.fence, %arg5: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
    %cst = arith.constant 8.837890e-02 : f16
    %0 = hal.tensor.import wait(%arg4) => %arg0 : !hal.buffer_view -> tensor<1x32x1x128xf16>
    %1 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[2] : index
    %2 = hal.tensor.import wait(%arg4) => %arg1 : !hal.buffer_view -> tensor<1x32x?x128xf16>{%1}
    %3 = hal.buffer_view.dim<%arg2 : !hal.buffer_view>[2] : index
    %4 = hal.tensor.import wait(%arg4) => %arg2 : !hal.buffer_view -> tensor<1x32x?x128xf16>{%3}
    %5 = hal.buffer_view.dim<%arg3 : !hal.buffer_view>[3] : index
    %6 = hal.tensor.import wait(%arg4) => %arg3 : !hal.buffer_view -> tensor<1x32x1x?xf16>{%5}
    %collapsed = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<1x32x1x128xf16> into tensor<32x1x128xf16>
    %collapsed_0 = tensor.collapse_shape %2 [[0, 1], [2], [3]] : tensor<1x32x?x128xf16> into tensor<32x?x128xf16>
    %collapsed_1 = tensor.collapse_shape %4 [[0, 1], [2], [3]] : tensor<1x32x?x128xf16> into tensor<32x?x128xf16>
    %collapsed_2 = tensor.collapse_shape %6 [[0, 1], [2], [3]] : tensor<1x32x1x?xf16> into tensor<32x1x?xf16>
    %7 = tensor.empty() : tensor<32x1x128xf16>
    %8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>, affine_map<(d0, d1, d2, d3, d4) -> ()>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>]} ins(%collapsed, %collapsed_0, %collapsed_1, %cst, %collapsed_2 : tensor<32x1x128xf16>, tensor<32x?x128xf16>, tensor<32x?x128xf16>, f16, tensor<32x1x?xf16>) outs(%7 : tensor<32x1x128xf16>) {
                ^bb0(%arg6: f32):
      iree_linalg_ext.yield %arg6 : f32
    } -> tensor<32x1x128xf16>
    %expanded = tensor.expand_shape %8 [[0, 1], [2], [3]] output_shape [1, 32, 1, 128] : tensor<32x1x128xf16> into tensor<1x32x1x128xf16>
    %9 = hal.tensor.barrier join(%expanded : tensor<1x32x1x128xf16>) => %arg5 : !hal.fence
    %10 = hal.tensor.export %9 : tensor<1x32x1x128xf16> -> !hal.buffer_view
    util.return %10 : !hal.buffer_view
  }
  util.func public @main(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %0 = util.null : !hal.fence
    %c-1_i32 = arith.constant -1 : i32
    %c0 = arith.constant 0 : index
    %device_0 = hal.devices.get %c0 : !hal.device
    %fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
    %1 = util.call @main$async(%arg0, %arg1, %arg2, %arg3, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) : i32
    util.return %1 : !hal.buffer_view
  }
}

What component(s) does this issue relate to?

Compiler

Additional context

No response

@JamesMBartlett JamesMBartlett added the enhancement ➕ New feature or request label Nov 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement ➕ New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant