Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cosine similarity tests and allow schemad data #2407

Merged
merged 3 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions splink/internals/testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import pyarrow as pa

from splink.internals.comparison_creator import ComparisonCreator
from splink.internals.comparison_level_creator import ComparisonLevelCreator
Expand All @@ -12,7 +14,7 @@

def is_in_level(
comparison_level: ComparisonLevelCreator,
literal_values: Dict[str, Any] | List[Dict[str, Any]],
literal_values: Union[Dict[str, Any], List[Dict[str, Any]], pa.Table],
db_api: DatabaseAPISubClass,
) -> bool | List[bool]:
sqlglot_dialect = db_api.sql_dialect.sqlglot_name
Expand All @@ -21,8 +23,11 @@ def is_in_level(
sql_cond = "TRUE"

table_name = f"__splink__temp_table_{ascii_uid(8)}"
literal_values_list = ensure_is_list(literal_values)
db_api._table_registration(literal_values_list, table_name)
if isinstance(literal_values, pa.Table):
db_api._table_registration(literal_values, table_name)
else:
literal_values_list = ensure_is_list(literal_values)
db_api._table_registration(literal_values_list, table_name)

sql_to_evaluate = f"SELECT {sql_cond} as result FROM {table_name}"

Expand All @@ -38,7 +43,7 @@ def is_in_level(

def comparison_vector_value(
comparison: ComparisonCreator,
literal_values: Dict[str, Any] | List[Dict[str, Any]],
literal_values: Union[Dict[str, Any], List[Dict[str, Any]], pa.Table],
db_api: DatabaseAPISubClass,
) -> Dict[str, Any] | List[Dict[str, Any]]:
sqlglot_dialect = db_api.sql_dialect.sqlglot_name
Expand All @@ -58,8 +63,11 @@ def comparison_vector_value(
case_statement = comparison_internal._case_statement

table_name = f"__splink__temp_table_{ascii_uid(8)}"
literal_values_list = ensure_is_list(literal_values)
db_api._table_registration(literal_values_list, table_name)
if isinstance(literal_values, pa.Table):
db_api._table_registration(literal_values, table_name)
else:
literal_values_list = ensure_is_list(literal_values)
db_api._table_registration(literal_values_list, table_name)

sql_to_evaluate = f"SELECT {case_statement} FROM {table_name}"

Expand Down Expand Up @@ -87,4 +95,4 @@ def comparison_vector_value(
for row in result_dicts
]

return output if isinstance(literal_values, list) else output[0]
return output if isinstance(literal_values, (list, pa.Table)) else output[0]
49 changes: 29 additions & 20 deletions tests/literal_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, List

import pyarrow as pa

from splink import DuckDBAPI
from splink.internals.testing import comparison_vector_value, is_in_level

Expand All @@ -8,13 +10,16 @@

def run_is_in_level_tests(test_cases: List[Dict[str, Any]], db_api: Any) -> None:
for case in test_cases:
inputs = []
expected = []

for input_data in case["inputs"]:
input_dict = {k: v for k, v in input_data.items() if k != "expected"}
inputs.append(input_dict)
expected.append(input_data["expected"])
if isinstance(case["inputs"], pa.Table):
inputs = case["inputs"]
expected = inputs["expected"].to_pylist()
else:
inputs = []
expected = []
for input_data in case["inputs"]:
input_dict = {k: v for k, v in input_data.items() if k != "expected"}
inputs.append(input_dict)
expected.append(input_data["expected"])

results = is_in_level(case["level"], inputs, db_api)
assert (
Expand All @@ -26,19 +31,23 @@ def run_comparison_vector_value_tests(
test_cases: List[Dict[str, Any]], db_api: Any
) -> None:
for case in test_cases:
inputs = []
expected_values = []
expected_labels = []

for input_data in case["inputs"]:
input_dict = {
k: v
for k, v in input_data.items()
if k not in ["expected_value", "expected_label"]
}
inputs.append(input_dict)
expected_values.append(input_data["expected_value"])
expected_labels.append(input_data["expected_label"])
if isinstance(case["inputs"], pa.Table):
inputs = case["inputs"]
expected_values = inputs["expected_value"].to_pylist()
expected_labels = inputs["expected_label"].to_pylist()
else:
inputs = []
expected_values = []
expected_labels = []
for input_data in case["inputs"]:
input_dict = {
k: v
for k, v in input_data.items()
if k not in ["expected_value", "expected_label"]
}
inputs.append(input_dict)
expected_values.append(input_data["expected_value"])
expected_labels.append(input_data["expected_label"])

results = comparison_vector_value(case["comparison"], inputs, db_api)

Expand Down
96 changes: 95 additions & 1 deletion tests/test_comparison_level_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from splink import ColumnExpression
from tests.literal_utils import run_comparison_vector_value_tests, run_is_in_level_tests

from .decorator import mark_with_dialects_excluding
from .decorator import mark_with_dialects_excluding, mark_with_dialects_including


@mark_with_dialects_excluding()
Expand Down Expand Up @@ -363,3 +363,97 @@ def test_absolute_difference(test_helpers, dialect):
]

run_comparison_vector_value_tests(test_cases, db_api)


@mark_with_dialects_including("duckdb", pass_dialect=True)
def test_cosine_similarity_level(test_helpers, dialect):
import pyarrow as pa

helper = test_helpers[dialect]
db_api = helper.extra_linker_args()["db_api"]

EMBEDDING_DIMENSION = 4

cosine_similarity_comparison_using_levels = cl.CustomComparison(
comparison_description="text_vector",
comparison_levels=[
cll.NullLevel("text_vector"),
cll.CosineSimilarityLevel("text_vector", 0.9), # 3
cll.CosineSimilarityLevel("text_vector", 0.7), # 2
cll.CosineSimilarityLevel("text_vector", 0.5), # 1
cll.ElseLevel(),
],
)

input_dicts = [
{
"text_vector_l": [0.5205, 0.4616, 0.3333, 0.2087],
"text_vector_r": [0.4137, 0.5439, 0.0737, 0.2041],
"expected_value": 3,
"expected_label": "Cosine similarity of text_vector >= 0.9",
}, # Cosine similarity: 0.9312
{
"text_vector_l": [0.7026, 0.8887, 0.1711, 0.0525],
"text_vector_r": [0.4549, 0.4891, 0.1555, 0.6263],
"expected_value": 2,
"expected_label": "Cosine similarity of text_vector >= 0.7",
}, # Cosine similarity: 0.7639
{
"text_vector_l": [0.8713, 0.3416, 0.4024, 0.1350],
"text_vector_r": [0.2104, 0.5763, 0.0442, 0.0872],
"expected_value": 1,
"expected_label": "Cosine similarity of text_vector >= 0.5",
}, # Cosine similarity: 0.6418
{
"text_vector_l": [0.99, 0.00, 0.99, 0.00],
"text_vector_r": [0.00, 0.99, 0.00, 0.99],
"expected_value": 0,
"expected_label": "All other comparisons",
},
{
"text_vector_l": None,
"text_vector_r": [0.99, 0.99, 0.99, 0.99],
"expected_value": -1,
"expected_label": "text_vector is NULL",
},
]

# Convert input_dicts to a pyarrow Table
inputs_pa = pa.Table.from_pydict(
{
"text_vector_l": [d["text_vector_l"] for d in input_dicts],
"text_vector_r": [d["text_vector_r"] for d in input_dicts],
"expected_value": [d["expected_value"] for d in input_dicts],
"expected_label": [d["expected_label"] for d in input_dicts],
},
schema=pa.schema(
[
("text_vector_l", pa.list_(pa.float32(), EMBEDDING_DIMENSION)),
("text_vector_r", pa.list_(pa.float32(), EMBEDDING_DIMENSION)),
("expected_value", pa.int16()),
("expected_label", pa.string()),
]
),
)

test_cases = [
{
"comparison": cosine_similarity_comparison_using_levels,
"inputs": inputs_pa,
},
]

run_comparison_vector_value_tests(test_cases, db_api)

cosine_similarity_comparison = cl.CosineSimilarityAtThresholds(
"text_vector", [0.9, 0.7, 0.5]
)

test_cases = [
{
"comparison": cosine_similarity_comparison,
"inputs": inputs_pa,
},
]

run_comparison_vector_value_tests(test_cases, db_api)
Loading