Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add single comparison column validation check #2160

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions splink/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class SplinkDeprecated(DeprecationWarning):
pass


class InvalidSplinkInput(SplinkException):
pass


class InvalidDialect(SplinkException):
pass

Expand Down
32 changes: 26 additions & 6 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from copy import copy, deepcopy
from pathlib import Path
from statistics import median
from typing import Any, Dict

import sqlglot

Expand All @@ -19,8 +20,9 @@
SettingsColumnCleaner,
)
from splink.settings_validation.valid_types import (
_check_input_dataframes_for_single_comparison_column,
_log_comparison_errors,
_validate_dialect,
log_comparison_errors,
)

from .accuracy import (
Expand Down Expand Up @@ -497,7 +499,26 @@ def _check_for_valid_settings(self):
else:
return True

def _validate_settings(self, validate_settings):
def _validate_settings_dictionary(
self, validate_settings: bool, settings_dict: Dict[Any]
):
if settings_dict is None:
return

if validate_settings:
_check_input_dataframes_for_single_comparison_column(
self._input_tables_dict,
source_dataset_column_name=settings_dict.get(
"source_dataset_column_name"
),
unique_id_column_name=settings_dict.get("unique_id_column_name"),
)
# Check the user's comparisons (if they exist)
_log_comparison_errors(
settings_dict.get("comparisons"), settings_dict.get("sql_dialect")
)

def _validate_settings_object(self, validate_settings: bool):
# Vaidate our settings after plugging them through
# `Settings(<settings>)`
if not self._check_for_valid_settings():
Expand All @@ -515,7 +536,7 @@ def _validate_settings(self, validate_settings):
# Constructs output logs for our various settings inputs
cleaned_settings = SettingsColumnCleaner(
settings_object=self._settings_obj,
input_columns=self._input_tables_dict,
splink_input_table_dfs=self._input_tables_dict,
)
InvalidColumnsLogger(cleaned_settings).construct_output_logs(validate_settings)

Expand Down Expand Up @@ -1133,11 +1154,10 @@ def load_settings(
settings_dict["sql_dialect"] = sql_dialect
settings_dict["linker_uid"] = settings_dict.get("linker_uid", cache_uid)

# Check the user's comparisons (if they exist)
log_comparison_errors(settings_dict.get("comparisons"), sql_dialect)
self._validate_settings_dictionary(validate_settings, settings_dict)
self._settings_obj_ = Settings(settings_dict)
# Check the final settings object
self._validate_settings(validate_settings)
self._validate_settings_object(validate_settings)

def load_model(self, model_path: Path):
"""
Expand Down
2 changes: 1 addition & 1 deletion splink/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def predict_from_comparison_vectors_sqls(
thres_prob_as_weight = prob_to_match_weight(threshold_match_probability)
else:
thres_prob_as_weight = None
if threshold_match_probability or threshold_match_weight:
if threshold_match_probability is not None or threshold_match_weight is not None:
thresholds = [
thres_prob_as_weight,
threshold_match_weight,
Expand Down
13 changes: 8 additions & 5 deletions splink/settings_validation/settings_column_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from copy import deepcopy
from functools import reduce
from operator import and_
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Dict, List

import sqlglot

Expand All @@ -15,6 +15,7 @@

if TYPE_CHECKING:
from ..settings import Settings
from ..splink_dataframe import SplinkDataFrame


def remove_suffix(c):
Expand All @@ -28,7 +29,7 @@ def find_columns_not_in_input_dfs(
does not apply any cleaning to the input column(s).
"""
# the key to use when producing our warning logs
if type(columns_to_check) == str:
if isinstance(columns_to_check, str):
columns_to_check = [columns_to_check]

return {col for col in columns_to_check if col not in valid_input_dataframe_columns}
Expand Down Expand Up @@ -81,7 +82,9 @@ def clean_list_of_column_names(col_list: List[InputColumn]):
return set((c.unquote().name for c in col_list))


def clean_user_input_columns(input_columns: dict, return_as_single_column: bool = True):
def clean_user_input_columns(
input_columns: Dict[str, "SplinkDataFrame"], return_as_single_column: bool = True
):
"""A dictionary containing all input dataframes and the columns located
within.

Expand All @@ -104,11 +107,11 @@ class SettingsColumnCleaner:
cleaned up settings columns and SQL strings.
"""

def __init__(self, settings_object: Settings, input_columns: dict):
def __init__(self, settings_object: Settings, splink_input_table_dfs: dict):
self.sql_dialect = settings_object._sql_dialect
self._settings_obj = settings_object
self.input_columns = clean_user_input_columns(
input_columns.items(), return_as_single_column=True
splink_input_table_dfs.items(), return_as_single_column=True
)

@property
Expand Down
20 changes: 19 additions & 1 deletion splink/settings_validation/settings_validation_log_strings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import List, NamedTuple, Tuple
from typing import Dict, List, NamedTuple, Tuple


def indent_error_message(message):
Expand Down Expand Up @@ -200,3 +200,21 @@ def create_incorrect_dialect_import_log_string(
"for your specified linker.\n"
)
return indent_error_message(log_message)


def construct_single_dataframe_log_str(input_columns: Dict[str, str]) -> str:
if len(input_columns) == 1:
df_txt = "dataframe is"
else:
df_txt = "dataframes are"

log_message = (
f"\nThe provided {df_txt} unsuitable for linkage with Splink as\n"
"it contains only a single column for matching.\n"
"Splink is not designed for linking based on a single 'bag of words'\n"
"column, such as a table with only a 'company name' column and\n"
"no other details.\n\nFor more information see: \n"
"https://github.com/moj-analytical-services/splink/issues/1362"
)

return log_message
107 changes: 70 additions & 37 deletions splink/settings_validation/valid_types.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
from __future__ import annotations

import logging
from typing import Dict, Union
from typing import TYPE_CHECKING, Dict, List, Union

from ..comparison import Comparison
from ..comparison_level import ComparisonLevel
from ..exceptions import ComparisonSettingsException, ErrorLogger, InvalidDialect
from ..default_from_jsonschema import default_value_from_schema
from ..exceptions import (
ComparisonSettingsException,
ErrorLogger,
InvalidDialect,
InvalidSplinkInput,
)
from .settings_column_cleaner import clean_user_input_columns
from .settings_validation_log_strings import (
construct_single_dataframe_log_str,
create_incorrect_dialect_import_log_string,
create_invalid_comparison_level_log_string,
create_invalid_comparison_log_string,
create_no_comparison_levels_error_log_string,
)

if TYPE_CHECKING:
from ..splink_dataframe import SplinkDataFrame


logger = logging.getLogger(__name__)


Expand All @@ -24,9 +36,6 @@ def extract_sql_dialect_from_cll(cll):


def _validate_dialect(settings_dialect: str, linker_dialect: str, linker_type: str):
# settings_dialect = self.linker._settings_obj._sql_dialect
# linker_dialect = self.linker._sql_dialect
# linker_type = self.linker.__class__.__name__
if settings_dialect != linker_dialect:
raise ValueError(
f"Incompatible SQL dialect! `settings` dictionary uses "
Expand All @@ -35,6 +44,34 @@ def _validate_dialect(settings_dialect: str, linker_dialect: str, linker_type: s
)


def _check_input_dataframes_for_single_comparison_column(
input_columns: Dict[str, "SplinkDataFrame"],
source_dataset_column_name: str = None,
unique_id_column_name: str = None,
):
if source_dataset_column_name is None:
source_dataset_column_name = default_value_from_schema(
"source_dataset_column_name", "root"
)
if unique_id_column_name is None:
unique_id_column_name = default_value_from_schema(
"unique_id_column_name", "root"
)

input_columns = clean_user_input_columns(
input_columns.items(), return_as_single_column=False
)

required_cols = (source_dataset_column_name, unique_id_column_name)

# Loop and exit if any dataframe has only possible comparison column
for columns in input_columns.values():
unique_columns = set(columns) - set(required_cols)

if len(unique_columns) == 1:
raise InvalidSplinkInput(construct_single_dataframe_log_str(input_columns))


def validate_comparison_levels(
error_logger: ErrorLogger, comparisons: list, linker_dialect: str
):
Expand All @@ -53,40 +90,12 @@ def validate_comparison_levels(
# If no error is found, append won't do anything
error_logger.log_error(evaluate_comparison_dtype_and_contents(c_dict))
error_logger.log_error(
evaluate_comparisons_for_imports_from_incorrect_dialects(
c_dict, linker_dialect
)
check_comparison_imported_for_correct_dialect(c_dict, linker_dialect)
)

return error_logger


def log_comparison_errors(comparisons, linker_dialect):
"""
Log any errors arising from `validate_comparison_levels`.
"""

# Check for empty inputs - Expecting None or []
if not comparisons:
return

error_logger = ErrorLogger()

error_logger = validate_comparison_levels(error_logger, comparisons, linker_dialect)

# Raise and log any errors identified
plural_this = "this" if len(error_logger.raw_errors) == 1 else "these"
comp_hyperlink_txt = (
f"\nFor more info on how to construct comparisons and avoid {plural_this} "
"error, please visit:\n"
"https://moj-analytical-services.github.io/splink/topic_guides/comparisons/customising_comparisons.html"
)

error_logger.raise_and_log_all_errors(
exception=ComparisonSettingsException, additional_txt=comp_hyperlink_txt
)


def check_comparison_level_types(
comparison_levels: Union[Comparison, Dict], comparison_str: str
):
Expand Down Expand Up @@ -146,9 +155,7 @@ def evaluate_comparison_dtype_and_contents(comparison_dict):
return check_comparison_level_types(comp_levels, comp_str)


def evaluate_comparisons_for_imports_from_incorrect_dialects(
comparison_dict, sql_dialect
):
def check_comparison_imported_for_correct_dialect(comparison_dict, sql_dialect):
"""
Given a comparison_dict, assess whether the sql dialect is valid for
your selected linker.
Expand Down Expand Up @@ -198,3 +205,29 @@ def evaluate_comparisons_for_imports_from_incorrect_dialects(
comp_str, sorted(invalid_dialects)
)
return InvalidDialect(error_message)


def _log_comparison_errors(comparisons: List[Comparison], linker_dialect: str):
"""
Log any errors arising from various comparison validation checks.
"""

# Check for empty inputs - Expecting None or []
if not comparisons:
return

error_logger = ErrorLogger()

error_logger = validate_comparison_levels(error_logger, comparisons, linker_dialect)

# Raise and log any errors identified
plural_this = "this" if len(error_logger.raw_errors) == 1 else "these"
comp_hyperlink_txt = (
f"\nFor more info on how to construct comparisons and avoid {plural_this} "
"error, please visit:\n"
"https://moj-analytical-services.github.io/splink/topic_guides/comparisons/customising_comparisons.html"
)

error_logger.raise_and_log_all_errors(
exception=ComparisonSettingsException, additional_txt=comp_hyperlink_txt
)
11 changes: 7 additions & 4 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def Linker(self):
pass

def extra_linker_args(self):
return {}
return {"validate_settings": False}

@property
def date_format(self):
Expand Down Expand Up @@ -113,7 +113,8 @@ def Linker(self):
return SparkLinker

def extra_linker_args(self):
return {"spark": self.spark, "num_partitions_on_repartition": 1}
core_args = super().extra_linker_args()
return {"spark": self.spark, "num_partitions_on_repartition": 1, **core_args}

def convert_frame(self, df):
spark_frame = self.spark.createDataFrame(df)
Expand Down Expand Up @@ -159,7 +160,8 @@ def Linker(self):
return SQLiteLinker

def extra_linker_args(self):
return {"connection": self.con}
core_args = super().extra_linker_args()
return {"connection": self.con, **core_args}

@classmethod
def _get_input_name(cls):
Expand Down Expand Up @@ -208,7 +210,8 @@ def Linker(self):
return PostgresLinker

def extra_linker_args(self):
return {"engine": self.engine}
core_args = super().extra_linker_args()
return {"engine": self.engine, **core_args}

@classmethod
def _get_input_name(cls):
Expand Down
Loading
Loading