Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 142 additions & 64 deletions photomap/backend/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import asyncio
import functools
import gc
import logging
import os
import sys
Expand Down Expand Up @@ -256,6 +257,35 @@ def __init__(self, **data):
data["embeddings_path"] = Path(data["embeddings_path"]).resolve()
super().__init__(**data)

@staticmethod
def _cleanup_cuda_memory(device: str) -> None:
"""
Clean up CUDA memory by clearing cache and forcing garbage collection.

This completely frees GPU VRAM to ensure it returns to zero (or minimal baseline)
after operations. The model will need to be reloaded on subsequent operations,
but this ensures GPU memory is available for other processes.

Note: A baseline CUDA context (~188 MiB) may remain after first GPU use.
This is a PyTorch/CUDA limitation and cannot be freed without ending the process.

Args:
device: The device string ("cuda" or "cpu")
"""
if device == "cuda":
try:
# Synchronize to ensure all CUDA operations are complete
torch.cuda.synchronize()
# Empty the CUDA cache
torch.cuda.empty_cache()
# Force garbage collection to clean up Python references
gc.collect()
# Empty cache again after GC to catch any newly freed memory
torch.cuda.empty_cache()
except RuntimeError as e:
# Log but don't crash if CUDA operations fail
logger.warning(f"CUDA cleanup failed: {e}")

