Skip to content

Commit dae9334

Browse files
author
Zohar Karnin
committed
feat: Similarity encoding
Adding a categorical encoder that maps categories into dense vectors based on their textual description. Categories with similar names will be mapped to vectors with similar numeric values. This is useful for datasets with noisy names (e.g. typos) or large number of categories
1 parent c5c5d58 commit dae9334

File tree

4 files changed

+189
-0
lines changed

4 files changed

+189
-0
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,4 @@ Overview of Submodules
134134
* :code:`RobustLabelEncoder` encode labels for seen and unseen labels
135135
* :code:`RobustStandardScaler` standardization for dense and sparse inputs
136136
* :code:`WOEEncoder` weight of evidence supervised encoder
137+
* :code:`SimilarityEncoder` encode categorical values based on their descriptive string

src/sagemaker_sklearn_extension/preprocessing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .encoders import RobustOrdinalEncoder
2525
from .encoders import ThresholdOneHotEncoder
2626
from .encoders import WOEEncoder
27+
from .encoders import SimilarityEncoder
2728

2829
__all__ = [
2930
"BaseExtremeValueTransformer",
@@ -39,4 +40,5 @@
3940
"log_transform",
4041
"quantile_transform_nonrandom",
4142
"WOEEncoder",
43+
"SimilarityEncoder",
4244
]

src/sagemaker_sklearn_extension/preprocessing/encoders.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,3 +916,141 @@ def fit_transform(self, X, y):
916916

917917
def _more_tags(self):
918918
return {"X_types": ["categorical"], "binary_only": True, "requires_y": True}
919+
920+
921+
class SimilarityAsserts(Enum):
922+
TARGET_DIM = "Target dimension must be a positive integer."
923+
924+
925+
class SimilarityEncoder(BaseEstimator, TransformerMixin):
926+
"""Similarity encoder: encodes categorical features as a numerical vector
927+
using their textual representation. Categories with similar descriptions are mapped to
928+
similar vectors.
929+
The underlying method used is locally sensitive hashing (LSH [2]) of the character level 3-gram
930+
tokens. The similarity between two category description is defined as the Jackard
931+
similarity between their corresponding bags of 3-grams. The known min-hash [3] embedding is
932+
then used to convert these token sets into vectors in a way that the l_0 distance, defined
933+
as the number of different entries, approximates the Jackard distance. This technique has
934+
been provided in [1] and shown to significantly outperform 1-hot encoding in scenarios where
935+
the number of categories is large.
936+
937+
Parameters
938+
----------
939+
target_dimension: int, default=30
940+
Dimension of the embedding. Small target dimension might not represent the categories in a descriptive enough
941+
way, and large target dimension take longer to compute and might result in over-fitting. For large datasets
942+
and a number of categories much larger than 30, consider raising this value.
943+
944+
seed: int, default=None
945+
seed for random number generation. Used when fitting and setting the hash functions
946+
947+
Example
948+
-------
949+
>>> import numpy as np
950+
>>> from sagemaker_sklearn_extension.preprocessing import SimilarityEncoder
951+
>>> category_data = np.array(['table', 'chair', 'table (red)', 'ladder', 'table (blue)', 'table'])
952+
>>> SimilarityEncoder(target_dimension=2, seed=112).fit_transform(category_data.reshape(-1, 1))
953+
array([[0.06143999, 0.08793556],
954+
[0.29021414, 0.29044514],
955+
[0.06143999, 0.08793556],
956+
[0.1312301 , 0.0455779 ],
957+
[0.06143999, 0.08793556],
958+
[0.06143999, 0.08793556]])
959+
960+
Attributes
961+
----------
962+
hash_prime_: prime used for hash funtions
963+
Hash functions operate on integers. A function consists of two numbers a,b and an integer x is hashed into
964+
x*a+b modulo hash_prime. To avoid overflows we use int64 and the largest prime p such that p*p < 2^63 -1, the
965+
maximum int64 value.
966+
967+
References
968+
----------
969+
[1] https://arxiv.org/abs/1907.01860
970+
[2] https://en.wikipedia.org/wiki/Locality-sensitive_hashing
971+
[3] https://en.wikipedia.org/wiki/MinHash
972+
"""
973+
974+
def __init__(self, target_dimension=30, seed=None):
975+
self.target_dimension = target_dimension
976+
self.seed = seed
977+
978+
def fit(self, X=None, y=None):
979+
"""Fit Similarity encoder.
980+
Ignores input data. This fixes the hash funtion(s) to be used for the minhash encoding
981+
982+
Parameters
983+
----------
984+
X: array-like, shape (n_samples, n_features)
985+
The data to encode.
986+
987+
y: array-like, shape (n_samples,)
988+
The binary target vector.
989+
990+
Returns
991+
-------
992+
self: SimilarityEncoder.
993+
"""
994+
# Validate parameters
995+
assert isinstance(self.target_dimension, int) and self.target_dimension > 0, SimilarityAsserts.TARGET_DIM
996+
997+
# prime to be used for hash function (largest prime p such that p**2 is still within int64 range)
998+
self.hash_prime_ = 2038074743
999+
# random numbers for hash functions
1000+
generator = np.random.RandomState(seed=self.seed)
1001+
self._mult = generator.randint(low=1, high=self.hash_prime_, size=(self.target_dimension, 1))
1002+
self._add = generator.randint(low=0, high=self.hash_prime_, size=(self.target_dimension, 1))
1003+
return self
1004+
1005+
def _minhash_index_sparse_vec(self, vec):
1006+
# prepare tokens as valid integers
1007+
ind = vec.indices.astype(np.int64)
1008+
ind %= self.hash_prime_
1009+
# if the vector was zero, ind is an empty list. In this case fill it with a single zero. This is needed to
1010+
# avoid an error below when taking a minimum along an axis
1011+
if ind.shape == (0,):
1012+
ind = np.zeros((1,), dtype=np.int64)
1013+
1014+
# compute for each token its hash values, create a matrix of dimensions (num_hash, num_tokens)
1015+
all_hash_values = self._mult * ind.reshape((1, -1)) + self._add
1016+
all_hash_values %= self.hash_prime_
1017+
1018+
# compute row-wise min to get vector of length num_hash
1019+
hash_values = np.min(all_hash_values, axis=1)
1020+
1021+
# normalize in [0,1)
1022+
return hash_values.astype(np.float64) / self.hash_prime_
1023+
1024+
def transform(self, X):
1025+
"""Transform each column of `X` using the Similarity encoding.
1026+
1027+
Returns
1028+
-------
1029+
X_encoded: array, shape (n_samples, n_encoded_features * target_dimension)
1030+
Array with each of the encoded columns.
1031+
"""
1032+
check_is_fitted(self, "hash_prime_")
1033+
X = check_array(X, dtype=str)
1034+
1035+
# remember shape, flatten X to be 1dim, and convert to string. Note - this makes sure all None values become
1036+
# the string 'None'. This is acceptible behavior
1037+
str_list = X.reshape((-1,)).astype("str")
1038+
# replace nones
1039+
# tokenize each string
1040+
# convert each token array into integers via hash function
1041+
from sklearn.feature_extraction.text import HashingVectorizer
1042+
1043+
# TODO: In the paper this function is based on the ngram number was fixed as 3. As a follow up, consider
1044+
# parametrizing this.
1045+
hv = HashingVectorizer(analyzer="char_wb", ngram_range=(3, 3), binary=True)
1046+
token_hash_matrix = hv.fit_transform(str_list)
1047+
# apply minhash
1048+
minhash_vectors = np.array([self._minhash_index_sparse_vec(row) for row in token_hash_matrix])
1049+
# reshape back
1050+
return minhash_vectors.reshape((X.shape[0], X.shape[1] * self.target_dimension))
1051+
1052+
def fit_transform(self, X, y=None):
1053+
return self.fit(X, y).transform(X)
1054+
1055+
def _more_tags(self):
1056+
return {"X_types": ["string"]}

