diff --git a/Cargo.toml b/Cargo.toml index 839c9977..2caed693 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,8 +17,9 @@ tonic = { version = "0.8.3", features = ["tls", "tls-roots"] } prost = "0.11.8" prost-types = "0.11.8" anyhow = "1" -reqwest = { version = "0.11.14", optional = true, features = ["stream"] } +reqwest = { version = "0.11.14", optional = true, features = ["stream", "multipart"] } futures-util = { version = "0.3.27", optional = true } +tokio = {version = "1.26.0", features = ["fs"]} [build-dependencies] tonic-build = { version = "0.8.4", features = ["prost"] } diff --git a/src/client.rs b/src/client.rs index 4ecbbd49..e3a40e98 100644 --- a/src/client.rs +++ b/src/client.rs @@ -9,8 +9,26 @@ use crate::qdrant::snapshots_client::SnapshotsClient; use crate::qdrant::value::Kind; use crate::qdrant::vectors::VectorsOptions; use crate::qdrant::with_payload_selector::SelectorOptions; -use crate::qdrant::{qdrant_client, with_vectors_selector, AliasOperations, ChangeAliases, ClearPayloadPoints, CollectionOperationResponse, Condition, CountPoints, CountResponse, CreateAlias, CreateCollection, CreateFieldIndexCollection, CreateFullSnapshotRequest, CreateSnapshotRequest, CreateSnapshotResponse, DeleteAlias, DeleteCollection, DeleteFieldIndexCollection, DeleteFullSnapshotRequest, DeletePayloadPoints, DeletePoints, DeleteSnapshotRequest, DeleteSnapshotResponse, FieldCondition, FieldType, Filter, GetCollectionInfoRequest, GetCollectionInfoResponse, GetPoints, GetResponse, HasIdCondition, HealthCheckReply, HealthCheckRequest, IsEmptyCondition, ListAliasesRequest, ListAliasesResponse, ListCollectionAliasesRequest, ListCollectionsRequest, ListCollectionsResponse, ListFullSnapshotsRequest, ListSnapshotsRequest, ListSnapshotsResponse, ListValue, NamedVectors, OptimizersConfigDiff, PayloadIncludeSelector, PayloadIndexParams, PointId, PointStruct, PointsIdsList, PointsOperationResponse, PointsSelector, RecommendBatchPoints, RecommendBatchResponse, RecommendPoints, RecommendResponse, RenameAlias, ScrollPoints, ScrollResponse, SearchBatchPoints, SearchBatchResponse, SearchPoints, SearchResponse, SetPayloadPoints, Struct, UpdateCollection, UpsertPoints, Value, Vector, Vectors, VectorsSelector, WithPayloadSelector, WithVectorsSelector, WriteOrdering, ReadConsistency}; +use crate::qdrant::{ + qdrant_client, with_vectors_selector, AliasOperations, ChangeAliases, ClearPayloadPoints, + CollectionOperationResponse, Condition, CountPoints, CountResponse, CreateAlias, + CreateCollection, CreateFieldIndexCollection, CreateFullSnapshotRequest, CreateSnapshotRequest, + CreateSnapshotResponse, DeleteAlias, DeleteCollection, DeleteFieldIndexCollection, + DeleteFullSnapshotRequest, DeletePayloadPoints, DeletePoints, DeleteSnapshotRequest, + DeleteSnapshotResponse, FieldCondition, FieldType, Filter, GetCollectionInfoRequest, + GetCollectionInfoResponse, GetPoints, GetResponse, HasIdCondition, HealthCheckReply, + HealthCheckRequest, IsEmptyCondition, ListAliasesRequest, ListAliasesResponse, + ListCollectionAliasesRequest, ListCollectionsRequest, ListCollectionsResponse, + ListFullSnapshotsRequest, ListSnapshotsRequest, ListSnapshotsResponse, ListValue, NamedVectors, + OptimizersConfigDiff, PayloadIncludeSelector, PayloadIndexParams, PointId, PointStruct, + PointsIdsList, PointsOperationResponse, PointsSelector, ReadConsistency, RecommendBatchPoints, + RecommendBatchResponse, RecommendPoints, RecommendResponse, RenameAlias, ScrollPoints, + ScrollResponse, SearchBatchPoints, SearchBatchResponse, SearchPoints, SearchResponse, + SetPayloadPoints, Struct, UpdateCollection, UpsertPoints, Value, Vector, Vectors, + VectorsSelector, WithPayloadSelector, WithVectorsSelector, WriteOrdering, +}; use anyhow::{bail, Result}; +use reqwest::multipart::{Form, Part}; use std::collections::HashMap; use std::future::Future; use std::path::PathBuf; @@ -30,7 +48,10 @@ pub struct QdrantClientConfig { impl QdrantClientConfig { pub fn from_url(url: &str) -> Self { - QdrantClientConfig { uri: url.to_string(), ..Self::default() } + QdrantClientConfig { + uri: url.to_string(), + ..Self::default() + } } pub fn set_api_key(&mut self, api_key: &str) { @@ -499,7 +520,8 @@ impl QdrantClient { points: Vec, ordering: Option, ) -> Result { - self._upsert_points(collection_name, &points, false, ordering).await + self._upsert_points(collection_name, &points, false, ordering) + .await } pub async fn upsert_points_blocking( @@ -508,7 +530,8 @@ impl QdrantClient { points: Vec, ordering: Option, ) -> Result { - self._upsert_points(collection_name, &points, true, ordering).await + self._upsert_points(collection_name, &points, true, ordering) + .await } #[inline] @@ -517,7 +540,7 @@ impl QdrantClient { collection_name: impl ToString, points: &Vec, block: bool, - ordering: Option + ordering: Option, ) -> Result { let collection_name = collection_name.to_string(); let collection_name_ref = collection_name.as_str(); @@ -554,7 +577,7 @@ impl QdrantClient { collection_name: impl ToString, points: &PointsSelector, payload: Payload, - ordering: Option + ordering: Option, ) -> Result { self._set_payload(collection_name, points, &payload, true, ordering) .await @@ -594,7 +617,7 @@ impl QdrantClient { collection_name: impl ToString, points: &PointsSelector, payload: Payload, - ordering: Option + ordering: Option, ) -> Result { self._overwrite_payload(collection_name, points, &payload, false, ordering) .await @@ -618,7 +641,7 @@ impl QdrantClient { points: &PointsSelector, payload: &Payload, block: bool, - ordering: Option + ordering: Option, ) -> Result { let collection_name = collection_name.to_string(); let collection_name_ref = collection_name.as_str(); @@ -800,7 +823,8 @@ impl QdrantClient { points: &PointsSelector, ordering: Option, ) -> Result { - self._delete_points(collection_name, false, points, ordering).await + self._delete_points(collection_name, false, points, ordering) + .await } pub async fn delete_points_blocking( @@ -809,7 +833,8 @@ impl QdrantClient { points: &PointsSelector, ordering: Option, ) -> Result { - self._delete_points(collection_name, true, points, ordering).await + self._delete_points(collection_name, true, points, ordering) + .await } async fn _delete_points( @@ -1105,6 +1130,7 @@ impl QdrantClient { { use futures_util::StreamExt; use std::io::Write; + use tokio::io::AsyncWriteExt; let snapshot_name = match snapshot_name { Some(sn) => sn.to_string(), @@ -1134,15 +1160,53 @@ impl QdrantClient { .bytes_stream(); let out_path = out_path.into(); - let _ = std::fs::remove_file(&out_path); - let mut file = std::fs::OpenOptions::new() + let _ = tokio::fs::remove_file(&out_path).await; + let mut file = tokio::fs::OpenOptions::new() .write(true) .create_new(true) - .open(out_path)?; + .open(out_path) + .await?; while let Some(chunk) = stream.next().await { - let _written = file.write(&chunk?)?; + let _written = file.write(&chunk?).await?; + } + + Ok(()) + } + + pub async fn upload_snapshot( + &self, + in_path: impl Into, + collection_name: T, + rest_api_uri: Option, + ) -> Result<()> + where + T: ToString + Clone, + { + let snapshot_path: PathBuf = in_path.into(); + let filename = snapshot_path + .file_name() + .and_then(|name| name.to_str().map(|name| name.to_string())); + + let snapshot_file = tokio::fs::read(snapshot_path).await?; + + let mut part = Part::bytes(snapshot_file); + if let Some(filename) = filename { + part = part.file_name(filename); } + let form = Form::new().part("snapshot", part); + let client = reqwest::Client::new(); + client + .post(format!( + "{}/collections/{}/snapshots/upload", + rest_api_uri + .map(|uri| uri.to_string()) + .unwrap_or_else(|| String::from("http://localhost:6333")), + collection_name.to_string() + )) + .multipart(form) + .send() + .await?; Ok(()) }