diff --git a/splink/comparison_level_library.py b/splink/comparison_level_library.py index 5e849144b..8fb13ed4a 100644 --- a/splink/comparison_level_library.py +++ b/splink/comparison_level_library.py @@ -3,6 +3,7 @@ AbsoluteTimeDifferenceLevel, And, ArrayIntersectLevel, + ArraySubsetLevel, ColumnsReversedLevel, CustomLevel, DamerauLevenshteinLevel, @@ -38,6 +39,7 @@ "AbsoluteDateDifferenceLevel", "DistanceInKMLevel", "ArrayIntersectLevel", + "ArraySubsetLevel", "PercentageDifferenceLevel", "And", "Not", diff --git a/splink/internals/comparison_level_library.py b/splink/internals/comparison_level_library.py index 788d98706..0eb1574a3 100644 --- a/splink/internals/comparison_level_library.py +++ b/splink/internals/comparison_level_library.py @@ -806,6 +806,40 @@ def create_label_for_charts(self) -> str: return f"Array intersection size >= {self.min_intersection}" +class ArraySubsetLevel(ComparisonLevelCreator): + def __init__(self, col_name: str | ColumnExpression): + """Represents a comparison level where the smaller array is an + exact subset of the larger array. + """ + self.col_expression = ColumnExpression.instantiate_if_str(col_name) + + @unsupported_splink_dialects(["sqlite"]) + def create_sql(self, sql_dialect: SplinkDialect) -> str: + sqlglot_dialect_name = sql_dialect.sqlglot_name + + sqlglot_base_dialect_sql = """ + LEAST(ARRAY_SIZE(___col____l), ARRAY_SIZE(___col____r))<>0 + AND + ARRAY_SIZE( + ARRAY_INTERSECT(___col____l, ___col____r)) / + LEAST(ARRAY_SIZE(___col____l), ARRAY_SIZE(___col____r)) + == 1 + """ + translated = _translate_sql_string( + sqlglot_base_dialect_sql, sqlglot_dialect_name + ) + + self.col_expression.sql_dialect = sql_dialect + col = self.col_expression + col = self.col_expression + translated = translated.replace("___col____l", col.name_l) + translated = translated.replace("___col____r", col.name_r) + return translated + + def create_label_for_charts(self) -> str: + return "Array subset" + + class PercentageDifferenceLevel(ComparisonLevelCreator): def __init__(self, col_name: str, percentage_threshold: float): """ diff --git a/tests/test_array_columns.py b/tests/test_array_columns.py index f52a77e13..d339c3de3 100644 --- a/tests/test_array_columns.py +++ b/tests/test_array_columns.py @@ -1,8 +1,10 @@ import pytest +import splink.internals.comparison_level_library as cll import splink.internals.comparison_library as cl from tests.decorator import mark_with_dialects_excluding from tests.literal_utils import ( + ComparisonLevelTestSpec, ComparisonTestSpec, LiteralTestValues, run_tests_with_args, @@ -72,3 +74,37 @@ def test_array_comparison_1(test_helpers, dialect): cl.ArrayIntersectAtSizes("postcode", [-1, 2]).get_comparison( db_api.sql_dialect.sqlglot_name ) + + +# No SQLite - no array comparisons in library +@mark_with_dialects_excluding("sqlite") +def test_array_subset(test_helpers, dialect): + helper = test_helpers[dialect] + db_api = helper.extra_linker_args()["db_api"] + + test_spec = ComparisonLevelTestSpec( + cll.ArraySubsetLevel("arr"), + tests=[ + LiteralTestValues( + {"arr_l": ["A", "B", "C", "D"], "arr_r": ["A", "B", "C", "D"]}, + expected_in_level=True, + ), + LiteralTestValues( + {"arr_l": ["A", "B", "C", "D"], "arr_r": ["A", "B", "C", "Z"]}, + expected_in_level=False, + ), + LiteralTestValues( + {"arr_l": ["A", "B"], "arr_r": ["A", "B", "C", "D"]}, + expected_in_level=True, + ), + LiteralTestValues( + {"arr_l": ["A", "B", "C", "D"], "arr_r": ["X", "Y", "Z"]}, + expected_in_level=False, + ), + LiteralTestValues( + {"arr_l": [], "arr_r": ["X", "Y", "Z"]}, + expected_in_level=False, + ), + ], + ) + run_tests_with_args(test_spec, db_api)