test/test_preprocessing_encoders.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker_sklearn_extension.preprocessing import ThresholdOneHotEncoder
2121
from sagemaker_sklearn_extension.preprocessing import RobustOrdinalEncoder
2222
from sagemaker_sklearn_extension.preprocessing import WOEEncoder
23+
from sagemaker_sklearn_extension.preprocessing import SimilarityEncoder
2324

2425

2526
X = np.array([["hot dog", 1], ["hot dog", 1], ["apple", 2], ["hot dog", 3], ["hot dog", 1], ["banana", 3]])
@@ -578,3 +579,50 @@ def test_woe_multi_cols():
578579
Xe = enc.fit_transform(X, titanic_y)
579580
assert len(np.unique(Xe[:, 0])) == 4
580581
assert len(np.unique(Xe[:, 1])) == 4
582+
583+
584+
def test_similarity_consistent():
585+
X = np.array(
586+
[
587+
"cat1",
588+
"cat2",
589+
"cat1",
590+
"abcdefghijkkjihgfedcba",
591+
"abcdefghijkkjihgfedcab",
592+
"lmnopqrstuvwxyzzyxwvutsrqponml",
593+
"a",
594+
"b",
595+
]
596+
).reshape((-1, 1))
597+
se = SimilarityEncoder(target_dimension=300, seed=5)
598+
out = se.fit_transform(X)
599+
# exact equal strings should get equal vectors
600+
assert np.array_equal(out[0], out[2])
601+
# completely different strings should get different vectors
602+
assert not np.array_equal(out[0], out[1])
603+
# output 3,4 should be similar vectors, meaning closer than the vectors of 3 and 5 since these are very
604+
# different strings
605+
assert np.linalg.norm(out[3] - out[4]) < np.linalg.norm(out[3] - out[5])
606+
# make sure single character inputs also get different outputs
607+
assert not np.array_equal(out[6], out[7])
608+
609+
610+
def test_similarity_multicol():
611+
X = np.array([["cat1a", "cat1b"], ["cat2a", "cat2b"], ["cat1a", "cat1b"]])
612+
se = SimilarityEncoder(target_dimension=3, seed=5)
613+
out = se.fit_transform(X)
614+
assert out.shape[1] == 6
615+
616+
617+
def test_similarity_fails_ilegal_target_dim():
618+
X = np.array(["cat1", "cat2", "cat1"]).reshape((-1, 1))
619+
se = SimilarityEncoder(target_dimension=0, seed=5)
620+
with pytest.raises(Exception):
621+
se.fit_transform(X)
622+
623+
624+
def test_similarity_handles_empty_string():
625+
X = np.array(["", " ", " ", "-1201230()*&(*&%$#!", None, np.nan]).reshape((-1, 1))
626+
se = SimilarityEncoder(target_dimension=3, seed=5)
627+
out = se.fit_transform(X)
628+
assert out.shape == (6, 3)

0 commit comments

Comments
 (0)