def get_image_files_from_directory(
self,
directory: Path,
Expand Down Expand Up @@ -447,15 +477,22 @@ def _process_images_batch(
umap_embeddings = self.create_umap_index(
np.array(embeddings) if embeddings else np.empty((0, 512))
)

return IndexResult(
result = IndexResult(
embeddings=np.array(embeddings) if embeddings else np.empty((0, 512)),
filenames=np.array(filenames),
modification_times=np.array(modification_times),
metadata=np.array(metadatas, dtype=object),
umap_embeddings=umap_embeddings,
bad_files=bad_files,
)

# Clean up GPU memory after batch processing
# Delete model references to completely free VRAM
del model, preprocess
self._cleanup_cuda_memory(device)

return result

async def _process_images_batch_async(
self, image_paths: list[Path], album_key: str, yield_interval: int = 10
Expand Down Expand Up @@ -498,13 +535,20 @@ async def _process_images_batch_async(
if i % yield_interval == 0:
await asyncio.sleep(0.01)

return IndexResult(
result = IndexResult(
embeddings=np.array(embeddings) if embeddings else np.empty((0, 512)),
filenames=np.array(filenames),
modification_times=np.array(modification_times),
metadata=np.array(metadatas, dtype=object),
bad_files=bad_files,
)

# Clean up GPU memory after async batch processing
# Delete model references to completely free VRAM
del model, preprocess
self._cleanup_cuda_memory(device)

return result

def _save_embeddings(self, index_result: IndexResult) -> None:
"""Save embeddings to disk and clear cache."""
Expand Down Expand Up @@ -983,71 +1027,105 @@ def search_images_by_text_and_image(
"ViT-B/32", device=device, download_root=self._clip_root()
)

# Handle None queries: set weight to zero and skip embedding
if query_image_data is None:
image_weight = 0.0
image_embedding = None
else:
pil_image = ImageOps.exif_transpose(query_image_data)
pil_image = pil_image.convert("RGB")
image_tensor: torch.Tensor = preprocess(pil_image) # type: ignore
image_tensor = image_tensor.unsqueeze(0).to(device)
with torch.no_grad():
image_embedding = model.encode_image(image_tensor).squeeze(0)

if not positive_query:
positive_weight = 0.0
pos_emb = None
else:
tokens = clip.tokenize([positive_query]).to(device)
with torch.no_grad():
pos_emb = model.encode_text(tokens).squeeze(0)

if not negative_query:
negative_weight = 0.0
neg_emb = None
else:
tokens = clip.tokenize([negative_query]).to(device)
with torch.no_grad():
neg_emb = model.encode_text(tokens).squeeze(0)

# If all weights are zero, return empty result
if image_weight == 0.0 and positive_weight == 0.0 and negative_weight == 0.0:
return [], []

# Weighted combination: image + positive - negative
combined_embedding = None
if image_weight > 0.0 and image_embedding is not None:
combined_embedding = image_weight * image_embedding
if positive_weight > 0.0 and pos_emb is not None:
if combined_embedding is None:
combined_embedding = positive_weight * pos_emb
try:
# Handle None queries: set weight to zero and skip embedding
if query_image_data is None:
image_weight = 0.0
image_embedding = None
else:
combined_embedding += positive_weight * pos_emb
if negative_weight > 0.0 and neg_emb is not None:
if combined_embedding is None:
combined_embedding = -negative_weight * neg_emb
pil_image = ImageOps.exif_transpose(query_image_data)
pil_image = pil_image.convert("RGB")
image_tensor: torch.Tensor = preprocess(pil_image) # type: ignore
image_tensor = image_tensor.unsqueeze(0).to(device)
with torch.no_grad():
image_embedding = model.encode_image(image_tensor).squeeze(0)

if not positive_query:
positive_weight = 0.0
pos_emb = None
else:
combined_embedding -= negative_weight * neg_emb

# Normalize
embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32, device=device)
norm_embeddings = F.normalize(embeddings_tensor, dim=-1).to(torch.float32)
assert combined_embedding is not None
combined_embedding_norm = F.normalize(combined_embedding, dim=-1).to(
torch.float32
)
tokens = clip.tokenize([positive_query]).to(device)
with torch.no_grad():
pos_emb = model.encode_text(tokens).squeeze(0)

# Similarity
similarities = (norm_embeddings @ combined_embedding_norm).cpu().numpy()
top_indices = similarities.argsort()[-top_k:][::-1]
top_indices = [i for i in top_indices if similarities[i] >= minimum_score]
if not top_indices:
return [], []
if not negative_query:
negative_weight = 0.0
neg_emb = None
else:
tokens = clip.tokenize([negative_query]).to(device)
with torch.no_grad():
neg_emb = model.encode_text(tokens).squeeze(0)

# If all weights are zero, return empty result
if image_weight == 0.0 and positive_weight == 0.0 and negative_weight == 0.0:
return [], []

# Weighted combination: image + positive - negative
combined_embedding = None
if image_weight > 0.0 and image_embedding is not None:
combined_embedding = image_weight * image_embedding
if positive_weight > 0.0 and pos_emb is not None:
if combined_embedding is None:
combined_embedding = positive_weight * pos_emb
else:
combined_embedding += positive_weight * pos_emb
if negative_weight > 0.0 and neg_emb is not None:
if combined_embedding is None:
combined_embedding = -negative_weight * neg_emb
else:
combined_embedding -= negative_weight * neg_emb

# Normalize
embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32, device=device)
norm_embeddings = F.normalize(embeddings_tensor, dim=-1).to(torch.float32)
assert combined_embedding is not None
combined_embedding_norm = F.normalize(combined_embedding, dim=-1).to(
torch.float32
)

# Translate from filename array indices to sorted filename top_indices
global_indices = [int(filename_map[filenames[i]]) for i in top_indices]
return global_indices, similarities[top_indices].tolist()
# Similarity
similarities = (norm_embeddings @ combined_embedding_norm).cpu().numpy()
top_indices = similarities.argsort()[-top_k:][::-1]
top_indices = [i for i in top_indices if similarities[i] >= minimum_score]

if not top_indices:
return [], []

# Translate from filename array indices to sorted filename top_indices
result_indices = [int(filename_map[filenames[i]]) for i in top_indices]
result_similarities = similarities[top_indices].tolist()

return result_indices, result_similarities
finally:
# Clean up GPU memory after search (always executed)
# Delete all GPU tensors and model references to completely free VRAM
try:
del model, preprocess
# Delete any tensors that may still be around
if 'image_tensor' in locals():
del image_tensor
if 'tokens' in locals():
del tokens
if 'embeddings_tensor' in locals():
del embeddings_tensor
if 'norm_embeddings' in locals():
del norm_embeddings
if 'combined_embedding' in locals():
del combined_embedding
if 'combined_embedding_norm' in locals():
del combined_embedding_norm
if 'similarities' in locals():
del similarities
if 'image_embedding' in locals():
del image_embedding
if 'pos_emb' in locals():
del pos_emb
if 'neg_emb' in locals():
del neg_emb
except (NameError, UnboundLocalError):
# Variables may not be defined if early return
pass
self._cleanup_cuda_memory(device)

def find_duplicate_clusters(self, similarity_threshold=0.995):
"""
Expand Down