Skip to content

Commit

Permalink
add tests and allow schemad data
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Sep 16, 2024
1 parent 7d3327c commit b56d3ec
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 28 deletions.
24 changes: 16 additions & 8 deletions splink/internals/testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import pyarrow as pa

from splink.internals.comparison_creator import ComparisonCreator
from splink.internals.comparison_level_creator import ComparisonLevelCreator
Expand All @@ -12,7 +14,7 @@

def is_in_level(
comparison_level: ComparisonLevelCreator,
literal_values: Dict[str, Any] | List[Dict[str, Any]],
literal_values: Union[Dict[str, Any], List[Dict[str, Any]], pa.Table],
db_api: DatabaseAPISubClass,
) -> bool | List[bool]:
sqlglot_dialect = db_api.sql_dialect.sqlglot_name
Expand All @@ -21,8 +23,11 @@ def is_in_level(
sql_cond = "TRUE"

table_name = f"__splink__temp_table_{ascii_uid(8)}"
literal_values_list = ensure_is_list(literal_values)
db_api._table_registration(literal_values_list, table_name)
if isinstance(literal_values, pa.Table):
db_api._table_registration(literal_values, table_name)
else:
literal_values_list = ensure_is_list(literal_values)
db_api._table_registration(literal_values_list, table_name)

sql_to_evaluate = f"SELECT {sql_cond} as result FROM {table_name}"

Expand All @@ -38,7 +43,7 @@ def is_in_level(

def comparison_vector_value(
comparison: ComparisonCreator,
literal_values: Dict[str, Any] | List[Dict[str, Any]],
literal_values: Union[Dict[str, Any], List[Dict[str, Any]], pa.Table],
db_api: DatabaseAPISubClass,
) -> Dict[str, Any] | List[Dict[str, Any]]:
sqlglot_dialect = db_api.sql_dialect.sqlglot_name
Expand All @@ -58,8 +63,11 @@ def comparison_vector_value(
case_statement = comparison_internal._case_statement

table_name = f"__splink__temp_table_{ascii_uid(8)}"
literal_values_list = ensure_is_list(literal_values)
db_api._table_registration(literal_values_list, table_name)
if isinstance(literal_values, pa.Table):
db_api._table_registration(literal_values, table_name)
else:
literal_values_list = ensure_is_list(literal_values)
db_api._table_registration(literal_values_list, table_name)

sql_to_evaluate = f"SELECT {case_statement} FROM {table_name}"

Expand Down Expand Up @@ -87,4 +95,4 @@ def comparison_vector_value(
for row in result_dicts
]

return output if isinstance(literal_values, list) else output[0]
return output if isinstance(literal_values, (list, pa.Table)) else output[0]
49 changes: 29 additions & 20 deletions tests/literal_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, List

import pyarrow as pa

from splink import DuckDBAPI
from splink.internals.testing import comparison_vector_value, is_in_level

Expand All @@ -8,13 +10,16 @@

def run_is_in_level_tests(test_cases: List[Dict[str, Any]], db_api: Any) -> None:
for case in test_cases:
inputs = []
expected = []

for input_data in case["inputs"]:
input_dict = {k: v for k, v in input_data.items() if k != "expected"}
inputs.append(input_dict)
expected.append(input_data["expected"])
if isinstance(case["inputs"], pa.Table):
inputs = case["inputs"]
expected = inputs["expected"].to_pylist()
else:
inputs = []
expected = []
for input_data in case["inputs"]:
input_dict = {k: v for k, v in input_data.items() if k != "expected"}
inputs.append(input_dict)
expected.append(input_data["expected"])

results = is_in_level(case["level"], inputs, db_api)
assert (
Expand All @@ -26,19 +31,23 @@ def run_comparison_vector_value_tests(
test_cases: List[Dict[str, Any]], db_api: Any
) -> None:
for case in test_cases:
inputs = []
expected_values = []
expected_labels = []

for input_data in case["inputs"]:
input_dict = {
k: v
for k, v in input_data.items()
if k not in ["expected_value", "expected_label"]
}
inputs.append(input_dict)
expected_values.append(input_data["expected_value"])
expected_labels.append(input_data["expected_label"])
if isinstance(case["inputs"], pa.Table):
inputs = case["inputs"]
expected_values = inputs["expected_value"].to_pylist()
expected_labels = inputs["expected_label"].to_pylist()
else:
inputs = []
expected_values = []
expected_labels = []
for input_data in case["inputs"]:
input_dict = {
k: v
for k, v in input_data.items()
if k not in ["expected_value", "expected_label"]
}
inputs.append(input_dict)
expected_values.append(input_data["expected_value"])
expected_labels.append(input_data["expected_label"])

results = comparison_vector_value(case["comparison"], inputs, db_api)

Expand Down
94 changes: 94 additions & 0 deletions tests/test_comparison_level_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,97 @@ def test_absolute_difference(test_helpers, dialect):
]

run_comparison_vector_value_tests(test_cases, db_api)


@mark_with_dialects_excluding()
def test_cosine_similarity_level(test_helpers, dialect):
import pyarrow as pa

helper = test_helpers[dialect]
db_api = helper.extra_linker_args()["db_api"]

EMBEDDING_DIMENSION = 4

cosine_similarity_comparison_using_levels = cl.CustomComparison(
comparison_description="text_vector",
comparison_levels=[
cll.NullLevel("text_vector"),
cll.CosineSimilarityLevel("text_vector", 0.9), # 3
cll.CosineSimilarityLevel("text_vector", 0.7), # 2
cll.CosineSimilarityLevel("text_vector", 0.5), # 1
cll.ElseLevel(),
],
)

input_dicts = [
{
"text_vector_l": [0.5205, 0.4616, 0.3333, 0.2087],
"text_vector_r": [0.4137, 0.5439, 0.0737, 0.2041],
"expected_value": 3,
"expected_label": "Cosine similarity of text_vector >= 0.9",
}, # Cosine similarity: 0.9312
{
"text_vector_l": [0.7026, 0.8887, 0.1711, 0.0525],
"text_vector_r": [0.4549, 0.4891, 0.1555, 0.6263],
"expected_value": 2,
"expected_label": "Cosine similarity of text_vector >= 0.7",
}, # Cosine similarity: 0.7639
{
"text_vector_l": [0.8713, 0.3416, 0.4024, 0.1350],
"text_vector_r": [0.2104, 0.5763, 0.0442, 0.0872],
"expected_value": 1,
"expected_label": "Cosine similarity of text_vector >= 0.5",
}, # Cosine similarity: 0.6418
{
"text_vector_l": [0.99, 0.00, 0.99, 0.00],
"text_vector_r": [0.00, 0.99, 0.00, 0.99],
"expected_value": 0,
"expected_label": "All other comparisons",
},
{
"text_vector_l": None,
"text_vector_r": [0.99, 0.99, 0.99, 0.99],
"expected_value": -1,
"expected_label": "text_vector is NULL",
},
]

# Convert input_dicts to a pyarrow Table
inputs_pa = pa.Table.from_pydict(
{
"text_vector_l": [d["text_vector_l"] for d in input_dicts],
"text_vector_r": [d["text_vector_r"] for d in input_dicts],
"expected_value": [d["expected_value"] for d in input_dicts],
"expected_label": [d["expected_label"] for d in input_dicts],
},
schema=pa.schema(
[
("text_vector_l", pa.list_(pa.float32(), EMBEDDING_DIMENSION)),
("text_vector_r", pa.list_(pa.float32(), EMBEDDING_DIMENSION)),
("expected_value", pa.int16()),
("expected_label", pa.string()),
]
),
)

test_cases = [
{
"comparison": cosine_similarity_comparison_using_levels,
"inputs": inputs_pa,
},
]

run_comparison_vector_value_tests(test_cases, db_api)

cosine_similarity_comparison = cl.CosineSimilarityAtThresholds(
"text_vector", [0.9, 0.7, 0.5]
)

test_cases = [
{
"comparison": cosine_similarity_comparison,
"inputs": inputs_pa,
},
]

run_comparison_vector_value_tests(test_cases, db_api)

0 comments on commit b56d3ec

Please sign in to comment.