Skip to content

Commit

Permalink
Add new dims into the func call
Browse files Browse the repository at this point in the history
  • Loading branch information
IanNod committed Jul 17, 2024
1 parent 9dc718d commit 8f7cfff
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tests/e2e/attention/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ py_binary(
"small",
"medium",
"large",
]]
]]
4 changes: 2 additions & 2 deletions tests/e2e/attention/generate_e2e_fa2_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def generate_function(
func_definition = ""

signature = f"({query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {result_tensor_type}) -> {result_tensor_type}"
import_declaration = f"func.func private @module.{func_name}(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %result: !hal.buffer_view) -> !hal.buffer_view"
import_declaration = f"func.func private @module.{func_name}(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %scale: f16) -> !hal.buffer_view"
func_definition = func_definition + (
f"func.func @{func_name}(%query: {query_tensor_type}, %key: {key_tensor_type}, %value: {value_tensor_type}, %scale: {F16}) -> {result_tensor_type} {{\n"
f" %result0 = tensor.empty(): {result_tensor_type}\n"
Expand Down Expand Up @@ -447,7 +447,7 @@ def write_calls_file(functions, calls, filename, requirements):
# Declare the custom module that generates arguments.
module_definition = module_definition + (
"func.func private @fa2_test.generate_random_tensor(%device: !hal.device, %dim0: i64, %dim1: i64, %dim2: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view\n"
"func.func private @fa2_test.check_fa2_results(%device: !hal.device, %numHeads: i64, %seqLen: i64, %headDim: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %result: !hal.buffer_view)\n"
"func.func private @fa2_test.check_fa2_results(%device: !hal.device, %batch: i64, %m: i64, %k1: i64, %k2: i64, %n: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %result: !hal.buffer_view)\n"
"\n"
)

Expand Down

0 comments on commit 8f7cfff

Please sign in to comment.