Skip to content

Commit 701000f

Browse files
committed
PR Feedback 1
1 parent d50f087 commit 701000f

File tree

9 files changed

+30
-81
lines changed

9 files changed

+30
-81
lines changed

sqlmesh/core/context.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def __init__(
274274
deployability_index: t.Optional[DeployabilityIndex] = None,
275275
default_dialect: t.Optional[str] = None,
276276
default_catalog: t.Optional[str] = None,
277+
is_restatement_plan: t.Optional[bool] = None,
277278
variables: t.Optional[t.Dict[str, t.Any]] = None,
278279
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
279280
):
@@ -284,6 +285,7 @@ def __init__(
284285
self._default_dialect = default_dialect
285286
self._variables = variables or {}
286287
self._blueprint_variables = blueprint_variables or {}
288+
self._is_restatement_plan = is_restatement_plan
287289

288290
@property
289291
def default_dialect(self) -> t.Optional[str]:
@@ -308,6 +310,10 @@ def gateway(self) -> t.Optional[str]:
308310
"""Returns the gateway name."""
309311
return self.var(c.GATEWAY)
310312

313+
@property
314+
def is_restatement_plan(self) -> t.Optional[bool]:
315+
return self._is_restatement_plan
316+
311317
def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
312318
"""Returns a variable value."""
313319
return self._variables.get(var_name.lower(), default)
@@ -328,6 +334,7 @@ def with_variables(
328334
self.deployability_index,
329335
self._default_dialect,
330336
self._default_catalog,
337+
self._is_restatement_plan,
331338
variables=variables,
332339
blueprint_variables=blueprint_variables,
333340
)

sqlmesh/core/engine_adapter/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class EngineAdapter:
119119
MAX_IDENTIFIER_LENGTH: t.Optional[int] = None
120120
ATTACH_CORRELATION_ID = True
121121
SUPPORTS_QUERY_EXECUTION_TRACKING = False
122-
SUPPORTS_EXTERNAL_MODEL_FRESHNESS = False
122+
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = False
123123

124124
def __init__(
125125
self,
@@ -2928,7 +2928,7 @@ def _check_identifier_length(self, expression: exp.Expression) -> None:
29282928
f"Identifier name '{name}' (length {name_length}) exceeds {self.dialect.capitalize()}'s max identifier limit of {self.MAX_IDENTIFIER_LENGTH} characters"
29292929
)
29302930

2931-
def get_external_model_freshness(self, table_names: t.List[TableName]) -> t.List[int]:
2931+
def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
29322932
raise NotImplementedError()
29332933

29342934

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def table_exists(self, table_name: TableName) -> bool:
754754
except NotFound:
755755
return False
756756

757-
def get_external_model_freshness(self, table_names: t.List[TableName]) -> t.List[int]:
757+
def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
758758
from sqlmesh.utils.date import to_timestamp
759759

760760
datasets_to_tables: t.DefaultDict[str, t.List[str]] = defaultdict(list)

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
5454
SUPPORTS_MANAGED_MODELS = True
5555
CURRENT_CATALOG_EXPRESSION = exp.func("current_database")
5656
SUPPORTS_CREATE_DROP_CATALOG = True
57-
SUPPORTS_EXTERNAL_MODEL_FRESHNESS = True
57+
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = True
5858
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA", "TABLE"]
5959
SCHEMA_DIFFER_KWARGS = {
6060
"parameterized_type_defaults": {
@@ -668,7 +668,7 @@ def close(self) -> t.Any:
668668

669669
return super().close()
670670

671-
def get_external_model_freshness(self, table_names: t.List[TableName]) -> t.List[int]:
671+
def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
672672
from sqlmesh.utils.date import to_timestamp
673673

674674
num_tables = len(table_names)

sqlmesh/core/scheduler.py

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@
5454

5555
if t.TYPE_CHECKING:
5656
from sqlmesh.core.context import ExecutionContext
57-
from sqlmesh.core._typing import TableName
58-
from sqlmesh.core.engine_adapter import EngineAdapter
5957

6058
logger = logging.getLogger(__name__)
6159
SnapshotToIntervals = t.Dict[Snapshot, Intervals]
@@ -190,46 +188,6 @@ def merged_missing_intervals(
190188
}
191189
return snapshots_to_intervals
192190

193-
def can_skip_evaluation(self, snapshot: Snapshot, snapshots: t.Dict[str, Snapshot]) -> bool:
194-
if not snapshot.last_altered_ts:
195-
return False
196-
197-
from collections import defaultdict
198-
199-
parent_snapshots = {p for p in snapshots.values() if p.name != snapshot.name}
200-
if len(parent_snapshots) != len(snapshot.node.depends_on):
201-
# The mismatch can happen if e.g an external model is not registered in the project
202-
return False
203-
204-
adapter_to_parent_snapshots: t.Dict[EngineAdapter, t.List[Snapshot]] = defaultdict(list)
205-
206-
for parent_snapshot in parent_snapshots:
207-
if not parent_snapshot.is_external:
208-
return False
209-
210-
adapter = self.snapshot_evaluator.get_adapter(parent_snapshot.model_gateway)
211-
if not adapter.SUPPORTS_EXTERNAL_MODEL_FRESHNESS:
212-
return False
213-
214-
adapter_to_parent_snapshots[adapter].append(parent_snapshot)
215-
216-
if not adapter_to_parent_snapshots:
217-
return False
218-
219-
external_models_freshness: t.List[int] = []
220-
221-
for adapter, adapter_snapshots in adapter_to_parent_snapshots.items():
222-
table_names: t.List[TableName] = [
223-
exp.to_table(parent_snapshot.name, parent_snapshot.node.dialect)
224-
for parent_snapshot in adapter_snapshots
225-
]
226-
external_models_freshness.extend(adapter.get_external_model_freshness(table_names))
227-
228-
return all(
229-
snapshot.last_altered_ts > external_model_freshness
230-
for external_model_freshness in external_models_freshness
231-
)
232-
233191
def evaluate(
234192
self,
235193
snapshot: Snapshot,
@@ -413,6 +371,7 @@ def batch_intervals(
413371
deployability_index,
414372
default_dialect=adapter.dialect,
415373
default_catalog=self.default_catalog,
374+
is_restatement_plan=is_restatement_plan,
416375
)
417376

418377
intervals = self._check_ready_intervals(
@@ -989,25 +948,13 @@ def _check_ready_intervals(
989948

990949
signal_names = signals.signals_to_kwargs.keys()
991950

992-
if (
993-
is_restatement_plan
994-
and len(signal_names) == 1
995-
and next(iter(signal_names)) == "freshness"
996-
):
997-
# Freshness signal is not checked for restatement plans to allow users
998-
# for an escape hatch in reevaluating models
999-
return intervals
1000-
1001951
self.console.start_signal_progress(
1002952
snapshot,
1003953
self.default_catalog,
1004954
environment_naming_info or EnvironmentNamingInfo(),
1005955
)
1006956

1007957
for signal_idx, (signal_name, kwargs) in enumerate(signals.signals_to_kwargs.items()):
1008-
if is_restatement_plan and signal_name == "freshness":
1009-
continue
1010-
1011958
# Capture intervals before signal check for display
1012959
intervals_to_check = merge_intervals(intervals)
1013960

sqlmesh/core/signal.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,13 @@ class signal(registry_decorator):
4242

4343
@signal()
4444
def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionContext) -> bool:
45+
if context.is_restatement_plan:
46+
return True
47+
4548
deployability_index = context.deployability_index
4649
adapter = context.engine_adapter
4750

48-
if not deployability_index or not adapter.SUPPORTS_EXTERNAL_MODEL_FRESHNESS:
51+
if not deployability_index or not adapter.SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS:
4952
return True
5053

5154
last_altered_ts = (
@@ -67,7 +70,7 @@ def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionConte
6770
# since the last time the model was evaluated
6871
upstream_dep_has_new_data = any(
6972
upstream_last_altered_ts > last_altered_ts
70-
for upstream_last_altered_ts in adapter.get_external_model_freshness(
73+
for upstream_last_altered_ts in adapter.get_table_last_modified_ts(
7174
[p.name for p in parent_snapshots]
7275
)
7376
)

sqlmesh/core/snapshot/definition.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,11 @@ def add_dev_interval(self, start: int, end: int) -> None:
207207
def add_pending_restatement_interval(self, start: int, end: int) -> None:
208208
self._add_interval(start, end, "pending_restatement_intervals")
209209

210-
def add_last_altered_ts(self, last_altered_ts: t.Optional[int]) -> None:
211-
self._add_last_altered_ts(last_altered_ts, "last_altered_ts")
210+
def update_last_altered_ts(self, last_altered_ts: t.Optional[int]) -> None:
211+
self._update_last_altered_ts(last_altered_ts, "last_altered_ts")
212212

213-
def add_dev_last_altered_ts(self, last_altered_ts: t.Optional[int]) -> None:
214-
self._add_last_altered_ts(last_altered_ts, "dev_last_altered_ts")
213+
def update_dev_last_altered_ts(self, last_altered_ts: t.Optional[int]) -> None:
214+
self._update_last_altered_ts(last_altered_ts, "dev_last_altered_ts")
215215

216216
def remove_interval(self, start: int, end: int) -> None:
217217
self._remove_interval(start, end, "intervals")
@@ -232,12 +232,12 @@ def _add_interval(self, start: int, end: int, interval_attr: str) -> None:
232232
target_intervals = merge_intervals([*target_intervals, (start, end)])
233233
setattr(self, interval_attr, target_intervals)
234234

235-
def _add_last_altered_ts(
235+
def _update_last_altered_ts(
236236
self, last_altered_ts: t.Optional[int], last_altered_attr: str
237237
) -> None:
238238
if last_altered_ts:
239239
existing_last_altered_ts = getattr(self, last_altered_attr)
240-
setattr(self, last_altered_attr, max(existing_last_altered_ts or -1, last_altered_ts))
240+
setattr(self, last_altered_attr, max(existing_last_altered_ts or 0, last_altered_ts))
241241

242242
def _remove_interval(self, start: int, end: int, interval_attr: str) -> None:
243243
target_intervals = getattr(self, interval_attr)
@@ -978,7 +978,7 @@ def merge_intervals(self, other: t.Union[Snapshot, SnapshotIntervals]) -> None:
978978
self.add_interval(start, end)
979979

980980
if other.last_altered_ts:
981-
self.last_altered_ts = max(self.last_altered_ts or -1, other.last_altered_ts)
981+
self.last_altered_ts = max(self.last_altered_ts or 0, other.last_altered_ts)
982982

983983
if self.dev_version == other.dev_version:
984984
# Merge dev intervals if the dev versions match which would mean
@@ -988,7 +988,7 @@ def merge_intervals(self, other: t.Union[Snapshot, SnapshotIntervals]) -> None:
988988

989989
if other.dev_last_altered_ts:
990990
self.dev_last_altered_ts = max(
991-
self.dev_last_altered_ts or -1, other.dev_last_altered_ts
991+
self.dev_last_altered_ts or 0, other.dev_last_altered_ts
992992
)
993993

994994
self.pending_restatement_intervals = merge_intervals(

sqlmesh/core/state_sync/db/interval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,10 +331,10 @@ def _get_snapshot_intervals(
331331
else:
332332
if is_dev:
333333
intervals[merge_key].add_dev_interval(start, end)
334-
intervals[merge_key].add_dev_last_altered_ts(last_altered_ts)
334+
intervals[merge_key].update_dev_last_altered_ts(last_altered_ts)
335335
else:
336336
intervals[merge_key].add_interval(start, end)
337-
intervals[merge_key].add_last_altered_ts(last_altered_ts)
337+
intervals[merge_key].update_last_altered_ts(last_altered_ts)
338338
# Remove all pending restatement intervals recorded before the current interval has been added
339339
intervals[
340340
pending_restatement_interval_merge_key

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3845,8 +3845,8 @@ def _assert_mview_value(value: int):
38453845
@use_terminal_console
38463846
def test_external_model_freshness(ctx: TestContext, mocker: MockerFixture, tmp_path: pathlib.Path):
38473847
adapter = ctx.engine_adapter
3848-
if not adapter.SUPPORTS_EXTERNAL_MODEL_FRESHNESS:
3849-
pytest.skip("This test only runs for engines that support external model freshness")
3848+
if not adapter.SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS:
3849+
pytest.skip("This test only runs for engines that support metadata-based freshness")
38503850

38513851
def _assert_snapshot_last_altered_ts(
38523852
context: Context,
@@ -3880,15 +3880,7 @@ def _assert_model_evaluation(lambda_func, was_evaluated, day_delta=0):
38803880

38813881
evaluate_function_called = spy.call_count == 1
38823882
signal_was_checked = "Checking signals for" in output.stdout
3883-
restatement_plan = isinstance(plan_or_run_result, Plan) and plan_or_run_result.restatements
3884-
if restatement_plan:
3885-
# Restatement plans exclude this signal so we expect the actual evaluation
3886-
# to happen but not through the signal
3887-
assert evaluate_function_called
3888-
assert not signal_was_checked
3889-
return
38903883

3891-
# All other cases (e.g normal plans or runs) will check the freshness signal
38923884
assert signal_was_checked
38933885
if was_evaluated:
38943886
assert "All ready" in output.stdout

0 commit comments

Comments
 (0)