diff --git a/.gitignore b/.gitignore index 1025596..67dde94 100644 --- a/.gitignore +++ b/.gitignore @@ -71,4 +71,4 @@ target/ # pyenv .python-version - +*.swp diff --git a/all2vec/__init__.py b/all2vec/__init__.py index fe73271..d499235 100644 --- a/all2vec/__init__.py +++ b/all2vec/__init__.py @@ -82,6 +82,17 @@ def __iter__(self): vector=self.get_item_vector(entity_id) ) for entity_id in self._ann_map.keys()) + def get_nfactor(self): + return self._nfactor + + def load(self, pkl, filepath): + entity_type = pkl.get_entity_type(self._entity_type_id) + self.__dict__ = entity_type.__dict__ + # initialize index + self._ann_obj = AnnoyIndex(pkl.get_nfactor(), entity_type._metric) + # mmap the file + self._ann_obj.load(filepath) + class EntitySet(object): """Organize many EntityType instances for different entity types. @@ -96,6 +107,10 @@ def __init__(self, nfactor): self._annoy_objects = {} self._entity_id_map = {} + def get_size(self): + """Return size of set.""" + return len(self._annoy_objects) + def create_entity_type(self, entity_type_id, entity_type, ntrees, metric='angular'): """Create an entity type and populate its metadata.""" @@ -103,6 +118,15 @@ def create_entity_type(self, entity_type_id, entity_type, self._nfactor, ntrees, metric, entity_type_id, entity_type) self._entity_id_map[entity_type_id] = entity_type + def get_entity_type(self, entity_type_id): + """Get entity-type for specified id.""" + entity_type = self._entity_id_map[entity_type_id] + return self._annoy_objects[entity_type] + + def get_nfactor(self): + """Get n-factor.""" + return self._nfactor + def add_item(self, entity_type_id, entity_id, factors): """Wrap annoy_object add_item.""" self._annoy_objects[self._entity_id_map[entity_type_id]].add_item( @@ -217,33 +241,25 @@ def save(self, folder): json.dump(enttypes, handle) @classmethod - def load(cls, folder): - """Load object.""" + def load_pickle(cls, file_getter): + """Load pickled EntitySet.""" + # grab pickle, replace the models with the mmapped saved annoy objects + with file_getter.get_binary_file('object.pickle') as f: + unpickled_class = dill.load(f) + return unpickled_class + + @classmethod + def load_entity_info(cls, file_getter): + """Load entity info from file.""" # Grab entity_info.json to ensure all the entity types exist and are # the right size - filepath = os.path.join(folder, 'entity_info.json') - with open(filepath) as f: + with file_getter.get_file('entity_info.json') as f: enttype_info = json.load(f) + return enttype_info - # grab pickle, replace the models with the mmapped saved annoy objects - pickle_filepath = os.path.join(folder, 'object.pickle') - with open(pickle_filepath, 'rb') as f: - unpickled_class = dill.load(f) - # annoy objects can't be pickled, so load these after pickle is loaded - for k in unpickled_class._annoy_objects: - annoy_filepath = os.path.join(folder, '{}.ann'.format(k)) - unpickled_class._annoy_objects[k]._ann_obj = AnnoyIndex( - unpickled_class._nfactor, - unpickled_class._annoy_objects[k]._metric, - ) - try: - unpickled_class._annoy_objects[k]._ann_obj.load(annoy_filepath) - except IOError as e: - raise IOError( - "Error: cannot load file {0}, which was built " - "with the model. '{1}'".format(annoy_filepath, e) - ) - + @classmethod + def check_load(cls, enttype_info, unpickled_class): + """Check loaded info against unpickled class.""" # Check that sizes match up - this can be used to protect against files # overwriting each other during a transfer, for example enttype_sizes = { @@ -269,4 +285,78 @@ def load(cls, folder): 'Entity type {0} exists in model_info.json ' 'but was not loaded'.format(enttype['entity_type']) ) + + @classmethod + def load(cls, file_getter_or_folder, entities=None): + """Load object. + + file_getter_or_folder -- a FileGetter or a folder string + entities -- optional subset of entities to load + """ + # to preserve backwards compatibility + if isinstance(file_getter_or_folder, str): + file_getter = FileGetter(file_getter_or_folder) + else: + file_getter = file_getter_or_folder + + unpickled_class = cls.load_pickle(file_getter) + + enttype_info = cls.load_entity_info(file_getter) + if entities is None: + entities = unpickled_class._annoy_objects + else: + # filter unwanted entities + enttype_info = [v for v in enttype_info + if v['entity_type'] in entities] + unpickled_class._annoy_objects = { + k: unpickled_class._annoy_objects[k] + for k in unpickled_class._annoy_objects + if k in entities + } + + # annoy objects can't be pickled, so load these after pickle is loaded + for k in entities: + annoy_filepath = file_getter.get_file_path('{}.ann'.format(k)) + try: + unpickled_class._annoy_objects[k].load(unpickled_class, + annoy_filepath) + except IOError as e: + raise IOError( + "Error: cannot load file {0}, which was built " + "with the model. '{1}'".format(annoy_filepath, e) + ) + cls.check_load(enttype_info, unpickled_class) return unpickled_class + + +class FileGetter(object): + """Helper class used in EntitySet load methods.""" + + def __init__(self, folder=None): + self.folder = folder + + def get_file_path(self, file_name): + return os.path.join(self.folder, file_name) + + def get_file(self, file_name): + """Return file object.""" + return open(self.get_file_path(file_name)) + + def get_binary_file(self, file_name): + """Return binary file object.""" + return open(self.get_file_path(file_name), 'rb') + + +class SparkFileGetter(FileGetter): + + def __init__(self, sparkfiles): + """Initialize SparkFileGetter. + + sparkfiles -- SparkFiles (imported from pyspark), must have + sc.addPyFile(path + file) prior to SparkFiles.get(file). + """ + self.sparkfiles = sparkfiles + super(SparkFileGetter, self).__init__() + + def get_file_path(self, file_name): + return self.sparkfiles.get(file_name) diff --git a/tests/all2vec_test.py b/tests/all2vec_test.py index ca1d5e4..9be0ec3 100644 --- a/tests/all2vec_test.py +++ b/tests/all2vec_test.py @@ -1,4 +1,4 @@ -from all2vec import EntitySet +from all2vec import EntitySet, FileGetter def test_get_similar_vector(): t = EntitySet(3) @@ -53,3 +53,45 @@ def test_cross_entity_scores(): sim = t.get_scores('type0', 0, 'type1', [0,1,2]) assert [x['score'] for x in sim] == [1.0, 2.0, 3.0] +def test_save_and_load(tmpdir): + t = EntitySet(3) + t.create_entity_type(entity_type_id=0, entity_type="type0", ntrees=10, + metric="angular") + t.create_entity_type(entity_type_id=1, entity_type="type1", ntrees=10, + metric="angular") + + t.add_item(0, 0, [1, 2, 3]) + t.add_item(1, 0, [1, 0, 0]) + t.add_item(1, 1, [0, 1, 0]) + t.add_item(1, 2, [0, 0, 1]) + t.build() + + a_dir = str(tmpdir) + t.save(a_dir) + + loaded = EntitySet.load(FileGetter(a_dir)) + + sim = loaded.get_scores('type0', 0, 'type1', [0,1,2]) + assert [x['score'] for x in sim] == [1.0, 2.0, 3.0] + assert loaded.get_size() == 2 + +def test_save_and_load_subset(tmpdir): + t = EntitySet(3) + t.create_entity_type(entity_type_id=0, entity_type="type0", ntrees=10, + metric="angular") + t.create_entity_type(entity_type_id=1, entity_type="type1", ntrees=10, + metric="angular") + + t.add_item(0, 0, [0, 0, 1]) + t.add_item(0, 1, [0, 1, 0]) + t.add_item(0, 2, [1, 0, 0]) + t.build() + + a_dir = str(tmpdir) + t.save(a_dir) + + loaded = EntitySet.load(FileGetter(a_dir), ['type0']) + + sim = loaded.get_similar_vector([3, 2, 1], "type0", 3, 1, False) + assert [x['entity_id'] for x in sim] == [2, 1, 0] + assert loaded.get_size() == 1