Skip to content

Commit b4f0707

Browse files
committed
Add support for multiple GPU optimization,
Add support for enabling and disabling directories
1 parent 4e4b5dd commit b4f0707

25 files changed

+696
-630
lines changed

backend/core/embedders.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22
from timm import create_model, data
3-
43
from core.singleton import Singleton
54
from settings import settings
5+
import torch.nn as nn
66

77

88
class ImageEmbedder:
@@ -12,39 +12,45 @@ def __init__(self, name, model_name, weight, device=torch.device("cpu")):
1212
self._device = device
1313
self._weight = weight
1414

15-
self.model = create_model(model_name, pretrained=True, num_classes=0).to(device)
15+
# Create and move the model to the device.
16+
model = create_model(model_name, pretrained=True, num_classes=0).to(device)
17+
18+
# Wrap the model with DataParallel if more than one GPU is available.
19+
if torch.cuda.is_available() and torch.cuda.device_count() > 1 and settings.service.use_cuda:
20+
self.model = nn.DataParallel(model)
21+
else:
22+
self.model = model
23+
1624
self.model.eval()
1725

26+
# Use the unwrapped model for configuration
1827
self.preprocess = self.get_preprocess()
19-
self._weight = 1.0
2028
self._embedding_dim = self._determine_embedding_dim()
2129

2230
def get_preprocess(self):
23-
data_config = data.resolve_model_data_config(self.model)
31+
# Unwrap the model if wrapped in DataParallel
32+
model_for_config = self.model.module if hasattr(self.model, 'module') else self.model
33+
data_config = data.resolve_model_data_config(model_for_config)
2434
return data.create_transform(**data_config, is_training=False)
2535

26-
@property
27-
def name(self) -> str:
28-
return self._name
29-
30-
@property
31-
def model_name(self) -> str:
32-
return self._model_name
33-
34-
@property
35-
def device(self) -> torch.device:
36-
return self._device
37-
3836
def embed(self, img_binary):
39-
img_binary = self.preprocess(img_binary)
40-
img_binary = img_binary.unsqueeze(0).to(self.device)
37+
# Preprocess the image and add batch dimension.
38+
img_tensor = self.preprocess(img_binary)
39+
img_tensor = img_tensor.unsqueeze(0).to(self.device)
4140
with torch.no_grad():
42-
embedding = self.model(img_binary).squeeze(0).cpu().numpy()
41+
# DataParallel will split the batch across GPUs.
42+
embedding = self.model(img_tensor).squeeze(0).cpu().numpy()
4343
return embedding
4444

4545
def _determine_embedding_dim(self):
46-
# Generate a dummy image tensor to determine the output dimension of the embedder
47-
dummy_input = torch.zeros((3, 224, 224)).to(self.device) # Assuming the input size is 3x224x224
46+
# Unwrap the model to get the proper configuration.
47+
model_for_config = self.model.module if hasattr(self.model, 'module') else self.model
48+
data_config = data.resolve_model_data_config(model_for_config)
49+
# Get the expected input size from the configuration; defaults to (3,224,224)
50+
input_size = data_config.get("input_size", (3, 224, 224))
51+
52+
# Create a dummy input tensor with the correct size.
53+
dummy_input = torch.zeros(input_size).to(self.device)
4854
dummy_input = self.preprocess(dummy_input).unsqueeze(0).to(self.device)
4955
with torch.no_grad():
5056
embedding = self.model(dummy_input).squeeze(0).cpu().numpy()
@@ -58,6 +64,10 @@ def embedding_dim(self):
5864
def weight(self):
5965
return self._weight
6066

67+
@property
68+
def device(self):
69+
return self._device
70+
6171

6272
@Singleton
6373
class EmbedderManager:

backend/core/singleton.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,12 @@
11
class Singleton:
2-
"""
3-
A non-thread-safe helper class to ease implementing singletons.
4-
This should be used as a decorator -- not a metaclass -- to the
5-
class that should be a singleton.
6-
7-
The decorated class can define one `__init__` function that
8-
takes only the `self` argument. Also, the decorated class cannot be
9-
inherited from. Other than that, there are no restrictions that apply
10-
to the decorated class.
11-
12-
To get the singleton instance, use the `instance` method. Trying
13-
to use `__call__` will result in a `TypeError` being raised.
14-
15-
"""
16-
172
def __init__(self, decorated):
183
self._decorated = decorated
194

