Skip to content

Commit

Permalink
tests work with new framework
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Sep 18, 2024
1 parent 3c0c2e6 commit 2c16b23
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 120 deletions.
6 changes: 5 additions & 1 deletion splink/internals/connected_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,16 @@ def solve_connected_components(

pipeline = CTEPipeline([edges_table, nodes_table])

match_prob_expr = f"where match_probability >= {threshold_match_probability}"
if threshold_match_probability is None:
match_prob_expr = ""

sql = f"""
select
{edge_id_column_name_left} as node_id_l,
{edge_id_column_name_right} as node_id_r
from {edges_table.physical_name}
where match_probability >= {threshold_match_probability}
{match_prob_expr}
UNION
Expand Down
108 changes: 19 additions & 89 deletions tests/cc_testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@
import pandas as pd
from networkx.algorithms import connected_components as cc_nx

from splink.internals.connected_components import solve_connected_components
from splink.internals.clustering import cluster_pairwise_predictions_at_threshold
from splink.internals.duckdb.database_api import DuckDBAPI
from splink.internals.duckdb.dataframe import DuckDBDataFrame
from splink.internals.linker import Linker
from splink.internals.pipeline import CTEPipeline
from splink.internals.vertically_concatenate import compute_df_concat_with_tf


def generate_random_graph(graph_size, seed=None):
Expand All @@ -20,105 +16,39 @@ def generate_random_graph(graph_size, seed=None):
return graph


def register_cc_df(G):
from tests.basic_settings import get_settings_dict
def nodes_and_edges_from_graph(G):
edges = nx.to_pandas_edgelist(G)
edges.columns = ["unique_id_l", "unique_id_r"]

settings_dict = get_settings_dict()
nodes = pd.DataFrame({"unique_id": G.nodes})

df = nx.to_pandas_edgelist(G)
df.columns = ["unique_id_l", "unique_id_r"]
df_concat = pd.concat(
[pd.DataFrame({"unique_id_l": G.nodes, "unique_id_r": G.nodes}), df]
)
return nodes, edges

# boot up our linker
table_name = "__splink__df_predict_graph"
# this registers our table under __splink__df__{table_name}
# but our cc function actively looks for "__splink__df_predict"
db_api = DuckDBAPI()

linker = Linker(
df_concat, settings_dict, input_table_aliases=table_name, db_api=db_api
)

# re-register under our required name to run the CC function
linker.table_management.register_table(df_concat, table_name, overwrite=True)

df_nodes = pd.DataFrame({"unique_id": G.nodes})
linker.table_management.register_table_input_nodes_concat_with_tf(df_nodes)

# add our prediction df to our list of created tables
predict_df = DuckDBDataFrame(table_name, table_name, db_api)

return linker, predict_df


def run_cc_implementation(linker, predict_df):
pipeline = CTEPipeline()
concat_with_tf = compute_df_concat_with_tf(linker, pipeline)

def run_cc_implementation(nodes, edges):
# finally, run our connected components algorithm
cc = solve_connected_components(
linker,
predict_df,
concat_with_tf=concat_with_tf,
_generated_graph=True,
db_api = DuckDBAPI()
cc = cluster_pairwise_predictions_at_threshold(
nodes=nodes,
edges=edges,
db_api=db_api,
node_id_column_name="unique_id",
edge_id_column_name_left="unique_id_l",
edge_id_column_name_right="unique_id_r",
threshold_match_probability=None,
).as_pandas_dataframe()

cc = cc.rename(columns={"unique_id": "node_id", "cluster_id": "representative"})
cc = cc[["node_id", "representative"]]
cc.sort_values(by=["node_id", "representative"], inplace=True)
return cc


def benchmark_cc_implementation(linker_df):
# add a schema so we don't need to re-register our df
linker_df.db_api._con.execute(
"""
create schema if not exists con_comp;
set schema 'con_comp';
"""
)

df = run_cc_implementation(linker_df)
linker_df.db_api._con.execute("drop schema con_comp cascade")

return df


def networkx_solve(G):
rows = []
for cc in cc_nx(G):
m = min(list(cc))
for n in cc:
row = {"node_id": n, "representative": m}
rows.append(row)
return pd.DataFrame(rows)


def check_df_equality(df1, df2, skip_dtypes=False):
"""
Test if two dataframes are equal
Args:
df1 (pd.DataFrame): First dataframe for comparisons
df2 (pd.DataFrame): Second dataframe for comparisons
skip_dtypes (bool, optional): Whether to check the datatypes in both
dataframes. This should be skipped if one of your dataframes
consistently returns pandas dtypes, instead of numpy dtypes,
as is the case in the athena example.
Returns:
True if equal, False if not
"""
if df1.shape != df2.shape:
return False
if df1.columns.tolist() != df2.columns.tolist():
return False
if not skip_dtypes:
if df1.dtypes.tolist() != df2.dtypes.tolist():
return False

columns = df1.columns
for col in columns:
if df1[col].tolist() != df2[col].tolist():
return False
return True
return pd.DataFrame(rows).sort_values(by=["node_id", "representative"])
47 changes: 17 additions & 30 deletions tests/test_cc_random_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,41 @@
import pytest

from tests.cc_testing_utils import (
check_df_equality,
generate_random_graph,
networkx_solve,
register_cc_df,
nodes_and_edges_from_graph,
run_cc_implementation,
)

###############################################################################
# Accuracy Testing
###############################################################################


@pytest.mark.parametrize("execution_number", range(20))
def test_small_erdos_renyi_graph(execution_number):
g = generate_random_graph(graph_size=500)
linker, predict_df = register_cc_df(g)
df_nodes, df_edges = nodes_and_edges_from_graph(g)

cc_df = run_cc_implementation(df_nodes, df_edges)
nx_df = networkx_solve(g)

assert check_df_equality(
run_cc_implementation(linker, predict_df).sort_values(
by=["node_id", "representative"]
),
networkx_solve(g).sort_values(by=["node_id", "representative"]),
)
assert (cc_df.values == nx_df.values).all()


@pytest.mark.skip(reason="Slow")
# @pytest.mark.skip(reason="Slow")
@pytest.mark.parametrize("execution_number", range(10))
def test_medium_erdos_renyi_graph(execution_number):
g = generate_random_graph(graph_size=10000)
linker, predict_df = register_cc_df(g)
df_nodes, df_edges = nodes_and_edges_from_graph(g)

assert check_df_equality(
run_cc_implementation(linker, predict_df).sort_values(
by=["node_id", "representative"]
),
networkx_solve(g).sort_values(by=["node_id", "representative"]),
)
cc_df = run_cc_implementation(df_nodes, df_edges)
nx_df = networkx_solve(g)
assert (cc_df.values == nx_df.values).all()


@pytest.mark.skip(reason="Slow")
# @pytest.mark.skip(reason="Slow")
@pytest.mark.parametrize("execution_number", range(2))
def test_large_erdos_renyi_graph(execution_number):
g = generate_random_graph(graph_size=100000)
linker, predict_df = register_cc_df(g)

assert check_df_equality(
run_cc_implementation(linker, predict_df).sort_values(
by=["node_id", "representative"]
),
networkx_solve(g).sort_values(by=["node_id", "representative"]),
)
df_nodes, df_edges = nodes_and_edges_from_graph(g)

cc_df = run_cc_implementation(df_nodes, df_edges)
nx_df = networkx_solve(g)
assert (cc_df.values == nx_df.values).all()

0 comments on commit 2c16b23

Please sign in to comment.