-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[Improvements] Manage segment cache and memory #1670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
dd2331d
877d712
cb15ae4
48b9cfe
db78fa1
06e7a84
1f30cbc
9e909ef
8a6f537
097cc51
42fbc6d
2386cdd
ba45af9
9724918
034940c
15cc717
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -116,6 +116,7 @@ class Settings(BaseSettings): # type: ignore | |
| is_persistent: bool = False | ||
| persist_directory: str = "./chroma" | ||
|
|
||
| chroma_memory_limit_bytes: int = 0 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do I turn this capability on and off? Is 0 implicitly off?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. 0 is unlimited |
||
| chroma_server_host: Optional[str] = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we introduce a config - called segment_manager_cache_policy and make this one of many types? |
||
| chroma_server_headers: Optional[Dict[str, str]] = None | ||
| chroma_server_http_port: Optional[str] = None | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,9 @@ | |
| VectorReader, | ||
| S, | ||
| ) | ||
| import time | ||
| import os | ||
|
|
||
| from chromadb.config import System, get_class | ||
| from chromadb.db.system import SysDB | ||
| from overrides import override | ||
|
|
@@ -37,6 +40,16 @@ | |
| SegmentType.HNSW_LOCAL_MEMORY: "chromadb.segment.impl.vector.local_hnsw.LocalHnswSegment", | ||
| SegmentType.HNSW_LOCAL_PERSISTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment", | ||
| } | ||
| def get_size(start_path: str): | ||
| total_size = 0 | ||
| for dirpath, _, filenames in os.walk(start_path): | ||
| for f in filenames: | ||
| fp = os.path.join(dirpath, f) | ||
| # skip if it is symbolic link | ||
| if not os.path.islink(fp): | ||
| total_size += os.path.getsize(fp) | ||
|
|
||
| return total_size | ||
|
|
||
|
|
||
| class LocalSegmentManager(SegmentManager): | ||
|
|
@@ -140,16 +153,52 @@ def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: | |
| "LocalSegmentManager.get_segment", | ||
| OpenTelemetryGranularity.OPERATION_AND_SEGMENT, | ||
| ) | ||
| def _get_segment_disk_size(self, collection_id: UUID) -> float: | ||
| segments = self._sysdb.get_segments(collection=collection_id, scope=SegmentScope.VECTOR) | ||
| if len(segments) == 0: | ||
| return 0 | ||
| size = get_size(os.path.join(self._system.settings.require("persist_directory"), str(segments[0]["id"]))) | ||
| return size | ||
|
|
||
|
|
||
| def _cleanup_segment(self, collection_id: UUID, target_size: int): | ||
| segment_sizes = {id: self._get_segment_disk_size(id) for id in self._segment_cache if SegmentScope.VECTOR in self._segment_cache[id]} | ||
| total_size = sum(segment_sizes.values()) | ||
| new_segment_size = self._get_segment_disk_size(collection_id) | ||
|
|
||
| while total_size + new_segment_size >= target_size and self._segment_cache.keys(): | ||
|
||
| oldest_key = min( | ||
| (k for k in self._segment_cache if SegmentScope.VECTOR in self._segment_cache[k]), | ||
| key=lambda k: self._segment_cache[k][SegmentScope.VECTOR]["last_used"], | ||
| default=None | ||
| ) | ||
|
|
||
| if oldest_key is not None: | ||
| # Stop the instance and remove from cache | ||
| instance = self._instance(self._segment_cache[oldest_key][SegmentScope.VECTOR]) | ||
| instance.stop() | ||
| # Update total_size and remove the segment from cache and sizes dictionary | ||
| total_size -= segment_sizes[oldest_key] | ||
| del segment_sizes[oldest_key] | ||
| del self._segment_cache[oldest_key] | ||
| else: | ||
| break | ||
|
|
||
|
|
||
| @override | ||
| def get_segment(self, collection_id: UUID, type: Type[S]) -> S: | ||
|
|
||
| if type == MetadataReader: | ||
| scope = SegmentScope.METADATA | ||
| elif type == VectorReader: | ||
| scope = SegmentScope.VECTOR | ||
| else: | ||
| raise ValueError(f"Invalid segment type: {type}") | ||
|
|
||
| if scope not in self._segment_cache[collection_id]: | ||
| memory_limit = self._system.settings.require("chroma_memory_limit_bytes") | ||
| if type == VectorReader and self._system.settings.require("is_persistent") and memory_limit > 0: | ||
| self._cleanup_segment(collection_id, memory_limit) | ||
|
||
| segments = self._sysdb.get_segments(collection=collection_id, scope=scope) | ||
| known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()]) | ||
| # Get the first segment of a known type | ||
|
|
@@ -158,6 +207,7 @@ def get_segment(self, collection_id: UUID, type: Type[S]) -> S: | |
|
|
||
| # Instances must be atomically created, so we use a lock to ensure that only one thread | ||
| # creates the instance. | ||
| self._segment_cache[collection_id][scope]["last_used"] = time.time() | ||
| with self._lock: | ||
| instance = self._instance(self._segment_cache[collection_id][scope]) | ||
| return cast(S, instance) | ||
|
|
@@ -209,4 +259,5 @@ def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) -> | |
| topic=collection["topic"], | ||
| collection=collection["id"], | ||
| metadata=metadata, | ||
| last_used=0 | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| import uuid | ||
|
|
||
| import pytest | ||
| import chromadb.test.property.strategies as strategies | ||
| from unittest.mock import patch | ||
| from dataclasses import asdict | ||
| import random | ||
| from hypothesis.stateful import ( | ||
| Bundle, | ||
| RuleBasedStateMachine, | ||
| rule, | ||
| initialize, | ||
| multiple, | ||
| precondition, | ||
| invariant, | ||
| run_state_machine_as_test, | ||
| MultipleResults, | ||
| ) | ||
| from typing import Dict | ||
| from chromadb.segment import ( | ||
| VectorReader | ||
| ) | ||
| from chromadb.segment import SegmentManager | ||
|
|
||
| from chromadb.segment.impl.manager.local import LocalSegmentManager | ||
| from chromadb.types import SegmentScope | ||
| from chromadb.db.system import SysDB | ||
| from chromadb.config import System, get_class | ||
|
|
||
| memory_limit = 100 | ||
HammadB marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class LastUse: | ||
| def __init__(self, n: int): | ||
| self.n = n | ||
| self.store = [] | ||
|
|
||
| def add(self, id: uuid.UUID): | ||
| # Check if new_id is already in the list | ||
| if id in self.store: | ||
| # Move the existing ID to the end of the list | ||
| self.store.remove(id) | ||
| self.store.append(id) | ||
| else: | ||
| # Add new_id to the list | ||
| self.store.append(id) | ||
| # Keep only the last N IDs | ||
| while len(self.store) > self.n: | ||
| self.store.pop(0) | ||
| return self.store | ||
|
|
||
| def reset(self): | ||
| self.store = [] | ||
|
|
||
|
|
||
| class SegmentManagerStateMachine(RuleBasedStateMachine): | ||
| collections: Bundle[strategies.Collection] | ||
| collections = Bundle("collections") | ||
| collection_size_store: Dict[uuid.UUID, int] = {} | ||
|
|
||
| def __init__(self, system: System): | ||
| super().__init__() | ||
| self.segment_manager = system.require(SegmentManager) | ||
| self.segment_manager.start() | ||
| self.segment_manager.reset_state() | ||
| self.last_use = LastUse(n=40) | ||
| self.collection_created_counter = 0 | ||
| self.sysdb = system.require(SysDB) | ||
| self.system = system | ||
|
|
||
| @invariant() | ||
| def last_queried_segments_should_be_in_cache(self): | ||
| cache_sum = 0 | ||
| index = 0 | ||
| for id in reversed(self.last_use.store): | ||
| cache_sum += self.collection_size_store[id] | ||
| if cache_sum >= memory_limit and index is not 0: | ||
| break | ||
| assert self.segment_manager._segment_cache[id][SegmentScope.VECTOR] is not None | ||
| index += 1 | ||
|
|
||
| @invariant() | ||
| @precondition(lambda self: self.system.settings.is_persistent is True) | ||
| def cache_should_not_be_bigger_than_settings(self): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. whats the behavior for boundary conditions? Eg. if the limit is 10GB and we have two files - 6GB and 7GB, will we always only allow one?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it this what we want? We could add a message in log when a collection got evicted for memory constraint
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah lets log |
||
| segment_sizes = {id: self.collection_size_store[id] for id in self.segment_manager._segment_cache} | ||
| total_size = sum(segment_sizes.values()) | ||
| if len(segment_sizes) != 1: | ||
| assert total_size <= memory_limit | ||
|
|
||
| @initialize() | ||
| def initialize(self) -> None: | ||
| self.segment_manager.reset_state() | ||
| self.segment_manager.start() | ||
| self.collection_created_counter = 0 | ||
| self.last_use.reset() | ||
|
|
||
| @rule(target=collections, coll=strategies.collections()) | ||
| @precondition(lambda self: self.collection_created_counter <= 50) | ||
| def create_segment( | ||
| self, coll: strategies.Collection | ||
| ) -> MultipleResults[strategies.Collection]: | ||
| segments = self.segment_manager.create_segments(asdict(coll)) | ||
| for segment in segments: | ||
| self.sysdb.create_segment(segment) | ||
| self.collection_created_counter += 1 | ||
| self.collection_size_store[coll.id] = random.randint(0, memory_limit) | ||
| return multiple(coll) | ||
|
|
||
| @rule(coll=collections) | ||
| def get_segment(self, coll: strategies.Collection) -> None: | ||
| segment = self.segment_manager.get_segment(collection_id=coll.id, type=VectorReader) | ||
| self.last_use.add(coll.id) | ||
| assert segment is not None | ||
|
|
||
| @staticmethod | ||
| def mock_collection_size(self, collection_id): | ||
| return SegmentManagerStateMachine.collection_size_store[collection_id] | ||
|
|
||
|
|
||
| @patch.object(LocalSegmentManager, '_get_segment_disk_size', SegmentManagerStateMachine.mock_collection_size) | ||
| def test_segment_manager(caplog: pytest.LogCaptureFixture, system: System) -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
| system.settings.chroma_memory_limit_bytes = memory_limit | ||
| run_state_machine_as_test( | ||
| lambda: SegmentManagerStateMachine(system=system)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would update the documentation.