Skip to content

Commit

Permalink
adding update_index() method for similarity indexes
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Dec 17, 2024
1 parent e330043 commit 54fa81f
Showing 1 changed file with 135 additions and 16 deletions.
151 changes: 135 additions & 16 deletions fiftyone/brain/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,49 @@ def label_ids(self):
"""
return None

def get_index_ids(self):
"""Returns the list of IDs in the full index.
All backends support this method. If the backend supports
:meth:`sample_ids` and :meth:`label_ids`, then the appropriate primary
keys are returned. For other backends, this operation can take some
time as we must query the backend sequentially to retrieve these.
Returns:
the list of sample IDs (or label IDs for patch indexes) in the full
index
"""
if self.config.patches_field is not None:
index_ids = self.label_ids
else:
index_ids = self.sample_ids

if index_ids is not None:
return index_ids

# Unfortunately for this index, the only way to infer the available IDs
# is to download all embeddings

logger.info("Retrieving IDs from index. This can take awhile...")

sample_ids, label_ids = fbu.get_ids(
self._samples, patches_field=self.config.patches_field
)

_, sample_ids, label_ids = self.get_embeddings(
sample_ids=sample_ids,
label_ids=label_ids,
allow_missing=True,
warn_missing=False,
)

if self.config.patches_field is not None:
index_ids = label_ids
else:
index_ids = sample_ids

return index_ids

@property
def total_index_size(self):
"""The total number of data points in the index.
Expand Down Expand Up @@ -948,22 +991,12 @@ def compute_embeddings(
model = self.get_model()

if skip_existing:
if self.config.patches_field is not None:
index_ids = self.label_ids
else:
index_ids = self.sample_ids

if index_ids is not None:
samples = fbu.skip_ids(
samples,
index_ids,
patches_field=self.config.patches_field,
warn_existing=warn_existing,
)
else:
logger.warning(
"This index does not support skipping existing IDs"
)
samples = fbu.skip_ids(
samples,
self.get_index_ids(),
patches_field=self.config.patches_field,
warn_existing=warn_existing,
)

if self.config.roi_field is not None:
patches_field = self.config.roi_field
Expand All @@ -988,6 +1021,92 @@ def compute_embeddings(
progress=progress,
)

def update_index(
self,
samples=None,
model=None,
overwrite=False,
batch_size=None,
num_workers=None,
skip_failures=True,
force_square=False,
alpha=None,
progress=None,
reload=True,
):
"""Updates the index, if necessary, by adding embeddings for any
samples that are not already present in the index.
Args:
samples (None): a
:class:`fiftyone.core.collections.SampleCollection` for which
to update the index. By default, :meth:`samples` is used
model (None): a :class:`fiftyone.core.models.Model` to use to
generate embeddings. If not provided, these results must have
been created with a stored model, which will be used by default
overwrite (False): whether to regenerate embeddings for
sample/label IDs that are already in the index
batch_size (None): an optional batch size to use when computing
embeddings. Only applicable when a ``model`` is provided
num_workers (None): the number of workers to use when loading
images. Only applicable when a Torch-based model is being used
to compute embeddings
skip_failures (True): whether to gracefully continue without
raising an error if embeddings cannot be generated for a sample
force_square (False): whether to minimally manipulate the patch
bounding boxes into squares prior to extraction. Only
applicable when a ``model`` and ``patches_field`` are specified
alpha (None): an optional expansion/contraction to apply to the
patches before extracting them, in ``[-1, inf)``. If provided,
the length and width of the box are expanded (or contracted,
when ``alpha < 0``) by ``(100 * alpha)%``. For example, set
``alpha = 0.1`` to expand the boxes by 10%, and set
``alpha = -0.1`` to contract the boxes by 10%. Only applicable
when a ``model`` and ``patches_field`` are specified
progress (None): whether to render a progress bar (True/False), use
the default value ``fiftyone.config.show_progress_bars``
(None), or a progress callback function to invoke instead
reload (True): whether to call :meth:`reload` to refresh the
current view after the update
"""
if samples is None:
samples = self._samples

embeddings, sample_ids, label_ids = self.compute_embeddings(
samples,
model=model,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
skip_existing=not overwrite,
warn_existing=False,
force_square=force_square,
alpha=alpha,
progress=progress,
)

num_added = len(embeddings)
if num_added == 0:
logger.info("Index is already up to date")
return

logger.info(f"Adding {num_added} embeddings to the index...")
self.add_to_index(
embeddings,
sample_ids,
label_ids=label_ids,
overwrite=overwrite,
allow_existing=True,
warn_existing=False,
reload=reload,
)

if (
self.config.method == "sklearn"
and self.config.embeddings_field is None
):
self.save()

@classmethod
def _from_dict(cls, d, samples, config, brain_key):
"""Builds a :class:`SimilarityIndex` from a JSON representation of it.
Expand Down

0 comments on commit 54fa81f

Please sign in to comment.