1
+ import torch
2
+ from PIL import Image
3
+ import os
4
+ from transformers import CLIPProcessor , CLIPModel
5
+ import numpy as np
6
+ import faiss
7
+ import pickle
8
+ import time
9
+
10
+ class ClipEncoder :
11
+ def __init__ (self ):
12
+ self .device = "cuda" if torch .cuda .is_available () else "cpu"
13
+ print (f"Using device: { self .device } " )
14
+
15
+ # Load CLIP model
16
+ self .model = CLIPModel .from_pretrained ("openai/clip-vit-base-patch32" ).to (self .device )
17
+ self .processor = CLIPProcessor .from_pretrained ("openai/clip-vit-base-patch32" )
18
+
19
+ # Initialize FAISS index
20
+ self .dimension = 512 # CLIP embedding dimension
21
+ self .index = faiss .IndexFlatIP (self .dimension ) # Inner product for cosine similarity
22
+
23
+ # Store mapping of indices to image paths
24
+ self .image_paths = []
25
+
26
+ def encode_images (self , image_dir ):
27
+ """Encode all images in the directory and build the FAISS index"""
28
+ start_time = time .time ()
29
+ print (f"Starting to encode images from { image_dir } ..." )
30
+
31
+ image_paths = [os .path .join (image_dir , f ) for f in os .listdir (image_dir )
32
+ if f .lower ().endswith (('.png' , '.jpg' , '.jpeg' ))]
33
+
34
+ all_embeddings = []
35
+
36
+ for img_path in image_paths :
37
+ try :
38
+ image = Image .open (img_path ).convert ('RGB' )
39
+ inputs = self .processor (images = image , return_tensors = "pt" ).to (self .device )
40
+
41
+ with torch .no_grad ():
42
+ image_features = self .model .get_image_features (** inputs )
43
+ image_embeddings = image_features .cpu ().numpy ()
44
+
45
+ # Normalize embeddings
46
+ image_embeddings = image_embeddings / np .linalg .norm (image_embeddings , axis = 1 , keepdims = True )
47
+
48
+ all_embeddings .append (image_embeddings [0 ])
49
+ self .image_paths .append (img_path )
50
+ except Exception as e :
51
+ print (f"Error processing { img_path } : { e } " )
52
+
53
+ # Add all embeddings to the FAISS index
54
+ if all_embeddings :
55
+ all_embeddings_array = np .array (all_embeddings ).astype ('float32' )
56
+ self .index .add (all_embeddings_array )
57
+ print (f"Added { len (all_embeddings )} images to the index" )
58
+ else :
59
+ print ("No images were successfully encoded" )
60
+
61
+ elapsed_time = time .time () - start_time
62
+ print (f"Encoding completed in { elapsed_time :.2f} seconds" )
63
+
64
+ def encode_text (self , text ):
65
+ """Encode text query using CLIP"""
66
+ inputs = self .processor (text = text , return_tensors = "pt" , padding = True ).to (self .device )
67
+
68
+ with torch .no_grad ():
69
+ text_features = self .model .get_text_features (** inputs )
70
+ text_embeddings = text_features .cpu ().numpy ()
71
+
72
+ # Normalize embeddings
73
+ text_embeddings = text_embeddings / np .linalg .norm (text_embeddings , axis = 1 , keepdims = True )
74
+
75
+ return text_embeddings
76
+
77
+ def search (self , query , k = 8 ):
78
+ """Search for similar images using text query"""
79
+ text_embedding = self .encode_text (query )
80
+ scores , indices = self .index .search (text_embedding .astype ('float32' ), k )
81
+
82
+ results = []
83
+ for idx , score in zip (indices [0 ], scores [0 ]):
84
+ if idx != - 1 : # Valid index
85
+ results .append ({
86
+ 'image_path' : self .image_paths [idx ],
87
+ 'score' : float (score ),
88
+ 'filename' : os .path .basename (self .image_paths [idx ])
89
+ })
90
+
91
+ return results
92
+
93
+ def encode_query_image (self , image_path ):
94
+ """Encode query image using CLIP"""
95
+ image = Image .open (image_path ).convert ('RGB' )
96
+ inputs = self .processor (images = image , return_tensors = "pt" ).to (self .device )
97
+
98
+ with torch .no_grad ():
99
+ image_features = self .model .get_image_features (** inputs )
100
+ image_embeddings = image_features .cpu ().numpy ()
101
+
102
+ # Normalize embeddings
103
+ image_embeddings = image_embeddings / np .linalg .norm (image_embeddings , axis = 1 , keepdims = True )
104
+
105
+ return image_embeddings
106
+
107
+ def search_by_image (self , image_path , k = 8 ):
108
+ """Search for similar images using an image query"""
109
+ image_embedding = self .encode_query_image (image_path )
110
+ scores , indices = self .index .search (image_embedding .astype ('float32' ), k )
111
+
112
+ results = []
113
+ for idx , score in zip (indices [0 ], scores [0 ]):
114
+ if idx != - 1 : # Valid index
115
+ results .append ({
116
+ 'image_path' : self .image_paths [idx ],
117
+ 'score' : float (score ),
118
+ 'filename' : os .path .basename (self .image_paths [idx ])
119
+ })
120
+
121
+ return results
122
+
123
+ def save_index (self , filename = "clip_search_index.pkl" ):
124
+ """Save the index and image paths to a file"""
125
+ with open (filename , 'wb' ) as f :
126
+ pickle .dump ({
127
+ 'index' : faiss .serialize_index (self .index ),
128
+ 'image_paths' : self .image_paths
129
+ }, f )
130
+ print (f"Index saved to { filename } " )
131
+
132
+ def load_index (self , filename = "clip_search_index.pkl" ):
133
+ """Load the index and image paths from a file"""
134
+ with open (filename , 'rb' ) as f :
135
+ data = pickle .load (f )
136
+ self .index = faiss .deserialize_index (data ['index' ])
137
+ self .image_paths = data ['image_paths' ]
138
+ print (f"Loaded index with { len (self .image_paths )} images" )
0 commit comments