Skip to content

Commit c981048

Browse files
added ability to cache matrix in queries across which master is constant
1 parent 558aa02 commit c981048

File tree

2 files changed

+82
-8
lines changed

2 files changed

+82
-8
lines changed

string_grouper/string_grouper.py

+48-7
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
# similar string index-columns with corresponding duplicates-index values
2727
DEFAULT_INCLUDE_ZEROES: bool = True # when the minimum cosine similarity <=0, determines whether zero-similarity
2828
# matches appear in the output
29+
DEFAULT_ENABLE_CACHE: bool = False # does not cache the master tf-idf matrix between queries which preserve master
2930
GROUP_REP_CENTROID: str = 'centroid' # Option value to select the string in each group with the largest
3031
# similarity aggregate as group-representative:
3132
GROUP_REP_FIRST: str = 'first' # Option value to select the first string in each group as group-representative:
@@ -185,6 +186,9 @@ class StringGrouperConfig(NamedTuple):
185186
before performing the string-comparisons block-wise. Defaults to 'guess', in which case the numbers of
186187
blocks are estimated based on previous empirical results. If n_blocks = 'auto', then splitting is done
187188
automatically in the event of an OverflowError.
189+
:param enable_cache: bool. Whether or not to cache the tf-idf matrix for ``master`` between queries which
190+
preserve ``master``. Defaults to False. Use with caution: setting this option to True may degrade
191+
performance when ``master`` is too large to fit into RAM.
188192
"""
189193

190194
ngram_size: int = DEFAULT_NGRAM_SIZE
@@ -200,6 +204,7 @@ class StringGrouperConfig(NamedTuple):
200204
group_rep: str = DEFAULT_GROUP_REP
201205
force_symmetries: bool = DEFAULT_FORCE_SYMMETRIES
202206
n_blocks: Tuple[int, int] = DEFAULT_N_BLOCKS
207+
enable_cache: bool = DEFAULT_ENABLE_CACHE
203208

204209

205210
def validate_is_fit(f):
@@ -242,6 +247,7 @@ def __init__(self, master: pd.Series,
242247
"""
243248
# private members:
244249
self.is_build = False
250+
self._cache = dict()
245251

