Skip to content

Commit e6cc77f

Browse files
authored
Merge pull request #39 from zkarnin/sim_encode
feat: Similarity encoding
2 parents c1d0945 + dae9334 commit e6cc77f

File tree

5 files changed

+190
-0
lines changed

5 files changed

+190
-0
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,4 @@ Overview of Submodules
134134
* :code:`RobustLabelEncoder` encode labels for seen and unseen labels
135135
* :code:`RobustStandardScaler` standardization for dense and sparse inputs
136136
* :code:`WOEEncoder` weight of evidence supervised encoder
137+
* :code:`SimilarityEncoder` encode categorical values based on their descriptive string

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ scikit-learn==0.23.2
44
python-dateutil==2.8.0
55
pandas==1.2.4
66
tsfresh==0.18.0
7+
statsmodels==0.12.2

src/sagemaker_sklearn_extension/preprocessing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .encoders import RobustOrdinalEncoder
2525
from .encoders import ThresholdOneHotEncoder
2626
from .encoders import WOEEncoder
27+
from .encoders import SimilarityEncoder
2728

2829
__all__ = [
2930
"BaseExtremeValueTransformer",
@@ -39,4 +40,5 @@
3940
"log_transform",
4041
"quantile_transform_nonrandom",
4142
"WOEEncoder",
43+
"SimilarityEncoder",
4244
]

src/sagemaker_sklearn_extension/preprocessing/encoders.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,3 +916,141 @@ def fit_transform(self, X, y):
916916

917917
def _more_tags(self):
918918
return {"X_types": ["categorical"], "binary_only": True, "requires_y": True}
919+
920+
921+
class SimilarityAsserts(Enum):
922+
TARGET_DIM = "Target dimension must be a positive integer."
923+
924+
925+
class SimilarityEncoder(BaseEstimator, TransformerMixin):
926+
"""Similarity encoder: encodes categorical features as a numerical vector
927+
using their textual representation. Categories with similar descriptions are mapped to
928+
similar vectors.
929+
The underlying method used is locally sensitive hashing (LSH [2]) of the character level 3-gram
930+
tokens. The similarity between two category description is defined as the Jackard
931+
similarity between their corresponding bags of 3-grams. The known min-hash [3] embedding is
932+
then used to convert these token sets into vectors in a way that the l_0 distance, defined
933+
as the number of different entries, approximates the Jackard distance. This technique has
934+
been provided in [1] and shown to significantly outperform 1-hot encoding in scenarios where
935+
the number of categories is large.
936+
937+
Parameters
938+
----------
939+
target_dimension: int, default=30
940+
Dimension of the embedding. Small target dimension might not represent the categories in a descriptive enough
941+
way, and large target dimension take longer to compute and might result in over-fitting. For large datasets
942+
and a number of categories much larger than 30, consider raising this value.
943+
944+
seed: int, default=None
945+
seed for random number generation. Used when fitting and setting the hash functions
946+
947+
Example
948+
-------
949+
>>> import numpy as np
950+
>>> from sagemaker_sklearn_extension.preprocessing import SimilarityEncoder
951+
>>> category_data = np.array(['table', 'chair', 'table (red)', 'ladder', 'table (blue)', 'table'])
952+
>>> SimilarityEncoder(target_dimension=2, seed=112).fit_transform(category_data.reshape(-1, 1))
953+
array([[0.06143999, 0.08793556],
954+
[0.29021414, 0.29044514],
955+
[0.06143999, 0.08793556],
956+
[0.1312301 , 0.0455779 ],
957+
[0.06143999, 0.08793556],
958+
[0.06143999, 0.08793556]])
959+
960+
Attributes
961+
----------
962+
hash_prime_: prime used for hash funtions
963+
Hash functions operate on integers. A function consists of two numbers a,b and an integer x is hashed into
964+
x*a+b modulo hash_prime. To avoid overflows we use int64 and the largest prime p such that p*p < 2^63 -1, the
965+
maximum int64 value.
966+
967+
References
968+
----------
969+
[1] https://arxiv.org/abs/1907.01860
970+
[2] https://en.wikipedia.org/wiki/Locality-sensitive_hashing
971+
[3] https://en.wikipedia.org/wiki/MinHash
972+
"""
973+
974+
def __init__(self, target_dimension=30, seed=None):
975+
self.target_dimension = target_dimension
976+
self.seed = seed
977+
978+
def fit(self, X=None, y=None):
979+
"""Fit Similarity encoder.
980+
Ignores input data. This fixes the hash funtion(s) to be used for the minhash encoding
981+
982+
Parameters
983+
----------
984+
X: array-like, shape (n_samples, n_features)
985+
The data to encode.
986+
987+
y: array-like, shape (n_samples,)
988+
The binary target vector.
989+
990+
Returns
991+
-------
992+
self: SimilarityEncoder.
993+
"""
994+
# Validate parameters
995+
assert isinstance(self.target_dimension, int) and self.target_dimension > 0, SimilarityAsserts.TARGET_DIM
996+
997+
# prime to be used for hash function (largest prime p such that p**2 is still within int64 range)
998+
self.hash_prime_ = 2038074743
999+
# random numbers for hash functions
1000+
generator = np.random.RandomState(seed=self.seed)
1001+
self._mult = generator.randint(low=1, high=self.hash_prime_, size=(self.target_dimension, 1))
1002+
self._add = generator.randint(low=0, high=self.hash_prime_, size=(self.target_dimension, 1))
1003+
return self
1004+
1005+
def _minhash_index_sparse_vec(self, vec):
1006+
# prepare tokens as valid integers
1007+
ind = vec.indices.astype(np.int64)
1008+
ind %= self.hash_prime_
1009+
# if the vector was zero, ind is an empty list. In this case fill it with a single zero. This is needed to
1010+
# avoid an error below when taking a minimum along an axis
1011+
if ind.shape == (0,):
1012+
ind = np.zeros((1,), dtype=np.int64)
1013+
1014+
# compute for each token its hash values, create a matrix of dimensions (num_hash, num_tokens)
1015+
all_hash_values = self._mult * ind.reshape((1, -1)) + self._add
1016+
all_hash_values %= self.hash_prime_
1017+
1018+
# compute row-wise min to get vector of length num_hash
1019+
hash_values = np.min(all_hash_values, axis=1)
1020+
1021+
# normalize in [0,1)
1022+
return hash_values.astype(np.float64) / self.hash_prime_
1023+
1024+
def transform(self, X):
1025+
"""Transform each column of `X` using the Similarity encoding.
1026+
1027+
Returns
1028+
-------
1029+
X_encoded: array, shape (n_samples, n_encoded_features * target_dimension)
1030+
Array with each of the encoded columns.
1031+
"""
1032+
check_is_fitted(self, "hash_prime_")
1033+
X = check_array(X, dtype=str)
1034+
1035+
# remember shape, flatten X to be 1dim, and convert to string. Note - this makes sure all None values become
1036+
# the string 'None'. This is acceptible behavior
1037+
str_list = X.reshape((-1,)).astype("str")
1038+
# replace nones
1039+
# tokenize each string
1040+
# convert each token array into integers via hash function
1041+
from sklearn.feature_extraction.text import HashingVectorizer
1042+
1043+
# TODO: In the paper this function is based on the ngram number was fixed as 3. As a follow up, consider
1044+
# parametrizing this.
1045+
hv = HashingVectorizer(analyzer="char_wb", ngram_range=(3, 3), binary=True)
1046+
token_hash_matrix = hv.fit_transform(str_list)
1047+
# apply minhash
1048+
minhash_vectors = np.array([self._minhash_index_sparse_vec(row) for row in token_hash_matrix])
1049+
# reshape back
1050+
return minhash_vectors.reshape((X.shape[0], X.shape[1] * self.target_dimension))
1051+
1052+
def fit_transform(self, X, y=None):
1053+
return self.fit(X, y).transform(X)
1054+
1055+
def _more_tags(self):
1056+
return {"X_types": ["string"]}

