Skip to content

Commit e6a63fa

Browse files
authored
Update our RAG to use CLIP embeddings to support the new multi-modal data. (#4)
1 parent 49a962b commit e6a63fa

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

api/memoryalpha/rag.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,29 @@ def __init__(self,
115115
path=self.chroma_db_path,
116116
settings=Settings(allow_reset=False)
117117
)
118-
self.collection = self.client.get_collection(self.collection_name)
118+
119+
# Initialize CLIP model for consistent embeddings with the database
120+
logger.info("Loading CLIP model for embedding compatibility...")
121+
self.clip_model = SentenceTransformer('clip-ViT-B-32')
122+
logger.info("CLIP model loaded successfully")
123+
124+
# Create CLIP embedding function to match the one used during data creation
125+
from chromadb.utils import embedding_functions
126+
127+
class CLIPEmbeddingFunction(embedding_functions.EmbeddingFunction):
128+
def __init__(self, clip_model):
129+
self.clip_model = clip_model
130+
131+
def __call__(self, input):
132+
"""Generate embeddings using CLIP model"""
133+
embeddings = []
134+
for text in input:
135+
embedding = self.clip_model.encode(text)
136+
embeddings.append(embedding.tolist())
137+
return embeddings
138+
139+
self.clip_ef = CLIPEmbeddingFunction(self.clip_model)
140+
self.collection = self.client.get_collection(self.collection_name, embedding_function=self.clip_ef)
119141

120142
# Initialize Ollama client
121143
self.ollama_client = ollama.Client(host=self.ollama_url)
@@ -139,12 +161,18 @@ def _cosine_similarity(self, query_embedding: np.ndarray, doc_embeddings: np.nda
139161
return np.dot(doc_norms, query_norm)
140162

141163
def search(self, query: str, top_k: int = 10) -> List[Dict[str, Any]]:
142-
results = self.collection.query(query_texts=[query], n_results=top_k)
164+
# Search only text documents (filter out image documents for now)
165+
results = self.collection.query(
166+
query_texts=[query],
167+
n_results=top_k,
168+
where={"content_type": "text"} # Only search text documents
169+
)
143170
docs = [
144171
{
145172
"content": doc,
146173
"title": meta["title"],
147-
"distance": dist
174+
"distance": dist,
175+
"content_type": meta.get("content_type", "text")
148176
}
149177
for doc, meta, dist in zip(results["documents"][0], results["metadatas"][0], results["distances"][0])
150178
]

api/memoryalpha/rag/query.py

Whitespace-only changes.

0 commit comments

Comments
 (0)