Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions libs/community/langchain_community/vectorstores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import base64
import logging
import os
import uuid
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -125,6 +126,15 @@ def __init__(
)

self._embedding_function = embedding_function
# --- BEGIN ADDED CPU GUARD LOGIC ---
if (
collection_metadata
and "hnsw:num_threads" in collection_metadata
and collection_metadata["hnsw:num_threads"] > os.cpu_count()
):
collection_metadata["hnsw:num_threads"] = os.cpu_count()
# --- END ADDED LOGIC ---

self._collection = self._client.get_or_create_collection(
name=collection_name,
embedding_function=None,
Expand Down
80 changes: 80 additions & 0 deletions libs/community/tests/unit_tests/vectorstores/test_chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Unit tests for Chroma vectorstore."""
import os
import unittest.mock
import pytest
from unittest.mock import patch, MagicMock

from langchain_community.vectorstores import Chroma


class TestChromaCPUGuard:
"""Test CPU guard functionality in Chroma initialization."""

@patch('langchain_community.vectorstores.chroma.chromadb')
@patch('os.cpu_count')
def test_cpu_guard_caps_excessive_threads(self, mock_cpu_count, mock_chromadb):
"""Test that hnsw:num_threads is capped to CPU count when excessive."""
mock_cpu_count.return_value = 4
mock_client = MagicMock()
mock_chromadb.Client.return_value = mock_client
mock_chromadb.config.Settings.return_value = MagicMock()

collection_metadata = {"hnsw:num_threads": 8} # More than CPU count

Chroma(
collection_name="test_cpu_guard",
collection_metadata=collection_metadata
)

# Verify the metadata was modified
assert collection_metadata["hnsw:num_threads"] == 4

@patch('langchain_community.vectorstores.chroma.chromadb')
@patch('os.cpu_count')
def test_cpu_guard_preserves_valid_threads(self, mock_cpu_count, mock_chromadb):
"""Test that valid thread counts are preserved."""
mock_cpu_count.return_value = 8
mock_client = MagicMock()
mock_chromadb.Client.return_value = mock_client
mock_chromadb.config.Settings.return_value = MagicMock()

collection_metadata = {"hnsw:num_threads": 4} # Less than CPU count

Chroma(
collection_name="test_cpu_guard_valid",
collection_metadata=collection_metadata
)

# Verify the metadata was NOT modified
assert collection_metadata["hnsw:num_threads"] == 4

@patch('langchain_community.vectorstores.chroma.chromadb')
def test_cpu_guard_handles_none_metadata(self, mock_chromadb):
"""Test that None collection_metadata doesn't cause issues."""
mock_client = MagicMock()
mock_chromadb.Client.return_value = mock_client
mock_chromadb.config.Settings.return_value = MagicMock()

# Should not raise any exceptions
Chroma(
collection_name="test_none_metadata",
collection_metadata=None
)

@patch('langchain_community.vectorstores.chroma.chromadb')
def test_cpu_guard_handles_missing_threads_key(self, mock_chromadb):
"""Test metadata without hnsw:num_threads key."""
mock_client = MagicMock()
mock_chromadb.Client.return_value = mock_client
mock_chromadb.config.Settings.return_value = MagicMock()

collection_metadata = {"hnsw:space": "cosine"} # No num_threads key

Chroma(
collection_name="test_missing_key",
collection_metadata=collection_metadata
)

# Should remain unchanged
assert "hnsw:num_threads" not in collection_metadata
assert collection_metadata["hnsw:space"] == "cosine"