Skip to content

Commit

Permalink
Remote load (#6)
Browse files Browse the repository at this point in the history
* Added loading methods for pyspark.

* fixed indent.

* Update __init__.py

Pep fixes

* Travis Fixes

* More Pep8 travis fixes

* travis fix

* refactor load_remote
  • Loading branch information
sinemetu1 authored and acompa committed Jan 25, 2017
1 parent 84b1454 commit 2e87e56
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ target/
# pyenv
.python-version


*.swp
136 changes: 113 additions & 23 deletions all2vec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ def __iter__(self):
vector=self.get_item_vector(entity_id)
) for entity_id in self._ann_map.keys())

def get_nfactor(self):
return self._nfactor

def load(self, pkl, filepath):
entity_type = pkl.get_entity_type(self._entity_type_id)
self.__dict__ = entity_type.__dict__
# initialize index
self._ann_obj = AnnoyIndex(pkl.get_nfactor(), entity_type._metric)
# mmap the file
self._ann_obj.load(filepath)


class EntitySet(object):
"""Organize many EntityType instances for different entity types.
Expand All @@ -96,13 +107,26 @@ def __init__(self, nfactor):
self._annoy_objects = {}
self._entity_id_map = {}

def get_size(self):
"""Return size of set."""
return len(self._annoy_objects)

def create_entity_type(self, entity_type_id, entity_type,
ntrees, metric='angular'):
"""Create an entity type and populate its metadata."""
self._annoy_objects[entity_type] = EntityType(
self._nfactor, ntrees, metric, entity_type_id, entity_type)
self._entity_id_map[entity_type_id] = entity_type

def get_entity_type(self, entity_type_id):
"""Get entity-type for specified id."""
entity_type = self._entity_id_map[entity_type_id]
return self._annoy_objects[entity_type]

def get_nfactor(self):
"""Get n-factor."""
return self._nfactor

def add_item(self, entity_type_id, entity_id, factors):
"""Wrap annoy_object add_item."""
self._annoy_objects[self._entity_id_map[entity_type_id]].add_item(
Expand Down Expand Up @@ -217,33 +241,25 @@ def save(self, folder):
json.dump(enttypes, handle)

@classmethod
def load(cls, folder):
"""Load object."""
def load_pickle(cls, file_getter):
"""Load pickled EntitySet."""
# grab pickle, replace the models with the mmapped saved annoy objects
with file_getter.get_binary_file('object.pickle') as f:
unpickled_class = dill.load(f)
return unpickled_class

@classmethod
def load_entity_info(cls, file_getter):
"""Load entity info from file."""
# Grab entity_info.json to ensure all the entity types exist and are
# the right size
filepath = os.path.join(folder, 'entity_info.json')
with open(filepath) as f:
with file_getter.get_file('entity_info.json') as f:
enttype_info = json.load(f)
return enttype_info

# grab pickle, replace the models with the mmapped saved annoy objects
pickle_filepath = os.path.join(folder, 'object.pickle')
with open(pickle_filepath, 'rb') as f:
unpickled_class = dill.load(f)
# annoy objects can't be pickled, so load these after pickle is loaded
for k in unpickled_class._annoy_objects:
annoy_filepath = os.path.join(folder, '{}.ann'.format(k))
unpickled_class._annoy_objects[k]._ann_obj = AnnoyIndex(
unpickled_class._nfactor,
unpickled_class._annoy_objects[k]._metric,
)
try:
unpickled_class._annoy_objects[k]._ann_obj.load(annoy_filepath)
except IOError as e:
raise IOError(
"Error: cannot load file {0}, which was built "
"with the model. '{1}'".format(annoy_filepath, e)
)

@classmethod
def check_load(cls, enttype_info, unpickled_class):
"""Check loaded info against unpickled class."""
# Check that sizes match up - this can be used to protect against files
# overwriting each other during a transfer, for example
enttype_sizes = {
Expand All @@ -269,4 +285,78 @@ def load(cls, folder):
'Entity type {0} exists in model_info.json '
'but was not loaded'.format(enttype['entity_type'])
)

@classmethod
def load(cls, file_getter_or_folder, entities=None):
"""Load object.
file_getter_or_folder -- a FileGetter or a folder string
entities -- optional subset of entities to load
"""
# to preserve backwards compatibility
if isinstance(file_getter_or_folder, str):
file_getter = FileGetter(file_getter_or_folder)
else:
file_getter = file_getter_or_folder

unpickled_class = cls.load_pickle(file_getter)

enttype_info = cls.load_entity_info(file_getter)
if entities is None:
entities = unpickled_class._annoy_objects
else:
# filter unwanted entities
enttype_info = [v for v in enttype_info
if v['entity_type'] in entities]
unpickled_class._annoy_objects = {
k: unpickled_class._annoy_objects[k]
for k in unpickled_class._annoy_objects
if k in entities
}

# annoy objects can't be pickled, so load these after pickle is loaded
for k in entities:
annoy_filepath = file_getter.get_file_path('{}.ann'.format(k))
try:
unpickled_class._annoy_objects[k].load(unpickled_class,
annoy_filepath)
except IOError as e:
raise IOError(
"Error: cannot load file {0}, which was built "
"with the model. '{1}'".format(annoy_filepath, e)
)
cls.check_load(enttype_info, unpickled_class)
return unpickled_class


class FileGetter(object):
"""Helper class used in EntitySet load methods."""

def __init__(self, folder=None):
self.folder = folder

def get_file_path(self, file_name):
return os.path.join(self.folder, file_name)

def get_file(self, file_name):
"""Return file object."""
return open(self.get_file_path(file_name))

def get_binary_file(self, file_name):
"""Return binary file object."""
return open(self.get_file_path(file_name), 'rb')


class SparkFileGetter(FileGetter):

def __init__(self, sparkfiles):
"""Initialize SparkFileGetter.
sparkfiles -- SparkFiles (imported from pyspark), must have
sc.addPyFile(path + file) prior to SparkFiles.get(file).
"""
self.sparkfiles = sparkfiles
super(SparkFileGetter, self).__init__()

def get_file_path(self, file_name):
return self.sparkfiles.get(file_name)
44 changes: 43 additions & 1 deletion tests/all2vec_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from all2vec import EntitySet
from all2vec import EntitySet, FileGetter

def test_get_similar_vector():
t = EntitySet(3)
Expand Down Expand Up @@ -53,3 +53,45 @@ def test_cross_entity_scores():
sim = t.get_scores('type0', 0, 'type1', [0,1,2])
assert [x['score'] for x in sim] == [1.0, 2.0, 3.0]

def test_save_and_load(tmpdir):
t = EntitySet(3)
t.create_entity_type(entity_type_id=0, entity_type="type0", ntrees=10,
metric="angular")
t.create_entity_type(entity_type_id=1, entity_type="type1", ntrees=10,
metric="angular")

t.add_item(0, 0, [1, 2, 3])
t.add_item(1, 0, [1, 0, 0])
t.add_item(1, 1, [0, 1, 0])
t.add_item(1, 2, [0, 0, 1])
t.build()

a_dir = str(tmpdir)
t.save(a_dir)

loaded = EntitySet.load(FileGetter(a_dir))

sim = loaded.get_scores('type0', 0, 'type1', [0,1,2])
assert [x['score'] for x in sim] == [1.0, 2.0, 3.0]
assert loaded.get_size() == 2

def test_save_and_load_subset(tmpdir):
t = EntitySet(3)
t.create_entity_type(entity_type_id=0, entity_type="type0", ntrees=10,
metric="angular")
t.create_entity_type(entity_type_id=1, entity_type="type1", ntrees=10,
metric="angular")

t.add_item(0, 0, [0, 0, 1])
t.add_item(0, 1, [0, 1, 0])
t.add_item(0, 2, [1, 0, 0])
t.build()

a_dir = str(tmpdir)
t.save(a_dir)

loaded = EntitySet.load(FileGetter(a_dir), ['type0'])

sim = loaded.get_similar_vector([3, 2, 1], "type0", 3, 1, False)
assert [x['entity_id'] for x in sim] == [2, 1, 0]
assert loaded.get_size() == 1

0 comments on commit 2e87e56

Please sign in to comment.