Skip to content

Commit

Permalink
Add conversion from f32 to f16 in the generated IR
Browse files Browse the repository at this point in the history
  • Loading branch information
IanNod committed Jul 22, 2024
1 parent 3436482 commit c5a78ec
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 33 deletions.
3 changes: 1 addition & 2 deletions tests/e2e/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,4 @@ iree_generated_e2e_runner_test(
TARGET_CPU_FEATURES_VARIANTS
"default"
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
74 changes: 43 additions & 31 deletions tests/e2e/attention/generate_e2e_fa2_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class TestInputTensorShapes:
# the set of shapes to be used in a test function's input tensors.
def generate_shapes_and_scale(shape: TestShapeAndScale):
batch = shape.batch
m = shape.m
m = shape.m
k1 = shape.k1
k2 = shape.k2
n = shape.n
Expand Down Expand Up @@ -145,7 +145,12 @@ def get_tensor_shapes(
n = shapes_scale.n
scale = shapes_scale.scale

query_tensor_shape, key_tensor_shape, value_tensor_shape, result_tensor_shape = [], [], [], []
query_tensor_shape, key_tensor_shape, value_tensor_shape, result_tensor_shape = (
[],
[],
[],
[],
)

query_tensor_shape = [batch, m, k1]
key_tensor_shape = [batch, k2, k1]
Expand Down Expand Up @@ -207,27 +212,34 @@ def generate_function(
shapes_scale,
)

query_shape, key_shape, value_shape, result_shape = get_tensor_shapes(
shapes_scale
query_shape, key_shape, value_shape, result_shape = get_tensor_shapes(shapes_scale)
query_tensor_type = (
f"tensor<{query_shape[0]}x{query_shape[1]}x{query_shape[2]}x{query_type.value}>"
)
key_tensor_type = (
f"tensor<{key_shape[0]}x{key_shape[1]}x{key_shape[2]}x{key_type.value}>"
)
query_tensor_type = f"tensor<{query_shape[0]}x{query_shape[1]}x{query_shape[2]}x{query_type.value}>"
key_tensor_type = f"tensor<{key_shape[0]}x{key_shape[1]}x{key_shape[2]}x{key_type.value}>"
value_tensor_type = f"tensor<{value_shape[0]}x{value_shape[1]}x{value_shape[2]}x{value_type.value}>"
result_tensor_type = f"tensor<{result_shape[0]}x{result_shape[1]}x{result_shape[2]}x{value_type.value}>"
F16="f16"
value_tensor_type = (
f"tensor<{value_shape[0]}x{value_shape[1]}x{value_shape[2]}x{value_type.value}>"
)
result_tensor_type = (
f"tensor<{result_shape[0]}x{result_shape[1]}x{result_shape[2]}x{value_type.value}>"
)
F32 = "f32"
F16 = "f16"
op_name = "iree_linalg_ext.attention"


# Compilation info is optional; prints empty string by default.
func_definition = ""

signature = f"({query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {result_tensor_type}) -> {result_tensor_type}"
import_declaration = f"func.func private @module.{func_name}(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %scale: f16) -> !hal.buffer_view"
import_declaration = f"func.func private @module.{func_name}(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %scale: f32) -> !hal.buffer_view"
func_definition = func_definition + (
f"func.func @{func_name}(%query: {query_tensor_type}, %key: {key_tensor_type}, %value: {value_tensor_type}, %scale: {F16}) -> {result_tensor_type} {{\n"
f"func.func @{func_name}(%query: {query_tensor_type}, %key: {key_tensor_type}, %value: {value_tensor_type}, %scale: {F32}) -> {result_tensor_type} {{\n"
f" %result0 = tensor.empty(): {result_tensor_type}\n"
#f" %scale = arith.constant {shapes_scale.scale} : f16 \n"
f" %result1 = {op_name} ins(%query, %key, %value, %scale: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16}) outs(%result0: {result_tensor_type}) -> {result_tensor_type}\n"
f" %scale_f16 = arith.truncf %scale : f32 to f16 \n"
f" %result1 = {op_name} ins(%query, %key, %value, %scale_f16: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16}) outs(%result0: {result_tensor_type}) -> {result_tensor_type}\n"
f" return %result1: {result_tensor_type}\n"
f"}}\n"
)
Expand Down Expand Up @@ -321,8 +333,8 @@ def generate_call(
global pseudorandom_generator_seed
pseudorandom_generator_seed = pseudorandom_generator_seed - 1
op = op + (
f" %scale = arith.constant {shapes_scale.scale} : f16\n"
f" %result = call @module.{function.name}(%query, %key, %value, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f16) -> !hal.buffer_view\n"
f" %scale = arith.constant {shapes_scale.scale} : f32\n"
f" %result = call @module.{function.name}(%query, %key, %value, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f32) -> !hal.buffer_view\n"
)

op = op + (
Expand Down Expand Up @@ -351,23 +363,23 @@ def generate(
calls = []

for shape in get_test_shapes(shapes_id):
function = generate_function(
function = generate_function(
query_type,
key_type,
value_type,
shape,
)
if function.name not in functions:
functions[function.name] = function
calls.append(
generate_call(
function,
query_type,
key_type,
value_type,
shape,
)
if function.name not in functions:
functions[function.name] = function
calls.append(
generate_call(
function,
query_type,
key_type,
value_type,
shape,
)
)
)

return (functions, calls)

Expand Down Expand Up @@ -473,10 +485,10 @@ def main(args):
shapes_id = ShapesId(args.shapes_scale)

(functions, calls) = generate(
query_type,
key_type,
value_type,
shapes_id,
query_type,
key_type,
value_type,
shapes_id,
)

write_code_file(functions, args.output_fa2_mlir)
Expand Down

0 comments on commit c5a78ec

Please sign in to comment.