From c900d02840df5a8c6b0ce53d4dac3af90584dc0b Mon Sep 17 00:00:00 2001 From: aniongithub Date: Fri, 5 Sep 2025 05:00:53 +0000 Subject: [PATCH 1/2] Enhance devcontainer setup and improve RAG functionality - Updated devcontainer configuration to include multiple docker-compose files. - Added new docker-compose.devcontainer.yml for devcontainer overriddes - Refined environment variables for text and image collections in .env file. - Improved prompt generation in rag.py for better user interaction. - Streamlined chat.sh for continuous question and image identification modes. - Adjusted docker-compose.yml to use model_cache for caching. --- .devcontainer/devcontainer.json | 6 +- .devcontainer/docker-compose.devcontainer.yml | 10 + .env | 3 +- api/memoryalpha/rag.py | 350 ++++++------------ chat.sh | 52 ++- docker-compose.yml | 4 +- 6 files changed, 180 insertions(+), 245 deletions(-) create mode 100644 .devcontainer/docker-compose.devcontainer.yml diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 1ef9694..4689b65 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,10 +1,14 @@ { "name": "Memoryalpha RAG API", - "dockerComposeFile": "../docker-compose.yml", + "dockerComposeFile": [ + "../docker-compose.yml", + "docker-compose.devcontainer.yml" + ], "service": "lcars", "overrideCommand": true, "remoteUser": "vscode", "postAttachCommand": "/workspace/memoryalpha-rag-api/wait-for-ollama.sh", + "postCreateCommand": "sudo chown -R vscode:users ${containerWorkspaceFolder}/.cache", "features": { "ghcr.io/nils-geistmann/devcontainers-features/create-remote-user:0": { "passwordLessSudo": true diff --git a/.devcontainer/docker-compose.devcontainer.yml b/.devcontainer/docker-compose.devcontainer.yml new file mode 100644 index 0000000..7ce13bf --- /dev/null +++ b/.devcontainer/docker-compose.devcontainer.yml @@ -0,0 +1,10 @@ +version: "3.9" + +services: + lcars: + build: + context: . + dockerfile: Dockerfile + user: vscode + volumes: + - ./.cache:/home/vscode/.cache # Override default cache location for devcontainer only \ No newline at end of file diff --git a/.env b/.env index 35bc98c..378b7b1 100644 --- a/.env +++ b/.env @@ -3,4 +3,5 @@ DEFAULT_IMAGE_MODEL="qwen2.5vl:3b" OLLAMA_URL="http://ollama:11434" DB_PATH="/data/enmemoryalpha_db" -COLLECTION_NAME="memoryalpha" \ No newline at end of file +TEXT_COLLECTION_NAME="memoryalpha_text" +IMAGE_COLLECTION_NAME="memoryalpha_images" \ No newline at end of file diff --git a/api/memoryalpha/rag.py b/api/memoryalpha/rag.py index c653d3e..6fddc3a 100644 --- a/api/memoryalpha/rag.py +++ b/api/memoryalpha/rag.py @@ -5,7 +5,6 @@ import re import logging import warnings -import numpy as np from typing import List, Dict, Any # External modules @@ -35,28 +34,43 @@ class ThinkingMode(Enum): def get_system_prompt(thinking_mode: ThinkingMode) -> str: """Generate the LCARS-style system prompt based on thinking mode""" - + + base_prompt = """You are an LCARS computer system with access to Star Trek Memory Alpha records. + +CRITICAL INSTRUCTIONS: +- You MUST answer ONLY using information from the provided records below +- If the records don't contain relevant information, say "I don't have information about that in my records" +- DO NOT make up information, invent characters, or hallucinate details +- DO NOT use external knowledge about Star Trek - only use the provided records +- If asked about something not in the records, be honest about the limitation + +""" + if thinking_mode == ThinkingMode.DISABLED: - return "You are an LCARS computer. Use the provided records to answer questions precisely in a single paragraph. Do not use thinking tags or analysis blocks." + return base_prompt + "Answer directly in a single paragraph without thinking tags." elif thinking_mode == ThinkingMode.QUIET: - return "You are an LCARS computer. Use tags for your analysis, then provide a precise answer in a single paragraph. Users will only see your final answer, not your thinking." + return base_prompt + "Use tags for internal analysis, then provide your final answer in a single paragraph." else: # VERBOSE - return "You are an LCARS computer. Use tags for your analysis, then provide a precise answer in a single paragraph. Your thinking process will be visible to users." + return base_prompt + "Use tags for analysis, then provide your final answer in a single paragraph." def get_user_prompt(context_text: str, query: str) -> str: """Format user prompt with context and query""" - - return f"""Records: + + if not context_text.strip(): + return f"I have no relevant records for this query. Please ask about Star Trek topics that are documented in Memory Alpha.\n\nQuery: {query}" + + return f"""MEMORY ALPHA RECORDS: {context_text} -Query: {query}""" +QUESTION: {query} + +Answer using ONLY the information in the records above. If the records don't contain the information needed to answer this question, say so clearly.""" class MemoryAlphaRAG: def __init__(self, chroma_db_path: str = os.getenv("DB_PATH"), ollama_url: str = os.getenv("OLLAMA_URL"), collection_name: str = os.getenv("COLLECTION_NAME", "memoryalpha"), - rerank_method: str = "cross-encoder", thinking_mode: ThinkingMode = ThinkingMode.DISABLED, max_history_turns: int = 5, thinking_text: str = "Processing..."): @@ -65,47 +79,29 @@ def __init__(self, raise ValueError("chroma_db_path must be provided or set in CHROMA_DB_PATH environment variable.") if not ollama_url: raise ValueError("ollama_url must be provided or set in OLLAMA_URL environment variable.") - if not collection_name: - raise ValueError("collection_name must be provided or set in COLLECTION_NAME environment variable.") self.chroma_db_path = chroma_db_path self.ollama_url = ollama_url self.collection_name = collection_name self.thinking_mode = thinking_mode self.max_history_turns = max_history_turns - self.rerank_method = rerank_method self.thinking_text = thinking_text self.conversation_history: List[Dict[str, str]] = [] - # Initialize conversation messages for ollama chat - self.messages = [] - - self.cross_encoder = None - self.embedding_model = None - - if rerank_method == "cross-encoder": - try: - logger.info("Loading cross-encoder model BAAI/bge-reranker-v2-m3...") - self.cross_encoder = CrossEncoder('BAAI/bge-reranker-v2-m3') - logger.info("Cross-encoder model loaded successfully") - except Exception: - logger.info("Loading fallback cross-encoder model BAAI/bge-reranker-base...") - self.cross_encoder = CrossEncoder('BAAI/bge-reranker-base') - logger.info("Fallback cross-encoder model loaded successfully") - elif rerank_method == "cosine": - logger.info("Loading embedding model all-MiniLM-L6-v2...") - self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') - logger.info("Embedding model loaded successfully") + # Initialize Ollama client first + self.ollama_client = ollama.Client(host=self.ollama_url) + # Initialize ChromaDB self.client = chromadb.PersistentClient( path=self.chroma_db_path, settings=Settings(allow_reset=False) ) - # Initialize text embedding model and collection - logger.info("Loading text embedding model all-MiniLM-L6-v2 for text collection...") + # Initialize text collection + logger.info("Loading text embedding model...") self.text_model = SentenceTransformer('all-MiniLM-L6-v2') logger.info("Text model loaded successfully") + from chromadb.utils import embedding_functions class TextEmbeddingFunction(embedding_functions.EmbeddingFunction): def __init__(self, text_model): @@ -116,83 +112,91 @@ def __call__(self, input): embedding = self.text_model.encode(text) embeddings.append(embedding.tolist()) return embeddings + self.text_ef = TextEmbeddingFunction(self.text_model) self.text_collection = self.client.get_or_create_collection("memoryalpha_text", embedding_function=self.text_ef) - # Initialize CLIP model and image collection - logger.info("Loading CLIP model for image collection...") - self.clip_model = SentenceTransformer('clip-ViT-B-32') - logger.info("CLIP model loaded successfully") - class CLIPEmbeddingFunction(embedding_functions.EmbeddingFunction): - def __init__(self, clip_model): - self.clip_model = clip_model - def __call__(self, input): - embeddings = [] - for img in input: - embedding = self.clip_model.encode(img) - embeddings.append(embedding.tolist()) - return embeddings - self.clip_ef = CLIPEmbeddingFunction(self.clip_model) - self.image_collection = self.client.get_or_create_collection("memoryalpha_images", embedding_function=self.clip_ef) - - # Initialize Ollama client - self.ollama_client = ollama.Client(host=self.ollama_url) - - def _cosine_similarity(self, query_embedding: np.ndarray, doc_embeddings: np.ndarray) -> np.ndarray: - query_norm = query_embedding / np.linalg.norm(query_embedding) - doc_norms = doc_embeddings / np.linalg.norm(doc_embeddings, axis=1, keepdims=True) - return np.dot(doc_norms, query_norm) + # Initialize cross-encoder for reranking + try: + logger.info("Loading cross-encoder model...") + self.cross_encoder = CrossEncoder('BAAI/bge-reranker-v2-m3') + logger.info("Cross-encoder model loaded successfully") + except Exception: + logger.warning("Could not load cross-encoder, using basic search only") + self.cross_encoder = None def search(self, query: str, top_k: int = 10) -> List[Dict[str, Any]]: - # Search only text documents using the text collection and text embedding model - results = self.text_collection.query( - query_texts=[query], - n_results=top_k - ) - docs = [ - { - "content": doc, - "title": meta["title"], - "distance": dist - } - for doc, meta, dist in zip(results["documents"][0], results["metadatas"][0], results["distances"][0]) - ] + """Search the Memory Alpha database for relevant documents.""" + + try: + # Perform semantic search + results = self.text_collection.query( + query_texts=[query], + n_results=min(top_k * 2, 50) # Get more results for reranking + ) + + if not results["documents"] or not results["documents"][0]: + logger.warning(f"No documents found for query: {query}") + return [] + + docs = [] + for doc, meta, dist in zip(results["documents"][0], results["metadatas"][0], results["distances"][0]): + docs.append({ + "content": doc, + "title": meta.get("title", "Unknown"), + "distance": dist + }) + + # Rerank with cross-encoder if available + if self.cross_encoder and len(docs) > 1: + pairs = [[query, doc["content"][:500]] for doc in docs] + scores = self.cross_encoder.predict(pairs) + for doc, score in zip(docs, scores): + doc["score"] = float(score) + docs = sorted(docs, key=lambda d: d["score"], reverse=True) + + return docs[:top_k] + + except Exception as e: + logger.error(f"Search failed: {e}") + return [] - if self.cross_encoder: - pairs = [[query, d["content"][:300]] for d in docs] - scores = self.cross_encoder.predict(pairs) - for doc, score in zip(docs, scores): - doc["score"] = float(score) - return sorted(docs, key=lambda d: d["score"], reverse=True) + def build_prompt(self, query: str, docs: List[Dict[str, Any]]) -> tuple[str, str]: + """Build the prompt with retrieved documents.""" - elif self.embedding_model: - query_emb = self.embedding_model.encode([query])[0] - doc_embs = self.embedding_model.encode([d["content"][:300] for d in docs]) - sims = self._cosine_similarity(query_emb, np.array(doc_embs)) - for doc, score in zip(docs, sims): - doc["score"] = float(score) - return sorted(docs, key=lambda d: d["score"], reverse=True) + system_prompt = get_system_prompt(self.thinking_mode) - return sorted(docs, key=lambda d: d["distance"]) + if not docs: + context_text = "" + else: + # Format context with clear structure + context_parts = [] + for i, doc in enumerate(docs, 1): + content = doc['content'] + # Limit content length to avoid token limits + if len(content) > 1000: + content = content[:1000] + "..." + context_parts.append(f"DOCUMENT {i}: {doc['title']}\n{content}") + + context_text = "\n\n".join(context_parts) - def build_prompt(self, query: str, docs: List[Dict[str, Any]]) -> tuple[str, str]: - system_prompt = get_system_prompt(self.thinking_mode) - context_text = "\n\n".join( - f"=== {doc['title']} ===\n{doc['content']}" for doc in docs - ) user_prompt = get_user_prompt(context_text, query) return system_prompt, user_prompt - def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float = 0.8, temperature: float = 0.3, + def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float = 0.8, temperature: float = 0.3, model: str = os.getenv("DEFAULT_MODEL")) -> str: """ - Ask a question using the specified model (defaults to $DEFAULT_MODEL if not provided). + Ask a question using the Memory Alpha RAG system. """ if not model: raise ValueError("model must be provided or set in DEFAULT_MODEL environment variable.") + # Search for relevant documents docs = self.search(query, top_k=top_k) + logger.info(f"Found {len(docs)} documents for query: {query}") + + # Build prompts system_prompt, user_prompt = self.build_prompt(query, docs) # Build messages for chat @@ -200,34 +204,40 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float {"role": "system", "content": system_prompt} ] - # Add conversation history - for exchange in self.conversation_history[-3:]: # Last 3 exchanges + # Add conversation history (limited) + for exchange in self.conversation_history[-2:]: # Last 2 exchanges messages.append({"role": "user", "content": exchange["question"]}) messages.append({"role": "assistant", "content": exchange["answer"]}) # Add current query messages.append({"role": "user", "content": user_prompt}) - result = self.ollama_client.chat( - model=model, - messages=messages, - stream=False, - options={"temperature": temperature, "top_p": top_p, "num_predict": max_tokens} - ) - full_response = result['message']['content'] - - # Handle thinking mode response processing - if self.thinking_mode == ThinkingMode.DISABLED: - final_response = self._clean_response(full_response) - elif self.thinking_mode == ThinkingMode.QUIET: - final_response = self._replace_thinking_tags(full_response) - else: # VERBOSE - final_response = full_response.strip() - - self._update_history(query, final_response) - return final_response + try: + result = self.ollama_client.chat( + model=model, + messages=messages, + stream=False, + options={"temperature": temperature, "top_p": top_p, "num_predict": max_tokens} + ) + full_response = result['message']['content'] + + # Handle thinking mode response processing + if self.thinking_mode == ThinkingMode.DISABLED: + final_response = self._clean_response(full_response) + elif self.thinking_mode == ThinkingMode.QUIET: + final_response = self._replace_thinking_tags(full_response) + else: # VERBOSE + final_response = full_response.strip() + + self._update_history(query, final_response) + return final_response + + except Exception as e: + logger.error(f"Chat failed: {e}") + return f"Error processing query: {str(e)}" def _clean_response(self, answer: str) -> str: + """Clean response by removing ANSI codes and thinking tags.""" clean = re.sub(r"\033\[[0-9;]*m", "", answer).replace("LCARS: ", "").strip() while "" in clean and "" in clean: start = clean.find("") @@ -236,6 +246,7 @@ def _clean_response(self, answer: str) -> str: return clean.strip() def _replace_thinking_tags(self, answer: str) -> str: + """Replace thinking tags with processing text.""" clean = re.sub(r"\033\[[0-9;]*m", "", answer).replace("LCARS: ", "").strip() while "" in clean and "" in clean: start = clean.find("") @@ -244,121 +255,6 @@ def _replace_thinking_tags(self, answer: str) -> str: return clean.strip() def _update_history(self, question: str, answer: str): + """Update conversation history.""" self.conversation_history.append({"question": question, "answer": answer}) - self.conversation_history = self.conversation_history[-self.max_history_turns:] - - def search_image(self, image_path: str, top_k: int = 5, - model: str = os.getenv("DEFAULT_IMAGE_MODEL")) -> Dict[str, Any]: - """ - 1. Generates CLIP embedding for the provided image - 2. Searches image records, retrieves top_k - 3. Downloads actual images for image results - 4. Uses source page titles to fetch text context from text collection - 5. Passes all info to the model to guess the theme and image - """ - from PIL import Image - import requests - import tempfile - import os - - if not model: - raise ValueError("model must be provided or set in DEFAULT_IMAGE_MODEL environment variable.") - - # 1. Generate CLIP embedding for the image - image = Image.open(image_path).convert('RGB') - image_embedding = self.clip_model.encode(image) - image_embedding = image_embedding.tolist() - - # 2. Search image records only - image_results = self.image_collection.query( - query_embeddings=[image_embedding], - n_results=top_k - ) - - # 3. Download actual images for image results and prepare for attachment - downloaded_images = [] - image_binaries = [] - image_docs = image_results['documents'][0] - image_metas = image_results['metadatas'][0] - image_urls = [meta.get('image_url') for meta in image_metas] - for idx, url in enumerate(image_urls): - if url: - try: - resp = requests.get(url, timeout=30) - if resp.status_code == 200: - with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp: - tmp.write(resp.content) - downloaded_images.append(tmp.name) - image_binaries.append(resp.content) - else: - downloaded_images.append(None) - image_binaries.append(None) - except Exception: - downloaded_images.append(None) - image_binaries.append(None) - else: - downloaded_images.append(None) - image_binaries.append(None) - - # 4. Use source page titles to fetch text context from text collection - source_titles = [meta.get('source_page') for meta in image_metas if meta.get('source_page')] - text_contexts = [] - if source_titles: - # Query text collection for each source page title - for title in source_titles: - text_results = self.text_collection.query( - query_texts=[title], - n_results=1 - ) - if text_results['documents'][0]: - doc = text_results['documents'][0][0] - meta = text_results['metadatas'][0][0] - dist = text_results['distances'][0][0] - text_contexts.append(f"Text Context for '{title}':\nTitle: {meta.get('title', 'Unknown')}\nSimilarity: {1-dist:.4f}\nContent: {doc[:300]}\n") - - # 5. Number and format results, reference images as Image 1, Image 2, etc. - formatted_images = [] - image_indices = [] - for i, (doc, meta, dist, img_path, img_bin) in enumerate(zip(image_docs, image_metas, image_results['distances'][0], downloaded_images, image_binaries), 1): - if img_bin: - formatted_images.append(f"Image {i}:\nImage Name: {meta.get('image_name', 'Unknown')}\nSource Page: {meta.get('source_page', 'Unknown')}\nSimilarity: {1-dist:.4f}\nDescription: {doc}\n(Refer to attached Image {i})\n") - image_indices.append(i-1) # index in image_binaries - else: - formatted_images.append(f"Image {i}:\nImage Name: {meta.get('image_name', 'Unknown')}\nSource Page: {meta.get('source_page', 'Unknown')}\nSimilarity: {1-dist:.4f}\nDescription: {doc}\nImage download failed.\n") - - # 6. Pass all info to the model, attach images - prompt = ( - "You are an expert Star Trek analyst. Look at the first attached image and determine which of the retrieved images below most closely matches it. " - "Use the metadata (image name, source page, description, similarity score, and text context) of the closest match to identify what is shown. " - "Provide a direct identification without mentioning image numbers, matches, or references to user images. " - "If no close match is found, say so clearly.\n\n" - ) - prompt += "\n".join(formatted_images) - if text_contexts: - prompt += "\n".join(text_contexts) - prompt += "\nRespond with one or two lines directly identifying what is shown in the image, based on the closest match and its metadata." - - messages = [ - {"role": "system", "content": "You are an expert Star Trek analyst."}, - {"role": "user", "content": prompt, "images": [image_binaries[i] for i in image_indices]} - ] - - # Only attach images that were successfully downloaded - response = self.ollama_client.chat( - model=model, - messages=messages, - stream=False - ) - answer = response['message']['content'] - - # Clean up temp images - for img_path in downloaded_images: - if img_path and os.path.exists(img_path): - try: - os.remove(img_path) - except Exception: - pass - - return { - "model_answer": answer - } \ No newline at end of file + self.conversation_history = self.conversation_history[-self.max_history_turns:] \ No newline at end of file diff --git a/chat.sh b/chat.sh index e54bdc9..1d2c962 100755 --- a/chat.sh +++ b/chat.sh @@ -12,6 +12,40 @@ echo "🖖 Welcome to MemoryAlpha RAG Chat" echo "Type 'quit' or 'exit' to end the session" echo "----------------------------------------" +# Function to handle continuous text questions +ask_mode() { + echo "🤖 Entering Question Mode - Type 'q' to return to main menu" + echo "----------------------------------------" + while true; do + echo -n "❓ Enter your question (or 'q' to quit): " + read -r question + if [[ "$question" == "q" || "$question" == "quit" ]]; then + break + fi + if [[ -z "$question" ]]; then + continue + fi + ask_question "$question" + done +} + +# Function to handle continuous image identification +identify_mode() { + echo "🖼️ Entering Image Identification Mode - Type 'q' to return to main menu" + echo "----------------------------------------" + while true; do + echo -n "🖼️ Enter local image path or image URL (or 'q' to quit): " + read -r image_path + if [[ "$image_path" == "q" || "$image_path" == "quit" ]]; then + break + fi + if [[ -z "$image_path" ]]; then + continue + fi + identify_image "$image_path" + done +} + # Function to handle text question ask_question() { local question="$1" @@ -83,27 +117,17 @@ identify_image() { while true; do echo "Choose an option:" - echo " 1) Ask a Star Trek question" - echo " 2) Identify an image" + echo " 1) Ask Star Trek questions" + echo " 2) Identify images" echo " q) Quit" echo -n "Enter choice [1/2/q]: " read -r choice case "$choice" in 1) - echo -n "❓ Enter your question: " - read -r question - if [[ -z "$question" ]]; then - continue - fi - ask_question "$question" + ask_mode ;; 2) - echo -n "🖼️ Enter local image path or image URL: " - read -r image_path - if [[ -z "$image_path" ]]; then - continue - fi - identify_image "$image_path" + identify_mode ;; q|quit|exit) echo "🖖 Live long and prosper!" diff --git a/docker-compose.yml b/docker-compose.yml index 2e5e4e2..b1f85cd 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -20,7 +20,7 @@ services: - ollama volumes: - .:/workspace/memoryalpha-rag-api - - rag_cache:/root/.cache + - model_cache:/root/.cache ports: - "8000:8000" # for REST API networks: @@ -31,7 +31,7 @@ services: volumes: ollama_data: - rag_cache: + model_cache: networks: odn: From f657659997ce0ced0ad2f49fb8a8c530ae562344 Mon Sep 17 00:00:00 2001 From: aniongithub Date: Fri, 5 Sep 2025 05:15:44 +0000 Subject: [PATCH 2/2] Refactor prompts for clarity and implement lazy loading for models in MemoryAlphaRAG --- api/memoryalpha/rag.py | 187 +++++++++++++++++++++++++++++++++-------- 1 file changed, 152 insertions(+), 35 deletions(-) diff --git a/api/memoryalpha/rag.py b/api/memoryalpha/rag.py index 6fddc3a..f3a96ff 100644 --- a/api/memoryalpha/rag.py +++ b/api/memoryalpha/rag.py @@ -38,11 +38,12 @@ def get_system_prompt(thinking_mode: ThinkingMode) -> str: base_prompt = """You are an LCARS computer system with access to Star Trek Memory Alpha records. CRITICAL INSTRUCTIONS: -- You MUST answer ONLY using information from the provided records below +- You MUST answer ONLY using information from the provided records - If the records don't contain relevant information, say "I don't have information about that in my records" - DO NOT make up information, invent characters, or hallucinate details - DO NOT use external knowledge about Star Trek - only use the provided records - If asked about something not in the records, be honest about the limitation +- Stay in character as an LCARS computer system at all times """ @@ -57,14 +58,14 @@ def get_user_prompt(context_text: str, query: str) -> str: """Format user prompt with context and query""" if not context_text.strip(): - return f"I have no relevant records for this query. Please ask about Star Trek topics that are documented in Memory Alpha.\n\nQuery: {query}" + return f"Starfleet database records contain no relevant information for this inquiry. Please inquire about documented Star Trek topics.\n\nINQUIRY: {query}" return f"""MEMORY ALPHA RECORDS: {context_text} -QUESTION: {query} +INQUIRY: {query} -Answer using ONLY the information in the records above. If the records don't contain the information needed to answer this question, say so clearly.""" +Accessing Starfleet database records. Provide analysis using ONLY the information in the records above. If the records don't contain the information needed to answer this inquiry, state that the information is not available in current records.""" class MemoryAlphaRAG: def __init__(self, @@ -88,42 +89,92 @@ def __init__(self, self.thinking_text = thinking_text self.conversation_history: List[Dict[str, str]] = [] - # Initialize Ollama client first + # Initialize lightweight components self.ollama_client = ollama.Client(host=self.ollama_url) - - # Initialize ChromaDB self.client = chromadb.PersistentClient( path=self.chroma_db_path, settings=Settings(allow_reset=False) ) - # Initialize text collection - logger.info("Loading text embedding model...") - self.text_model = SentenceTransformer('all-MiniLM-L6-v2') - logger.info("Text model loaded successfully") - - from chromadb.utils import embedding_functions - class TextEmbeddingFunction(embedding_functions.EmbeddingFunction): - def __init__(self, text_model): - self.text_model = text_model - def __call__(self, input): - embeddings = [] - for text in input: - embedding = self.text_model.encode(text) - embeddings.append(embedding.tolist()) - return embeddings - - self.text_ef = TextEmbeddingFunction(self.text_model) - self.text_collection = self.client.get_or_create_collection("memoryalpha_text", embedding_function=self.text_ef) - - # Initialize cross-encoder for reranking - try: - logger.info("Loading cross-encoder model...") - self.cross_encoder = CrossEncoder('BAAI/bge-reranker-v2-m3') - logger.info("Cross-encoder model loaded successfully") - except Exception: - logger.warning("Could not load cross-encoder, using basic search only") - self.cross_encoder = None + # Lazy-loaded components + self._text_model = None + self._cross_encoder = None + self._clip_model = None + self._text_collection = None + self._image_collection = None + self._text_ef = None + self._clip_ef = None + + @property + def text_model(self): + """Lazy load text embedding model.""" + if self._text_model is None: + logger.info("Loading text embedding model...") + self._text_model = SentenceTransformer('all-MiniLM-L6-v2') + logger.info("Text model loaded successfully") + return self._text_model + + @property + def cross_encoder(self): + """Lazy load cross-encoder model.""" + if self._cross_encoder is None: + try: + logger.info("Loading cross-encoder model...") + self._cross_encoder = CrossEncoder('BAAI/bge-reranker-v2-m3') + logger.info("Cross-encoder model loaded successfully") + except Exception as e: + logger.warning(f"Could not load cross-encoder: {e}") + self._cross_encoder = None + return self._cross_encoder + + @property + def clip_model(self): + """Lazy load CLIP model for image search.""" + if self._clip_model is None: + logger.info("Loading CLIP model for image search...") + self._clip_model = SentenceTransformer('clip-ViT-B-32') + logger.info("CLIP model loaded successfully") + return self._clip_model + + @property + def text_collection(self): + """Lazy load text collection.""" + if self._text_collection is None: + from chromadb.utils import embedding_functions + + class TextEmbeddingFunction(embedding_functions.EmbeddingFunction): + def __init__(self, text_model): + self.text_model = text_model + def __call__(self, input): + embeddings = [] + for text in input: + embedding = self.text_model.encode(text) + embeddings.append(embedding.tolist()) + return embeddings + + self._text_ef = TextEmbeddingFunction(self.text_model) + self._text_collection = self.client.get_or_create_collection("memoryalpha_text", embedding_function=self._text_ef) + return self._text_collection + + @property + def image_collection(self): + """Lazy load image collection.""" + if self._image_collection is None: + from chromadb.utils import embedding_functions + + class CLIPEmbeddingFunction(embedding_functions.EmbeddingFunction): + def __init__(self, clip_model): + self.clip_model = clip_model + def __call__(self, input): + embeddings = [] + for img in input: + embedding = self.clip_model.encode(img) + embeddings.append(embedding.tolist()) + return embeddings + + self._clip_ef = CLIPEmbeddingFunction(self.clip_model) + self._image_collection = self.client.get_or_create_collection("memoryalpha_images", embedding_function=self._clip_ef) + return self._image_collection def search(self, query: str, top_k: int = 10) -> List[Dict[str, Any]]: """Search the Memory Alpha database for relevant documents.""" @@ -257,4 +308,70 @@ def _replace_thinking_tags(self, answer: str) -> str: def _update_history(self, question: str, answer: str): """Update conversation history.""" self.conversation_history.append({"question": question, "answer": answer}) - self.conversation_history = self.conversation_history[-self.max_history_turns:] \ No newline at end of file + self.conversation_history = self.conversation_history[-self.max_history_turns:] + + def search_image(self, image_path: str, top_k: int = 5, + model: str = os.getenv("DEFAULT_IMAGE_MODEL")) -> Dict[str, Any]: + """ + Search for images similar to the provided image. + """ + from PIL import Image + import requests + import tempfile + import os + + if not model: + raise ValueError("model must be provided or set in DEFAULT_IMAGE_MODEL environment variable.") + + try: + # Load image and generate embedding + image = Image.open(image_path).convert('RGB') + image_embedding = self.clip_model.encode(image) + image_embedding = image_embedding.tolist() + + # Search image collection + image_results = self.image_collection.query( + query_embeddings=[image_embedding], + n_results=top_k + ) + + # Process results + if not image_results["documents"] or not image_results["documents"][0]: + return {"model_answer": "No matching visual records found in Starfleet archives."} + + # Format results for the model + formatted_results = [] + for i, (doc, meta, dist) in enumerate(zip( + image_results['documents'][0], + image_results['metadatas'][0], + image_results['distances'][0] + ), 1): + record_name = meta.get('image_name', 'Unknown visual record') + formatted_results.append(f"Visual Record {i}: {record_name}") + + result_text = "\n".join(formatted_results) + + # Use LLM to provide a natural language summary + prompt = f"""You are an LCARS computer system analyzing visual records from Starfleet archives. + +Based on these visual record matches, identify what subject or scene is being depicted: + +{result_text} + +Provide a direct identification of the subject without referencing images, searches, or technical processes. Stay in character as an LCARS computer system.""" + + result = self.ollama_client.chat( + model=model, + messages=[ + {"role": "system", "content": "You are an LCARS computer system. Respond in character without breaking the Star Trek universe immersion. Do not reference images, searches, or technical processes."}, + {"role": "user", "content": prompt} + ], + stream=False, + options={"temperature": 0.3, "num_predict": 200} + ) + + return {"model_answer": result['message']['content']} + + except Exception as e: + logger.error(f"Image search failed: {e}") + return {"model_answer": "Error accessing visual records database."} \ No newline at end of file