246252
self._master: pd.DataFrame = pd.DataFrame()
247253
self._duplicates: Optional[pd.Series] = None
@@ -323,8 +329,24 @@ def reset_data(self,
323329
:param duplicates_id: pandas.Series. If set, contains ID values for each row in duplicates Series.
324330
:param kwargs: All other keyword arguments are passed to StringGrouperConfig
325331
"""
332+
self._cache.clear()
326333
self._set_data(master, duplicates, master_id, duplicates_id)
327334

335+
def _reset_duplicates_only(self, duplicates: pd.Series = None, duplicates_id: Optional[pd.Series] = None):
336+
# Validate input strings data
337+
self.duplicates = duplicates
338+
339+
# Validate optional IDs input
340+
if not StringGrouper._is_input_data_combination_valid(duplicates, self._master_id, duplicates_id):
341+
raise Exception('List of data Series options is invalid')
342+
StringGrouper._validate_id_data(self._master, duplicates, self._master_id, duplicates_id)
343+
self._duplicates_id = duplicates_id
344+
345+
# Set some private members
346+
self._left_Series = self._duplicates
347+
348+
self.is_build = False
349+
328350
def clear_data(self):
329351
self._master = None
330352
self._duplicates = None
@@ -333,6 +355,7 @@ def clear_data(self):
333355
self._matches_list = None
334356
self._left_Series = None
335357
self._right_Series = None
358+
self._cache.clear()
336359
self.is_build = False
337360

338361
def update_options(self, **kwargs):
@@ -718,7 +741,7 @@ def get_groups(self,
718741
return self._get_nearest_matches(ignore_index=ignore_index, replace_na=replace_na)
719742

720743
def match_strings(self,
721-
master: pd.Series,
744+
master: Optional[pd.Series] = None,
722745
duplicates: Optional[pd.Series] = None,
723746
master_id: Optional[pd.Series] = None,
724747
duplicates_id: Optional[pd.Series] = None,
@@ -729,14 +752,19 @@ def match_strings(self,
729752
This can be seen as an self-join. If both master and duplicates is given, it will return highly similar strings
730753
between master and duplicates. This can be seen as an inner-join.
731754
732-
:param master: pandas.Series. Series of strings against which matches are calculated.
755+
:param master: pandas.Series. Series of strings against which matches are calculated. If not set, or is set to
756+
``None``, then the currently stored ``master`` Series will be reused.
733757
:param duplicates: pandas.Series. Series of strings that will be matched with master if given (Optional).
734758
:param master_id: pandas.Series. Series of values that are IDs for master column rows (Optional).
735759
:param duplicates_id: pandas.Series. Series of values that are IDs for duplicates column rows (Optional).
736760
:param kwargs: All other keyword arguments are passed to StringGrouperConfig.
737761
:return: pandas.Dataframe.
738762
"""
739-
self.reset_data(master, duplicates, master_id, duplicates_id)
763+
if master is None:
764+
self._reset_duplicates_only(duplicates, duplicates_id)
765+
else:
766+
self.reset_data(master, duplicates, master_id, duplicates_id)
767+
740768
self.update_options(**kwargs)
741769
self = self.fit()
742770
return self.get_matches()
@@ -761,14 +789,18 @@ def match_most_similar(self,
761789
If IDs (both 'master_id' and 'duplicates_id') are also given, returns a DataFrame of the same strings
762790
output in the above case with their corresponding IDs.
763791
764-
:param master: pandas.Series. Series of strings that the duplicates will be matched with.
792+
:param master: pandas.Series. Series of strings that the duplicates will be matched with. If it is
793+
set to ``None``, then the currently stored ``master`` Series will be reused.
765794
:param duplicates: pandas.Series. Series of strings that will me matched with the master.
766795
:param master_id: pandas.Series. Series of values that are IDs for master column rows. (Optional)
767796
:param duplicates_id: pandas.Series. Series of values that are IDs for duplicates column rows. (Optional)
768797
:param kwargs: All other keyword arguments are passed to StringGrouperConfig. (Optional)
769798
:return: pandas.Series or pandas.DataFrame.
770799
"""
771-
self.reset_data(master, duplicates, master_id, duplicates_id)
800+
if master is None:
801+
self._reset_duplicates_only(duplicates, duplicates_id)
802+
else:
803+
self.reset_data(master, duplicates, master_id, duplicates_id)
772804

773805
old_max_n_matches = self._max_n_matches
774806
new_max_n_matches = None
@@ -875,8 +907,17 @@ def _get_right_tf_idf_matrix(self, partition=(None, None)):
875907
# unlike _get_tf_idf_matrices(), _get_right_tf_idf_matrix
876908
# does not set the corpus but rather
877909
# builds a matrix using the existing corpus
878-
return self._vectorizer.transform(
879-
self._right_Series.iloc[slice(*partition)])
910+
key = tuple(partition)
911+
if self._config.enable_cache and key in self._cache:
912+
matrix = self._cache[key]
913+
else:
914+
matrix = self._vectorizer.transform(
915+
self._right_Series.iloc[slice(*partition)])
916+
917+
if self._config.enable_cache:
918+
self._cache[key] = matrix
919+
920+
return matrix
880921

881922
def _fit_vectorizer(self) -> TfidfVectorizer:
882923
# if both dupes and master string series are set - we concat them to fit the vectorizer on all

string_grouper/test/test_string_grouper.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from scipy.sparse.csr import csr_matrix
55
from string_grouper.string_grouper import DEFAULT_MIN_SIMILARITY, \
66
DEFAULT_REGEX, DEFAULT_NGRAM_SIZE, DEFAULT_N_PROCESSES, DEFAULT_IGNORE_CASE, \
7+
DEFAULT_ENABLE_CACHE, \
78
StringGrouperConfig, StringGrouper, StringGrouperNotFitException, \
89
match_most_similar, group_similar_strings, match_strings, \
910
compute_pairwise_similarities
@@ -100,6 +101,7 @@ def test_config_defaults(self):
100101
self.assertEqual(config.ngram_size, DEFAULT_NGRAM_SIZE)
101102
self.assertEqual(config.number_of_processes, DEFAULT_N_PROCESSES)
102103
self.assertEqual(config.ignore_case, DEFAULT_IGNORE_CASE)
104+
self.assertEqual(config.enable_cache, DEFAULT_ENABLE_CACHE)
103105

104106
def test_config_immutable(self):
105107
"""Configurations should be immutable"""
@@ -117,6 +119,35 @@ def test_config_non_default_values(self):
117119

118120
class StringGrouperTest(unittest.TestCase):
119121

122+
def test_cache(self):
123+
"""tests caching when the option is enabled"""
124+
125+
sort_cols = ['right_index', 'left_index']
126+
127+
def fix_row_order(df):
128+
return df.sort_values(sort_cols).reset_index(drop=True)
129+
130+
simple_example = SimpleExample()
131+
df1 = simple_example.customers_df2['Customer Name']
132+
133+
sg = StringGrouper(df1, min_similarity=0.1)
134+
assert sg._cache == dict()
135+
matches = fix_row_order(sg.match_strings(df1)) # no cache
136+
assert sg._cache == dict()
137+
138+
matches_ = fix_row_order(sg.match_strings(duplicates=df1, enable_cache=True))
139+
assert len(sg._cache) > 0
140+
for _, value in sg._cache.items():
141+
assert isinstance(value, csr_matrix)
142+
pd.testing.assert_frame_equal(matches_, matches)
143+
matches__ = fix_row_order(sg.match_strings(duplicates=df1))
144+
assert len(sg._cache) > 0
145+
for _, value in sg._cache.items():
146+
assert isinstance(value, csr_matrix)
147+
pd.testing.assert_frame_equal(matches__, matches)
148+
with self.assertRaises(Exception):
149+
_ = sg.match_strings(duplicates=df1, duplicates_id=simple_example.customers_df2['Customer ID'])
150+
120151
def test_auto_blocking_single_Series(self):
121152
"""tests whether automatic blocking yields consistent results"""
122153
# This function will force an OverflowError to occur when
@@ -870,8 +901,10 @@ def test_get_groups_two_df(self):
870901
result = sg.get_groups()
871902
expected_result = pd.Series(['foooo', 'bar', 'baz', 'foooo'], name='most_similar_master')
872903
pd.testing.assert_series_equal(expected_result, result)
873-
result = sg.match_most_similar(test_series_1, test_series_2, max_n_matches=3)
904+
result = sg.match_most_similar(test_series_1, test_series_2, max_n_matches=3, enable_cache=True)
874905
pd.testing.assert_series_equal(expected_result, result)
906+
result2 = sg.match_most_similar(None, test_series_2, max_n_matches=3)
907+
pd.testing.assert_series_equal(expected_result, result2)
875908

876909
def test_get_groups_2_string_series_2_id_series(self):
877910
"""Should return a pd.DataFrame object with the length of the dupes. The series will contain the master string

0 commit comments

Comments
 (0)