test/test_preprocessing_encoders.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker_sklearn_extension.preprocessing import ThresholdOneHotEncoder
2121
from sagemaker_sklearn_extension.preprocessing import RobustOrdinalEncoder
2222
from sagemaker_sklearn_extension.preprocessing import WOEEncoder
23+
from sagemaker_sklearn_extension.preprocessing import SimilarityEncoder
2324

2425

2526
X = np.array([["hot dog", 1], ["hot dog", 1], ["apple", 2], ["hot dog", 3], ["hot dog", 1], ["banana", 3]])
@@ -578,3 +579,50 @@ def test_woe_multi_cols():
578579
Xe = enc.fit_transform(X, titanic_y)
579580
assert len(np.unique(Xe[:, 0])) == 4
580581
assert len(np.unique(Xe[:, 1])) == 4
582+
583+
584+
def test_similarity_consistent():
585+
X = np.array(
586+
[
587+
"cat1",
588+
"cat2",
589+
"cat1",
590+
"abcdefghijkkjihgfedcba",
591+
"abcdefghijkkjihgfedcab",
592+
"lmnopqrstuvwxyzzyxwvutsrqponml",
593+
"a",
594+
"b",
595+
]
596+
).reshape((-1, 1))
597+
se = SimilarityEncoder(target_dimension=300, seed=5)
598+
out = se.fit_transform(X)
599+
# exact equal strings should get equal vectors
600+
assert np.array_equal(out[0], out[2])
601+
# completely different strings should get different vectors
602+
assert not np.array_equal(out[0], out[1])
603+
# output 3,4 should be similar vectors, meaning closer than the vectors of 3 and 5 since these are very
604+
# different strings
605+
assert np.linalg.norm(out[3] - out[4]) < np.linalg.norm(out[3] - out[5])
606+
# make sure single character inputs also get different outputs
607+
assert not np.array_equal(out[6], out[7])
608+
609+
610+
def test_similarity_multicol():
611+
X = np.array([["cat1a", "cat1b"], ["cat2a", "cat2b"], ["cat1a", "cat1b"]])
612+
se = SimilarityEncoder(target_dimension=3, seed=5)
613+
out = se.fit_transform(X)
614+
assert out.shape[1] == 6
615+
616+
617+
def test_similarity_fails_ilegal_target_dim():
618+
X = np.array(["cat1", "cat2", "cat1"]).reshape((-1, 1))
619+
se = SimilarityEncoder(target_dimension=0, seed=5)
620+
with pytest.raises(Exception):
621+
se.fit_transform(X)
622+
623+
624+
def test_similarity_handles_empty_string():
625+
X = np.array(["", " ", " ", "-1201230()*&(*&%$#!", None, np.nan]).reshape((-1, 1))
626+
se = SimilarityEncoder(target_dimension=3, seed=5)
627+
out = se.fit_transform(X)
628+
assert out.shape == (6, 3)

0 commit comments

Comments
 (0)