Skip to content

Commit 87db241

Browse files
migrates mistral text vectorizer to new mistral client
1 parent 201d676 commit 87db241

File tree

3 files changed

+31
-27
lines changed

3 files changed

+31
-27
lines changed

docs/user_guide/vectorizers_04.ipynb

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -532,14 +532,14 @@
532532
}
533533
],
534534
"source": [
535-
"# from redisvl.utils.vectorize import MistralAITextVectorizer\n",
535+
"from redisvl.utils.vectorize import MistralAITextVectorizer\n",
536536
"\n",
537-
"# mistral = MistralAITextVectorizer()\n",
537+
"mistral = MistralAITextVectorizer()\n",
538538
"\n",
539-
"# # embed a sentence using their asyncronous method\n",
540-
"# test = await mistral.aembed(\"This is a test sentence.\")\n",
541-
"# print(\"Vector dimensions: \", len(test))\n",
542-
"# print(test[:10])"
539+
"# embed a sentence using their asyncronous method\n",
540+
"test = await mistral.aembed(\"This is a test sentence.\")\n",
541+
"print(\"Vector dimensions: \", len(test))\n",
542+
"print(test[:10])"
543543
]
544544
},
545545
{
@@ -588,9 +588,17 @@
588588
},
589589
{
590590
"cell_type": "code",
591-
"execution_count": null,
591+
"execution_count": 3,
592592
"metadata": {},
593-
"outputs": [],
593+
"outputs": [
594+
{
595+
"name": "stdout",
596+
"output_type": "stream",
597+
"text": [
598+
"Vector dimensions: 1024\n"
599+
]
600+
}
601+
],
594602
"source": [
595603
"from redisvl.utils.vectorize import BedrockTextVectorizer\n",
596604
"\n",
@@ -836,7 +844,7 @@
836844
],
837845
"metadata": {
838846
"kernelspec": {
839-
"display_name": "Python 3.8.13 ('redisvl2')",
847+
"display_name": "redisvl-dev",
840848
"language": "python",
841849
"name": "python3"
842850
},
@@ -852,12 +860,7 @@
852860
"pygments_lexer": "ipython3",
853861
"version": "3.12.2"
854862
},
855-
"orig_nbformat": 4,
856-
"vscode": {
857-
"interpreter": {
858-
"hash": "9b1e6e9c2967143209c2f955cb869d1d3234f92dc4787f49f155f3abbdfb1316"
859-
}
860-
}
863+
"orig_nbformat": 4
861864
},
862865
"nbformat": 4,
863866
"nbformat_minor": 2

redisvl/utils/vectorize/text/mistral.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class MistralAITextVectorizer(BaseVectorizer):
4444
"""
4545

4646
_client: Any = PrivateAttr()
47-
_aclient: Any = PrivateAttr()
4847

4948
def __init__(self, model: str = "mistral-embed", api_config: Optional[Dict] = None):
5049
"""Initialize the MistralAI vectorizer.
@@ -69,8 +68,7 @@ def _initialize_clients(self, api_config: Optional[Dict]):
6968
"""
7069
# Dynamic import of the mistralai module
7170
try:
72-
from mistralai.async_client import MistralAsyncClient
73-
from mistralai.client import MistralClient
71+
from mistralai import Mistral
7472
except ImportError:
7573
raise ImportError(
7674
"MistralAI vectorizer requires the mistralai library. \
@@ -88,13 +86,12 @@ def _initialize_clients(self, api_config: Optional[Dict]):
8886
environment variable."
8987
)
9088

91-
self._client = MistralClient(api_key=api_key)
92-
self._aclient = MistralAsyncClient(api_key=api_key)
89+
self._client = Mistral(api_key=api_key)
9390

9491
def _set_model_dims(self, model) -> int:
9592
try:
9693
embedding = (
97-
self._client.embeddings(model=model, input=["dimension test"])
94+
self._client.embeddings.create(model=model, inputs=["dimension test"])
9895
.data[0]
9996
.embedding
10097
)
@@ -144,7 +141,7 @@ def embed_many(
144141

145142
embeddings: List = []
146143
for batch in self.batchify(texts, batch_size, preprocess):
147-
response = self._client.embeddings(model=self.model, input=batch)
144+
response = self._client.embeddings.create(model=self.model, inputs=batch)
148145
embeddings += [
149146
self._process_embedding(r.embedding, as_buffer, dtype)
150147
for r in response.data
@@ -186,7 +183,7 @@ def embed(
186183

187184
dtype = kwargs.pop("dtype", None)
188185

189-
result = self._client.embeddings(model=self.model, input=[text])
186+
result = self._client.embeddings.create(model=self.model, inputs=[text])
190187
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
191188

192189
@retry(
@@ -228,7 +225,9 @@ async def aembed_many(
228225

229226
embeddings: List = []
230227
for batch in self.batchify(texts, batch_size, preprocess):
231-
response = await self._aclient.embeddings(model=self.model, input=batch)
228+
response = await self._client.embeddings.create_async(
229+
model=self.model, inputs=batch
230+
)
232231
embeddings += [
233232
self._process_embedding(r.embedding, as_buffer, dtype)
234233
for r in response.data
@@ -270,7 +269,9 @@ async def aembed(
270269

271270
dtype = kwargs.pop("dtype", None)
272271

273-
result = await self._aclient.embeddings(model=self.model, input=[text])
272+
result = await self._client.embeddings.create_async(
273+
model=self.model, inputs=[text]
274+
)
274275
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
275276

276277
@property

tests/integration/test_vectorizers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def skip_vectorizer() -> bool:
2828
CohereTextVectorizer,
2929
AzureOpenAITextVectorizer,
3030
BedrockTextVectorizer,
31-
# MistralAITextVectorizer,
31+
MistralAITextVectorizer,
3232
CustomTextVectorizer,
3333
]
3434
)
@@ -242,7 +242,7 @@ def bad_return_type(text: str) -> str:
242242
params=[
243243
OpenAITextVectorizer,
244244
BedrockTextVectorizer,
245-
# MistralAITextVectorizer,
245+
MistralAITextVectorizer,
246246
CustomTextVectorizer,
247247
]
248248
)

0 commit comments

Comments
 (0)