Skip to content

Commit

Permalink
Merge branch 'acsands13-multiword' into 'master'
Browse files Browse the repository at this point in the history
Fix fts3 special characters

See merge request Plasticity/magnitude!1
  • Loading branch information
alexsands committed Feb 28, 2018
2 parents 1688800 + a637706 commit 18ba50c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
40 changes: 30 additions & 10 deletions pymagnitude/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import gc
import os
Expand Down Expand Up @@ -59,6 +60,7 @@ class Magnitude(object):
BOW = BOW
EOW = EOW
RARE_CHAR = u"\uF002".encode('utf-8')
FTS_SPECIAL = set('*^')
MMAP_THREAD_LOCK = {}
OOV_RNG_LOCK = threading.Lock()

Expand Down Expand Up @@ -421,12 +423,12 @@ def _db_query_similar_keys_vector(self, key, orig_key, topn = 3):
+ """ DESC,
"""
if true_key_len <= 5 and key_shrunk != key:
exact_match = list(char_ngrams(key_shrunk,
true_key_len, true_key_len))
exact_match = list(char_ngrams(
key_shrunk, true_key_len, true_key_len))
search_query = """
SELECT magnitude.*
FROM magnitude_subword, magnitude
WHERE char_ngrams MATCH ?
WHERE char_ngrams {0}
AND magnitude.rowid = magnitude_subword.rowid
ORDER BY
((
Expand All @@ -439,17 +441,35 @@ def _db_query_similar_keys_vector(self, key, orig_key, topn = 3):
LIMIT ?;
"""
if len(exact_match) > 0:
results = self._db().execute(search_query,
(' OR '.join(exact_match), topn)).fetchall()
# Handle fts3 special characters
if any((c in Magnitude.FTS_SPECIAL)
for c in ''.join(exact_match)):
q = search_query.format(
'IN (' + ', '.join('?' * len(exact_match)) + ')')
params = exact_match + [topn]
else:
q = search_query.format('MATCH ?')
params = (' OR '.join('"{0}"'.format(
e.replace('"', '""')) for e in exact_match), topn)
results = self._db().execute(q, params).fetchall()
else:
results = []
if len(results) == 0:
while (len(results) < topn and
current_subword_start >= self.subword_start):
results = self._db().execute(search_query,
(' OR '.join(char_ngrams(key,
current_subword_start, self.subword_end)),
topn)).fetchall()
current_subword_start >= self.subword_start):
ngrams = list(char_ngrams(
key, current_subword_start, self.subword_end))
# Handle fts3 special characters
if any((c in Magnitude.FTS_SPECIAL)
for c in ''.join(ngrams)):
q = search_query.format(
'IN (' + ', '.join('?' * len(ngrams)) + ')')
params = ngrams + [topn]
else:
q = search_query.format('MATCH ?')
params = (' OR '.join('"{0}"'.format(
n.replace('"', '""')) for n in ngrams), topn)
results = self._db().execute(q, params).fetchall()
current_subword_start -= 1
# if current_subword_start > self.subword_start:
# results = self._db().execute(search_query,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
setup(
name='pymagnitude',
packages=find_packages(exclude=['tests', 'tests.*']),
version='0.1.2',
version='0.1.3',
description='A fast, efficient universal vector embedding utility package.',
long_description="""
About
Expand Down
11 changes: 11 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,17 @@ def test_contains(self):
def test_contains_false(self):
self.assertTrue("blah123" not in self.vectors)

def test_special_characters(self):
self.assertTrue("Wilkes-Barre/Scranton" in self.vectors)
self.assertTrue("out-of-vocabulary" not in self.vectors)
self.assertTrue('quotation"s' not in self.vectors)
self.assertTrue("quotation's" not in self.vectors)
self.assertTrue("colon;s" not in self.vectors)
self.assertEqual(self.vectors.query("out-of-vocabulary").shape,
self.vectors.query("Wilkes-Barre/Scranton").shape)
self.assertEqual(self.vectors.query("cat").shape,
self.vectors.query('quotation"s').shape)

def test_oov_dim(self):
self.assertEqual(self.vectors.query("*<<<<").shape,
self.vectors.query("cat").shape)
Expand Down

0 comments on commit 18ba50c

Please sign in to comment.