Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,8 @@ def _get_model_for_singleton_value(
# Find the correct model for the discriminator value by unwrapping the Union and then the discriminator Literals
assert typing.get_origin(model) is typing.Union # For the type checker
for sub_model in typing.get_args(model):
if sub_model is type(None):
continue
sub_model_discr_value = sub_model.model_fields[discriminator].annotation
if typing.get_origin(sub_model_discr_value) is not typing.Literal:
raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from typing import Any, Literal, Union
from typing import Any, Literal, Optional, Union
from enum import Enum
from typing_extensions import Annotated
from pydantic import Field
Expand Down Expand Up @@ -1145,6 +1145,54 @@ class BaseModel(OpenJDModel):
# THEN
assert len(errors) == 0

@pytest.mark.parametrize(
"data",
[
pytest.param({"name": "Foo", "sub": {"kind": "INVALID"}}, id="invalid discriminator"),
pytest.param({"name": "Foo", "sub": {"kind": ""}}, id="empty discriminator"),
pytest.param({"name": "Foo", "sub": None}, id="None value"),
],
)
def test_optional_discriminated_union_invalid_discriminator(self, data: dict[str, Any]) -> None:
# Test that Optional discriminated unions with invalid discriminator values
# don't crash with AttributeError on NoneType.

# GIVEN
class Kind(str, Enum):
ONE = "ONE"
TWO = "TWO"

class SubModel1(OpenJDModel):
kind: Literal[Kind.ONE]
field1: FormatString
_template_variable_scope = ResolutionScope.TEMPLATE

class SubModel2(OpenJDModel):
kind: Literal[Kind.TWO]
field2: FormatString
_template_variable_scope = ResolutionScope.TEMPLATE

class BaseModel(OpenJDModel):
name: str
sub: Optional[
Annotated[Union[SubModel1, SubModel2], Field(..., discriminator="kind")]
] = None
_template_variable_definitions = DefinesTemplateVariables(
defines={TemplateVariableDef(prefix="|Param.", resolves=ResolutionScope.TEMPLATE)},
field="name",
)
_template_variable_sources = {
"sub": {"__self__"},
}

# WHEN
errors = prevalidate_model_template_variable_references(
BaseModel, data, context=ModelParsingContext_v2023_09()
)

# THEN
assert len(errors) == 0


class TestNonDiscriminatedUnion:
"""Test that if we have unions in the model that isn't a discriminated union then we handle them correctly.
Expand Down
47 changes: 47 additions & 0 deletions test/openjd/model/test_fuzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,42 @@ def fuzz_format_string():
)


def fuzz_cancelation():
return random.choice(
[
None,
{},
{"mode": "NOTIFY_THEN_TERMINATE"},
{"mode": "TERMINATE"},
{"mode": "INVALID_MODE"},
{"mode": random_string(1, 20)},
{"mode": None},
{"mode": 123},
{"mode": "NOTIFY_THEN_TERMINATE", "notifyPeriodInSeconds": 30},
{"mode": "NOTIFY_THEN_TERMINATE", "notifyPeriodInSeconds": "invalid"},
random_string(1, 10),
]
)


def fuzz_task_parameter():
return random.choice(
[
{},
{"name": "P", "type": "INT", "range": "1-10"},
{"name": "P", "type": "FLOAT", "range": [1.0, 2.0]},
{"name": "P", "type": "STRING", "values": ["a", "b"]},
{"name": "P", "type": "PATH", "values": ["/tmp"]},
{"name": "P", "type": "INVALID_TYPE", "range": "1-10"},
{"name": "P", "type": random_string(1, 15), "range": "1-10"},
{"name": "P", "type": None},
{"name": "P", "type": 123},
{"type": "INT", "range": "1-10"},
random_string(1, 10),
]
)


def fuzz_step():
base = {"name": "Step1", "script": {"actions": {"onRun": {"command": "echo"}}}}
mutations = [
Expand All @@ -81,8 +117,15 @@ def fuzz_step():
{"name": "Step1", "script": {"actions": {}}},
{"name": "Step1", "script": {"actions": {"onRun": {}}}},
{"name": "Step1", "script": {"actions": {"onRun": {"command": ""}}}},
{
"name": "Step1",
"script": {
"actions": {"onRun": {"command": "echo", "cancelation": fuzz_cancelation()}}
},
},
{**base, "parameterSpace": {}},
{**base, "parameterSpace": {"taskParameterDefinitions": []}},
{**base, "parameterSpace": {"taskParameterDefinitions": [fuzz_task_parameter()]}},
{
**base,
"parameterSpace": {
Expand All @@ -106,12 +149,16 @@ def fuzz_job_parameter():
{"name": "Param1", "type": "FLOAT"},
{"name": "Param1", "type": "PATH"},
{"name": "Param1", "type": "INVALID"},
{"name": "Param1", "type": random_string(1, 15)},
{"name": "Param1", "type": None},
{"name": "Param1", "type": 123},
{"name": "", "type": "STRING"},
{"name": 123, "type": "STRING"},
{"type": "STRING"},
{"name": "P", "type": "INT", "minValue": 10, "maxValue": 5},
{"name": "P", "type": "INT", "default": "notanint"},
{"name": "P", "type": "STRING", "allowedValues": []},
random_string(1, 10),
]
)

Expand Down
Loading