Skip to content

Commit

Permalink
Merge branch 'ajayp/limit_sqlite_max_number_of_variables' into 'master'
Browse files Browse the repository at this point in the history
Batch multi-db lookups to respect SQLITE_MAX_VARIABLE_NUMBER

See merge request Plasticity/magnitude!4
  • Loading branch information
AjayP13 committed Mar 4, 2018
2 parents 7283948 + 4132b0d commit 0c7fd2a
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 40 deletions.
120 changes: 81 additions & 39 deletions pymagnitude/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@

DEFAULT_LRU_CACHE_SIZE = 1000

def _sqlite_try_max_variable_number(num):
""" Tests whether SQLite can handle num variables """
db = sqlite3.connect(':memory:')
try:
db.cursor().execute(
"SELECT 1 IN ("+",".join(["?"]*num)+")",
([0]*num)
).fetchall()
return num
except:
return -1
finally:
db.close()


class Magnitude(object):

NGRAM_BEG = 1
Expand All @@ -63,9 +78,11 @@ class Magnitude(object):
FTS_SPECIAL = set('*^')
MMAP_THREAD_LOCK = {}
OOV_RNG_LOCK = threading.Lock()

SQLITE_MAX_VARIABLE_NUMBER = max(max((_sqlite_try_max_variable_number(n)
for n in [99, 999, 9999, 99999])), 1)

def __new__(cls, *args, **kwargs):
''' Returns a concatenated magnitude object, if Magnitude parameters '''
""" Returns a concatenated magnitude object, if Magnitude parameters """
if len(args) > 0 and isinstance(args[0], Magnitude):
obj = object.__new__(ConcatenatedMagnitude, *args, **kwargs)
obj.__init__(*args, **kwargs)
Expand Down Expand Up @@ -536,6 +553,19 @@ def _out_of_vocab_vector(self, key):
else:
return final_vector.tolist()

def _db_batch_generator(self, params):
""" Generates batches of paramaters that respect
SQLite's MAX_VARIABLE_NUMBER """
if len(params) <= Magnitude.SQLITE_MAX_VARIABLE_NUMBER:
yield params
else:
it = iter(params)
for batch in \
iter(lambda: tuple(
islice(it, Magnitude.SQLITE_MAX_VARIABLE_NUMBER)
), ()):
yield batch

def _db_result_to_vec(self, result):
"""Converts a database result to a vector."""
if self.use_numpy:
Expand Down Expand Up @@ -581,27 +611,33 @@ def _vectors_for_keys(self, keys):
if len(unseen_keys) > 0:
unseen_keys_map = {self._key_t(k): i for i, k in
enumerate(unseen_keys)}
results = self._db().execute(
"""
SELECT *
FROM `magnitude`
WHERE key
IN ("""+ ' ,'.join(['?'] * len(unseen_keys)) + """);
""",
unseen_keys)
unseen_vectors = [None] * len(unseen_keys)
seen_keys = set()
for result in results:
result_key, vec = self._db_full_result_to_vec(result)
result_key_t = self._key_t(result_key)
if result_key_t in unseen_keys_map:
i = unseen_keys_map[result_key_t]
if ((result_key_t not in seen_keys
or result_key==unseen_keys[i])
and
(self.case_insensitive or result_key==unseen_keys[i])):
seen_keys.add(result_key_t)
unseen_vectors[i] = vec
for unseen_keys_batch in self._db_batch_generator(unseen_keys):
results = self._db().execute(
"""
SELECT *
FROM `magnitude`
WHERE key
IN (""" + ' ,'.join(['?'] * len(unseen_keys_batch)) \
+ """);
""",
unseen_keys_batch)
for result in results:
result_key, vec = self._db_full_result_to_vec(result)
result_key_t = self._key_t(result_key)
if result_key_t in unseen_keys_map:
i = unseen_keys_map[result_key_t]
if (
(result_key_t not in seen_keys
or result_key == unseen_keys[i])
and
(
self.case_insensitive or
result_key == unseen_keys[i])
):
seen_keys.add(result_key_t)
unseen_vectors[i] = vec
for i in range(len(unseen_vectors)):
self._vector_for_key_cached._cache.put((unseen_keys[i],),
unseen_vectors[i])
Expand Down Expand Up @@ -646,25 +682,31 @@ def _keys_for_indices(self, indices, return_vector=True):
columns = "*"
unseen_indices_map = {(index - 1): i for i, index in
enumerate(unseen_indices)}
results = self._db().execute(
"""
SELECT rowid, """+ columns + """
FROM `magnitude`
WHERE rowid IN ("""+
' ,'.join(['?'] * len(unseen_indices)) +
""");""",
unseen_indices)
unseen_keys = [None] * len(unseen_indices)
for result in results:
i = unseen_indices_map[result[0] - 1]
result_key = result[1]
if return_vector:
unseen_keys[i] = self._db_full_result_to_vec(result[1:])
else:
unseen_keys[i] = result_key
self._key_for_index_cached._cache.put(
((unseen_indices[i] - 1,), frozenset(
[('return_vector', return_vector)])), unseen_keys[i])
for unseen_indices_batch in \
self._db_batch_generator(unseen_indices):
results = self._db().execute(
"""
SELECT rowid, """+ columns + """
FROM `magnitude`
WHERE rowid IN ("""+
' ,'.join(['?'] * len(unseen_indices_batch)) +
""");""",
unseen_indices_batch)
for result in results:
i = unseen_indices_map[result[0] - 1]
result_key = result[1]
if return_vector:
unseen_keys[i] = self._db_full_result_to_vec(result[1:])
else:
unseen_keys[i] = result_key
self._key_for_index_cached._cache.put(
(
(unseen_indices[i] - 1,),
frozenset([('return_vector', return_vector)])
),
unseen_keys[i]
)
for i in range(len(unseen_keys)):
if unseen_keys[i] is None:
raise IndexError("The index %d is out-of-range" % \
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
setup(
name='pymagnitude',
packages=find_packages(exclude=['tests', 'tests.*']),
version='0.1.6',
version='0.1.7',
description='A fast, efficient universal vector embedding utility package.',
long_description="""
About
Expand Down

0 comments on commit 0c7fd2a

Please sign in to comment.