Skip to content

Commit

Permalink
Add formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
IanNod committed Jul 22, 2024
1 parent 420df89 commit 3416fe8
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tests/e2e/attention/generate_e2e_fa2_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ class ValueElemTypeId(enum.Enum):
NONE = ""
F16 = "f16"


# Data type of input entries. The string values must match MLIR data types.
@enum.unique
class ResultElemTypeId(enum.Enum):
NONE = ""
F16 = "f16"


# Enumerates of the collections of shapes that we can generate tests for.
# The values are the accepted values for the --shapes= flag.
@enum.unique
Expand Down Expand Up @@ -68,6 +70,7 @@ class KernelGenerator(enum.Enum):
ZERO = "zero" # Fill with zeros
RANDOM = "random" # Fill with (deterministic) pseudorandom values.


@dataclasses.dataclass
class TestShapeAndScale:
batch: int
Expand All @@ -77,10 +80,10 @@ class TestShapeAndScale:
n: int
scale: float


# Returns the list of TestShape's to use for the collection of shapes
# identified by shapes_id.
def get_test_shapes(shapes_id: ShapesId):

if shapes_id == ShapesId.SMALL:
return [
TestShapeAndScale(batch=4, m=1024, k1=64, k2=512, n=32, scale=1.0),
Expand All @@ -103,6 +106,7 @@ def get_test_shapes(shapes_id: ShapesId):
# in which shuffling testcases changes which random values are generated.
local_pseudorandom_state = 1


# Determines the shape of input and kernel tensors.
@dataclasses.dataclass
class TestInputTensorShapes:
Expand Down

0 comments on commit 3416fe8

Please sign in to comment.