205
def instance(self):
21-
"""
22-
Returns the singleton instance. Upon its first call, it creates a
23-
new instance of the decorated class and calls its `__init__` method.
24-
On all subsequent calls, the already created instance is returned.
25-
26-
"""
27-
try:
28-
return self._instance
29-
except AttributeError:
30-
self._instance = self._decorated()
31-
return self._instance
6+
# Store the instance on the decorated class itself.
7+
if not hasattr(self._decorated, '_instance'):
8+
self._decorated._instance = self._decorated()
9+
return self._decorated._instance
3210

3311
def __call__(self):
3412
raise TypeError('Singletons must be accessed through `instance()`.')

backend/database/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

backend/indexing/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .services.image_indexing_service import ImageIndexingService
2+
3+
image_indexing_service = ImageIndexingService.instance()
4+
5+
__all__ = ["image_indexing_service"]

backend/indexing/consistency/__init__.py

Whitespace-only changes.
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import os
2+
import threading
3+
import time
4+
5+
from sqlalchemy.orm import Session
6+
7+
from core import embedder_manager
8+
from models.models import SessionLocal, Directory, Image
9+
from indexing.repositories.repositories import DirectoryRepository, ImageRepository, MilvusRepository
10+
from monitoring import logger
11+
from settings import settings
12+
13+
14+
class ConsistencyChecker:
15+
def __init__(self, interval: int = 3600):
16+
self.interval = interval
17+
self.thread = threading.Thread(target=self.run, daemon=True)
18+
19+
def start(self):
20+
self.thread.start()
21+
22+
def run(self):
23+
while True:
24+
time.sleep(self.interval)
25+
self.check_consistency()
26+
27+
def check_consistency(self):
28+
logger.info("Running system-wide consistency check")
29+
session = SessionLocal()
30+
try:
31+
directory_repo = DirectoryRepository(session)
32+
directories = directory_repo.get_all()
33+
for directory in directories:
34+
self.check_directory(session, directory)
35+
logger.info("Consistency check completed")
36+
except Exception as e:
37+
logger.error(f"Consistency check error: {e}", exc_info=True)
38+
session.rollback()
39+
finally:
40+
session.close()
41+
42+
def check_directory(self, session: Session, directory: Directory):
43+
logger.info(f"Checking consistency for directory {directory.path} (ID: {directory.id})")
44+
if not os.path.exists(directory.path):
45+
logger.warning(f"Directory missing: {directory.path}. Removing from system.")
46+
DirectoryRepository(session).delete(directory)
47+
return
48+
49+
# Gather filesystem image paths
50+
fs_paths = set()
51+
for entry in os.scandir(directory.path):
52+
if entry.is_file() and entry.name.lower().endswith(('.png', '.jpg', '.jpeg')):
53+
fs_paths.add(entry.path)
54+
elif entry.is_dir() and settings.directory.recursive_indexing:
55+
for root, _, files in os.walk(entry.path):
56+
for file in files:
57+
if file.lower().endswith(('.png', '.jpg', '.jpeg')):
58+
fs_paths.add(os.path.join(root, file))
59+
60+
# Get database image paths
61+
image_repo = ImageRepository(session)
62+
db_images = session.query(Image).filter(Image.directory_id == directory.id).all()
63+
db_paths = {img.path for img in db_images}
64+
65+
new_paths = fs_paths - db_paths
66+
deleted_paths = db_paths - fs_paths
67+
logger.info(f"Directory {directory.path}: {len(new_paths)} new images, {len(deleted_paths)} missing images")
68+
69+
# Add new images to DB
70+
for path in new_paths:
71+
if not image_repo.get_by_path(path):
72+
session.add(Image(path=path, directory_id=directory.id, is_indexed=False))
73+
session.commit()
74+
75+
# Remove deleted images from DB and Milvus
76+
for path in deleted_paths:
77+
image = image_repo.get_by_path(path)
78+
if image:
79+
for embedder_name in embedder_manager.get_image_embedders().keys():
80+
MilvusRepository().delete_entries(embedder_name, f"image_path == '{path}'")
81+
image_repo.delete(image)
82+
session.commit()

backend/indexing/queue_manager/__init__.py

