Skip to content

Commit 8c6d16a

Browse files
Fix Embedding Dimension Parameter Not Being Passed (mem0ai#2304)
1 parent dd1f298 commit 8c6d16a

File tree

4 files changed

+9
-9
lines changed

4 files changed

+9
-9
lines changed

mem0/embeddings/gemini.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@ def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]
2828
list: The embedding vector.
2929
"""
3030
text = text.replace("\n", " ")
31-
response = genai.embed_content(model=self.config.model, content=text)
31+
response = genai.embed_content(model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims)
3232
return response["embedding"]

mem0/embeddings/openai.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]
2929
list: The embedding vector.
3030
"""
3131
text = text.replace("\n", " ")
32-
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
32+
return self.client.embeddings.create(input=[text], model=self.config.model, dimensions = self.config.embedding_dims).data[0].embedding

tests/embeddings/test_gemini.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def mock_genai():
1212

1313
@pytest.fixture
1414
def config():
15-
return BaseEmbedderConfig(api_key="dummy_api_key", model="test_model")
15+
return BaseEmbedderConfig(api_key="dummy_api_key", model="test_model", embedding_dims=786)
1616

1717

1818
def test_embed_query(mock_genai, config):
@@ -25,4 +25,4 @@ def test_embed_query(mock_genai, config):
2525
embedding = embedder.embed(text)
2626

2727
assert embedding == [0.1, 0.2, 0.3, 0.4]
28-
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!")
28+
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786)

tests/embeddings/test_openai_embeddings.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_embed_default_model(mock_openai_client):
2121

2222
result = embedder.embed("Hello world")
2323

24-
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small")
24+
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536)
2525
assert result == [0.1, 0.2, 0.3]
2626

2727

@@ -35,7 +35,7 @@ def test_embed_custom_model(mock_openai_client):
3535
result = embedder.embed("Test embedding")
3636

3737
mock_openai_client.embeddings.create.assert_called_once_with(
38-
input=["Test embedding"], model="text-embedding-2-medium"
38+
input=["Test embedding"], model="text-embedding-2-medium", dimensions = 1024
3939
)
4040
assert result == [0.4, 0.5, 0.6]
4141

@@ -49,7 +49,7 @@ def test_embed_removes_newlines(mock_openai_client):
4949

5050
result = embedder.embed("Hello\nworld")
5151

52-
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small")
52+
mock_openai_client.embeddings.create.assert_called_once_with(input=["Hello world"], model="text-embedding-3-small", dimensions = 1536)
5353
assert result == [0.7, 0.8, 0.9]
5454

5555

@@ -63,7 +63,7 @@ def test_embed_without_api_key_env_var(mock_openai_client):
6363
result = embedder.embed("Testing API key")
6464

6565
mock_openai_client.embeddings.create.assert_called_once_with(
66-
input=["Testing API key"], model="text-embedding-3-small"
66+
input=["Testing API key"], model="text-embedding-3-small", dimensions = 1536
6767
)
6868
assert result == [1.0, 1.1, 1.2]
6969

@@ -79,6 +79,6 @@ def test_embed_uses_environment_api_key(mock_openai_client, monkeypatch):
7979
result = embedder.embed("Environment key test")
8080

8181
mock_openai_client.embeddings.create.assert_called_once_with(
82-
input=["Environment key test"], model="text-embedding-3-small"
82+
input=["Environment key test"], model="text-embedding-3-small", dimensions = 1536
8383
)
8484
assert result == [1.3, 1.4, 1.5]

0 commit comments

Comments
 (0)