Skip to content

Commit 64115ed

Browse files
Merge branch 'feat/RAAE-531/mistral-vectorizer' into feat/RAAE-517/default-float32
2 parents fb87288 + 84da8db commit 64115ed

File tree

5 files changed

+46
-45
lines changed

5 files changed

+46
-45
lines changed

docs/user_guide/vectorizers_04.ipynb

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
},
3232
{
3333
"cell_type": "code",
34-
"execution_count": 1,
34+
"execution_count": 2,
3535
"metadata": {},
3636
"outputs": [],
3737
"source": [
@@ -305,33 +305,25 @@
305305
},
306306
{
307307
"cell_type": "code",
308-
"execution_count": 6,
308+
"execution_count": 3,
309309
"metadata": {},
310310
"outputs": [
311-
{
312-
"name": "stderr",
313-
"output_type": "stream",
314-
"text": [
315-
"/Users/tyler.hutcherson/redis/redis-vl-python/.venv/lib/python3.9/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
316-
" return self.fget.__get__(instance, owner)()\n"
317-
]
318-
},
319311
{
320312
"data": {
321313
"text/plain": [
322-
"[0.00037810884532518685,\n",
323-
" -0.05080341175198555,\n",
324-
" -0.03514723479747772,\n",
325-
" -0.02325104922056198,\n",
326-
" -0.044158220291137695,\n",
327-
" 0.020487844944000244,\n",
328-
" 0.0014617963461205363,\n",
329-
" 0.031261757016181946,\n",
314+
"[0.0003780885017476976,\n",
315+
" -0.05080340430140495,\n",
316+
" -0.035147231072187424,\n",
317+
" -0.02325103059411049,\n",
318+
" -0.04415831342339516,\n",
319+
" 0.02048780582845211,\n",
320+
" 0.0014618589775636792,\n",
321+
" 0.03126184269785881,\n",
330322
" 0.05605152249336243,\n",
331-
" 0.018815357238054276]"
323+
" 0.018815429881215096]"
332324
]
333325
},
334-
"execution_count": 6,
326+
"execution_count": 3,
335327
"metadata": {},
336328
"output_type": "execute_result"
337329
}
@@ -532,14 +524,14 @@
532524
}
533525
],
534526
"source": [
535-
"# from redisvl.utils.vectorize import MistralAITextVectorizer\n",
527+
"from redisvl.utils.vectorize import MistralAITextVectorizer\n",
536528
"\n",
537-
"# mistral = MistralAITextVectorizer()\n",
529+
"mistral = MistralAITextVectorizer()\n",
538530
"\n",
539-
"# # embed a sentence using their asyncronous method\n",
540-
"# test = await mistral.aembed(\"This is a test sentence.\")\n",
541-
"# print(\"Vector dimensions: \", len(test))\n",
542-
"# print(test[:10])"
531+
"# embed a sentence using their asyncronous method\n",
532+
"test = await mistral.aembed(\"This is a test sentence.\")\n",
533+
"print(\"Vector dimensions: \", len(test))\n",
534+
"print(test[:10])"
543535
]
544536
},
545537
{
@@ -588,9 +580,17 @@
588580
},
589581
{
590582
"cell_type": "code",
591-
"execution_count": null,
583+
"execution_count": 3,
592584
"metadata": {},
593-
"outputs": [],
585+
"outputs": [
586+
{
587+
"name": "stdout",
588+
"output_type": "stream",
589+
"text": [
590+
"Vector dimensions: 1024\n"
591+
]
592+
}
593+
],
594594
"source": [
595595
"from redisvl.utils.vectorize import BedrockTextVectorizer\n",
596596
"\n",

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ sentence-transformers = { version = ">=2.2.2", optional = true }
3232
google-cloud-aiplatform = { version = ">=1.26", optional = true }
3333
protobuf = { version = ">=5.29.1,<6.0.0.dev0", optional = true }
3434
cohere = { version = ">=4.44", optional = true }
35-
mistralai = { version = ">=0.2.0", optional = true }
35+
mistralai = { version = ">=1.0.0", optional = true }
3636
boto3 = { version = ">=1.34.0", optional = true }
3737

3838
[tool.poetry.extras]

redisvl/utils/vectorize/text/mistral.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class MistralAITextVectorizer(BaseVectorizer):
4444
"""
4545

4646
_client: Any = PrivateAttr()
47-
_aclient: Any = PrivateAttr()
4847

4948
def __init__(
5049
self,
@@ -78,8 +77,7 @@ def _initialize_clients(self, api_config: Optional[Dict]):
7877
"""
7978
# Dynamic import of the mistralai module
8079
try:
81-
from mistralai.async_client import MistralAsyncClient
82-
from mistralai.client import MistralClient
80+
from mistralai import Mistral
8381
except ImportError:
8482
raise ImportError(
8583
"MistralAI vectorizer requires the mistralai library. \
@@ -97,13 +95,12 @@ def _initialize_clients(self, api_config: Optional[Dict]):
9795
environment variable."
9896
)
9997

100-
self._client = MistralClient(api_key=api_key)
101-
self._aclient = MistralAsyncClient(api_key=api_key)
98+
self._client = Mistral(api_key=api_key)
10299

103100
def _set_model_dims(self, model) -> int:
104101
try:
105102
embedding = (
106-
self._client.embeddings(model=model, input=["dimension test"])
103+
self._client.embeddings.create(model=model, inputs=["dimension test"])
107104
.data[0]
108105
.embedding
109106
)
@@ -153,7 +150,7 @@ def embed_many(
153150

154151
embeddings: List = []
155152
for batch in self.batchify(texts, batch_size, preprocess):
156-
response = self._client.embeddings(model=self.model, input=batch)
153+
response = self._client.embeddings.create(model=self.model, inputs=batch)
157154
embeddings += [
158155
self._process_embedding(r.embedding, as_buffer, dtype)
159156
for r in response.data
@@ -195,7 +192,7 @@ def embed(
195192

196193
dtype = kwargs.pop("dtype", self.dtype)
197194

198-
result = self._client.embeddings(model=self.model, input=[text])
195+
result = self._client.embeddings.create(model=self.model, inputs=[text])
199196
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
200197

201198
@retry(
@@ -237,7 +234,9 @@ async def aembed_many(
237234

238235
embeddings: List = []
239236
for batch in self.batchify(texts, batch_size, preprocess):
240-
response = await self._aclient.embeddings(model=self.model, input=batch)
237+
response = await self._client.embeddings.create_async(
238+
model=self.model, inputs=batch
239+
)
241240
embeddings += [
242241
self._process_embedding(r.embedding, as_buffer, dtype)
243242
for r in response.data
@@ -279,7 +278,9 @@ async def aembed(
279278

280279
dtype = kwargs.pop("dtype", self.dtype)
281280

282-
result = await self._aclient.embeddings(model=self.model, input=[text])
281+
result = await self._client.embeddings.create_async(
282+
model=self.model, inputs=[text]
283+
)
283284
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
284285

285286
@property

tests/integration/test_vectorizers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def skip_vectorizer() -> bool:
2828
CohereTextVectorizer,
2929
AzureOpenAITextVectorizer,
3030
BedrockTextVectorizer,
31-
# MistralAITextVectorizer,
31+
MistralAITextVectorizer,
3232
CustomTextVectorizer,
3333
]
3434
)
@@ -299,7 +299,7 @@ def test_dtypes(vector_class, skip_vectorizer):
299299
params=[
300300
OpenAITextVectorizer,
301301
BedrockTextVectorizer,
302-
# MistralAITextVectorizer,
302+
MistralAITextVectorizer,
303303
CustomTextVectorizer,
304304
]
305305
)

0 commit comments

Comments
 (0)