From 467d29d3c4cc4b6a1e2defc86be8ca8ea20e4e5d Mon Sep 17 00:00:00 2001 From: Stephen Crowe <6042774+crowecawcaw@users.noreply.github.com> Date: Fri, 19 Dec 2025 08:24:32 -0800 Subject: [PATCH] fix: fix crash on invalid discriminator in Optional discriminated unions Signed-off-by: Stephen Crowe <6042774+crowecawcaw@users.noreply.github.com> --- .../_variable_reference_validation.py | 2 + .../test_variable_reference_validation.py | 50 ++++++++++++++++++- test/openjd/model/test_fuzz.py | 47 +++++++++++++++++ 3 files changed, 98 insertions(+), 1 deletion(-) diff --git a/src/openjd/model/_internal/_variable_reference_validation.py b/src/openjd/model/_internal/_variable_reference_validation.py index 06ba6b3..4ee8c48 100644 --- a/src/openjd/model/_internal/_variable_reference_validation.py +++ b/src/openjd/model/_internal/_variable_reference_validation.py @@ -459,6 +459,8 @@ def _get_model_for_singleton_value( # Find the correct model for the discriminator value by unwrapping the Union and then the discriminator Literals assert typing.get_origin(model) is typing.Union # For the type checker for sub_model in typing.get_args(model): + if sub_model is type(None): + continue sub_model_discr_value = sub_model.model_fields[discriminator].annotation if typing.get_origin(sub_model_discr_value) is not typing.Literal: raise NotImplementedError( diff --git a/test/openjd/model/_internal/test_variable_reference_validation.py b/test/openjd/model/_internal/test_variable_reference_validation.py index fa9f161..f3978db 100644 --- a/test/openjd/model/_internal/test_variable_reference_validation.py +++ b/test/openjd/model/_internal/test_variable_reference_validation.py @@ -1,6 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -from typing import Any, Literal, Union +from typing import Any, Literal, Optional, Union from enum import Enum from typing_extensions import Annotated from pydantic import Field @@ -1145,6 +1145,54 @@ class BaseModel(OpenJDModel): # THEN assert len(errors) == 0 + @pytest.mark.parametrize( + "data", + [ + pytest.param({"name": "Foo", "sub": {"kind": "INVALID"}}, id="invalid discriminator"), + pytest.param({"name": "Foo", "sub": {"kind": ""}}, id="empty discriminator"), + pytest.param({"name": "Foo", "sub": None}, id="None value"), + ], + ) + def test_optional_discriminated_union_invalid_discriminator(self, data: dict[str, Any]) -> None: + # Test that Optional discriminated unions with invalid discriminator values + # don't crash with AttributeError on NoneType. + + # GIVEN + class Kind(str, Enum): + ONE = "ONE" + TWO = "TWO" + + class SubModel1(OpenJDModel): + kind: Literal[Kind.ONE] + field1: FormatString + _template_variable_scope = ResolutionScope.TEMPLATE + + class SubModel2(OpenJDModel): + kind: Literal[Kind.TWO] + field2: FormatString + _template_variable_scope = ResolutionScope.TEMPLATE + + class BaseModel(OpenJDModel): + name: str + sub: Optional[ + Annotated[Union[SubModel1, SubModel2], Field(..., discriminator="kind")] + ] = None + _template_variable_definitions = DefinesTemplateVariables( + defines={TemplateVariableDef(prefix="|Param.", resolves=ResolutionScope.TEMPLATE)}, + field="name", + ) + _template_variable_sources = { + "sub": {"__self__"}, + } + + # WHEN + errors = prevalidate_model_template_variable_references( + BaseModel, data, context=ModelParsingContext_v2023_09() + ) + + # THEN + assert len(errors) == 0 + class TestNonDiscriminatedUnion: """Test that if we have unions in the model that isn't a discriminated union then we handle them correctly. diff --git a/test/openjd/model/test_fuzz.py b/test/openjd/model/test_fuzz.py index 4afa2b0..ba4d9ed 100644 --- a/test/openjd/model/test_fuzz.py +++ b/test/openjd/model/test_fuzz.py @@ -71,6 +71,42 @@ def fuzz_format_string(): ) +def fuzz_cancelation(): + return random.choice( + [ + None, + {}, + {"mode": "NOTIFY_THEN_TERMINATE"}, + {"mode": "TERMINATE"}, + {"mode": "INVALID_MODE"}, + {"mode": random_string(1, 20)}, + {"mode": None}, + {"mode": 123}, + {"mode": "NOTIFY_THEN_TERMINATE", "notifyPeriodInSeconds": 30}, + {"mode": "NOTIFY_THEN_TERMINATE", "notifyPeriodInSeconds": "invalid"}, + random_string(1, 10), + ] + ) + + +def fuzz_task_parameter(): + return random.choice( + [ + {}, + {"name": "P", "type": "INT", "range": "1-10"}, + {"name": "P", "type": "FLOAT", "range": [1.0, 2.0]}, + {"name": "P", "type": "STRING", "values": ["a", "b"]}, + {"name": "P", "type": "PATH", "values": ["/tmp"]}, + {"name": "P", "type": "INVALID_TYPE", "range": "1-10"}, + {"name": "P", "type": random_string(1, 15), "range": "1-10"}, + {"name": "P", "type": None}, + {"name": "P", "type": 123}, + {"type": "INT", "range": "1-10"}, + random_string(1, 10), + ] + ) + + def fuzz_step(): base = {"name": "Step1", "script": {"actions": {"onRun": {"command": "echo"}}}} mutations = [ @@ -81,8 +117,15 @@ def fuzz_step(): {"name": "Step1", "script": {"actions": {}}}, {"name": "Step1", "script": {"actions": {"onRun": {}}}}, {"name": "Step1", "script": {"actions": {"onRun": {"command": ""}}}}, + { + "name": "Step1", + "script": { + "actions": {"onRun": {"command": "echo", "cancelation": fuzz_cancelation()}} + }, + }, {**base, "parameterSpace": {}}, {**base, "parameterSpace": {"taskParameterDefinitions": []}}, + {**base, "parameterSpace": {"taskParameterDefinitions": [fuzz_task_parameter()]}}, { **base, "parameterSpace": { @@ -106,12 +149,16 @@ def fuzz_job_parameter(): {"name": "Param1", "type": "FLOAT"}, {"name": "Param1", "type": "PATH"}, {"name": "Param1", "type": "INVALID"}, + {"name": "Param1", "type": random_string(1, 15)}, + {"name": "Param1", "type": None}, + {"name": "Param1", "type": 123}, {"name": "", "type": "STRING"}, {"name": 123, "type": "STRING"}, {"type": "STRING"}, {"name": "P", "type": "INT", "minValue": 10, "maxValue": 5}, {"name": "P", "type": "INT", "default": "notanint"}, {"name": "P", "type": "STRING", "allowedValues": []}, + random_string(1, 10), ] )