Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding update_index() method for similarity indexes #205

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 135 additions & 16 deletions fiftyone/brain/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,49 @@ def label_ids(self):
"""
return None

def get_index_ids(self):
"""Returns the list of IDs in the full index.

All backends support this method. If the backend supports
:meth:`sample_ids` and :meth:`label_ids`, then the appropriate primary
keys are returned. For other backends, this operation can take some
time as we must query the backend sequentially to retrieve these.

Returns:
the list of sample IDs (or label IDs for patch indexes) in the full
index
"""
if self.config.patches_field is not None:
index_ids = self.label_ids
else:
index_ids = self.sample_ids

if index_ids is not None:
return index_ids

# Unfortunately for this index, the only way to infer the available IDs
# is to download all embeddings

logger.info("Retrieving IDs from index. This can take awhile...")

sample_ids, label_ids = fbu.get_ids(
self._samples, patches_field=self.config.patches_field
)

_, sample_ids, label_ids = self.get_embeddings(
sample_ids=sample_ids,
label_ids=label_ids,
allow_missing=True,
warn_missing=False,
)

if self.config.patches_field is not None:
index_ids = label_ids
else:
index_ids = sample_ids

return index_ids

@property
def total_index_size(self):
"""The total number of data points in the index.
Expand Down Expand Up @@ -948,22 +991,12 @@ def compute_embeddings(
model = self.get_model()

if skip_existing:
if self.config.patches_field is not None:
index_ids = self.label_ids
else:
index_ids = self.sample_ids

if index_ids is not None:
samples = fbu.skip_ids(
samples,
index_ids,
patches_field=self.config.patches_field,
warn_existing=warn_existing,
)
else:
logger.warning(
"This index does not support skipping existing IDs"
)
samples = fbu.skip_ids(
samples,
self.get_index_ids(),
patches_field=self.config.patches_field,
warn_existing=warn_existing,
)

if self.config.roi_field is not None:
patches_field = self.config.roi_field
Expand All @@ -988,6 +1021,92 @@ def compute_embeddings(
progress=progress,
)

def update_index(
self,
samples=None,
model=None,
overwrite=False,
batch_size=None,
num_workers=None,
skip_failures=True,
force_square=False,
alpha=None,
progress=None,
reload=True,
):
"""Updates the index, if necessary, by adding embeddings for any
samples that are not already present in the index.

Args:
samples (None): a
:class:`fiftyone.core.collections.SampleCollection` for which
to update the index. By default, :meth:`samples` is used
model (None): a :class:`fiftyone.core.models.Model` to use to
generate embeddings. If not provided, these results must have
been created with a stored model, which will be used by default
overwrite (False): whether to regenerate embeddings for
sample/label IDs that are already in the index
batch_size (None): an optional batch size to use when computing
embeddings. Only applicable when a ``model`` is provided
num_workers (None): the number of workers to use when loading
images. Only applicable when a Torch-based model is being used
to compute embeddings
skip_failures (True): whether to gracefully continue without
raising an error if embeddings cannot be generated for a sample
force_square (False): whether to minimally manipulate the patch
bounding boxes into squares prior to extraction. Only
applicable when a ``model`` and ``patches_field`` are specified
alpha (None): an optional expansion/contraction to apply to the
patches before extracting them, in ``[-1, inf)``. If provided,
the length and width of the box are expanded (or contracted,
when ``alpha < 0``) by ``(100 * alpha)%``. For example, set
``alpha = 0.1`` to expand the boxes by 10%, and set
``alpha = -0.1`` to contract the boxes by 10%. Only applicable
when a ``model`` and ``patches_field`` are specified
progress (None): whether to render a progress bar (True/False), use
the default value ``fiftyone.config.show_progress_bars``
(None), or a progress callback function to invoke instead
reload (True): whether to call :meth:`reload` to refresh the
current view after the update
"""
if samples is None:
samples = self._samples

embeddings, sample_ids, label_ids = self.compute_embeddings(
samples,
model=model,
batch_size=batch_size,
num_workers=num_workers,
skip_failures=skip_failures,
skip_existing=not overwrite,
warn_existing=False,
force_square=force_square,
alpha=alpha,
progress=progress,
)

num_added = len(embeddings)
if num_added == 0:
logger.info("Index is already up to date")
return

logger.info(f"Adding {num_added} embeddings to the index...")
self.add_to_index(
embeddings,
sample_ids,
label_ids=label_ids,
overwrite=overwrite,
allow_existing=True,
warn_existing=False,
reload=reload,
)

if (
self.config.method == "sklearn"
and self.config.embeddings_field is None
):
self.save()

@classmethod
def _from_dict(cls, d, samples, config, brain_key):
"""Builds a :class:`SimilarityIndex` from a JSON representation of it.
Expand Down