Skip to content

Commit 341cce6

Browse files
authored
Add files via upload
Initial project upload
0 parents  commit 341cce6

File tree

7 files changed

+516
-0
lines changed

7 files changed

+516
-0
lines changed

.gitignore

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Python
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
*.so
6+
.Python
7+
env/
8+
build/
9+
develop-eggs/
10+
dist/
11+
downloads/
12+
eggs/
13+
.eggs/
14+
lib/
15+
lib64/
16+
parts/
17+
sdist/
18+
var/
19+
*.egg-info/
20+
.installed.cfg
21+
*.egg
22+
23+
# Virtual Environment
24+
venv/
25+
ENV/
26+
27+
# Project specific
28+
clip_search_index.pkl
29+
static/uploads/*
30+
!static/uploads/.gitkeep
31+
32+
# IDE
33+
.idea/
34+
.vscode/
35+
*.swp
36+
*.swo
37+
38+
# OS specific
39+
.DS_Store
40+
.DS_Store?
41+
._*
42+
.Spotlight-V100
43+
.Trashes
44+
ehthumbs.db
45+
Thumbs.db

README.txt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# One-Shot Image Search Engine
2+
3+
A semantic image search engine built with CLIP and FAISS that allows searching by text descriptions or similar images.
4+
5+
## Features
6+
7+
- **Text-to-Image Search**: Find images by describing them in natural language
8+
- **Image-to-Image Search**: Upload an image to find visually similar ones
9+
- **Fast Vector Search**: Uses FAISS for efficient similarity search
10+
- **Pre-trained AI Model**: Leverages OpenAI's CLIP for understanding image content
11+
- **Web Interface**: Clean, responsive UI built with Flask and Bootstrap
12+
13+
## Technologies Used
14+
15+
- **CLIP**: OpenAI's Contrastive Language-Image Pre-training model
16+
- **FAISS**: Facebook AI Similarity Search for vector similarity search
17+
- **PyTorch**: Deep learning framework
18+
- **Flask**: Web application framework
19+
- **Bootstrap**: Frontend styling
20+
21+
## Installation
22+
23+
1. Clone this repository:
24+
git clone https://github.com/shubhrat12/Image-search-engine.git
25+
cd image-search-engine
26+
2. Create a virtual environment and install dependencies:
27+
python -m venv venv
28+
source venv/bin/activate # On Windows: venv\Scripts\activate
29+
pip install -r requirements.txt
30+
3. Run the application:
31+
python app.py
32+
4. Open your browser and go to http://127.0.0.1:5000
33+
34+
## How It Works
35+
36+
1. The application uses CLIP to convert images into vector embeddings
37+
2. These embeddings capture the semantic meaning of each image
38+
3. When searching with text, the query is also converted to the same vector space
39+
4. FAISS finds the most similar image vectors to your query vector
40+
5. Results are returned based on cosine similarity scores

app.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import os
2+
from flask import Flask, render_template, request, redirect, url_for
3+
from werkzeug.utils import secure_filename
4+
from clip_encoder import ClipEncoder
5+
import time
6+
7+
app = Flask(__name__)
8+
app.config['UPLOAD_FOLDER'] = 'static/uploads'
9+
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max upload size
10+
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
11+
12+
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
13+
14+
def allowed_file(filename):
15+
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
16+
17+
# Initialize the CLIP encoder
18+
encoder = ClipEncoder()
19+
20+
# Check if the index file exists, otherwise create it
21+
if os.path.exists('clip_search_index.pkl'):
22+
encoder.load_index()
23+
else:
24+
# Build the index from images in the static/images directory
25+
encoder.encode_images('static/images')
26+
encoder.save_index()
27+
28+
@app.route('/', methods=['GET'])
29+
def index():
30+
return render_template('index.html')
31+
32+
@app.route('/search', methods=['POST'])
33+
def search():
34+
if 'text_query' in request.form:
35+
# Text-based search
36+
query = request.form['text_query']
37+
if not query:
38+
return redirect(url_for('index'))
39+
40+
start_time = time.time()
41+
results = encoder.search(query, k=12)
42+
search_time = time.time() - start_time
43+
44+
# Convert image paths to web paths
45+
for result in results:
46+
result['image_url'] = '/' + result['image_path'].replace('\\', '/')
47+
48+
return render_template('results.html',
49+
query=query,
50+
results=results,
51+
search_type="text",
52+
search_time=search_time)
53+
54+
elif 'image_query' in request.files:
55+
# Image-based search
56+
file = request.files['image_query']
57+
if file and allowed_file(file.filename):
58+
filename = secure_filename(file.filename)
59+
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
60+
file.save(file_path)
61+
62+
start_time = time.time()
63+
results = encoder.search_by_image(file_path, k=12)
64+
search_time = time.time() - start_time
65+
66+
# Convert image paths to web paths
67+
query_image_url = url_for('static', filename=f'uploads/{filename}')
68+
for result in results:
69+
result['image_url'] = '/' + result['image_path'].replace('\\', '/')
70+
71+
return render_template('results.html',
72+
query_image=query_image_url,
73+
results=results,
74+
search_type="image",
75+
search_time=search_time)
76+
77+
return redirect(url_for('index'))
78+
79+
if __name__ == '__main__':
80+
app.run(debug=True)

clip_encoder.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import torch
2+
from PIL import Image
3+
import os
4+
from transformers import CLIPProcessor, CLIPModel
5+
import numpy as np
6+
import faiss
7+
import pickle
8+
import time
9+
10+
class ClipEncoder:
11+
def __init__(self):
12+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
13+
print(f"Using device: {self.device}")
14+
15+
# Load CLIP model
16+
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
17+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
18+
19+
# Initialize FAISS index
20+
self.dimension = 512 # CLIP embedding dimension
21+
self.index = faiss.IndexFlatIP(self.dimension) # Inner product for cosine similarity
22+
23+
# Store mapping of indices to image paths
24+
self.image_paths = []
25+
26+
def encode_images(self, image_dir):
27+
"""Encode all images in the directory and build the FAISS index"""
28+
start_time = time.time()
29+
print(f"Starting to encode images from {image_dir}...")
30+
31+
image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)
32+
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
33+
34+
all_embeddings = []
35+
36+
for img_path in image_paths:
37+
try:
38+
image = Image.open(img_path).convert('RGB')
39+
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
40+
41+
with torch.no_grad():
42+
image_features = self.model.get_image_features(**inputs)
43+
image_embeddings = image_features.cpu().numpy()
44+
45+
# Normalize embeddings
46+
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
47+
48+
all_embeddings.append(image_embeddings[0])
49+
self.image_paths.append(img_path)
50+
except Exception as e:
51+
print(f"Error processing {img_path}: {e}")
52+
53+
# Add all embeddings to the FAISS index
54+
if all_embeddings:
55+
all_embeddings_array = np.array(all_embeddings).astype('float32')
56+
self.index.add(all_embeddings_array)
57+
print(f"Added {len(all_embeddings)} images to the index")
58+
else:
59+
print("No images were successfully encoded")
60+
61+
elapsed_time = time.time() - start_time
62+
print(f"Encoding completed in {elapsed_time:.2f} seconds")
63+
64+
def encode_text(self, text):
65+
"""Encode text query using CLIP"""
66+
inputs = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
67+
68+
with torch.no_grad():
69+
text_features = self.model.get_text_features(**inputs)
70+
text_embeddings = text_features.cpu().numpy()
71+
72+
# Normalize embeddings
73+
text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
74+
75+
return text_embeddings
76+
77+
def search(self, query, k=8):
78+
"""Search for similar images using text query"""
79+
text_embedding = self.encode_text(query)
80+
scores, indices = self.index.search(text_embedding.astype('float32'), k)
81+
82+
results = []
83+
for idx, score in zip(indices[0], scores[0]):
84+
if idx != -1: # Valid index
85+
results.append({
86+
'image_path': self.image_paths[idx],
87+
'score': float(score),
88+
'filename': os.path.basename(self.image_paths[idx])
89+
})
90+
91+
return results
92+
93+
def encode_query_image(self, image_path):
94+
"""Encode query image using CLIP"""
95+
image = Image.open(image_path).convert('RGB')
96+
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
97+
98+
with torch.no_grad():
99+
image_features = self.model.get_image_features(**inputs)
100+
image_embeddings = image_features.cpu().numpy()
101+
102+
# Normalize embeddings
103+
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
104+
105+
return image_embeddings
106+
107+
def search_by_image(self, image_path, k=8):
108+
"""Search for similar images using an image query"""
109+
image_embedding = self.encode_query_image(image_path)
110+
scores, indices = self.index.search(image_embedding.astype('float32'), k)
111+
112+
results = []
113+
for idx, score in zip(indices[0], scores[0]):
114+
if idx != -1: # Valid index
115+
results.append({
116+
'image_path': self.image_paths[idx],
117+
'score': float(score),
118+
'filename': os.path.basename(self.image_paths[idx])
119+
})
120+
121+
return results
122+
123+
def save_index(self, filename="clip_search_index.pkl"):
124+
"""Save the index and image paths to a file"""
125+
with open(filename, 'wb') as f:
126+
pickle.dump({
127+
'index': faiss.serialize_index(self.index),
128+
'image_paths': self.image_paths
129+
}, f)
130+
print(f"Index saved to {filename}")
131+
132+
def load_index(self, filename="clip_search_index.pkl"):
133+
"""Load the index and image paths from a file"""
134+
with open(filename, 'rb') as f:
135+
data = pickle.load(f)
136+
self.index = faiss.deserialize_index(data['index'])
137+
self.image_paths = data['image_paths']
138+
print(f"Loaded index with {len(self.image_paths)} images")

0 commit comments

Comments
 (0)