Whitespace-only changes.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import queue
2+
import threading
3+
from monitoring import logger
4+
5+
from concurrent.futures import ThreadPoolExecutor
6+
7+
from core.singleton import Singleton
8+
from models.models import SessionLocal
9+
from indexing.repositories.repositories import MilvusRepository
10+
from indexing.services.directory_indexer import DirectoryIndexer
11+
from indexing.services.embedder_service import EmbedderService
12+
from settings import settings
13+
14+
15+
@Singleton
16+
class IndexQueueManager:
17+
def __init__(self):
18+
self.index_queue = queue.PriorityQueue()
19+
self.processing_paths = set()
20+
self.queue_lock = threading.Lock()
21+
self.index_workers = ThreadPoolExecutor(max_workers=settings.directory.num_watcher_workers)
22+
self.embedder_service = EmbedderService()
23+
self.milvus_repo = MilvusRepository()
24+
self.directory_indexer = DirectoryIndexer(self.embedder_service, self.milvus_repo)
25+
26+
def add_to_queue(self, directory_id: int, path: str, priority: int = 0):
27+
with self.queue_lock:
28+
if (directory_id, path) not in self.processing_paths:
29+
self.index_queue.put((priority, (directory_id, path)))
30+
self.processing_paths.add((directory_id, path))
31+
logger.debug(f"Queued directory {path} (ID: {directory_id}) with priority {priority}")
32+
self.index_workers.submit(self._process_queue)
33+
34+
def _process_queue(self):
35+
while not self.index_queue.empty():
36+
priority, (directory_id, path) = self.index_queue.get()
37+
session = SessionLocal()
38+
try:
39+
self.directory_indexer.index_directory(directory_id, path, session)
40+
finally:
41+
session.close()
42+
with self.queue_lock:
43+
self.processing_paths.discard((directory_id, path))
44+
logger.debug(f"Finished processing directory {path} (ID: {directory_id})")

backend/indexing/repositories/__init__.py

Whitespace-only changes.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from monitoring import logger
2+
3+
from typing import List, Dict
4+
5+
from pymilvus import Collection
6+
from sqlalchemy.orm import Session
7+
8+
from models.models import Directory, Image
9+
10+
11+
class DirectoryRepository:
12+
def __init__(self, session: Session):
13+
self.session = session
14+
15+
def get_by_path(self, path: str) -> Directory:
16+
return self.session.query(Directory).filter(Directory.path == path).first()
17+
18+
def create(self, path: str) -> Directory:
19+
directory = Directory(path=path, is_indexed=False)
20+
self.session.add(directory)
21+
self.session.commit()
22+
self.session.refresh(directory)
23+
logger.debug(f"Created directory entry with ID {directory.id} for path {path}")
24+
return directory
25+
26+
def get_all(self) -> List[Directory]:
27+
return self.session.query(Directory).all()
28+
29+
def delete(self, directory: Directory):
30+
self.session.delete(directory)
31+
self.session.commit()
32+
33+
34+
class ImageRepository:
35+
def __init__(self, session: Session):
36+
self.session = session
37+
38+
def get_by_path(self, path: str) -> Image:
39+
return self.session.query(Image).filter(Image.path == path).first()
40+
41+
def add_new_images(self, directory_id: int, image_paths: List[str]) -> List[Image]:
42+
new_images = []
43+
for path in image_paths:
44+
if not self.get_by_path(path):
45+
image = Image(path=path, directory_id=directory_id, is_indexed=False)
46+
self.session.add(image)
47+
new_images.append(image)
48+
self.session.commit()
49+
logger.info(f"Added {len(new_images)} new images to database for directory {directory_id}")
50+
return new_images
51+
52+
def get_unindexed_images(self, directory_id: int) -> List[Image]:
53+
return self.session.query(Image).filter(
54+
Image.directory_id == directory_id,
55+
Image.is_indexed == False
56+
).all()
57+
58+
def delete(self, image: Image):
59+
self.session.delete(image)
60+
self.session.commit()
61+
62+
63+
class MilvusRepository:
64+
def delete_entries(self, embedder_name: str, expr: str):
65+
collection = Collection(embedder_name)
66+
result = collection.delete(expr)
67+
collection.flush()
68+
logger.info(f"Deleted {result.delete_count} entries in Milvus collection '{embedder_name}' using expr {expr}")
69+
return result
70+
71+
def insert_entries(self, embedder_name: str, entries: List[Dict]):
72+
collection = Collection(embedder_name)
73+
collection.insert(entries)
74+
collection.flush()
75+
logger.debug(f"Inserted {len(entries)} entries into Milvus collection '{embedder_name}'")
76+
77+
def query_entries(self, embedder_name: str, expr: str, output_fields: List[str], batch_size: int = 1000):
78+
collection = Collection(embedder_name)
79+
return collection.query_iterator(expr=expr, output_fields=output_fields, batch_size=batch_size)

0 commit comments

Comments
 (0)