diff --git a/CHANGELOG.md b/CHANGELOG.md index 17ed1a98e90..11dfc2b7ba1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,26 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.6.32] - 2025-09-29 + +### Added + +- 🗝️ Permission toggle for public sharing of notes was added, allowing note owners to quickly enable or disable public access from the note settings interface. +- ⚠️ A warning is now displayed in the user edit modal if conflicting group permissions are detected, helping administrators resolve access control ambiguities before saving changes. + +### Fixed + +- 🧰 Fixed regression where External Tool servers (OpenAPI) were nonfunctional after the 0.6.31 update; external tools integration is now restored and reliable. +- 🚑 Resolved a critical bug causing Streamable HTTP OAuth 2.1 (MCP server) integrations to throw a 500 error on first invocation due to missing 'SessionMiddleware'. OAuth 2.1 registration now succeeds and works on subsequent requests as expected. +- 🐛 The "Set as default" option is now reliably clickable in model and filter selection menus, fixing cases where the interface appeared unresponsive. +- 🛠️ Embed UI now works seamlessly with both default and native function calling flows, ensuring the tool embedding experience is consistent regardless of invocation method. +- 🧹 Addressed various minor UI bugs and inconsistencies for a cleaner user experience. + +### Changed + +- 🧬 MCP tool result handling code was refactored for improved parsing and robustness of tool outputs. +- 🧩 The user edit modal was overhauled for clarity and usability, improving the organization of group, permission, and public sharing controls. + ## [0.6.31] - 2025-09-25 ### Added diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 7e5c35a4512..95e79d891d4 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1098,6 +1098,32 @@ def feishu_oauth_register(client: OAuth): ), ) + +# PATCH EXTRA LOGIN INFO +SYSTEM_REGISTER_URL = PersistentConfig( + "SYSTEM_REGISTER_URL", + "ui.SYSTEM_REGISTER_URL", + os.environ.get("SYSTEM_REGISTER_URL", ""), +) + + +SYSTEM_REGISTER_GUIDE_URL = PersistentConfig( + "SYSTEM_REGISTER_GUIDE_URL", + "ui.SYSTEM_REGISTER_GUIDE_URL", + os.environ.get("SYSTEM_REGISTER_GUIDE_URL", ""), +) +# /PATCH EXTRA LOGIN INFO + + +# PATCH ADD LOGO TO SIDEBAR +LOGO_URL = PersistentConfig( + "LOGO_URL", + "ui.LOGO_URL", + os.environ.get("LOGO_URL", ""), +) +# /PATCH ADD LOGO TO SIDEBAR + + ENABLE_LOGIN_FORM = PersistentConfig( "ENABLE_LOGIN_FORM", "ui.ENABLE_LOGIN_FORM", @@ -1217,6 +1243,11 @@ def feishu_oauth_register(client: OAuth): == "true" ) +USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING = ( + os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING", "False").lower() + == "true" +) + USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING = ( os.environ.get( "USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING", "False" @@ -1354,6 +1385,7 @@ def feishu_oauth_register(client: OAuth): "public_knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING, "public_prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING, "public_tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING, + "public_notes": USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING, }, "chat": { "controls": USER_PERMISSIONS_CHAT_CONTROLS, @@ -1999,16 +2031,23 @@ class BannerModel(BaseModel): # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) # Milvus - MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") MILVUS_DB = os.environ.get("MILVUS_DB", "default") MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None) - MILVUS_INDEX_TYPE = os.environ.get("MILVUS_INDEX_TYPE", "HNSW") MILVUS_METRIC_TYPE = os.environ.get("MILVUS_METRIC_TYPE", "COSINE") MILVUS_HNSW_M = int(os.environ.get("MILVUS_HNSW_M", "16")) MILVUS_HNSW_EFCONSTRUCTION = int(os.environ.get("MILVUS_HNSW_EFCONSTRUCTION", "100")) MILVUS_IVF_FLAT_NLIST = int(os.environ.get("MILVUS_IVF_FLAT_NLIST", "128")) +MILVUS_DISKANN_MAX_DEGREE = int(os.environ.get("MILVUS_DISKANN_MAX_DEGREE", "56")) +MILVUS_DISKANN_SEARCH_LIST_SIZE = int( + os.environ.get("MILVUS_DISKANN_SEARCH_LIST_SIZE", "100") +) +ENABLE_MILVUS_MULTITENANCY_MODE = ( + os.environ.get("ENABLE_MILVUS_MULTITENANCY_MODE", "true").lower() == "true" +) +# Hyphens not allowed, need to use underscores in collection names +MILVUS_COLLECTION_PREFIX = os.environ.get("MILVUS_COLLECTION_PREFIX", "open_webui") # Qdrant QDRANT_URI = os.environ.get("QDRANT_URI", None) diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index d102263cb34..316efe18e7f 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -86,6 +86,10 @@ async def get_function_models(request): try: function_module = get_function_module_by_id(request, pipe.id) + has_user_valves = False + if hasattr(function_module, "UserValves"): + has_user_valves = True + # Check if function is a manifold if hasattr(function_module, "pipes"): sub_pipes = [] @@ -124,6 +128,7 @@ async def get_function_models(request): "created": pipe.created_at, "owned_by": "openai", "pipe": pipe_flag, + "has_user_valves": has_user_valves, } ) else: @@ -141,6 +146,7 @@ async def get_function_models(request): "created": pipe.created_at, "owned_by": "openai", "pipe": pipe_flag, + "has_user_valves": has_user_valves, } ) except Exception as e: diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index f38bd471097..53ecc09de97 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -348,6 +348,13 @@ PENDING_USER_OVERLAY_TITLE, DEFAULT_PROMPT_SUGGESTIONS, DEFAULT_MODELS, + # PATCH EXTRA LOGIN INFO + SYSTEM_REGISTER_URL, + SYSTEM_REGISTER_GUIDE_URL, + # /PATCH EXTRA LOGIN INFO + # PATCH ADD LOGO TO SIDEBAR + LOGO_URL, + # /PATCH ADD LOGO TO SIDEBAR DEFAULT_ARENA_MODEL, MODEL_ORDER_LIST, EVALUATION_ARENA_MODELS, @@ -728,6 +735,18 @@ async def lifespan(app: FastAPI): app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE + +# PATCH EXTRA LOGIN INFO +app.state.config.SYSTEM_REGISTER_URL = SYSTEM_REGISTER_URL +app.state.config.SYSTEM_REGISTER_GUIDE_URL = SYSTEM_REGISTER_GUIDE_URL +# /PATCH EXTRA LOGIN INFO + + +# PATCH ADD LOGO TO SIDEBAR +app.state.config.LOGO_URL = LOGO_URL +# /PATCH ADD LOGO TO SIDEBAR + + app.state.config.PENDING_USER_OVERLAY_CONTENT = PENDING_USER_OVERLAY_CONTENT app.state.config.PENDING_USER_OVERLAY_TITLE = PENDING_USER_OVERLAY_TITLE @@ -1552,7 +1571,7 @@ async def process_chat(request, form_data, user, metadata, model): finally: try: if mcp_clients := metadata.get("mcp_clients"): - for client in mcp_clients: + for client in mcp_clients.values(): await client.disconnect() except Exception as e: log.debug(f"Error cleaning up: {e}") @@ -1698,6 +1717,16 @@ async def get_app_config(request: Request): for name, config in OAUTH_PROVIDERS.items() } }, + # Environment variables for patches + "extended_features": { + # PATCH EXTRA LOGIN INFO + "system_register_url": app.state.config.SYSTEM_REGISTER_URL, + "system_register_guide_url": app.state.config.SYSTEM_REGISTER_GUIDE_URL, + # /PATCH EXTRA LOGIN INFO + # PATCH ADD LOGO TO SIDEBAR + "logo_url": app.state.config.LOGO_URL, + # /PATCH ADD LOGO TO SIDEBAR + }, "features": { "auth": WEBUI_AUTH, "auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER), @@ -1908,37 +1937,32 @@ async def get_current_usage(user=Depends(get_verified_user)): f"mcp:{server_id}", OAuthClientInformationFull(**oauth_client_info) ) +try: + if REDIS_URL: + redis_session_store = RedisStore( + url=REDIS_URL, + prefix=(f"{REDIS_KEY_PREFIX}:session:" if REDIS_KEY_PREFIX else "session:"), + ) -# SessionMiddleware is used by authlib for oauth -if len(OAUTH_PROVIDERS) > 0: - try: - if REDIS_URL: - redis_session_store = RedisStore( - url=REDIS_URL, - prefix=( - f"{REDIS_KEY_PREFIX}:session:" if REDIS_KEY_PREFIX else "session:" - ), - ) - - app.add_middleware(SessionAutoloadMiddleware) - app.add_middleware( - StarSessionsMiddleware, - store=redis_session_store, - cookie_name="oui-session", - cookie_same_site=WEBUI_SESSION_COOKIE_SAME_SITE, - cookie_https_only=WEBUI_SESSION_COOKIE_SECURE, - ) - log.info("Using Redis for session") - else: - raise ValueError("No Redis URL provided") - except Exception as e: + app.add_middleware(SessionAutoloadMiddleware) app.add_middleware( - SessionMiddleware, - secret_key=WEBUI_SECRET_KEY, - session_cookie="oui-session", - same_site=WEBUI_SESSION_COOKIE_SAME_SITE, - https_only=WEBUI_SESSION_COOKIE_SECURE, + StarSessionsMiddleware, + store=redis_session_store, + cookie_name="owui-session", + cookie_same_site=WEBUI_SESSION_COOKIE_SAME_SITE, + cookie_https_only=WEBUI_SESSION_COOKIE_SECURE, ) + log.info("Using Redis for session") + else: + raise ValueError("No Redis URL provided") +except Exception as e: + app.add_middleware( + SessionMiddleware, + secret_key=WEBUI_SECRET_KEY, + session_cookie="owui-session", + same_site=WEBUI_SESSION_COOKIE_SAME_SITE, + https_only=WEBUI_SESSION_COOKIE_SECURE, + ) @app.get("/oauth/clients/{client_id}/authorize") diff --git a/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py b/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py new file mode 100644 index 00000000000..dd2b7d1a680 --- /dev/null +++ b/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py @@ -0,0 +1,34 @@ +"""Add reply_to_id column to message + +Revision ID: a5c220713937 +Revises: 38d63c18f30f +Create Date: 2025-09-27 02:24:18.058455 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "a5c220713937" +down_revision: Union[str, None] = "38d63c18f30f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add 'reply_to_id' column to the 'message' table for replying to messages + op.add_column( + "message", + sa.Column("reply_to_id", sa.Text(), nullable=True), + ) + pass + + +def downgrade() -> None: + # Remove 'reply_to_id' column from the 'message' table + op.drop_column("message", "reply_to_id") + + pass diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 97fd9b6256d..98b1166ce47 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -366,6 +366,15 @@ def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: except Exception: return False + def unarchive_all_chats_by_user_id(self, user_id: str) -> bool: + try: + with get_db() as db: + db.query(Chat).filter_by(user_id=user_id).update({"archived": False}) + db.commit() + return True + except Exception: + return False + def update_chat_share_id_by_id( self, id: str, share_id: Optional[str] ) -> Optional[ChatModel]: @@ -810,7 +819,7 @@ def get_chats_by_user_id_and_search_text( return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_folder_id_and_user_id( - self, folder_id: str, user_id: str + self, folder_id: str, user_id: str, skip: int = 0, limit: int = 60 ) -> list[ChatModel]: with get_db() as db: query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id) @@ -819,6 +828,11 @@ def get_chats_by_folder_id_and_user_id( query = query.order_by(Chat.updated_at.desc()) + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + all_chats = query.all() return [ChatModel.model_validate(chat) for chat in all_chats] diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py index c8766457507..45f82470809 100644 --- a/backend/open_webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -50,6 +50,20 @@ class FolderModel(BaseModel): model_config = ConfigDict(from_attributes=True) +class FolderMetadataResponse(BaseModel): + icon: Optional[str] = None + + +class FolderNameIdResponse(BaseModel): + id: str + name: str + meta: Optional[FolderMetadataResponse] = None + parent_id: Optional[str] = None + is_expanded: bool = False + created_at: int + updated_at: int + + #################### # Forms #################### diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index ff4553ee9dd..8b0027b8e78 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -5,6 +5,7 @@ from open_webui.internal.db import Base, get_db from open_webui.models.tags import TagModel, Tag, Tags +from open_webui.models.users import Users, UserNameResponse from pydantic import BaseModel, ConfigDict @@ -43,6 +44,7 @@ class Message(Base): user_id = Column(Text) channel_id = Column(Text, nullable=True) + reply_to_id = Column(Text, nullable=True) parent_id = Column(Text, nullable=True) content = Column(Text) @@ -60,6 +62,7 @@ class MessageModel(BaseModel): user_id: str channel_id: Optional[str] = None + reply_to_id: Optional[str] = None parent_id: Optional[str] = None content: str @@ -77,6 +80,7 @@ class MessageModel(BaseModel): class MessageForm(BaseModel): content: str + reply_to_id: Optional[str] = None parent_id: Optional[str] = None data: Optional[dict] = None meta: Optional[dict] = None @@ -88,7 +92,15 @@ class Reactions(BaseModel): count: int -class MessageResponse(MessageModel): +class MessageUserResponse(MessageModel): + user: Optional[UserNameResponse] = None + + +class MessageReplyToResponse(MessageUserResponse): + reply_to_message: Optional[MessageUserResponse] = None + + +class MessageResponse(MessageReplyToResponse): latest_reply_at: Optional[int] reply_count: int reactions: list[Reactions] @@ -107,6 +119,7 @@ def insert_new_message( "id": id, "user_id": user_id, "channel_id": channel_id, + "reply_to_id": form_data.reply_to_id, "parent_id": form_data.parent_id, "content": form_data.content, "data": form_data.data, @@ -128,19 +141,32 @@ def get_message_by_id(self, id: str) -> Optional[MessageResponse]: if not message: return None + reply_to_message = ( + self.get_message_by_id(message.reply_to_id) + if message.reply_to_id + else None + ) + reactions = self.get_reactions_by_message_id(id) - replies = self.get_replies_by_message_id(id) + thread_replies = self.get_thread_replies_by_message_id(id) - return MessageResponse( - **{ + user = Users.get_user_by_id(message.user_id) + return MessageResponse.model_validate( + { **MessageModel.model_validate(message).model_dump(), - "latest_reply_at": replies[0].created_at if replies else None, - "reply_count": len(replies), + "user": user.model_dump() if user else None, + "reply_to_message": ( + reply_to_message.model_dump() if reply_to_message else None + ), + "latest_reply_at": ( + thread_replies[0].created_at if thread_replies else None + ), + "reply_count": len(thread_replies), "reactions": reactions, } ) - def get_replies_by_message_id(self, id: str) -> list[MessageModel]: + def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]: with get_db() as db: all_messages = ( db.query(Message) @@ -148,7 +174,27 @@ def get_replies_by_message_id(self, id: str) -> list[MessageModel]: .order_by(Message.created_at.desc()) .all() ) - return [MessageModel.model_validate(message) for message in all_messages] + + messages = [] + for message in all_messages: + reply_to_message = ( + self.get_message_by_id(message.reply_to_id) + if message.reply_to_id + else None + ) + messages.append( + MessageReplyToResponse.model_validate( + { + **MessageModel.model_validate(message).model_dump(), + "reply_to_message": ( + reply_to_message.model_dump() + if reply_to_message + else None + ), + } + ) + ) + return messages def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: with get_db() as db: @@ -159,7 +205,7 @@ def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: def get_messages_by_channel_id( self, channel_id: str, skip: int = 0, limit: int = 50 - ) -> list[MessageModel]: + ) -> list[MessageReplyToResponse]: with get_db() as db: all_messages = ( db.query(Message) @@ -169,11 +215,31 @@ def get_messages_by_channel_id( .limit(limit) .all() ) - return [MessageModel.model_validate(message) for message in all_messages] + + messages = [] + for message in all_messages: + reply_to_message = ( + self.get_message_by_id(message.reply_to_id) + if message.reply_to_id + else None + ) + messages.append( + MessageReplyToResponse.model_validate( + { + **MessageModel.model_validate(message).model_dump(), + "reply_to_message": ( + reply_to_message.model_dump() + if reply_to_message + else None + ), + } + ) + ) + return messages def get_messages_by_parent_id( self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 - ) -> list[MessageModel]: + ) -> list[MessageReplyToResponse]: with get_db() as db: message = db.get(Message, parent_id) @@ -193,7 +259,26 @@ def get_messages_by_parent_id( if len(all_messages) < limit: all_messages.append(message) - return [MessageModel.model_validate(message) for message in all_messages] + messages = [] + for message in all_messages: + reply_to_message = ( + self.get_message_by_id(message.reply_to_id) + if message.reply_to_id + else None + ) + messages.append( + MessageReplyToResponse.model_validate( + { + **MessageModel.model_validate(message).model_dump(), + "reply_to_message": ( + reply_to_message.model_dump() + if reply_to_message + else None + ), + } + ) + ) + return messages def update_message_by_id( self, id: str, form_data: MessageForm diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py index 9675e141e7b..1fdb064c51f 100755 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -11,7 +11,7 @@ SearchResult, GetResult, ) -from open_webui.retrieval.vector.utils import stringify_metadata +from open_webui.retrieval.vector.utils import process_metadata from open_webui.config import ( CHROMA_DATA_PATH, @@ -146,7 +146,7 @@ def insert(self, collection_name: str, items: list[VectorItem]): ids = [item["id"] for item in items] documents = [item["text"] for item in items] embeddings = [item["vector"] for item in items] - metadatas = [stringify_metadata(item["metadata"]) for item in items] + metadatas = [process_metadata(item["metadata"]) for item in items] for batch in create_batches( api=self.client, @@ -166,7 +166,7 @@ def upsert(self, collection_name: str, items: list[VectorItem]): ids = [item["id"] for item in items] documents = [item["text"] for item in items] embeddings = [item["vector"] for item in items] - metadatas = [stringify_metadata(item["metadata"]) for item in items] + metadatas = [process_metadata(item["metadata"]) for item in items] collection.upsert( ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas diff --git a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py index 727d831cff3..6de0d859f8a 100644 --- a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py +++ b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py @@ -3,7 +3,7 @@ import ssl from elasticsearch.helpers import bulk, scan -from open_webui.retrieval.vector.utils import stringify_metadata +from open_webui.retrieval.vector.utils import process_metadata from open_webui.retrieval.vector.main import ( VectorDBBase, VectorItem, @@ -245,7 +245,7 @@ def insert(self, collection_name: str, items: list[VectorItem]): "collection": collection_name, "vector": item["vector"], "text": item["text"], - "metadata": stringify_metadata(item["metadata"]), + "metadata": process_metadata(item["metadata"]), }, } for item in batch @@ -266,7 +266,7 @@ def upsert(self, collection_name: str, items: list[VectorItem]): "collection": collection_name, "vector": item["vector"], "text": item["text"], - "metadata": stringify_metadata(item["metadata"]), + "metadata": process_metadata(item["metadata"]), }, "doc_as_upsert": True, } diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 059ea43cc0c..98f8e335f21 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -6,7 +6,7 @@ import logging from typing import Optional -from open_webui.retrieval.vector.utils import stringify_metadata +from open_webui.retrieval.vector.utils import process_metadata from open_webui.retrieval.vector.main import ( VectorDBBase, VectorItem, @@ -22,6 +22,8 @@ MILVUS_HNSW_M, MILVUS_HNSW_EFCONSTRUCTION, MILVUS_IVF_FLAT_NLIST, + MILVUS_DISKANN_MAX_DEGREE, + MILVUS_DISKANN_SEARCH_LIST_SIZE, ) from open_webui.env import SRC_LOG_LEVELS @@ -131,12 +133,18 @@ def _create_collection(self, collection_name: str, dimension: int): elif index_type == "IVF_FLAT": index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST} log.info(f"IVF_FLAT params: {index_creation_params}") + elif index_type == "DISKANN": + index_creation_params = { + "max_degree": MILVUS_DISKANN_MAX_DEGREE, + "search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE, + } + log.info(f"DISKANN params: {index_creation_params}") elif index_type in ["FLAT", "AUTOINDEX"]: log.info(f"Using {index_type} index with no specific build-time params.") else: log.warning( f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. " - f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. " + f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. " f"Milvus will use its default for the collection if this type is not directly supported for index creation." ) # For unsupported types, pass the type directly to Milvus; it might handle it or use a default. @@ -189,7 +197,7 @@ def search( ) return self._result_to_search_result(result) - def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): + def query(self, collection_name: str, filter: dict, limit: int = -1): connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB) # Construct the filter string for querying @@ -222,7 +230,7 @@ def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) "data", "metadata", ], - limit=limit, # Pass the limit directly; None means no limit. + limit=limit, # Pass the limit directly; -1 means no limit. ) while True: @@ -249,7 +257,7 @@ def get(self, collection_name: str) -> Optional[GetResult]: ) # Using query with a trivial filter to get all items. # This will use the paginated query logic. - return self.query(collection_name=collection_name, filter={}, limit=None) + return self.query(collection_name=collection_name, filter={}, limit=-1) def insert(self, collection_name: str, items: list[VectorItem]): # Insert the items into the collection, if the collection does not exist, it will be created. @@ -281,7 +289,7 @@ def insert(self, collection_name: str, items: list[VectorItem]): "id": item["id"], "vector": item["vector"], "data": {"text": item["text"]}, - "metadata": stringify_metadata(item["metadata"]), + "metadata": process_metadata(item["metadata"]), } for item in items ], @@ -317,7 +325,7 @@ def upsert(self, collection_name: str, items: list[VectorItem]): "id": item["id"], "vector": item["vector"], "data": {"text": item["text"]}, - "metadata": stringify_metadata(item["metadata"]), + "metadata": process_metadata(item["metadata"]), } for item in items ], diff --git a/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py new file mode 100644 index 00000000000..5c80d155d35 --- /dev/null +++ b/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py @@ -0,0 +1,282 @@ +import logging +from typing import Optional, Tuple, List, Dict, Any + +from open_webui.config import ( + MILVUS_URI, + MILVUS_TOKEN, + MILVUS_DB, + MILVUS_COLLECTION_PREFIX, + MILVUS_INDEX_TYPE, + MILVUS_METRIC_TYPE, + MILVUS_HNSW_M, + MILVUS_HNSW_EFCONSTRUCTION, + MILVUS_IVF_FLAT_NLIST, +) +from open_webui.env import SRC_LOG_LEVELS +from open_webui.retrieval.vector.main import ( + GetResult, + SearchResult, + VectorDBBase, + VectorItem, +) +from pymilvus import ( + connections, + utility, + Collection, + CollectionSchema, + FieldSchema, + DataType, +) + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + +RESOURCE_ID_FIELD = "resource_id" + + +class MilvusClient(VectorDBBase): + def __init__(self): + # Milvus collection names can only contain numbers, letters, and underscores. + self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_") + connections.connect( + alias="default", + uri=MILVUS_URI, + token=MILVUS_TOKEN, + db_name=MILVUS_DB, + ) + + # Main collection types for multi-tenancy + self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" + self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge" + self.FILE_COLLECTION = f"{self.collection_prefix}_files" + self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search" + self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based" + self.shared_collections = [ + self.MEMORY_COLLECTION, + self.KNOWLEDGE_COLLECTION, + self.FILE_COLLECTION, + self.WEB_SEARCH_COLLECTION, + self.HASH_BASED_COLLECTION, + ] + + def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]: + """ + Maps the traditional collection name to multi-tenant collection and resource ID. + + WARNING: This mapping relies on current Open WebUI naming conventions for + collection names. If Open WebUI changes how it generates collection names + (e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash + formats), this mapping will break and route data to incorrect collections. + POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT + DATA MAPPING INSIDE THE DATABASE. + """ + resource_id = collection_name + + if collection_name.startswith("user-memory-"): + return self.MEMORY_COLLECTION, resource_id + elif collection_name.startswith("file-"): + return self.FILE_COLLECTION, resource_id + elif collection_name.startswith("web-search-"): + return self.WEB_SEARCH_COLLECTION, resource_id + elif len(collection_name) == 63 and all( + c in "0123456789abcdef" for c in collection_name + ): + return self.HASH_BASED_COLLECTION, resource_id + else: + return self.KNOWLEDGE_COLLECTION, resource_id + + def _create_shared_collection(self, mt_collection_name: str, dimension: int): + fields = [ + FieldSchema( + name="id", + dtype=DataType.VARCHAR, + is_primary=True, + auto_id=False, + max_length=36, + ), + FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension), + FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255), + ] + schema = CollectionSchema(fields, "Shared collection for multi-tenancy") + collection = Collection(mt_collection_name, schema) + + index_params = { + "metric_type": MILVUS_METRIC_TYPE, + "index_type": MILVUS_INDEX_TYPE, + "params": {}, + } + if MILVUS_INDEX_TYPE == "HNSW": + index_params["params"] = { + "M": MILVUS_HNSW_M, + "efConstruction": MILVUS_HNSW_EFCONSTRUCTION, + } + elif MILVUS_INDEX_TYPE == "IVF_FLAT": + index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST} + + collection.create_index("vector", index_params) + collection.create_index(RESOURCE_ID_FIELD) + log.info(f"Created shared collection: {mt_collection_name}") + return collection + + def _ensure_collection(self, mt_collection_name: str, dimension: int): + if not utility.has_collection(mt_collection_name): + self._create_shared_collection(mt_collection_name, dimension) + + def has_collection(self, collection_name: str) -> bool: + mt_collection, resource_id = self._get_collection_and_resource_id( + collection_name + ) + if not utility.has_collection(mt_collection): + return False + + collection = Collection(mt_collection) + collection.load() + res = collection.query(expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1) + return len(res) > 0 + + def upsert(self, collection_name: str, items: List[VectorItem]): + if not items: + return + mt_collection, resource_id = self._get_collection_and_resource_id( + collection_name + ) + dimension = len(items[0]["vector"]) + self._ensure_collection(mt_collection, dimension) + collection = Collection(mt_collection) + + entities = [ + { + "id": item["id"], + "vector": item["vector"], + "text": item["text"], + "metadata": item["metadata"], + RESOURCE_ID_FIELD: resource_id, + } + for item in items + ] + collection.insert(entities) + collection.flush() + + def search( + self, collection_name: str, vectors: List[List[float]], limit: int + ) -> Optional[SearchResult]: + if not vectors: + return None + + mt_collection, resource_id = self._get_collection_and_resource_id( + collection_name + ) + if not utility.has_collection(mt_collection): + return None + + collection = Collection(mt_collection) + collection.load() + + search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}} + results = collection.search( + data=vectors, + anns_field="vector", + param=search_params, + limit=limit, + expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", + output_fields=["id", "text", "metadata"], + ) + + ids, documents, metadatas, distances = [], [], [], [] + for hits in results: + batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], [] + for hit in hits: + batch_ids.append(hit.entity.get("id")) + batch_docs.append(hit.entity.get("text")) + batch_metadatas.append(hit.entity.get("metadata")) + batch_dists.append(hit.distance) + ids.append(batch_ids) + documents.append(batch_docs) + metadatas.append(batch_metadatas) + distances.append(batch_dists) + + return SearchResult( + ids=ids, documents=documents, metadatas=metadatas, distances=distances + ) + + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict[str, Any]] = None, + ): + mt_collection, resource_id = self._get_collection_and_resource_id( + collection_name + ) + if not utility.has_collection(mt_collection): + return + + collection = Collection(mt_collection) + + # Build expression + expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"] + if ids: + # Milvus expects a string list for 'in' operator + id_list_str = ", ".join([f"'{id_val}'" for id_val in ids]) + expr.append(f"id in [{id_list_str}]") + + if filter: + for key, value in filter.items(): + expr.append(f"metadata['{key}'] == '{value}'") + + collection.delete(" and ".join(expr)) + + def reset(self): + for collection_name in self.shared_collections: + if utility.has_collection(collection_name): + utility.drop_collection(collection_name) + + def delete_collection(self, collection_name: str): + mt_collection, resource_id = self._get_collection_and_resource_id( + collection_name + ) + if not utility.has_collection(mt_collection): + return + + collection = Collection(mt_collection) + collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'") + + def query( + self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None + ) -> Optional[GetResult]: + mt_collection, resource_id = self._get_collection_and_resource_id( + collection_name + ) + if not utility.has_collection(mt_collection): + return None + + collection = Collection(mt_collection) + collection.load() + + expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"] + if filter: + for key, value in filter.items(): + if isinstance(value, str): + expr.append(f"metadata['{key}'] == '{value}'") + else: + expr.append(f"metadata['{key}'] == {value}") + + results = collection.query( + expr=" and ".join(expr), + output_fields=["id", "text", "metadata"], + limit=limit, + ) + + ids = [res["id"] for res in results] + documents = [res["text"] for res in results] + metadatas = [res["metadata"] for res in results] + + return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) + + def get(self, collection_name: str) -> Optional[GetResult]: + return self.query(collection_name, filter={}, limit=None) + + def insert(self, collection_name: str, items: List[VectorItem]): + return self.upsert(collection_name, items) diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 510070f97a7..2e946710e24 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -2,7 +2,7 @@ from opensearchpy.helpers import bulk from typing import Optional -from open_webui.retrieval.vector.utils import stringify_metadata +from open_webui.retrieval.vector.utils import process_metadata from open_webui.retrieval.vector.main import ( VectorDBBase, VectorItem, @@ -201,7 +201,7 @@ def insert(self, collection_name: str, items: list[VectorItem]): "_source": { "vector": item["vector"], "text": item["text"], - "metadata": stringify_metadata(item["metadata"]), + "metadata": process_metadata(item["metadata"]), }, } for item in batch @@ -223,7 +223,7 @@ def upsert(self, collection_name: str, items: list[VectorItem]): "doc": { "vector": item["vector"], "text": item["text"], - "metadata": stringify_metadata(item["metadata"]), + "metadata": process_metadata(item["metadata"]), }, "doc_as_upsert": True, } diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index 06c1698cdd9..312b48944c9 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -27,7 +27,7 @@ from sqlalchemy.exc import NoSuchTableError -from open_webui.retrieval.vector.utils import stringify_metadata +from open_webui.retrieval.vector.utils import process_metadata from open_webui.retrieval.vector.main import ( VectorDBBase, VectorItem, @@ -265,7 +265,7 @@ def insert(self, collection_name: str, items: List[VectorItem]) -> None: vector=vector, collection_name=collection_name, text=item["text"], - vmetadata=stringify_metadata(item["metadata"]), + vmetadata=process_metadata(item["metadata"]), ) new_items.append(new_chunk) self.session.bulk_save_objects(new_items) @@ -323,7 +323,7 @@ def upsert(self, collection_name: str, items: List[VectorItem]) -> None: if existing: existing.vector = vector existing.text = item["text"] - existing.vmetadata = stringify_metadata(item["metadata"]) + existing.vmetadata = process_metadata(item["metadata"]) existing.collection_name = ( collection_name # Update collection_name if necessary ) @@ -333,7 +333,7 @@ def upsert(self, collection_name: str, items: List[VectorItem]) -> None: vector=vector, collection_name=collection_name, text=item["text"], - vmetadata=stringify_metadata(item["metadata"]), + vmetadata=process_metadata(item["metadata"]), ) self.session.add(new_chunk) self.session.commit() diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py index 466b5a6e24f..5bef0d9ea7d 100644 --- a/backend/open_webui/retrieval/vector/dbs/pinecone.py +++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py @@ -32,7 +32,7 @@ PINECONE_CLOUD, ) from open_webui.env import SRC_LOG_LEVELS -from open_webui.retrieval.vector.utils import stringify_metadata +from open_webui.retrieval.vector.utils import process_metadata NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system @@ -185,7 +185,7 @@ def _create_points( point = { "id": item["id"], "values": item["vector"], - "metadata": stringify_metadata(metadata), + "metadata": process_metadata(metadata), } points.append(point) return points diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py index ed4a8bab348..e9fa03d4591 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py @@ -105,6 +105,13 @@ def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str] Returns: tuple: (collection_name, tenant_id) + + WARNING: This mapping relies on current Open WebUI naming conventions for + collection names. If Open WebUI changes how it generates collection names + (e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash + formats), this mapping will break and route data to incorrect collections. + POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT + DATA MAPPING INSIDE THE DATABASE. """ # Check for user memory collections tenant_id = collection_name diff --git a/backend/open_webui/retrieval/vector/dbs/s3vector.py b/backend/open_webui/retrieval/vector/dbs/s3vector.py index 2ac6911769a..519ee5abad3 100644 --- a/backend/open_webui/retrieval/vector/dbs/s3vector.py +++ b/backend/open_webui/retrieval/vector/dbs/s3vector.py @@ -1,4 +1,4 @@ -from open_webui.retrieval.vector.utils import stringify_metadata +from open_webui.retrieval.vector.utils import process_metadata from open_webui.retrieval.vector.main import ( VectorDBBase, VectorItem, @@ -185,7 +185,7 @@ def insert(self, collection_name: str, items: List[VectorItem]) -> None: metadata["text"] = item["text"] # Convert metadata to string format for consistency - metadata = stringify_metadata(metadata) + metadata = process_metadata(metadata) # Filter metadata to comply with S3 Vector API limit of 10 keys metadata = self._filter_metadata(metadata, item["id"]) @@ -256,7 +256,7 @@ def upsert(self, collection_name: str, items: List[VectorItem]) -> None: metadata["text"] = item["text"] # Convert metadata to string format for consistency - metadata = stringify_metadata(metadata) + metadata = process_metadata(metadata) # Filter metadata to comply with S3 Vector API limit of 10 keys metadata = self._filter_metadata(metadata, item["id"]) diff --git a/backend/open_webui/retrieval/vector/factory.py b/backend/open_webui/retrieval/vector/factory.py index 36cb85c948c..7888c22be88 100644 --- a/backend/open_webui/retrieval/vector/factory.py +++ b/backend/open_webui/retrieval/vector/factory.py @@ -1,6 +1,10 @@ from open_webui.retrieval.vector.main import VectorDBBase from open_webui.retrieval.vector.type import VectorType -from open_webui.config import VECTOR_DB, ENABLE_QDRANT_MULTITENANCY_MODE +from open_webui.config import ( + VECTOR_DB, + ENABLE_QDRANT_MULTITENANCY_MODE, + ENABLE_MILVUS_MULTITENANCY_MODE, +) class Vector: @@ -12,9 +16,16 @@ def get_vector(vector_type: str) -> VectorDBBase: """ match vector_type: case VectorType.MILVUS: - from open_webui.retrieval.vector.dbs.milvus import MilvusClient + if ENABLE_MILVUS_MULTITENANCY_MODE: + from open_webui.retrieval.vector.dbs.milvus_multitenancy import ( + MilvusClient, + ) + + return MilvusClient() + else: + from open_webui.retrieval.vector.dbs.milvus import MilvusClient - return MilvusClient() + return MilvusClient() case VectorType.QDRANT: if ENABLE_QDRANT_MULTITENANCY_MODE: from open_webui.retrieval.vector.dbs.qdrant_multitenancy import ( diff --git a/backend/open_webui/retrieval/vector/utils.py b/backend/open_webui/retrieval/vector/utils.py index 1d9698c6b1e..a597390b920 100644 --- a/backend/open_webui/retrieval/vector/utils.py +++ b/backend/open_webui/retrieval/vector/utils.py @@ -1,10 +1,24 @@ from datetime import datetime +KEYS_TO_EXCLUDE = ["content", "pages", "tables", "paragraphs", "sections", "figures"] -def stringify_metadata( + +def filter_metadata(metadata: dict[str, any]) -> dict[str, any]: + metadata = { + key: value for key, value in metadata.items() if key not in KEYS_TO_EXCLUDE + } + return metadata + + +def process_metadata( metadata: dict[str, any], ) -> dict[str, any]: for key, value in metadata.items(): + # Remove large fields + if key in KEYS_TO_EXCLUDE: + del metadata[key] + + # Convert non-serializable fields to strings if ( isinstance(value, datetime) or isinstance(value, list) diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index e7b83663476..77c3d9ba535 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -167,7 +167,7 @@ async def delete_channel_by_id(id: str, user=Depends(get_admin_user)): class MessageUserResponse(MessageResponse): - user: UserNameResponse + pass @router.get("/{id}/messages", response_model=list[MessageUserResponse]) @@ -196,15 +196,17 @@ async def get_channel_messages( user = Users.get_user_by_id(message.user_id) users[message.user_id] = user - replies = Messages.get_replies_by_message_id(message.id) - latest_reply_at = replies[0].created_at if replies else None + thread_replies = Messages.get_thread_replies_by_message_id(message.id) + latest_thread_reply_at = ( + thread_replies[0].created_at if thread_replies else None + ) messages.append( MessageUserResponse( **{ **message.model_dump(), - "reply_count": len(replies), - "latest_reply_at": latest_reply_at, + "reply_count": len(thread_replies), + "latest_reply_at": latest_thread_reply_at, "reactions": Messages.get_reactions_by_message_id(message.id), "user": UserNameResponse(**users[message.user_id].model_dump()), } @@ -253,12 +255,26 @@ async def model_response_handler(request, channel, message, user): mentions = extract_mentions(message.content) message_content = replace_mentions(message.content) + model_mentions = {} + + # check if the message is a reply to a message sent by a model + if ( + message.reply_to_message + and message.reply_to_message.meta + and message.reply_to_message.meta.get("model_id", None) + ): + model_id = message.reply_to_message.meta.get("model_id", None) + model_mentions[model_id] = {"id": model_id, "id_type": "M"} + # check if any of the mentions are models - model_mentions = [mention for mention in mentions if mention["id_type"] == "M"] + for mention in mentions: + if mention["id_type"] == "M" and mention["id"] not in model_mentions: + model_mentions[mention["id"]] = mention + if not model_mentions: return False - for mention in model_mentions: + for mention in model_mentions.values(): model_id = mention["id"] model = MODELS.get(model_id, None) @@ -326,9 +342,9 @@ async def model_response_handler(request, channel, message, user): system_message = { "role": "system", - "content": f"You are {model.get('name', model_id)}, an AI assistant participating in a threaded conversation. Be helpful, concise, and conversational." + "content": f"You are {model.get('name', model_id)}, participating in a threaded conversation. Be concise and conversational." + ( - f"Here's the thread history:\n\n{''.join([f'{msg}' for msg in thread_history])}\n\nContinue the conversation naturally, addressing the most recent message while being aware of the full context." + f"Here's the thread history:\n\n{''.join([f'{msg}' for msg in thread_history])}\n\nContinue the conversation naturally as {model.get('name', model_id)}, addressing the most recent message while being aware of the full context." if thread_history else "" ), @@ -406,24 +422,14 @@ async def new_message_handler( try: message = Messages.insert_new_message(form_data, channel.id, user.id) - if message: + message = Messages.get_message_by_id(message.id) event_data = { "channel_id": channel.id, "message_id": message.id, "data": { "type": "message", - "data": MessageUserResponse( - **{ - **message.model_dump(), - "reply_count": 0, - "latest_reply_at": None, - "reactions": Messages.get_reactions_by_message_id( - message.id - ), - "user": UserNameResponse(**user.model_dump()), - } - ).model_dump(), + "data": message.model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), @@ -447,23 +453,16 @@ async def new_message_handler( "message_id": parent_message.id, "data": { "type": "message:reply", - "data": MessageUserResponse( - **{ - **parent_message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id( - parent_message.user_id - ).model_dump() - ), - } - ).model_dump(), + "data": parent_message.model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), }, to=f"channel:{channel.id}", ) - return MessageModel(**message.model_dump()), channel + return message, channel + else: + raise Exception("Error creating message") except Exception as e: log.exception(e) raise HTTPException( @@ -651,14 +650,7 @@ async def update_message_by_id( "message_id": message.id, "data": { "type": "message:update", - "data": MessageUserResponse( - **{ - **message.model_dump(), - "user": UserNameResponse( - **user.model_dump() - ).model_dump(), - } - ).model_dump(), + "data": message.model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), @@ -724,9 +716,6 @@ async def add_reaction_to_message( "type": "message:reaction:add", "data": { **message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() - ).model_dump(), "name": form_data.name, }, }, @@ -793,9 +782,6 @@ async def remove_reaction_by_id_and_user_id_and_name( "type": "message:reaction:remove", "data": { **message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() - ).model_dump(), "name": form_data.name, }, }, @@ -882,16 +868,7 @@ async def delete_message_by_id( "message_id": parent_message.id, "data": { "type": "message:reply", - "data": MessageUserResponse( - **{ - **parent_message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id( - parent_message.user_id - ).model_dump() - ), - } - ).model_dump(), + "data": parent_message.model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 788e355f2b8..1f065988fe1 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -218,6 +218,28 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user) ] +@router.get("/folder/{folder_id}/list") +async def get_chat_list_by_folder_id( + folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user) +): + try: + limit = 60 + skip = (page - 1) * limit + + return [ + {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at} + for chat in Chats.get_chats_by_folder_id_and_user_id( + folder_id, user.id, skip=skip, limit=limit + ) + ] + + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # GetPinnedChats ############################ @@ -339,6 +361,16 @@ async def archive_all_chats(user=Depends(get_verified_user)): return Chats.archive_all_chats_by_user_id(user.id) +############################ +# UnarchiveAllChats +############################ + + +@router.post("/unarchive/all", response_model=bool) +async def unarchive_all_chats(user=Depends(get_verified_user)): + return Chats.unarchive_all_chats_by_user_id(user.id) + + ############################ # GetSharedChatById ############################ diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index d4b88032e2b..f19fbeedd00 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -207,38 +207,39 @@ async def verify_tool_servers_config( if form_data.type == "mcp": if form_data.auth_type == "oauth_2.1": discovery_urls = get_discovery_urls(form_data.url) - async with aiohttp.ClientSession() as session: - async with session.get( - discovery_urls[0] - ) as oauth_server_metadata_response: - if oauth_server_metadata_response.status != 200: - raise HTTPException( - status_code=400, - detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls[0]}", - ) - - try: - oauth_server_metadata = OAuthMetadata.model_validate( - await oauth_server_metadata_response.json() - ) - return { - "status": True, - "oauth_server_metadata": oauth_server_metadata.model_dump( - mode="json" - ), - } - except Exception as e: - log.info( - f"Failed to parse OAuth 2.1 discovery document: {e}" - ) - raise HTTPException( - status_code=400, - detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_urls[0]}", - ) + for discovery_url in discovery_urls: + log.debug( + f"Trying to fetch OAuth 2.1 discovery document from {discovery_url}" + ) + async with aiohttp.ClientSession() as session: + async with session.get( + discovery_urls[0] + ) as oauth_server_metadata_response: + if oauth_server_metadata_response.status == 200: + try: + oauth_server_metadata = ( + OAuthMetadata.model_validate( + await oauth_server_metadata_response.json() + ) + ) + return { + "status": True, + "oauth_server_metadata": oauth_server_metadata.model_dump( + mode="json" + ), + } + except Exception as e: + log.info( + f"Failed to parse OAuth 2.1 discovery document: {e}" + ) + raise HTTPException( + status_code=400, + detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_urls[0]}", + ) raise HTTPException( status_code=400, - detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls[0]}", + detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls}", ) else: try: diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py index ddee71ea4df..51c1eba5f4a 100644 --- a/backend/open_webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -12,6 +12,7 @@ FolderForm, FolderUpdateForm, FolderModel, + FolderNameIdResponse, Folders, ) from open_webui.models.chats import Chats @@ -44,7 +45,7 @@ ############################ -@router.get("/", response_model=list[FolderModel]) +@router.get("/", response_model=list[FolderNameIdResponse]) async def get_folders(user=Depends(get_verified_user)): folders = Folders.get_folders_by_user_id(user.id) @@ -76,14 +77,6 @@ async def get_folders(user=Depends(get_verified_user)): return [ { **folder.model_dump(), - "items": { - "chats": [ - {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at} - for chat in Chats.get_chats_by_folder_id_and_user_id( - folder.id, user.id - ) - ] - }, } for folder in folders ] diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 05d7c680065..5c5a2dcd903 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -1,6 +1,9 @@ from typing import Optional import io import base64 +import json +import asyncio +import logging from open_webui.models.models import ( ModelForm, @@ -12,7 +15,14 @@ from pydantic import BaseModel from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Request, status, Response +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + status, + Response, +) from fastapi.responses import FileResponse, StreamingResponse @@ -20,6 +30,8 @@ from open_webui.utils.access_control import has_access, has_permission from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR +log = logging.getLogger(__name__) + router = APIRouter() @@ -93,6 +105,50 @@ async def export_models(user=Depends(get_admin_user)): return Models.get_models() +############################ +# ImportModels +############################ + + +class ModelsImportForm(BaseModel): + models: list[dict] + + +@router.post("/import", response_model=bool) +async def import_models( + user: str = Depends(get_admin_user), form_data: ModelsImportForm = (...) +): + try: + data = form_data.models + if isinstance(data, list): + for model_data in data: + # Here, you can add logic to validate model_data if needed + model_id = model_data.get("id") + if model_id: + existing_model = Models.get_model_by_id(model_id) + if existing_model: + # Update existing model + model_data["meta"] = model_data.get("meta", {}) + model_data["params"] = model_data.get("params", {}) + + updated_model = ModelForm( + **{**existing_model.model_dump(), **model_data} + ) + Models.update_model_by_id(model_id, updated_model) + else: + # Insert new model + model_data["meta"] = model_data.get("meta", {}) + model_data["params"] = model_data.get("params", {}) + new_model = ModelForm(**model_data) + Models.insert_new_model(user_id=user.id, form_data=new_model) + return True + else: + raise HTTPException(status_code=400, detail="Invalid JSON format") + except Exception as e: + log.exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + ############################ # SyncModels ############################ diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index 0c420e4f12e..3858c4670f2 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -180,6 +180,18 @@ async def update_note_by_id( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) + # Check if user can share publicly + if ( + user.role != "admin" + and form_data.access_control == None + and not has_permission( + user.id, + "sharing.public_notes", + request.app.state.config.USER_PERMISSIONS, + ) + ): + form_data.access_control = {} + try: note = Notes.update_note_by_id(id, form_data) await sio.emit( diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 3681008c874..d322addfa64 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -78,6 +78,7 @@ query_doc, query_doc_with_hybrid_search, ) +from open_webui.retrieval.vector.utils import filter_metadata from open_webui.utils.misc import ( calculate_sha256_string, ) @@ -1535,7 +1536,7 @@ def process_file( Document( page_content=doc.page_content, metadata={ - **doc.metadata, + **filter_metadata(doc.metadata), "name": file.filename, "created_by": file.user_id, "file_id": file.id, diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index eb66a868253..2fa3f6abf61 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -17,7 +17,11 @@ ToolUserResponse, Tools, ) -from open_webui.utils.plugin import load_tool_module_by_id, replace_imports +from open_webui.utils.plugin import ( + load_tool_module_by_id, + replace_imports, + get_tool_module_from_cache, +) from open_webui.utils.tools import get_tool_specs from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission @@ -35,6 +39,14 @@ router = APIRouter() +def get_tool_module(request, tool_id, load_from_db=True): + """ + Get the tool module by its ID. + """ + tool_module, _ = get_tool_module_from_cache(request, tool_id, load_from_db) + return tool_module + + ############################ # GetTools ############################ @@ -42,15 +54,19 @@ @router.get("/", response_model=list[ToolUserResponse]) async def get_tools(request: Request, user=Depends(get_verified_user)): - tools = [ - ToolUserResponse( - **{ - **tool.model_dump(), - "has_user_valves": "class UserValves(BaseModel):" in tool.content, - } + tools = [] + + # Local Tools + for tool in Tools.get_tools(): + tool_module = get_tool_module(request, tool.id) + tools.append( + ToolUserResponse( + **{ + **tool.model_dump(), + "has_user_valves": hasattr(tool_module, "UserValves"), + } + ) ) - for tool in Tools.get_tools() - ] # OpenAPI Tool Servers for server in await get_tool_servers(request): diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 9a0f8c6aaf6..2dd229eeb77 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -157,6 +157,7 @@ class SharingPermissions(BaseModel): public_knowledge: bool = True public_prompts: bool = True public_tools: bool = True + public_notes: bool = True class ChatPermissions(BaseModel): diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index b64eab08aca..e481571df43 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -705,6 +705,23 @@ async def __event_emitter__(event_data): }, ) + if "type" in event_data and event_data["type"] == "embeds": + message = Chats.get_message_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + ) + + embeds = event_data.get("data", {}).get("embeds", []) + embeds.extend(message.get("embeds", [])) + + Chats.upsert_message_to_chat_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + { + "embeds": embeds, + }, + ) + if "type" in event_data and event_data["type"] == "files": message = Chats.get_message_by_id_and_message_id( request_info["chat_id"], diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index ff8c2156078..e4bf1195ff7 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -133,6 +133,149 @@ DEFAULT_CODE_INTERPRETER_TAGS = [("", "")] +def process_tool_result( + request, + tool_function_name, + tool_result, + tool_type, + direct_tool=False, + metadata=None, + user=None, +): + tool_result_embeds = [] + + if isinstance(tool_result, HTMLResponse): + content_disposition = tool_result.headers.get("Content-Disposition", "") + if "inline" in content_disposition: + content = tool_result.body.decode("utf-8") + tool_result_embeds.append(content) + + if 200 <= tool_result.status_code < 300: + tool_result = { + "status": "success", + "code": "ui_component", + "message": f"{tool_function_name}: Embedded UI result is active and visible to the user.", + } + elif 400 <= tool_result.status_code < 500: + tool_result = { + "status": "error", + "code": "ui_component", + "message": f"{tool_function_name}: Client error {tool_result.status_code} from embedded UI result.", + } + elif 500 <= tool_result.status_code < 600: + tool_result = { + "status": "error", + "code": "ui_component", + "message": f"{tool_function_name}: Server error {tool_result.status_code} from embedded UI result.", + } + else: + tool_result = { + "status": "error", + "code": "ui_component", + "message": f"{tool_function_name}: Unexpected status code {tool_result.status_code} from embedded UI result.", + } + else: + tool_result = tool_result.body.decode("utf-8") + + elif (tool_type == "external" and isinstance(tool_result, tuple)) or ( + direct_tool and isinstance(tool_result, list) and len(tool_result) == 2 + ): + tool_result, tool_response_headers = tool_result + + try: + if not isinstance(tool_response_headers, dict): + tool_response_headers = dict(tool_response_headers) + except Exception as e: + tool_response_headers = {} + log.debug(e) + + if tool_response_headers and isinstance(tool_response_headers, dict): + content_disposition = tool_response_headers.get( + "Content-Disposition", + tool_response_headers.get("content-disposition", ""), + ) + + if "inline" in content_disposition: + content_type = tool_response_headers.get( + "Content-Type", + tool_response_headers.get("content-type", ""), + ) + location = tool_response_headers.get( + "Location", + tool_response_headers.get("location", ""), + ) + + if "text/html" in content_type: + # Display as iframe embed + tool_result_embeds.append(tool_result) + tool_result = { + "status": "success", + "code": "ui_component", + "message": f"{tool_function_name}: Embedded UI result is active and visible to the user.", + } + elif location: + tool_result_embeds.append(location) + tool_result = { + "status": "success", + "code": "ui_component", + "message": f"{tool_function_name}: Embedded UI result is active and visible to the user.", + } + + tool_result_files = [] + + if isinstance(tool_result, list): + if tool_type == "mcp": # MCP + tool_response = [] + for item in tool_result: + if isinstance(item, dict): + if item.get("type") == "text": + text = item.get("text", "") + if isinstance(text, str): + try: + text = json.loads(text) + except json.JSONDecodeError: + pass + tool_response.append(text) + elif item.get("type") in ["image", "audio"]: + file_url = get_file_url_from_base64( + request, + f"data:{item.get('mimeType')};base64,{item.get('data', item.get('blob', ''))}", + { + "chat_id": metadata.get("chat_id", None), + "message_id": metadata.get("message_id", None), + "session_id": metadata.get("session_id", None), + "result": item, + }, + user, + ) + + tool_result_files.append( + { + "type": item.get("type", "data"), + "url": file_url, + } + ) + tool_result = tool_response[0] if len(tool_response) == 1 else tool_response + else: # OpenAPI + for item in tool_result: + if isinstance(item, str) and item.startswith("data:"): + tool_result_files.append( + { + "type": "data", + "content": item, + } + ) + tool_result.remove(item) + + if isinstance(tool_result, list): + tool_result = {"results": tool_result} + + if isinstance(tool_result, dict) or isinstance(tool_result, list): + tool_result = json.dumps(tool_result, indent=2, ensure_ascii=False) + + return tool_result, tool_result_files, tool_result_embeds + + async def chat_completion_tools_handler( request: Request, body: dict, extra_params: dict, user: UserModel, models, tools ) -> tuple[dict, dict]: @@ -172,6 +315,7 @@ def get_tools_function_calling_payload(messages, task_model_id, content): } event_caller = extra_params["__event_call__"] + event_emitter = extra_params["__event_emitter__"] metadata = extra_params["__metadata__"] task_model_id = get_task_model_id( @@ -226,8 +370,14 @@ async def tool_call_handler(tool_call): tool_function_params = tool_call.get("parameters", {}) + tool = None + tool_type = "" + direct_tool = False + try: tool = tools[tool_function_name] + tool_type = tool.get("type", "") + direct_tool = tool.get("direct", False) spec = tool.get("spec", {}) allowed_params = ( @@ -259,18 +409,46 @@ async def tool_call_handler(tool_call): except Exception as e: tool_result = str(e) - tool_result_files = [] - if isinstance(tool_result, list): - for item in tool_result: - # check if string - if isinstance(item, str) and item.startswith("data:"): - tool_result_files.append(item) - tool_result.remove(item) + tool_result, tool_result_files, tool_result_embeds = ( + process_tool_result( + request, + tool_function_name, + tool_result, + tool_type, + direct_tool, + metadata, + user, + ) + ) + + if event_emitter: + if tool_result_files: + await event_emitter( + { + "type": "files", + "data": { + "files": tool_result_files, + }, + } + ) + + if tool_result_embeds: + await event_emitter( + { + "type": "embeds", + "data": { + "embeds": tool_result_embeds, + }, + } + ) - if isinstance(tool_result, dict) or isinstance(tool_result, list): - tool_result = json.dumps(tool_result, indent=2) + print( + f"Tool {tool_function_name} result: {tool_result}", + tool_result_files, + tool_result_embeds, + ) - if isinstance(tool_result, str): + if tool_result: tool = tools[tool_function_name] tool_id = tool.get("tool_id", "") @@ -284,18 +462,19 @@ async def tool_call_handler(tool_call): sources.append( { "source": { - "name": (f"TOOL:{tool_name}"), + "name": (f"{tool_name}"), }, - "document": [tool_result], + "document": [str(tool_result)], "metadata": [ { - "source": (f"TOOL:{tool_name}"), + "source": (f"{tool_name}"), "parameters": tool_function_params, } ], "tool_result": True, } ) + # Citation is not enabled for this tool body["messages"] = add_or_update_user_message( f"\nTool `{tool_name}` Output: {tool_result}", @@ -1010,7 +1189,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): tools_dict = {} - mcp_clients = [] + mcp_clients = {} mcp_tools_dict = {} if tool_ids: @@ -1071,35 +1250,41 @@ async def process_chat_payload(request, form_data, user, metadata, model): log.error(f"Error getting OAuth token: {e}") oauth_token = None - mcp_client = MCPClient() - await mcp_client.connect( + mcp_clients[server_id] = MCPClient() + await mcp_clients[server_id].connect( url=mcp_server_connection.get("url", ""), headers=headers if headers else None, ) - tool_specs = await mcp_client.list_tool_specs() + tool_specs = await mcp_clients[server_id].list_tool_specs() for tool_spec in tool_specs: - def make_tool_function(function_name): + def make_tool_function(client, function_name): async def tool_function(**kwargs): - return await mcp_client.call_tool( + print(kwargs) + print(client) + print(await client.list_tool_specs()) + return await client.call_tool( function_name, function_args=kwargs, ) return tool_function - tool_function = make_tool_function(tool_spec["name"]) + tool_function = make_tool_function( + mcp_clients[server_id], tool_spec["name"] + ) - mcp_tools_dict[tool_spec["name"]] = { - "spec": tool_spec, + mcp_tools_dict[f"{server_id}_{tool_spec['name']}"] = { + "spec": { + **tool_spec, + "name": f"{server_id}_{tool_spec['name']}", + }, "callable": tool_function, "type": "mcp", - "client": mcp_client, + "client": mcp_clients[server_id], "direct": False, } - - mcp_clients.append(mcp_client) except Exception as e: log.debug(e) continue @@ -1140,7 +1325,6 @@ async def tool_function(**kwargs): {"type": "function", "function": tool.get("spec", {})} for tool in tools_dict.values() ] - else: # If the function calling is not native, then call the tools function calling handler try: @@ -1165,9 +1349,7 @@ async def tool_function(**kwargs): citation_idx_map = {} for source in sources: - is_tool_result = source.get("tool_result", False) - - if "document" in source and not is_tool_result: + if "document" in source: for document_text, document_metadata in zip( source["document"], source["metadata"] ): @@ -1228,6 +1410,10 @@ async def tool_function(**kwargs): } ) + print("Final form_data:", form_data) + print("Final metadata:", metadata) + print("Final events:", events) + return form_data, metadata, events @@ -2436,7 +2622,9 @@ async def flush_pending_delta_data(threshold: int = 0): print("tool_call", tool_call) tool_call_id = tool_call.get("id", "") - tool_name = tool_call.get("function", {}).get("name", "") + tool_function_name = tool_call.get("function", {}).get( + "name", "" + ) tool_args = tool_call.get("function", {}).get("arguments", "{}") tool_function_params = {} @@ -2466,11 +2654,17 @@ async def flush_pending_delta_data(threshold: int = 0): ) tool_result = None + tool = None + tool_type = None + direct_tool = False - if tool_name in tools: - tool = tools[tool_name] + if tool_function_name in tools: + tool = tools[tool_function_name] spec = tool.get("spec", {}) + tool_type = tool.get("type", "") + direct_tool = tool.get("direct", False) + try: allowed_params = ( spec.get("parameters", {}) @@ -2484,13 +2678,13 @@ async def flush_pending_delta_data(threshold: int = 0): if k in allowed_params } - if tool.get("direct", False): + if direct_tool: tool_result = await event_caller( { "type": "execute:tool", "data": { "id": str(uuid4()), - "name": tool_name, + "name": tool_function_name, "params": tool_function_params, "server": tool.get("server", {}), "session_id": metadata.get( @@ -2509,151 +2703,17 @@ async def flush_pending_delta_data(threshold: int = 0): except Exception as e: tool_result = str(e) - tool_result_embeds = [] - if isinstance(tool_result, HTMLResponse): - content_disposition = tool_result.headers.get( - "Content-Disposition", "" - ) - if "inline" in content_disposition: - content = tool_result.body.decode("utf-8") - tool_result_embeds.append(content) - - if 200 <= tool_result.status_code < 300: - tool_result = { - "status": "success", - "code": "ui_component", - "message": "Embedded UI result is active and visible to the user.", - } - elif 400 <= tool_result.status_code < 500: - tool_result = { - "status": "error", - "code": "ui_component", - "message": f"Client error {tool_result.status_code} from embedded UI result.", - } - elif 500 <= tool_result.status_code < 600: - tool_result = { - "status": "error", - "code": "ui_component", - "message": f"Server error {tool_result.status_code} from embedded UI result.", - } - else: - tool_result = { - "status": "error", - "code": "ui_component", - "message": f"Unexpected status code {tool_result.status_code} from embedded UI result.", - } - else: - tool_result = tool_result.body.decode("utf-8") - - elif ( - tool.get("type") == "external" - and isinstance(tool_result, tuple) - ) or ( - tool.get("direct", True) - and isinstance(tool_result, list) - and len(tool_result) == 2 - ): - tool_result, tool_response_headers = tool_result - - if tool_response_headers: - content_disposition = tool_response_headers.get( - "Content-Disposition", - tool_response_headers.get( - "content-disposition", "" - ), - ) - - if "inline" in content_disposition: - content_type = tool_response_headers.get( - "Content-Type", - tool_response_headers.get("content-type", ""), - ) - location = tool_response_headers.get( - "Location", - tool_response_headers.get("location", ""), - ) - - if "text/html" in content_type: - # Display as iframe embed - tool_result_embeds.append(tool_result) - tool_result = { - "status": "success", - "code": "ui_component", - "message": "Embedded UI result is active and visible to the user.", - } - elif location: - tool_result_embeds.append(location) - tool_result = { - "status": "success", - "code": "ui_component", - "message": "Embedded UI result is active and visible to the user.", - } - - tool_result_files = [] - if isinstance(tool_result, list): - for item in tool_result: - # check if string - if isinstance(item, str) and item.startswith("data:"): - tool_result_files.append( - { - "type": "data", - "content": item, - } - ) - tool_result.remove(item) - - if tool.get("type") == "mcp": - if isinstance(item, dict): - if ( - item.get("type") == "image" - or item.get("type") == "audio" - ): - file_url = get_file_url_from_base64( - request, - f"data:{item.get('mimeType')};base64,{item.get('data', item.get('blob', ''))}", - { - "chat_id": metadata.get( - "chat_id", None - ), - "message_id": metadata.get( - "message_id", None - ), - "session_id": metadata.get( - "session_id", None - ), - "result": item, - }, - user, - ) - - tool_result_files.append( - { - "type": item.get("type", "data"), - "url": file_url, - } - ) - tool_result.remove(item) - - if tool_result_files: - if not isinstance(tool_result, list): - tool_result = [ - tool_result, - ] - - for file in tool_result_files: - tool_result.append( - { - "type": file.get("type", "data"), - "content": "Result is being displayed as a file.", - } - ) - - if isinstance(tool_result, dict) or isinstance( - tool_result, list - ): - tool_result = json.dumps( - tool_result, indent=2, ensure_ascii=False + tool_result, tool_result_files, tool_result_embeds = ( + process_tool_result( + request, + tool_function_name, + tool_result, + tool_type, + direct_tool, + metadata, + user, ) + ) results.append( { @@ -2673,7 +2733,6 @@ async def flush_pending_delta_data(threshold: int = 0): ) content_blocks[-1]["results"] = results - content_blocks.append( { "type": "text", diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index e8cfa0d1580..81a4142ea00 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -391,17 +391,10 @@ def parse_ollama_modelfile(model_text): "top_k": int, "top_p": float, "num_keep": int, - "typical_p": float, "presence_penalty": float, "frequency_penalty": float, - "penalize_newline": bool, - "numa": bool, "num_batch": int, "num_gpu": int, - "main_gpu": int, - "low_vram": bool, - "f16_kv": bool, - "vocab_only": bool, "use_mmap": bool, "use_mlock": bool, "num_thread": int, diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 7e69661f567..587e2a2c7de 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -263,6 +263,7 @@ def get_filter_items_from_module(function, module): "icon": function.meta.manifest.get("icon_url", None) or getattr(module, "icon_url", None) or getattr(module, "icon", None), + "has_user_valves": hasattr(module, "UserValves"), } ] diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 93992418538..6cf91e3f12b 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -198,13 +198,25 @@ def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]: def get_discovery_urls(server_url) -> list[str]: - urls = [] parsed, base_url = get_parsed_and_base_url(server_url) - urls.append( - urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server") - ) - urls.append(urllib.parse.urljoin(base_url, "/.well-known/openid-configuration")) + urls = [ + urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"), + urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"), + ] + + if parsed.path and parsed.path != "/": + urls.append( + urllib.parse.urljoin( + base_url, + f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}", + ) + ) + urls.append( + urllib.parse.urljoin( + base_url, f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" + ) + ) return urls diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 39c785854a5..8cb36b3759a 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -153,17 +153,11 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict: "repeat_last_n": int, "top_k": int, "min_p": float, - "typical_p": float, "repeat_penalty": float, "presence_penalty": float, "frequency_penalty": float, - "penalize_newline": bool, "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], - "numa": bool, "num_gpu": int, - "main_gpu": int, - "low_vram": bool, - "vocab_only": bool, "use_mmap": bool, "use_mlock": bool, "num_thread": int, diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index 8d9729bae2c..51c3f4f5f7f 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -166,6 +166,48 @@ def load_function_module_by_id(function_id: str, content: str | None = None): os.unlink(temp_file.name) +def get_tool_module_from_cache(request, tool_id, load_from_db=True): + if load_from_db: + # Always load from the database by default + tool = Tools.get_tool_by_id(tool_id) + if not tool: + raise Exception(f"Tool not found: {tool_id}") + content = tool.content + + new_content = replace_imports(content) + if new_content != content: + content = new_content + # Update the tool content in the database + Tools.update_tool_by_id(tool_id, {"content": content}) + + if ( + hasattr(request.app.state, "TOOL_CONTENTS") + and tool_id in request.app.state.TOOL_CONTENTS + ) and ( + hasattr(request.app.state, "TOOLS") and tool_id in request.app.state.TOOLS + ): + if request.app.state.TOOL_CONTENTS[tool_id] == content: + return request.app.state.TOOLS[tool_id], None + + tool_module, frontmatter = load_tool_module_by_id(tool_id, content) + else: + if hasattr(request.app.state, "TOOLS") and tool_id in request.app.state.TOOLS: + return request.app.state.TOOLS[tool_id], None + + tool_module, frontmatter = load_tool_module_by_id(tool_id) + + if not hasattr(request.app.state, "TOOLS"): + request.app.state.TOOLS = {} + + if not hasattr(request.app.state, "TOOL_CONTENTS"): + request.app.state.TOOL_CONTENTS = {} + + request.app.state.TOOLS[tool_id] = tool_module + request.app.state.TOOL_CONTENTS[tool_id] = content + + return tool_module, frontmatter + + def get_function_module_from_cache(request, function_id, load_from_db=True): if load_from_db: # Always load from the database by default diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 4c8289578e5..5cd73778767 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -588,28 +588,20 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: error = str(err) raise Exception(error) - data = { - "openapi": res, - "info": res.get("info", {}), - "specs": convert_openapi_to_tool_payload(res), - } - - log.info(f"Fetched data: {data}") - return data + log.debug(f"Fetched data: {res}") + return res async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # Prepare list of enabled servers along with their original index + + tasks = [] server_entries = [] for idx, server in enumerate(servers): if ( server.get("config", {}).get("enable") and server.get("type", "openapi") == "openapi" ): - # Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL - openapi_path = server.get("path", "openapi.json") - full_url = get_tool_server_url(server.get("url"), openapi_path) - info = server.get("info", {}) auth_type = server.get("auth_type", "bearer") @@ -625,12 +617,34 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, if not id: id = str(idx) - server_entries.append((id, idx, server, full_url, info, token)) + server_url = server.get("url") + spec_type = server.get("spec_type", "url") + + # Create async tasks to fetch data + task = None + if spec_type == "url": + # Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL + openapi_path = server.get("path", "openapi.json") + spec_url = get_tool_server_url(server_url, openapi_path) + # Fetch from URL + task = get_tool_server_data(token, spec_url) + elif spec_type == "json" and server.get("spec", ""): + # Use provided JSON spec + spec_json = None + try: + spec_json = json.loads(server.get("spec", "")) + except Exception as e: + log.error(f"Error parsing JSON spec for tool server {id}: {e}") + + if spec_json: + task = asyncio.sleep( + 0, + result=spec_json, + ) - # Create async tasks to fetch data - tasks = [ - get_tool_server_data(token, url) for (_, _, _, url, _, token) in server_entries - ] + if task: + tasks.append(task) + server_entries.append((id, idx, server, server_url, info, token)) # Execute tasks concurrently responses = await asyncio.gather(*tasks, return_exceptions=True) @@ -642,8 +656,13 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, log.error(f"Failed to connect to {url} OpenAPI tool server") continue - openapi_data = response.get("openapi", {}) + response = { + "openapi": response, + "info": response.get("info", {}), + "specs": convert_openapi_to_tool_payload(response), + } + openapi_data = response.get("openapi", {}) if info and isinstance(openapi_data, dict): openapi_data["info"] = openapi_data.get("info", {}) diff --git a/package-lock.json b/package-lock.json index c6d6fc47af6..6b59776fa0d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.6.31", + "version": "0.6.32", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.6.31", + "version": "0.6.32", "dependencies": { "@azure/msal-browser": "^4.5.0", "@codemirror/lang-javascript": "^6.2.2", diff --git a/package.json b/package.json index 67f7b6dddae..658964de0bc 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.6.31", + "version": "0.6.32", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", diff --git a/src/lib/apis/channels/index.ts b/src/lib/apis/channels/index.ts index 548572c6fba..ac51e5a5d01 100644 --- a/src/lib/apis/channels/index.ts +++ b/src/lib/apis/channels/index.ts @@ -248,6 +248,7 @@ export const getChannelThreadMessages = async ( }; type MessageForm = { + reply_to_id?: string; parent_id?: string; content: string; data?: object; diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index 59d86007713..b8073d94fa1 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -33,6 +33,38 @@ export const createNewChat = async (token: string, chat: object, folderId: strin return res; }; +export const unarchiveAllChats = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/unarchive/all`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const importChat = async ( token: string, chat: object, @@ -327,6 +359,45 @@ export const getChatsByFolderId = async (token: string, folderId: string) => { return res; }; +export const getChatListByFolderId = async (token: string, folderId: string, page: number = 1) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (page !== null) { + searchParams.append('page', `${page}`); + } + + const res = await fetch( + `${WEBUI_API_BASE_URL}/chats/folder/${folderId}/list?${searchParams.toString()}`, + { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getAllArchivedChats = async (token: string) => { let error = null; diff --git a/src/lib/apis/files/index.ts b/src/lib/apis/files/index.ts index ac322200f97..6a1763edb88 100644 --- a/src/lib/apis/files/index.ts +++ b/src/lib/apis/files/index.ts @@ -23,7 +23,7 @@ export const uploadFile = async (token: string, file: File, metadata?: object | return res.json(); }) .catch((err) => { - error = err.detail; + error = err.detail || err.message; console.error(err); return null; }); diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 937e7cec817..43a5936a23e 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -337,14 +337,8 @@ export const getToolServerData = async (token: string, url: string) => { throw error; } - const data = { - openapi: res, - info: res.info, - specs: convertOpenApiToToolPayload(res) - }; - - console.log(data); - return data; + console.log(res); + return res; }; export const getToolServersData = async (servers: object[]) => { @@ -356,6 +350,7 @@ export const getToolServersData = async (servers: object[]) => { let error = null; let toolServerToken = null; + const auth_type = server?.auth_type ?? 'bearer'; if (auth_type === 'bearer') { toolServerToken = server?.key; @@ -365,18 +360,34 @@ export const getToolServersData = async (servers: object[]) => { toolServerToken = localStorage.token; } - const data = await getToolServerData( - toolServerToken, - (server?.path ?? '').includes('://') - ? server?.path - : `${server?.url}${(server?.path ?? '').startsWith('/') ? '' : '/'}${server?.path}` - ).catch((err) => { - error = err; - return null; - }); + let res = null; + const specType = server?.spec_type ?? 'url'; + + if (specType === 'url') { + res = await getToolServerData( + toolServerToken, + (server?.path ?? '').includes('://') + ? server?.path + : `${server?.url}${(server?.path ?? '').startsWith('/') ? '' : '/'}${server?.path}` + ).catch((err) => { + error = err; + return null; + }); + } else if ((specType === 'json' && server?.spec) ?? null) { + try { + res = JSON.parse(server?.spec); + } catch (e) { + error = 'Failed to parse JSON spec'; + } + } + + if (res) { + const { openapi, info, specs } = { + openapi: res, + info: res.info, + specs: convertOpenApiToToolPayload(res) + }; - if (data) { - const { openapi, info, specs } = data; return { url: server?.url, openapi: openapi, diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index 3e6e0d0c0bc..d324fa91733 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -31,6 +31,34 @@ export const getModels = async (token: string = '') => { return res; }; +export const importModels = async (token: string, models: object[]) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/import`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ models: models }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getBaseModels = async (token: string = '') => { let error = null; diff --git a/src/lib/components/AddToolServerModal.svelte b/src/lib/components/AddToolServerModal.svelte index c6894ddeeb8..08489348b3b 100644 --- a/src/lib/components/AddToolServerModal.svelte +++ b/src/lib/components/AddToolServerModal.svelte @@ -1,4 +1,9 @@ {#if loaded} - + +
{$i18n.t('Groups')} @@ -180,7 +149,7 @@
@@ -234,7 +203,7 @@
- - import { toast } from 'svelte-sonner'; - import { getContext, onMount } from 'svelte'; - const i18n = getContext('i18n'); - - import Spinner from '$lib/components/common/Spinner.svelte'; - import Modal from '$lib/components/common/Modal.svelte'; - import Textarea from '$lib/components/common/Textarea.svelte'; - import XMark from '$lib/components/icons/XMark.svelte'; - export let onSubmit: Function = () => {}; - export let show = false; - - let name = ''; - let description = ''; - let userIds = []; - - let loading = false; - - const submitHandler = async () => { - loading = true; - - const group = { - name, - description - }; - - await onSubmit(group); - - loading = false; - show = false; - - name = ''; - description = ''; - userIds = []; - }; - - - -
-
-
- {$i18n.t('Add User Group')} -
- -
- -
-
-
{ - e.preventDefault(); - submitHandler(); - }} - > -
-
-
-
{$i18n.t('Name')}
- -
- -
-
-
- -
-
{$i18n.t('Description')}
- -
-