Skip to content

Commit 3ee64da

Browse files
authored
Add fail/warn mechanism to ColumnObjects and ParamFunctions (#16)
- As discussed in [GETTSIM #1055](ttsim-dev/gettsim#1055), we sometimes want to fail or issue a warning if a certain node is visited in the graph. This PR adds a mechanism for that along with corresponding tests - Implementation: - All `ColumnObjects` and `ParamFunctions` are endowed with attributes `warn_msg_if_included: str | None = None` and `fail_msg_if_included: str | None = None`. - These are checked by a corresponding fail/warn function each, which depends on the TT DAG and the specialised environment without tree logic and with derived functions (guaranteed to catch the correct type of object) - For warn functions, only unique messages are displayed. E.g., in the GETTSIM example triggering this, we have at least three functions, which all have the same message. - As we visited that code, we realised that `tt_dag` is a better name for `tax_transfer_dag` as we are using that abbreviation much more extensively that 'tax_transfer' by now. Renamed, same for `tax_transfer_function` -> `tt_function`.
1 parent 4d049b1 commit 3ee64da

File tree

12 files changed

+310
-46
lines changed

12 files changed

+310
-46
lines changed

CHANGES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ releases are available on [Anaconda.org](https://anaconda.org/conda-forge/ttsim)
66

77
## v1.0a2 — 2025-07-xx
88

9+
- {gh}`16` Add fail/warn mechanism to ColumnObjects and ParamFunctions.
10+
({ghuser}`hmgaudecker`)
11+
912
- {gh}`15` Do not call len() on unsized arrays. ({ghuser}`hmgaudecker`)
1013

1114
- {gh}`14` Do not loop over the attributes of Jax arrays in

src/ttsim/interface_dag_elements/fail_if.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939

4040
if TYPE_CHECKING:
41-
from collections.abc import Callable
41+
from collections.abc import Callable, Iterable
4242

4343
from ttsim.interface_dag_elements.input_data import FlatData
4444
from ttsim.typing import (
@@ -704,9 +704,42 @@ def backend_has_changed(
704704
)
705705

706706

707+
@fail_function()
708+
def tt_dag_includes_function_with_fail_msg_if_included_set(
709+
specialized_environment__without_tree_logic_and_with_derived_functions: SpecEnvWithoutTreeLogicAndWithDerivedFunctions,
710+
specialized_environment__tt_dag: nx.DiGraph,
711+
labels__processed_data_columns: UnorderedQNames,
712+
) -> None:
713+
"""Fail if the TT DAG includes functions with `fail_msg_if_included` set."""
714+
715+
env = specialized_environment__without_tree_logic_and_with_derived_functions
716+
issues = ""
717+
for node in specialized_environment__tt_dag:
718+
if (
719+
# This may run before 'fail_if.root_nodes_are_missing'
720+
node not in env
721+
or
722+
# ColumnObjects overridden by data are fine
723+
(
724+
not isinstance(env[node], PolicyInput)
725+
and node in labels__processed_data_columns
726+
)
727+
):
728+
continue
729+
# Check because ParamObjects can be overridden by ColumnObjects down the road.
730+
if hasattr(env[node], "fail_msg_if_included"): # noqa: SIM102
731+
if msg := env[node].fail_msg_if_included:
732+
issues += f"{node}:\n\n{msg}\n\n\n"
733+
if issues:
734+
raise ValueError(
735+
"The TT DAG includes the following functions with `fail_msg_if_included` "
736+
f"set.\n\n{issues}"
737+
)
738+
739+
707740
@fail_function()
708741
def tt_root_nodes_are_missing(
709-
specialized_environment__tax_transfer_dag: nx.DiGraph,
742+
specialized_environment__tt_dag: nx.DiGraph,
710743
specialized_environment__with_partialled_params_and_scalars: SpecEnvWithPartialledParamsAndScalars,
711744
processed_data: QNameData,
712745
labels__grouping_levels: OrderedQNames,
@@ -715,7 +748,7 @@ def tt_root_nodes_are_missing(
715748
716749
Parameters
717750
----------
718-
specialized_environment__tax_transfer_dag
751+
specialized_environment__tt_dag
719752
The DAG of taxes and transfers functions.
720753
specialized_environment__with_partialled_params_and_scalars
721754
The specialized environment with partialled params and scalars.
@@ -738,9 +771,8 @@ def tt_root_nodes_are_missing(
738771
)
739772
# Obtain root nodes
740773
root_nodes = nx.subgraph_view(
741-
specialized_environment__tax_transfer_dag,
742-
filter_node=lambda n: specialized_environment__tax_transfer_dag.in_degree(n)
743-
== 0,
774+
specialized_environment__tt_dag,
775+
filter_node=lambda n: specialized_environment__tt_dag.in_degree(n) == 0,
744776
).nodes
745777

746778
missing_nodes = [
@@ -848,7 +880,7 @@ def format_errors_and_warnings(text: str, width: int = 79) -> str:
848880
return "\n\n".join(wrapped_paragraphs)
849881

850882

851-
def format_list_linewise(some_list: list[Any]) -> str: # type: ignore[type-arg, unused-ignore]
883+
def format_list_linewise(some_list: Iterable[Any]) -> str: # type: ignore[type-arg, unused-ignore]
852884
formatted_list = '",\n "'.join(some_list)
853885
return textwrap.dedent(
854886
"""

src/ttsim/interface_dag_elements/labels.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,14 @@ def input_columns(
142142

143143
@interface_function()
144144
def root_nodes(
145-
specialized_environment__tax_transfer_dag: nx.DiGraph,
145+
specialized_environment__tt_dag: nx.DiGraph,
146146
processed_data_columns: UnorderedQNames,
147147
) -> UnorderedQNames:
148148
"""Names of the columns in `processed_data` required for the tax transfer function.
149149
150150
Parameters
151151
----------
152-
specialized_environment__tax_transfer_dag:
152+
specialized_environment__tt_dag:
153153
The tax transfer DAG.
154154
processed_data:
155155
The processed data.
@@ -161,9 +161,8 @@ def root_nodes(
161161
"""
162162
# Obtain root nodes
163163
root_nodes = nx.subgraph_view(
164-
specialized_environment__tax_transfer_dag,
165-
filter_node=lambda n: specialized_environment__tax_transfer_dag.in_degree(n)
166-
== 0,
164+
specialized_environment__tt_dag,
165+
filter_node=lambda n: specialized_environment__tt_dag.in_degree(n) == 0,
167166
).nodes
168167

169168
# Restrict the passed data to the subset that is actually used.

src/ttsim/interface_dag_elements/raw_results.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
def columns(
2020
labels__root_nodes: UnorderedQNames,
2121
processed_data: QNameData,
22-
specialized_environment__tax_transfer_function: Callable[[QNameData], QNameData],
22+
specialized_environment__tt_function: Callable[[QNameData], QNameData],
2323
) -> QNameData:
24-
return specialized_environment__tax_transfer_function(
24+
return specialized_environment__tt_function(
2525
{k: v for k, v in processed_data.items() if k in labels__root_nodes},
2626
)
2727

src/ttsim/interface_dag_elements/specialized_environment.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def _apply_rounding(element: ColumnFunction, xnp: ModuleType) -> ColumnFunction:
340340

341341

342342
@interface_function()
343-
def tax_transfer_dag(
343+
def tt_dag(
344344
with_partialled_params_and_scalars: SpecEnvWithPartialledParamsAndScalars,
345345
labels__column_targets: OrderedQNames,
346346
) -> nx.DiGraph:
@@ -361,16 +361,16 @@ def tt_function_set_annotations() -> bool:
361361

362362

363363
@interface_function()
364-
def tax_transfer_function(
365-
tax_transfer_dag: nx.DiGraph,
364+
def tt_function(
365+
tt_dag: nx.DiGraph,
366366
with_partialled_params_and_scalars: SpecEnvWithPartialledParamsAndScalars,
367367
labels__column_targets: OrderedQNames,
368368
backend: Literal["numpy", "jax"],
369369
tt_function_set_annotations: bool,
370370
) -> Callable[[QNameData], QNameData]:
371371
"""Returns a function that takes a dictionary of arrays and unpacks them as keyword arguments."""
372372
ttf_with_keyword_args = concatenate_functions(
373-
dag=tax_transfer_dag,
373+
dag=tt_dag,
374374
functions=with_partialled_params_and_scalars,
375375
targets=list(labels__column_targets),
376376
return_type="dict",

src/ttsim/interface_dag_elements/warn_if.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
if TYPE_CHECKING:
1616
import datetime
1717

18+
import networkx as nx
19+
1820
from ttsim.typing import (
1921
PolicyEnvironment,
2022
QNameData,
23+
SpecEnvWithoutTreeLogicAndWithDerivedFunctions,
2124
UnorderedQNames,
2225
)
2326

@@ -114,3 +117,37 @@ def evaluation_date_set_in_multiple_places(
114117
`evaluation_year`.
115118
"""
116119
warnings.warn(UserWarning(msg), stacklevel=2)
120+
121+
122+
@warn_function()
123+
def tt_dag_includes_function_with_warn_msg_if_included_set(
124+
specialized_environment__without_tree_logic_and_with_derived_functions: SpecEnvWithoutTreeLogicAndWithDerivedFunctions, # noqa: E501
125+
specialized_environment__tt_dag: nx.DiGraph,
126+
labels__processed_data_columns: UnorderedQNames,
127+
) -> None:
128+
"""Warn if the TT DAG includes functions with `warn_msg_if_included` set."""
129+
130+
env = specialized_environment__without_tree_logic_and_with_derived_functions
131+
my_warnings = set()
132+
for node in specialized_environment__tt_dag:
133+
if (
134+
# This may run before 'fail_if.root_nodes_are_missing'
135+
node not in env
136+
or
137+
# ColumnObjects overridden by data are fine
138+
(
139+
not isinstance(env[node], PolicyInput)
140+
and node in labels__processed_data_columns
141+
)
142+
):
143+
continue
144+
# Check because ParamObjects can be overridden by ColumnObjects down the road.
145+
if hasattr(env[node], "fail_msg_if_included"): # noqa: SIM102
146+
if msg := env[node].warn_msg_if_included:
147+
my_warnings |= {f"{msg}\n\n\n"}
148+
if my_warnings:
149+
msg = (
150+
"The TT DAG includes elements with `warn_msg_if_included` set to the "
151+
f"following values:\n\n{format_list_linewise(my_warnings)}"
152+
)
153+
warnings.warn(UserWarning(msg), stacklevel=2)

src/ttsim/main_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,5 @@ class SpecializedEnvironment(MainArg):
166166
with_partialled_params_and_scalars: SpecEnvWithPartialledParamsAndScalars | None = (
167167
None
168168
)
169-
tax_transfer_dag: nx.DiGraph | None = None
170-
tax_transfer_function: Callable[[QNameData], QNameData] | None = None
169+
tt_dag: nx.DiGraph | None = None
170+
tt_function: Callable[[QNameData], QNameData] | None = None

src/ttsim/main_target.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class WarnIf(MainTargetABC):
2626
evaluation_date_set_in_multiple_places: str = (
2727
"warn_if__evaluation_date_set_in_multiple_places"
2828
)
29+
tt_dag_includes_function_with_warn_msg_if_included_set: str = (
30+
"warn_if__tt_dag_includes_function_with_warn_msg_if_included_set"
31+
)
2932

3033

3134
@dataclass(frozen=True)
@@ -66,6 +69,9 @@ class FailIf(MainTargetABC):
6669
"fail_if__targets_are_not_in_specialized_environment_or_data"
6770
)
6871
targets_tree_is_invalid: str = "fail_if__targets_tree_is_invalid"
72+
tt_dag_includes_function_with_fail_msg_if_included_set: str = (
73+
"fail_if__tt_dag_includes_function_with_fail_msg_if_included_set"
74+
)
6975

7076

7177
@dataclass(frozen=True)
@@ -94,8 +100,8 @@ class SpecializedEnvironment(MainTargetABC):
94100
with_partialled_params_and_scalars: str = (
95101
"specialized_environment__with_partialled_params_and_scalars"
96102
)
97-
tax_transfer_dag: str = "specialized_environment__tax_transfer_dag"
98-
tax_transfer_function: str = "specialized_environment__tax_transfer_function"
103+
tt_dag: str = "specialized_environment__tt_dag"
104+
tt_function: str = "specialized_environment__tt_function"
99105

100106

101107
@dataclass(frozen=True)

0 commit comments

Comments
 (0)