From b0adf6a3a256ffb2cf207d06911522ebcd7413ea Mon Sep 17 00:00:00 2001 From: mihir-packmoose Date: Wed, 4 Sep 2024 15:02:46 -0400 Subject: [PATCH 1/5] ArrayIntersectPercentage defined, first draft --- splink/internals/comparison_level_library.py | 39 ++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/splink/internals/comparison_level_library.py b/splink/internals/comparison_level_library.py index f65bf9df8b..c50646f6b9 100644 --- a/splink/internals/comparison_level_library.py +++ b/splink/internals/comparison_level_library.py @@ -793,6 +793,45 @@ def create_label_for_charts(self) -> str: return f"Array intersection size >= {self.min_intersection}" +class ArrayIntersectPercentage(ComparisonLevelCreator): + def __init__(self, col_name: str | ColumnExpression, percentage_threshold: int): + """Represents a comparison level where the difference between two array + sizes is within a specified percentage threshold. + + Args: + col_name (str): Input column name + percentage_threshold (int): The threshold percentage to use + to assess similarity e.g. 0.1 for 10%. + """ + if not 0 <= percentage_threshold <= 1: + raise ValueError("percentage_threshold must be between 0 and 1") + self.col_expression = ColumnExpression.instantiate_if_str(col_name) + self.percentage_threshold = percentage_threshold + + @unsupported_splink_dialects(["sqlite"]) + def create_sql(self, sql_dialect: SplinkDialect) -> str: + if hasattr(sql_dialect, "array_intersect"): + return sql_dialect.array_intersect(self) + + sqlglot_dialect_name = sql_dialect.sqlglot_name + + sqlglot_base_dialect_sql = f""" + ARRAY_SIZE(ARRAY_INTERSECT(___col____l, ___col____r)) / + GREATEST(ARRAY_SIZE(___col____l), ARRAY_SIZE(___col____r)) + >= {self.percentage_threshold} + """ + translated = _translate_sql_string( + sqlglot_base_dialect_sql, sqlglot_dialect_name + ) + + self.col_expression.sql_dialect = sql_dialect + col = self.col_expression + col = self.col_expression + translated = translated.replace("___col____l", col.name_l) + translated = translated.replace("___col____r", col.name_r) + return translated + + class PercentageDifferenceLevel(ComparisonLevelCreator): def __init__(self, col_name: str, percentage_threshold: float): """ From bc18a044cdee713f3278240de6cbd99a27232b6e Mon Sep 17 00:00:00 2001 From: mihir-packmoose Date: Fri, 6 Sep 2024 15:27:50 -0400 Subject: [PATCH 2/5] Replacing ArrayIntersectPercentage with ArraySubsetLevel as an alternative --- splink/internals/comparison_level_library.py | 27 ++++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/splink/internals/comparison_level_library.py b/splink/internals/comparison_level_library.py index c50646f6b9..1c0e660b8c 100644 --- a/splink/internals/comparison_level_library.py +++ b/splink/internals/comparison_level_library.py @@ -793,20 +793,12 @@ def create_label_for_charts(self) -> str: return f"Array intersection size >= {self.min_intersection}" -class ArrayIntersectPercentage(ComparisonLevelCreator): - def __init__(self, col_name: str | ColumnExpression, percentage_threshold: int): - """Represents a comparison level where the difference between two array - sizes is within a specified percentage threshold. - - Args: - col_name (str): Input column name - percentage_threshold (int): The threshold percentage to use - to assess similarity e.g. 0.1 for 10%. +class ArraySubsetLevel(ComparisonLevelCreator): + def __init__(self, col_name: str | ColumnExpression): + """Represents a comparison level where the smaller array is an + exact subset of the larger array. """ - if not 0 <= percentage_threshold <= 1: - raise ValueError("percentage_threshold must be between 0 and 1") self.col_expression = ColumnExpression.instantiate_if_str(col_name) - self.percentage_threshold = percentage_threshold @unsupported_splink_dialects(["sqlite"]) def create_sql(self, sql_dialect: SplinkDialect) -> str: @@ -815,11 +807,12 @@ def create_sql(self, sql_dialect: SplinkDialect) -> str: sqlglot_dialect_name = sql_dialect.sqlglot_name - sqlglot_base_dialect_sql = f""" - ARRAY_SIZE(ARRAY_INTERSECT(___col____l, ___col____r)) / - GREATEST(ARRAY_SIZE(___col____l), ARRAY_SIZE(___col____r)) - >= {self.percentage_threshold} - """ + sqlglot_base_dialect_sql = """ + ARRAY_SIZE( + ARRAY_INTERSECT(___col____l, ___col____r)) / + SMALLEST(ARRAY_SIZE(___col____l), ARRAY_SIZE(___col____r)) + == 1 + """ translated = _translate_sql_string( sqlglot_base_dialect_sql, sqlglot_dialect_name ) From 80806005f015472ba16b48aa26a1fc8fe0620ed8 Mon Sep 17 00:00:00 2001 From: mihir-packmoose Date: Wed, 11 Sep 2024 10:45:15 -0400 Subject: [PATCH 3/5] Minor fixes and exposing to user-facing API --- splink/comparison_level_library.py | 2 ++ splink/internals/comparison_level_library.py | 10 +++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/splink/comparison_level_library.py b/splink/comparison_level_library.py index 5e849144bc..8fb13ed4a1 100644 --- a/splink/comparison_level_library.py +++ b/splink/comparison_level_library.py @@ -3,6 +3,7 @@ AbsoluteTimeDifferenceLevel, And, ArrayIntersectLevel, + ArraySubsetLevel, ColumnsReversedLevel, CustomLevel, DamerauLevenshteinLevel, @@ -38,6 +39,7 @@ "AbsoluteDateDifferenceLevel", "DistanceInKMLevel", "ArrayIntersectLevel", + "ArraySubsetLevel", "PercentageDifferenceLevel", "And", "Not", diff --git a/splink/internals/comparison_level_library.py b/splink/internals/comparison_level_library.py index 1c0e660b8c..c8f544176f 100644 --- a/splink/internals/comparison_level_library.py +++ b/splink/internals/comparison_level_library.py @@ -802,15 +802,12 @@ def __init__(self, col_name: str | ColumnExpression): @unsupported_splink_dialects(["sqlite"]) def create_sql(self, sql_dialect: SplinkDialect) -> str: - if hasattr(sql_dialect, "array_intersect"): - return sql_dialect.array_intersect(self) - sqlglot_dialect_name = sql_dialect.sqlglot_name sqlglot_base_dialect_sql = """ ARRAY_SIZE( - ARRAY_INTERSECT(___col____l, ___col____r)) / - SMALLEST(ARRAY_SIZE(___col____l), ARRAY_SIZE(___col____r)) + ARRAY_INTERSECT(___col____l, ___col____r)) / + LEAST(ARRAY_SIZE(___col____l), ARRAY_SIZE(___col____r)) == 1 """ translated = _translate_sql_string( @@ -824,6 +821,9 @@ def create_sql(self, sql_dialect: SplinkDialect) -> str: translated = translated.replace("___col____r", col.name_r) return translated + def create_label_for_charts(self) -> str: + return "Array subset" + class PercentageDifferenceLevel(ComparisonLevelCreator): def __init__(self, col_name: str, percentage_threshold: float): From 55b59440b0aec46bca9534d8553ca1a6bf9999f4 Mon Sep 17 00:00:00 2001 From: mihir-packmoose Date: Wed, 11 Sep 2024 16:14:51 -0400 Subject: [PATCH 4/5] Added testing and edgecase handling for empty arrays --- splink/internals/comparison_level_library.py | 2 ++ tests/test_array_columns.py | 36 ++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/splink/internals/comparison_level_library.py b/splink/internals/comparison_level_library.py index 973bd8d0b9..0eb1574a39 100644 --- a/splink/internals/comparison_level_library.py +++ b/splink/internals/comparison_level_library.py @@ -818,6 +818,8 @@ def create_sql(self, sql_dialect: SplinkDialect) -> str: sqlglot_dialect_name = sql_dialect.sqlglot_name sqlglot_base_dialect_sql = """ + LEAST(ARRAY_SIZE(___col____l), ARRAY_SIZE(___col____r))<>0 + AND ARRAY_SIZE( ARRAY_INTERSECT(___col____l, ___col____r)) / LEAST(ARRAY_SIZE(___col____l), ARRAY_SIZE(___col____r)) diff --git a/tests/test_array_columns.py b/tests/test_array_columns.py index f52a77e136..8e7f40f5f5 100644 --- a/tests/test_array_columns.py +++ b/tests/test_array_columns.py @@ -1,9 +1,11 @@ import pytest import splink.internals.comparison_library as cl +import splink.internals.comparison_level_library as cll from tests.decorator import mark_with_dialects_excluding from tests.literal_utils import ( ComparisonTestSpec, + ComparisonLevelTestSpec, LiteralTestValues, run_tests_with_args, ) @@ -72,3 +74,37 @@ def test_array_comparison_1(test_helpers, dialect): cl.ArrayIntersectAtSizes("postcode", [-1, 2]).get_comparison( db_api.sql_dialect.sqlglot_name ) + + +# No SQLite - no array comparisons in library +@mark_with_dialects_excluding("sqlite") +def test_array_subset(test_helpers, dialect): + helper = test_helpers[dialect] + db_api = helper.extra_linker_args()["db_api"] + + test_spec = ComparisonLevelTestSpec( + cll.ArraySubsetLevel("arr"), + tests=[ + LiteralTestValues( + {"arr_l": ["A", "B", "C", "D"], "arr_r": ["A", "B", "C", "D"]}, + expected_in_level=True, + ), + LiteralTestValues( + {"arr_l": ["A", "B", "C", "D"], "arr_r": ["A", "B", "C", "Z"]}, + expected_in_level=False, + ), + LiteralTestValues( + {"arr_l": ["A", "B"], "arr_r": ["A", "B", "C", "D"]}, + expected_in_level=True, + ), + LiteralTestValues( + {"arr_l": ["A", "B", "C", "D"], "arr_r": ["X", "Y", "Z"]}, + expected_in_level=False, + ), + LiteralTestValues( + {"arr_l": [], "arr_r": ["X", "Y", "Z"]}, + expected_in_level=False, + ), + ], + ) + run_tests_with_args(test_spec, db_api) From 8d5f912ec246350d3d5df38429b2b530e23500c5 Mon Sep 17 00:00:00 2001 From: mihir-packmoose Date: Wed, 11 Sep 2024 16:15:18 -0400 Subject: [PATCH 5/5] Reordered imports to pass ruff --- tests/test_array_columns.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_array_columns.py b/tests/test_array_columns.py index 8e7f40f5f5..d339c3de38 100644 --- a/tests/test_array_columns.py +++ b/tests/test_array_columns.py @@ -1,11 +1,11 @@ import pytest -import splink.internals.comparison_library as cl import splink.internals.comparison_level_library as cll +import splink.internals.comparison_library as cl from tests.decorator import mark_with_dialects_excluding from tests.literal_utils import ( - ComparisonTestSpec, ComparisonLevelTestSpec, + ComparisonTestSpec, LiteralTestValues, run_tests_with_args, )