Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
ModelTestMetadata,
generate_test,
run_tests,
filter_tests_by_patterns,
)
from sqlmesh.core.user import User
from sqlmesh.utils import UniqueKeyDict, Verbosity
Expand Down Expand Up @@ -146,8 +147,8 @@
from typing_extensions import Literal

from sqlmesh.core.engine_adapter._typing import (
BigframeSession,
DF,
BigframeSession,
PySparkDataFrame,
PySparkSession,
SnowparkSession,
Expand Down Expand Up @@ -390,6 +391,8 @@ def __init__(
self._standalone_audits: UniqueKeyDict[str, StandaloneAudit] = UniqueKeyDict(
"standaloneaudits"
)
self._models_with_tests: t.Set[str] = set()
self._model_test_metadata: t.List[ModelTestMetadata] = []
self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros")
self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics")
self._jinja_macros = JinjaMacroRegistry()
Expand Down Expand Up @@ -639,6 +642,8 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
self._requirements.update(project.requirements)
self._excluded_requirements.update(project.excluded_requirements)
self._environment_statements.extend(project.environment_statements)
self._models_with_tests.update(project.models_with_tests)
self._model_test_metadata.extend(project.model_test_metadata)

config = loader.config
self._linters[config.project] = Linter.from_rules(
Expand Down Expand Up @@ -1041,6 +1046,11 @@ def standalone_audits(self) -> MappingProxyType[str, StandaloneAudit]:
"""Returns all registered standalone audits in this context."""
return MappingProxyType(self._standalone_audits)

@property
def models_with_tests(self) -> t.Set[str]:
"""Returns all models with tests in this context."""
return self._models_with_tests

@property
def snapshots(self) -> t.Dict[str, Snapshot]:
"""Generates and returns snapshots based on models registered in this context.
Expand Down Expand Up @@ -2212,7 +2222,9 @@ def test(

pd.set_option("display.max_columns", None)

test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)
test_meta = self._filter_preloaded_tests(
test_meta=self._model_test_metadata, tests=tests, patterns=match_patterns
)

result = run_tests(
model_test_metadata=test_meta,
Expand Down Expand Up @@ -2773,6 +2785,35 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.")
return self.engine_adapter

def _filter_preloaded_tests(
self,
test_meta: t.List[ModelTestMetadata],
tests: t.Optional[t.List[str]] = None,
patterns: t.Optional[t.List[str]] = None,
) -> t.List[ModelTestMetadata]:
"""Filter pre-loaded test metadata based on tests and patterns."""

if tests:
filtered_tests = []
for test in tests:
if "::" in test:
filename, test_name = test.split("::", maxsplit=1)
filtered_tests.extend(
[
t
for t in test_meta
if str(t.path) == filename and t.test_name == test_name
]
)
else:
filtered_tests.extend([t for t in test_meta if str(t.path) == test])
test_meta = filtered_tests

if patterns:
test_meta = filter_tests_by_patterns(test_meta, patterns)

return test_meta

def _snapshots(
self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None
) -> t.Dict[str, Snapshot]:
Expand Down
15 changes: 15 additions & 0 deletions sqlmesh/core/linter/rules/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
return self.violation()


class NoMissingUnitTest(Rule):
"""All models must have a unit test found in the test/ directory yaml files"""

def check_model(self, model: Model) -> t.Optional[RuleViolation]:
# External models cannot have unit tests
if isinstance(model, ExternalModel):
return None

if model.name not in self.context.models_with_tests:
return self.violation(
violation_msg=f"Model {model.name} is missing unit test(s). Please add in the tests/ directory."
)
return None


class NoMissingExternalModels(Rule):
"""All external models must be registered in the external_models.yaml file"""

Expand Down
10 changes: 10 additions & 0 deletions sqlmesh/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class LoadedProject:
excluded_requirements: t.Set[str]
environment_statements: t.List[EnvironmentStatements]
user_rules: RuleSet
model_test_metadata: t.List[ModelTestMetadata]
models_with_tests: t.Set[str]


class CacheBase(abc.ABC):
Expand Down Expand Up @@ -243,6 +245,12 @@ def load(self) -> LoadedProject:

user_rules = self._load_linting_rules()

model_test_metadata = self.load_model_tests()

models_with_tests = {
model_test_metadata.model_name for model_test_metadata in model_test_metadata
}

project = LoadedProject(
macros=macros,
jinja_macros=jinja_macros,
Expand All @@ -254,6 +262,8 @@ def load(self) -> LoadedProject:
excluded_requirements=excluded_requirements,
environment_statements=environment_statements,
user_rules=user_rules,
model_test_metadata=model_test_metadata,
models_with_tests=models_with_tests,
)
return project

Expand Down
4 changes: 4 additions & 0 deletions sqlmesh/core/test/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class ModelTestMetadata(PydanticModel):
def fully_qualified_test_name(self) -> str:
return f"{self.path}::{self.test_name}"

@property
def model_name(self) -> str:
return self.body["model"]

def __hash__(self) -> int:
return self.fully_qualified_test_name.__hash__()

Expand Down
60 changes: 60 additions & 0 deletions tests/core/linter/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,63 @@ def test_no_missing_external_models_with_existing_file_not_ending_in_newline(
)
fix_path = sushi_path / "external_models.yaml"
assert edit.path == fix_path


def test_no_missing_unit_tests(tmp_path, copy_to_temp_path):
"""
Tests that the NoMissingUnitTest linter rule correctly identifies models
without corresponding unit tests in the tests/ directory

This test checks the sushi example project, enables the linter,
and verifies that the linter raises a rule violation for the models
that do not have a unit test
"""
sushi_paths = copy_to_temp_path("examples/sushi")
sushi_path = sushi_paths[0]

# Override the config.py to turn on lint
with open(sushi_path / "config.py", "r") as f:
read_file = f.read()

before = """ linter=LinterConfig(
enabled=False,
rules=[
"ambiguousorinvalidcolumn",
"invalidselectstarexpansion",
"noselectstar",
"nomissingaudits",
"nomissingowner",
"nomissingexternalmodels",
],
),"""
after = """linter=LinterConfig(enabled=True, rules=["nomissingunittest"]),"""
read_file = read_file.replace(before, after)
assert after in read_file
with open(sushi_path / "config.py", "w") as f:
f.writelines(read_file)

# Load the context with the temporary sushi path
context = Context(paths=[sushi_path])

# Lint the models
lints = context.lint_models(raise_on_error=False)

# Should have violations for models without tests (most models except customers)
assert len(lints) >= 1

# Check that we get violations for models without tests
violation_messages = [lint.violation_msg for lint in lints]
assert any("is missing unit test(s)" in msg for msg in violation_messages)

# Check that models with existing tests don't have violations
models_with_tests = ["customer_revenue_by_day", "customer_revenue_lifetime", "order_items"]

for model_name in models_with_tests:
model_violations = [
lint
for lint in lints
if model_name in lint.violation_msg and "is missing unit test(s)" in lint.violation_msg
]
assert len(model_violations) == 0, (
f"Model {model_name} should not have a violation since it has a test"
)