diff --git a/chromadb/config.py b/chromadb/config.py index 59bea5ee0e4..61b789d0eee 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -80,7 +80,6 @@ DEFAULT_TENANT = "default_tenant" DEFAULT_DATABASE = "default_database" - class Settings(BaseSettings): # type: ignore environment: str = "" @@ -116,6 +115,9 @@ class Settings(BaseSettings): # type: ignore is_persistent: bool = False persist_directory: str = "./chroma" + chroma_memory_limit_bytes: int = 0 + chroma_segment_cache_policy: Optional[str] = None + chroma_server_host: Optional[str] = None chroma_server_headers: Optional[Dict[str, str]] = None chroma_server_http_port: Optional[str] = None @@ -313,6 +315,15 @@ def __init__(self, settings: Settings): if settings[key] is not None: raise ValueError(LEGACY_ERROR) + if settings["chroma_segment_cache_policy"] is not None and settings["chroma_segment_cache_policy"] != "LRU": + logger.error( + f"Failed to set chroma_segment_cache_policy: Only LRU is available." + ) + if settings["chroma_memory_limit_bytes"] == 0: + logger.error( + f"Failed to set chroma_segment_cache_policy: chroma_memory_limit_bytes is require." + ) + # Apply the nofile limit if set if settings["chroma_server_nofile"] is not None: if platform.system() != "Windows": diff --git a/chromadb/segment/impl/__init__.py b/chromadb/segment/impl/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/chromadb/segment/impl/manager/__init__.py b/chromadb/segment/impl/manager/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/chromadb/segment/impl/manager/cache/__init__.py b/chromadb/segment/impl/manager/cache/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/chromadb/segment/impl/manager/cache/cache.py b/chromadb/segment/impl/manager/cache/cache.py new file mode 100644 index 00000000000..80cab0d8e91 --- /dev/null +++ b/chromadb/segment/impl/manager/cache/cache.py @@ -0,0 +1,104 @@ +import uuid +from typing import Any, Callable +from chromadb.types import Segment +from overrides import override +from typing import Dict, Optional +from abc import ABC, abstractmethod + +class SegmentCache(ABC): + @abstractmethod + def get(self, key: uuid.UUID) -> Optional[Segment]: + pass + + @abstractmethod + def pop(self, key: uuid.UUID) -> Optional[Segment]: + pass + + @abstractmethod + def set(self, key: uuid.UUID, value: Segment) -> None: + pass + + @abstractmethod + def reset(self) -> None: + pass + + +class BasicCache(SegmentCache): + def __init__(self): + self.cache:Dict[uuid.UUID, Segment] = {} + + @override + def get(self, key: uuid.UUID) -> Optional[Segment]: + return self.cache.get(key) + + @override + def pop(self, key: uuid.UUID) -> Optional[Segment]: + return self.cache.pop(key, None) + + @override + def set(self, key: uuid.UUID, value: Segment) -> None: + self.cache[key] = value + + @override + def reset(self) -> None: + self.cache = {} + + +class SegmentLRUCache(BasicCache): + """A simple LRU cache implementation that handles objects with dynamic sizes. + The size of each object is determined by a user-provided size function.""" + + def __init__(self, capacity: int, size_func: Callable[[uuid.UUID], int], + callback: Optional[Callable[[uuid.UUID, Segment], Any]] = None): + self.capacity = capacity + self.size_func = size_func + self.cache: Dict[uuid.UUID, Segment] = {} + self.history = [] + self.callback = callback + + def _upsert_key(self, key: uuid.UUID): + if key in self.history: + self.history.remove(key) + self.history.append(key) + else: + self.history.append(key) + + @override + def get(self, key: uuid.UUID) -> Optional[Segment]: + self._upsert_key(key) + if key in self.cache: + return self.cache[key] + else: + return None + + @override + def pop(self, key: uuid.UUID) -> Optional[Segment]: + if key in self.history: + self.history.remove(key) + return self.cache.pop(key, None) + + + @override + def set(self, key: uuid.UUID, value: Segment) -> None: + if key in self.cache: + return + item_size = self.size_func(key) + key_sizes = {key: self.size_func(key) for key in self.cache} + total_size = sum(key_sizes.values()) + index = 0 + # Evict items if capacity is exceeded + while total_size + item_size > self.capacity and len(self.history) > index: + key_delete = self.history[index] + if key_delete in self.cache: + self.callback(key_delete, self.cache[key_delete]) + del self.cache[key_delete] + total_size -= key_sizes[key_delete] + index += 1 + + self.cache[key] = value + self._upsert_key(key) + + @override + def reset(self): + self.cache = {} + self.history = [] diff --git a/chromadb/segment/impl/manager/local.py b/chromadb/segment/impl/manager/local.py index 246d9a00c64..c5afef2d012 100644 --- a/chromadb/segment/impl/manager/local.py +++ b/chromadb/segment/impl/manager/local.py @@ -7,6 +7,10 @@ VectorReader, S, ) +import logging +from chromadb.segment.impl.manager.cache.cache import SegmentLRUCache, BasicCache,SegmentCache +import os + from chromadb.config import System, get_class from chromadb.db.system import SysDB from overrides import override @@ -21,24 +25,23 @@ from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata from typing import Dict, Type, Sequence, Optional, cast from uuid import UUID, uuid4 -from collections import defaultdict import platform from chromadb.utils.lru_cache import LRUCache +from chromadb.utils.directory import get_directory_size + if platform.system() != "Windows": import resource elif platform.system() == "Windows": import ctypes - SEGMENT_TYPE_IMPLS = { SegmentType.SQLITE: "chromadb.segment.impl.metadata.sqlite.SqliteMetadataSegment", SegmentType.HNSW_LOCAL_MEMORY: "chromadb.segment.impl.vector.local_hnsw.LocalHnswSegment", SegmentType.HNSW_LOCAL_PERSISTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment", } - class LocalSegmentManager(SegmentManager): _sysdb: SysDB _system: System @@ -47,9 +50,6 @@ class LocalSegmentManager(SegmentManager): _vector_instances_file_handle_cache: LRUCache[ UUID, PersistentLocalHnswSegment ] # LRU cache to manage file handles across vector segment instances - _segment_cache: Dict[ - UUID, Dict[SegmentScope, Segment] - ] # Tracks which segments are loaded for a given collection _vector_segment_type: SegmentType = SegmentType.HNSW_LOCAL_MEMORY _lock: Lock _max_file_handles: int @@ -59,8 +59,17 @@ def __init__(self, system: System): self._sysdb = self.require(SysDB) self._system = system self._opentelemetry_client = system.require(OpenTelemetryClient) + self.logger = logging.getLogger(__name__) self._instances = {} - self._segment_cache = defaultdict(dict) + self.segment_cache: Dict[SegmentScope, SegmentCache] = {SegmentScope.METADATA: BasicCache()} + if system.settings.chroma_segment_cache_policy == "LRU" and system.settings.chroma_memory_limit_bytes > 0: + self.segment_cache[SegmentScope.VECTOR] = SegmentLRUCache(capacity=system.settings.chroma_memory_limit_bytes,callback=lambda k, v: self.callback_cache_evict(v), size_func=lambda k: self._get_segment_disk_size(k)) + else: + self.segment_cache[SegmentScope.VECTOR] = BasicCache() + + + + self._lock = Lock() # TODO: prototyping with distributed segment for now, but this should be a configurable option @@ -72,13 +81,21 @@ def __init__(self, system: System): else: self._max_file_handles = ctypes.windll.msvcrt._getmaxstdio() # type: ignore segment_limit = ( - self._max_file_handles - // PersistentLocalHnswSegment.get_file_handle_count() + self._max_file_handles + // PersistentLocalHnswSegment.get_file_handle_count() ) self._vector_instances_file_handle_cache = LRUCache( segment_limit, callback=lambda _, v: v.close_persistent_index() ) + def callback_cache_evict(self, segment: Segment): + collection_id = segment["collection"] + self.logger.info(f"LRU cache evict collection {collection_id}") + instance = self._instance(segment) + instance.stop() + del self._instances[segment["id"]] + + @override def start(self) -> None: for instance in self._instances.values(): @@ -97,7 +114,7 @@ def reset_state(self) -> None: instance.stop() instance.reset_state() self._instances = {} - self._segment_cache = defaultdict(dict) + self.segment_cache[SegmentScope.VECTOR].reset() super().reset_state() @trace_method( @@ -130,16 +147,31 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: instance = self.get_segment(collection_id, MetadataReader) instance.delete() del self._instances[segment["id"]] - if collection_id in self._segment_cache: - if segment["scope"] in self._segment_cache[collection_id]: - del self._segment_cache[collection_id][segment["scope"]] - del self._segment_cache[collection_id] + if segment["scope"] is SegmentScope.VECTOR: + self.segment_cache[SegmentScope.VECTOR].pop(collection_id) + if segment["scope"] is SegmentScope.METADATA: + self.segment_cache[SegmentScope.METADATA].pop(collection_id) return [s["id"] for s in segments] @trace_method( "LocalSegmentManager.get_segment", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) + def _get_segment_disk_size(self, collection_id: UUID) -> int: + segments = self._sysdb.get_segments(collection=collection_id, scope=SegmentScope.VECTOR) + if len(segments) == 0: + return 0 + # With local segment manager (single server chroma), a collection always have one segment. + size = get_directory_size( + os.path.join(self._system.settings.require("persist_directory"), str(segments[0]["id"]))) + return size + + def _get_segment_sysdb(self, collection_id:UUID, scope: SegmentScope): + segments = self._sysdb.get_segments(collection=collection_id, scope=scope) + known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()]) + # Get the first segment of a known type + segment = next(filter(lambda s: s["type"] in known_types, segments)) + return segment @override def get_segment(self, collection_id: UUID, type: Type[S]) -> S: if type == MetadataReader: @@ -149,17 +181,15 @@ def get_segment(self, collection_id: UUID, type: Type[S]) -> S: else: raise ValueError(f"Invalid segment type: {type}") - if scope not in self._segment_cache[collection_id]: - segments = self._sysdb.get_segments(collection=collection_id, scope=scope) - known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()]) - # Get the first segment of a known type - segment = next(filter(lambda s: s["type"] in known_types, segments)) - self._segment_cache[collection_id][scope] = segment + segment = self.segment_cache[scope].get(collection_id) + if segment is None: + segment = self._get_segment_sysdb(collection_id, scope) + self.segment_cache[scope].set(collection_id, segment) # Instances must be atomically created, so we use a lock to ensure that only one thread # creates the instance. with self._lock: - instance = self._instance(self._segment_cache[collection_id][scope]) + instance = self._instance(segment) return cast(S, instance) @trace_method( @@ -208,5 +238,5 @@ def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) -> scope=scope, topic=collection["topic"], collection=collection["id"], - metadata=metadata, + metadata=metadata ) diff --git a/chromadb/test/db/test_system.py b/chromadb/test/db/test_system.py index 9971d81af93..3cd2a9954ec 100644 --- a/chromadb/test/db/test_system.py +++ b/chromadb/test/db/test_system.py @@ -721,7 +721,7 @@ def test_update_segment(sysdb: SysDB) -> None: scope=SegmentScope.VECTOR, topic="test_topic_a", collection=sample_collections[0]["id"], - metadata=metadata, + metadata=metadata ) sysdb.reset_state() diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 5a4c0d905cc..89def8ac316 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -3,6 +3,7 @@ import hypothesis.strategies as st from typing import Any, Optional, List, Dict, Union, cast from typing_extensions import TypedDict +import uuid import numpy as np import numpy.typing as npt import chromadb.api.types as types @@ -237,16 +238,17 @@ def embedding_function_strategy( @dataclass class Collection: name: str + id: uuid.UUID metadata: Optional[types.Metadata] dimension: int dtype: npt.DTypeLike + topic: str known_metadata_keys: types.Metadata known_document_keywords: List[str] has_documents: bool = False has_embeddings: bool = False embedding_function: Optional[types.EmbeddingFunction[Embeddable]] = None - @st.composite def collections( draw: st.DrawFn, @@ -309,7 +311,9 @@ def collections( embedding_function = draw(embedding_function_strategy(dimension, dtype)) return Collection( + id=uuid.uuid4(), name=name, + topic="topic", metadata=metadata, dimension=dimension, dtype=dtype, diff --git a/chromadb/test/property/test_segment_manager.py b/chromadb/test/property/test_segment_manager.py new file mode 100644 index 00000000000..ff5e057dff4 --- /dev/null +++ b/chromadb/test/property/test_segment_manager.py @@ -0,0 +1,128 @@ +import uuid + +import pytest +import chromadb.test.property.strategies as strategies +from unittest.mock import patch +from dataclasses import asdict +import random +from hypothesis.stateful import ( + Bundle, + RuleBasedStateMachine, + rule, + initialize, + multiple, + precondition, + invariant, + run_state_machine_as_test, + MultipleResults, +) +from typing import Dict +from chromadb.segment import ( + VectorReader +) +from chromadb.segment import SegmentManager + +from chromadb.segment.impl.manager.local import LocalSegmentManager +from chromadb.types import SegmentScope +from chromadb.db.system import SysDB +from chromadb.config import System, get_class + +# Memory limit use for testing +memory_limit = 100 + +# Helper class to keep tract of the last use id +class LastUse: + def __init__(self, n: int): + self.n = n + self.store = [] + + def add(self, id: uuid.UUID): + if id in self.store: + self.store.remove(id) + self.store.append(id) + else: + self.store.append(id) + while len(self.store) > self.n: + self.store.pop(0) + return self.store + + def reset(self): + self.store = [] + + +class SegmentManagerStateMachine(RuleBasedStateMachine): + collections: Bundle[strategies.Collection] + collections = Bundle("collections") + collection_size_store: Dict[uuid.UUID, int] = {} + segment_collection: Dict[uuid.UUID, uuid.UUID] = {} + + def __init__(self, system: System): + super().__init__() + self.segment_manager = system.require(SegmentManager) + self.segment_manager.start() + self.segment_manager.reset_state() + self.last_use = LastUse(n=40) + self.collection_created_counter = 0 + self.sysdb = system.require(SysDB) + self.system = system + + @invariant() + def last_queried_segments_should_be_in_cache(self): + cache_sum = 0 + index = 0 + for id in reversed(self.last_use.store): + cache_sum += self.collection_size_store[id] + if cache_sum >= memory_limit and index is not 0: + break + assert id in self.segment_manager.segment_cache[SegmentScope.VECTOR].cache + index += 1 + + @invariant() + @precondition(lambda self: self.system.settings.is_persistent is True) + def cache_should_not_be_bigger_than_settings(self): + segment_sizes = {id: self.collection_size_store[id] for id in self.segment_manager.segment_cache[SegmentScope.VECTOR].cache} + total_size = sum(segment_sizes.values()) + if len(segment_sizes) != 1: + assert total_size <= memory_limit + + @initialize() + def initialize(self) -> None: + self.segment_manager.reset_state() + self.segment_manager.start() + self.collection_created_counter = 0 + self.last_use.reset() + + @rule(target=collections, coll=strategies.collections()) + @precondition(lambda self: self.collection_created_counter <= 50) + def create_segment( + self, coll: strategies.Collection + ) -> MultipleResults[strategies.Collection]: + segments = self.segment_manager.create_segments(asdict(coll)) + for segment in segments: + self.sysdb.create_segment(segment) + self.segment_collection[segment["id"]] = coll.id + self.collection_created_counter += 1 + self.collection_size_store[coll.id] = random.randint(0, memory_limit) + return multiple(coll) + + @rule(coll=collections) + def get_segment(self, coll: strategies.Collection) -> None: + segment = self.segment_manager.get_segment(collection_id=coll.id, type=VectorReader) + self.last_use.add(coll.id) + assert segment is not None + + + @staticmethod + def mock_directory_size(directory: str): + path_id = directory.split("/").pop() + collection_id = SegmentManagerStateMachine.segment_collection[uuid.UUID(path_id)] + return SegmentManagerStateMachine.collection_size_store[collection_id] + + +@patch('chromadb.segment.impl.manager.local.get_directory_size', SegmentManagerStateMachine.mock_directory_size) +def test_segment_manager(caplog: pytest.LogCaptureFixture, system: System) -> None: + system.settings.chroma_memory_limit_bytes = memory_limit + system.settings.chroma_segment_cache_policy = "LRU" + + run_state_machine_as_test( + lambda: SegmentManagerStateMachine(system=system)) diff --git a/chromadb/utils/directory.py b/chromadb/utils/directory.py new file mode 100644 index 00000000000..d470a810ed5 --- /dev/null +++ b/chromadb/utils/directory.py @@ -0,0 +1,21 @@ +import os + +def get_directory_size(directory: str) -> int: + """ + Calculate the total size of the directory by walking through each file. + + Parameters: + directory (str): The path of the directory for which to calculate the size. + + Returns: + total_size (int): The total size of the directory in bytes. + """ + total_size = 0 + for dirpath, _, filenames in os.walk(directory): + for f in filenames: + fp = os.path.join(dirpath, f) + # skip if it is symbolic link + if not os.path.islink(fp): + total_size += os.path.getsize(fp) + + return total_size \ No newline at end of file