diff --git a/examples/audio/create_transcription/src/main.rs b/examples/audio/create_transcription/src/main.rs index 518125f2..c70b410c 100644 --- a/examples/audio/create_transcription/src/main.rs +++ b/examples/audio/create_transcription/src/main.rs @@ -1,5 +1,5 @@ use openai_dive::v1::api::Client; -use openai_dive::v1::models::WhisperModel; +use openai_dive::v1::models::TranscriptionModel; use openai_dive::v1::resources::audio::{AudioOutputFormat, AudioTranscriptionParametersBuilder}; use openai_dive::v1::resources::shared::FileUpload; @@ -9,7 +9,7 @@ async fn main() { let parameters = AudioTranscriptionParametersBuilder::default() .file(FileUpload::File("./audio/micro-machines.mp3".to_string())) - .model(WhisperModel::Whisper1.to_string()) + .model(TranscriptionModel::Whisper1.to_string()) .response_format(AudioOutputFormat::VerboseJson) .build() .unwrap(); diff --git a/examples/audio/create_translation/src/main.rs b/examples/audio/create_translation/src/main.rs index 9c7b602b..e9351184 100644 --- a/examples/audio/create_translation/src/main.rs +++ b/examples/audio/create_translation/src/main.rs @@ -1,5 +1,5 @@ use openai_dive::v1::api::Client; -use openai_dive::v1::models::WhisperModel; +use openai_dive::v1::models::TranscriptionModel; use openai_dive::v1::resources::audio::{AudioOutputFormat, AudioTranslationParametersBuilder}; use openai_dive::v1::resources::shared::FileUpload; @@ -9,7 +9,7 @@ async fn main() { let parameters = AudioTranslationParametersBuilder::default() .file(FileUpload::File("./audio/multilingual.mp3".to_string())) - .model(WhisperModel::Whisper1.to_string()) + .model(TranscriptionModel::Whisper1.to_string()) .response_format(AudioOutputFormat::Srt) .build() .unwrap(); diff --git a/examples/chat/create_chat_completion/src/main.rs b/examples/chat/create_chat_completion/src/main.rs index c316aa14..277d0f7d 100644 --- a/examples/chat/create_chat_completion/src/main.rs +++ b/examples/chat/create_chat_completion/src/main.rs @@ -14,15 +14,15 @@ async fn main() -> Result<(), Box> { let client = Client::new_from_env(); let parameters = ChatCompletionParametersBuilder::default() - .model(ReasoningModel::O1Mini.to_string()) + .model(ReasoningModel::O3Mini.to_string()) .messages(vec![ ChatMessage::User { content: ChatMessageContent::Text("Hello!".to_string()), - name: None, + name: Some("Judy".to_string()), }, ChatMessage::User { - content: ChatMessageContent::Text("What is the capital of Vietnam?".to_string()), - name: None, + content: ChatMessageContent::Text("What is the capital of Singapore?".to_string()), + name: Some("Judy".to_string()), }, ]) .response_format(ChatCompletionResponseFormat::Text) @@ -32,5 +32,11 @@ async fn main() -> Result<(), Box> { println!("{:#?}", result); + for choice in &result.choices { + if let Some(text) = choice.message.text() { + println!("{}", text); + } + } + Ok(()) } diff --git a/examples/chat/structured_outputs/src/main.rs b/examples/chat/structured_outputs/src/main.rs index 02fa3c62..f99d854e 100644 --- a/examples/chat/structured_outputs/src/main.rs +++ b/examples/chat/structured_outputs/src/main.rs @@ -26,31 +26,33 @@ async fn main() -> Result<(), Box> { name: None, }, ]) - .response_format(ChatCompletionResponseFormat::JsonSchema(JsonSchemaBuilder::default() - .name("math_reasoning") - .schema(serde_json::json!({ - "type": "object", - "properties": { - "steps": { - "type": "array", - "items": { - "type": "object", - "properties": { - "explanation": { "type": "string" }, - "output": { "type": "string" } - }, - "required": ["explanation", "output"], - "additionalProperties": false - } + .response_format(ChatCompletionResponseFormat::JsonSchema { + json_schema: JsonSchemaBuilder::default() + .name("math_reasoning") + .schema(serde_json::json!({ + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation", "output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } }, - "final_answer": { "type": "string" } - }, - "required": ["steps", "final_answer"], - "additionalProperties": false - })) - .strict(true) - .build()? - )) + "required": ["steps", "final_answer"], + "additionalProperties": false + })) + .strict(true) + .build()? + } + ) .build()?; let result = client.chat().create(parameters).await?; diff --git a/examples/images/create_image/src/main.rs b/examples/images/create_image/src/main.rs index 5c2c304b..4a9761bc 100644 --- a/examples/images/create_image/src/main.rs +++ b/examples/images/create_image/src/main.rs @@ -1,5 +1,5 @@ use openai_dive::v1::api::Client; -use openai_dive::v1::models::DallEModel; +use openai_dive::v1::models::ImageModel; use openai_dive::v1::resources::image::{ CreateImageParametersBuilder, ImageQuality, ImageSize, ImageStyle, ResponseFormat, }; @@ -10,7 +10,7 @@ async fn main() { let parameters = CreateImageParametersBuilder::default() .prompt("A cute dog in the park") - .model(DallEModel::DallE3.to_string()) + .model(ImageModel::DallE3.to_string()) .n(1u32) .quality(ImageQuality::Standard) .response_format(ResponseFormat::Url) diff --git a/examples/images/create_image_edit/src/main.rs b/examples/images/create_image_edit/src/main.rs index 3d2f52e4..e1f81a1e 100644 --- a/examples/images/create_image_edit/src/main.rs +++ b/examples/images/create_image_edit/src/main.rs @@ -1,5 +1,5 @@ use openai_dive::v1::api::Client; -use openai_dive::v1::resources::image::{EditImageParametersBuilder, ImageSize}; +use openai_dive::v1::resources::image::{EditImageParametersBuilder, ImageSize, MimeType}; use openai_dive::v1::resources::shared::FileUpload; #[tokio::main] @@ -13,6 +13,7 @@ async fn main() { .prompt("A cute baby sea otter") .mask(FileUpload::File("./images/image_edit_mask.png".to_string())) .n(1u32) + .mime_type(MimeType::Png) .size(ImageSize::Size512X512) .build() .unwrap(); diff --git a/examples/images/create_multiple_images/Cargo.toml b/examples/images/create_multiple_images/Cargo.toml new file mode 100644 index 00000000..1784b027 --- /dev/null +++ b/examples/images/create_multiple_images/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "create_multiple_images" +version = "0.1.0" +edition = "2021" + +[dependencies] +openai_dive = { path = "./../../../openai_dive" } +tokio = { version = "1.0", features = ["full"] } diff --git a/examples/images/create_multiple_images/images/image_edit_original.png b/examples/images/create_multiple_images/images/image_edit_original.png new file mode 100644 index 00000000..c62d8c67 Binary files /dev/null and b/examples/images/create_multiple_images/images/image_edit_original.png differ diff --git a/examples/images/create_multiple_images/images/person.png b/examples/images/create_multiple_images/images/person.png new file mode 100644 index 00000000..94f8b694 Binary files /dev/null and b/examples/images/create_multiple_images/images/person.png differ diff --git a/examples/images/create_multiple_images/src/main.rs b/examples/images/create_multiple_images/src/main.rs new file mode 100644 index 00000000..70ae4545 --- /dev/null +++ b/examples/images/create_multiple_images/src/main.rs @@ -0,0 +1,44 @@ +use openai_dive::v1::api::Client; +use openai_dive::v1::models::ImageModel; +use openai_dive::v1::resources::image::{ + EditImageParametersBuilder, ImageQuality, ImageSize, MimeType, +}; +use openai_dive::v1::resources::shared::FileUpload; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::new_from_env(); + + let parameters = EditImageParametersBuilder::default() + .prompt("Make this person smile with full teeth") + .image(FileUpload::File("./images/person.png".to_string())) + .model(ImageModel::GptImage1.to_string()) + .quality(ImageQuality::Low) + .mime_type(MimeType::Png) + .n(1u32) + .size(ImageSize::Size1024X1024) + .build()?; + + let result = client.images().edit(parameters).await?; + + println!("{:#?}", result); + + let parameters = EditImageParametersBuilder::default() + .prompt("Combine the person into the orginal image") + .image(FileUpload::FileArray(vec![ + "./images/image_edit_original.png".to_string(), + "./images/person.png".to_string(), + ])) + .model("gpt-image-1") + .quality(ImageQuality::Low) + .mime_type(MimeType::Png) + .n(1u32) + .size(ImageSize::Size1024X1024) + .build()?; + + let result = client.images().edit(parameters).await?; + + println!("{:#?}", result); + + Ok(()) +} diff --git a/examples/responses/functions/src/main.rs b/examples/responses/functions/src/main.rs index d41cf3f6..553b86b0 100644 --- a/examples/responses/functions/src/main.rs +++ b/examples/responses/functions/src/main.rs @@ -1,7 +1,11 @@ use ftail::Ftail; use openai_dive::v1::api::Client; use openai_dive::v1::models::CostOptimizedModel; -use openai_dive::v1::resources::response::request::{ResponseInput, ResponseParametersBuilder}; +use openai_dive::v1::resources::response::items::{FunctionToolCallOutput, InputItemStatus}; +use openai_dive::v1::resources::response::request::{ + InputItem, ResponseInput, ResponseInputItem, ResponseParametersBuilder, +}; +use openai_dive::v1::resources::response::response::ResponseOutput; use openai_dive::v1::resources::response::shared::{ResponseTool, ResponseToolChoice}; #[tokio::main] @@ -44,5 +48,27 @@ async fn main() -> Result<(), Box> { println!("{:#?}", result); + let call = match &result.output[0] { + ResponseOutput::FunctionToolCall(call) => call, + _ => panic!("unexpected output"), + }; + + let parameters = ResponseParametersBuilder::default() + .model(CostOptimizedModel::Gpt4OMini.to_string()) + .input(ResponseInput::List(vec![ResponseInputItem::Item( + InputItem::FunctionToolCallOutput(FunctionToolCallOutput { + id: None, + call_id: call.call_id.clone(), + output: "{\"temperature_2m\":30,\"wind_speed_10m\":5}".to_string(), + status: InputItemStatus::Completed, + }), + )])) + .previous_response_id(result.id) + .build()?; + + let result = client.responses().create(parameters).await?; + + println!("{:#?}", result); + Ok(()) } diff --git a/openai_dive/Cargo.toml b/openai_dive/Cargo.toml index 52042e4d..435d0b6e 100644 --- a/openai_dive/Cargo.toml +++ b/openai_dive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "openai_dive" -version = "1.0.1" +version = "1.2.2" edition = "2021" license = "MIT" description = "OpenAI Dive is an unofficial async Rust library that allows you to interact with the OpenAI API." diff --git a/openai_dive/README.md b/openai_dive/README.md index dcd1ca41..3c487a83 100644 --- a/openai_dive/README.md +++ b/openai_dive/README.md @@ -9,7 +9,7 @@ OpenAI Dive is an unofficial async Rust library that allows you to interact with ```ini [dependencies] -openai_dive = "1.0" +openai_dive = "1.2" ``` ## Get started @@ -264,31 +264,33 @@ let parameters = ChatCompletionParametersBuilder::default() name: None, }, ]) - .response_format(ChatCompletionResponseFormat::JsonSchema(JsonSchemaBuilder::default() - .name("math_reasoning") - .schema(serde_json::json!({ - "type": "object", - "properties": { - "steps": { - "type": "array", - "items": { - "type": "object", - "properties": { - "explanation": { "type": "string" }, - "output": { "type": "string" } - }, - "required": ["explanation", "output"], - "additionalProperties": false - } + .response_format(ChatCompletionResponseFormat::JsonSchema { + json_schema: JsonSchemaBuilder::default() + .name("math_reasoning") + .schema(serde_json::json!({ + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation", "output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } }, - "final_answer": { "type": "string" } - }, - "required": ["steps", "final_answer"], - "additionalProperties": false - })) - .strict(true) - .build()? - )) + "required": ["steps", "final_answer"], + "additionalProperties": false + })) + .strict(true) + .build()? + } + ) .build()?; let result = client.chat().create(parameters).await?; @@ -534,6 +536,8 @@ let mut client = Client::new(deepseek_api_key); client.set_base_url("https://api.deepseek.com"); ``` +Use `extra_body` in `ChatCompletionParametersBuilder` to pass non-standard parameters supported by OpenAI-compatible APIs. + ### Set organization/project ID You can create multiple organizations and projects in the OpenAI platform. This allows you to group files, fine-tuned models and other resources. @@ -587,20 +591,20 @@ You can use these predefined constants to set the model in the parameters or use #### Flagship Models -- Gpt45Preview (`gpt-4.5-preview`) +- Gpt41 (`gpt-4.1`) - Gpt4O (`gpt-4o`) - Gpt4OAudioPreview (`gpt-4o-audio-preview`) #### Cost-Optimized Models +- O4Mini (`o4-mini`) +- Gpt41Nano (`gpt-4.1-nano`) - Gpt4OMini (`gpt-4o-mini`) -- Gpt4OMiniAudioPreview (`gpt-4o-mini-audio-preview`) #### Reasoning Models +- O4Mini (`o4-mini`) - O3Mini (`o3-mini`) -- O1 (`o1`) -- O1Mini (`o1-mini`) #### Tool Models @@ -611,24 +615,26 @@ You can use these predefined constants to set the model in the parameters or use #### Moderation Models - OmniModerationLatest (`omni-moderation-latest`) -- TextModerationLatest (`text-moderation-latest`) #### Embedding Models - TextEmbedding3Small (`text-embedding-3-small`) - TextEmbedding3Large (`text-embedding-3-large`) -#### Whisper Models +#### Transcription Models +- Gpt4OTranscribe (`gpt-4o-transcribe`) - Whisper1 (`whisper-1`) #### TTS Models +- Gpt4OMiniTts (`gpt-4o-mini-tts`) - Tts1 (`tts-1`) - Tts1HD (`tts-1-hd`) -#### DALL·E Models +#### Image Models +- GptImage1 (`gpt-image-1`) - DallE3 (`dall-e-3`) - DallE2 (`dall-e-2`) diff --git a/openai_dive/src/lib.rs b/openai_dive/src/lib.rs index e8acd480..4b53cac1 100644 --- a/openai_dive/src/lib.rs +++ b/openai_dive/src/lib.rs @@ -4,7 +4,7 @@ //! //! ```ini //! [dependencies] -//! openai_dive = "1.0" +//! openai_dive = "1.2" //! ``` //! //! ## Get started @@ -259,31 +259,33 @@ //! name: None, //! }, //! ]) -//! .response_format(ChatCompletionResponseFormat::JsonSchema(JsonSchemaBuilder::default() -//! .name("math_reasoning") -//! .schema(serde_json::json!({ -//! "type": "object", -//! "properties": { -//! "steps": { -//! "type": "array", -//! "items": { -//! "type": "object", -//! "properties": { -//! "explanation": { "type": "string" }, -//! "output": { "type": "string" } -//! }, -//! "required": ["explanation", "output"], -//! "additionalProperties": false -//! } +//! .response_format(ChatCompletionResponseFormat::JsonSchema { +//! json_schema: JsonSchemaBuilder::default() +//! .name("math_reasoning") +//! .schema(serde_json::json!({ +//! "type": "object", +//! "properties": { +//! "steps": { +//! "type": "array", +//! "items": { +//! "type": "object", +//! "properties": { +//! "explanation": { "type": "string" }, +//! "output": { "type": "string" } +//! }, +//! "required": ["explanation", "output"], +//! "additionalProperties": false +//! } +//! }, +//! "final_answer": { "type": "string" } //! }, -//! "final_answer": { "type": "string" } -//! }, -//! "required": ["steps", "final_answer"], -//! "additionalProperties": false -//! })) -//! .strict(true) -//! .build()? -//! )) +//! "required": ["steps", "final_answer"], +//! "additionalProperties": false +//! })) +//! .strict(true) +//! .build()? +//! } +//! ) //! .build()?; //! //! let result = client.chat().create(parameters).await?; @@ -529,6 +531,8 @@ //! client.set_base_url("https://api.deepseek.com"); //! ``` //! +//! Use `extra_body` in `ChatCompletionParametersBuilder` to pass non-standard parameters supported by OpenAI-compatible APIs. +//! //! ### Set organization/project ID //! //! You can create multiple organizations and projects in the OpenAI platform. This allows you to group files, fine-tuned models and other resources. @@ -582,20 +586,20 @@ //! //! #### Flagship Models //! -//! - Gpt45Preview (`gpt-4.5-preview`) +//! - Gpt41 (`gpt-4.1`) //! - Gpt4O (`gpt-4o`) //! - Gpt4OAudioPreview (`gpt-4o-audio-preview`) //! //! #### Cost-Optimized Models //! +//! - O4Mini (`o4-mini`) +//! - Gpt41Nano (`gpt-4.1-nano`) //! - Gpt4OMini (`gpt-4o-mini`) -//! - Gpt4OMiniAudioPreview (`gpt-4o-mini-audio-preview`) //! //! #### Reasoning Models //! +//! - O4Mini (`o4-mini`) //! - O3Mini (`o3-mini`) -//! - O1 (`o1`) -//! - O1Mini (`o1-mini`) //! //! #### Tool Models //! @@ -606,24 +610,26 @@ //! #### Moderation Models //! //! - OmniModerationLatest (`omni-moderation-latest`) -//! - TextModerationLatest (`text-moderation-latest`) //! //! #### Embedding Models //! //! - TextEmbedding3Small (`text-embedding-3-small`) //! - TextEmbedding3Large (`text-embedding-3-large`) //! -//! #### Whisper Models +//! #### Transcription Models //! +//! - Gpt4OTranscribe (`gpt-4o-transcribe`) //! - Whisper1 (`whisper-1`) //! //! #### TTS Models //! +//! - Gpt4OMiniTts (`gpt-4o-mini-tts`) //! - Tts1 (`tts-1`) //! - Tts1HD (`tts-1-hd`) //! -//! #### DALL·E Models +//! #### Image Models //! +//! - GptImage1 (`gpt-image-1`) //! - DallE3 (`dall-e-3`) //! - DallE2 (`dall-e-2`) //! diff --git a/openai_dive/src/v1/endpoints/images.rs b/openai_dive/src/v1/endpoints/images.rs index 126deb2a..42663cf1 100644 --- a/openai_dive/src/v1/endpoints/images.rs +++ b/openai_dive/src/v1/endpoints/images.rs @@ -4,6 +4,7 @@ use crate::v1::helpers::format_response; use crate::v1::resources::image::{ CreateImageParameters, CreateImageVariationParameters, EditImageParameters, ImageResponse, }; +use crate::v1::resources::shared::FileUpload; pub struct Images<'a> { pub client: &'a Client, @@ -33,11 +34,65 @@ impl Images<'_> { pub async fn edit(&self, parameters: EditImageParameters) -> Result { let mut form = reqwest::multipart::Form::new(); - let image = parameters.image.into_part().await?; - form = form.part("image", image); + let mime_type = parameters.mime_type; + + match parameters.image { + #[cfg(all(feature = "tokio", feature = "tokio-util"))] + FileUpload::File(_) => { + let mut image = parameters.image.into_part().await?; + + if let Some(ref mime_type) = mime_type { + image = image + .mime_str(&mime_type.to_string()) + .map_err(|error| APIError::FileError(error.to_string()))?; + } + form = form.part("image", image); + } + #[cfg(all(feature = "tokio", feature = "tokio-util"))] + FileUpload::FileArray(_) => { + let images = parameters.image.into_parts().await?; + for mut image in images { + if let Some(ref mime_type) = mime_type { + image = image + .mime_str(&mime_type.to_string()) + .map_err(|error| APIError::FileError(error.to_string()))?; + } + form = form.part("image[]", image); + } + } + FileUpload::Bytes(_) => { + let mut image = parameters.image.into_part().await?; + + if let Some(ref mime_type) = mime_type { + image = image + .mime_str(&mime_type.to_string()) + .map_err(|error| APIError::FileError(error.to_string()))?; + } + form = form.part("image", image); + } + FileUpload::BytesArray(_) => { + let images = parameters.image.into_parts().await?; + for mut image in images { + if let Some(ref mime_type) = mime_type { + image = image + .mime_str(&mime_type.to_string()) + .map_err(|error| APIError::FileError(error.to_string()))?; + } + form = form.part("image[]", image); + } + } + } form = form.text("prompt", parameters.prompt); + if let Some(background) = parameters.background { + form = form.text("background", background.to_string()); + } + + if let Some(quality) = parameters.quality { + form = form.text("quality", quality.to_string()); + } + if let Some(mask) = parameters.mask { let image = mask.into_part().await?; form = form.part("mask", image); diff --git a/openai_dive/src/v1/models.rs b/openai_dive/src/v1/models.rs index d6964815..a689265d 100644 --- a/openai_dive/src/v1/models.rs +++ b/openai_dive/src/v1/models.rs @@ -2,8 +2,8 @@ use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub enum FlagshipModel { - #[serde(rename = "gpt-4.5-preview")] - Gpt45Preview, + #[serde(rename = "gpt-4.1")] + Gpt41, #[serde(rename = "gpt-4o")] Gpt4O, #[serde(rename = "gpt-4o-audio-preview")] @@ -12,20 +12,20 @@ pub enum FlagshipModel { #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub enum CostOptimizedModel { + #[serde(rename = "o4-mini")] + O4Mini, + #[serde(rename = "gpt-4.1-nano")] + Gpt41Nano, #[serde(rename = "gpt-4o-mini")] Gpt4OMini, - #[serde(rename = "gpt-4o-mini-audio-preview")] - Gpt4OMiniAudioPreview, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub enum ReasoningModel { + #[serde(rename = "o4-mini")] + O4Mini, #[serde(rename = "o3-mini")] O3Mini, - #[serde(rename = "o1")] - O1, - #[serde(rename = "o1-mini")] - O1Mini, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] @@ -42,8 +42,6 @@ pub enum ToolModel { pub enum ModerationModel { #[serde(rename = "omni-moderation-latest")] OmniModerationLatest, - #[serde(rename = "text-moderation-latest")] - TextModerationLatest, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] @@ -55,13 +53,17 @@ pub enum EmbeddingModel { } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub enum WhisperModel { +pub enum TranscriptionModel { + #[serde(rename = "gpt-4o-transcribe")] + Gpt4OTranscribe, #[serde(rename = "whisper-1")] Whisper1, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub enum TTSModel { + #[serde(rename = "gpt-4o-mini-tts")] + Gpt4OMiniTts, #[serde(rename = "tts-1")] Tts1, #[serde(rename = "tts-1-hd")] @@ -69,7 +71,9 @@ pub enum TTSModel { } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub enum DallEModel { +pub enum ImageModel { + #[serde(rename = "gpt-image-1")] + GptImage1, #[serde(rename = "dall-e-3")] DallE3, #[serde(rename = "dall-e-2")] @@ -94,6 +98,6 @@ impl_display!(ReasoningModel); impl_display!(ToolModel); impl_display!(ModerationModel); impl_display!(EmbeddingModel); -impl_display!(WhisperModel); +impl_display!(TranscriptionModel); impl_display!(TTSModel); -impl_display!(DallEModel); +impl_display!(ImageModel); diff --git a/openai_dive/src/v1/resources/audio.rs b/openai_dive/src/v1/resources/audio.rs index 88497950..7f1c9e16 100644 --- a/openai_dive/src/v1/resources/audio.rs +++ b/openai_dive/src/v1/resources/audio.rs @@ -18,6 +18,9 @@ pub struct AudioSpeechParameters { pub input: String, /// The voice to use when generating the audio. pub voice: AudioVoice, + /// Control the voice of your generated audio with additional instructions. Does not work with tts-1 or tts-1-hd + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, /// The format to audio in. Supported formats are mp3, opus, aac, flac, wav and pcm. #[serde(skip_serializing_if = "Option::is_none")] pub response_format: Option, diff --git a/openai_dive/src/v1/resources/chat.rs b/openai_dive/src/v1/resources/chat.rs index ebb7778e..baea3a43 100644 --- a/openai_dive/src/v1/resources/chat.rs +++ b/openai_dive/src/v1/resources/chat.rs @@ -1,13 +1,12 @@ +use super::shared::{ReasoningEffort, WebSearchContextSize}; use crate::v1::resources::shared::StopToken; use crate::v1::resources::shared::{FinishReason, Usage}; use derive_builder::Builder; -use serde::ser::SerializeStruct; -use serde::{Deserialize, Serialize, Serializer}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; use std::collections::HashMap; use std::fmt::Display; -use super::shared::{ReasoningEffort, WebSearchContextSize}; - #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct ChatCompletionResponse { /// A unique identifier for the chat completion. @@ -158,6 +157,10 @@ pub struct ChatCompletionParameters { /// This tool searches the web for relevant results to use in a response. #[serde(skip_serializing_if = "Option::is_none")] pub web_search_options: Option, + /// Allows to pass arbitrary json as an extra_body parameter, for specific features/openai-compatible endpoints. + #[serde(flatten)] + #[serde(skip_serializing_if = "Option::is_none")] + pub extra_body: Option, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] @@ -192,38 +195,12 @@ pub struct ChatCompletionFunction { pub parameters: serde_json::Value, } -#[derive(Deserialize, Debug, Clone, PartialEq)] -#[serde(rename_all = "snake_case")] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] pub enum ChatCompletionResponseFormat { Text, JsonObject, - JsonSchema(JsonSchema), -} - -impl Serialize for ChatCompletionResponseFormat { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match self { - ChatCompletionResponseFormat::Text => { - let mut state = serializer.serialize_struct("ChatCompletionResponseFormat", 1)?; - state.serialize_field("type", "text")?; - state.end() - } - ChatCompletionResponseFormat::JsonObject => { - let mut state = serializer.serialize_struct("ChatCompletionResponseFormat", 1)?; - state.serialize_field("type", "json_object")?; - state.end() - } - ChatCompletionResponseFormat::JsonSchema(json_schema) => { - let mut state = serializer.serialize_struct("ChatCompletionResponseFormat", 2)?; - state.serialize_field("type", "json_schema")?; - state.serialize_field("json_schema", json_schema)?; - state.end() - } - } - } + JsonSchema { json_schema: JsonSchema }, } #[derive(Serialize, Deserialize, Debug, Default, Builder, Clone, PartialEq)] @@ -303,6 +280,55 @@ pub enum ChatMessage { }, } +impl ChatMessage { + /// Get the ChatMessageContent data, if it exists. + pub fn message(&self) -> Option<&ChatMessageContent> { + match self { + ChatMessage::Developer { content, .. } + | ChatMessage::System { content, .. } + | ChatMessage::User { content, .. } + | ChatMessage::Assistant { + content: Some(content), + .. + } => Some(content), + ChatMessage::Assistant { content: None, .. } => None, + ChatMessage::Tool { .. } => None, + } + } + + /// Get the content of the message as text, if it is a simple text message. + pub fn text(&self) -> Option<&str> { + match self { + ChatMessage::Developer { content, .. } + | ChatMessage::System { content, .. } + | ChatMessage::User { content, .. } + | ChatMessage::Assistant { + content: Some(content), + .. + } => { + if let ChatMessageContent::Text(text) = content { + Some(text) + } else { + None + } + } + ChatMessage::Assistant { content: None, .. } => None, + ChatMessage::Tool { content, .. } => Some(content), + } + } + + /// Get the name of the message sender, if it exists. + pub fn name(&self) -> Option<&str> { + match self { + ChatMessage::Developer { name, .. } + | ChatMessage::System { name, .. } + | ChatMessage::User { name, .. } + | ChatMessage::Assistant { name, .. } => name.as_deref(), + ChatMessage::Tool { .. } => None, + } + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[serde(tag = "role", rename_all = "lowercase")] pub enum DeltaChatMessage { @@ -633,11 +659,11 @@ pub enum ChatCompletionToolType { #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[serde(rename_all = "lowercase")] -#[serde(untagged)] pub enum ChatCompletionToolChoice { None, Auto, Required, + #[serde(untagged)] ChatCompletionToolChoiceFunction(ChatCompletionToolChoiceFunction), } @@ -703,3 +729,81 @@ impl DeltaFunction { self.name.is_none() && self.arguments.is_none() } } + +#[cfg(test)] +mod tests { + use crate::v1::resources::chat::{ + ChatCompletionResponseFormat, ChatCompletionToolChoice, ChatCompletionToolChoiceFunction, + ChatCompletionToolChoiceFunctionName, ChatCompletionToolType, JsonSchemaBuilder, + }; + use serde_json; + + #[test] + fn test_chat_completion_response_format_serialization_deserialization() { + let json_schema = JsonSchemaBuilder::default() + .description("This is a test schema".to_string()) + .name("test_schema".to_string()) + .schema(Some(serde_json::json!({"type": "object"}))) + .strict(true) + .build() + .unwrap(); + + let response_format = ChatCompletionResponseFormat::JsonSchema { json_schema }; + + // Serialize the response format to a JSON string + let serialized = serde_json::to_string(&response_format).unwrap(); + assert_eq!(serialized, "{\"type\":\"json_schema\",\"json_schema\":{\"description\":\"This is a test schema\",\"name\":\"test_schema\",\"schema\":{\"type\":\"object\"},\"strict\":true}}"); + + // Deserialize the JSON string back to a ChatCompletionResponseFormat + let deserialized: ChatCompletionResponseFormat = serde_json::from_str(&serialized).unwrap(); + match deserialized { + ChatCompletionResponseFormat::JsonSchema { json_schema } => { + assert_eq!( + json_schema.description, + Some("This is a test schema".to_string()) + ); + assert_eq!(json_schema.name, "test_schema".to_string()); + assert_eq!( + json_schema.schema, + Some(serde_json::json!({"type": "object"})) + ); + assert_eq!(json_schema.strict, Some(true)); + } + _ => panic!("Deserialized format should be JsonSchema"), + } + } + + #[test] + fn test_chat_completion_tool_choice_required_serialization_deserialization() { + let tool_choice = ChatCompletionToolChoice::Required; + + let serialized = serde_json::to_string(&tool_choice).unwrap(); + assert_eq!(serialized, "\"required\""); + + let deserialized: ChatCompletionToolChoice = + serde_json::from_str(serialized.as_str()).unwrap(); + assert_eq!(deserialized, tool_choice) + } + + #[test] + fn test_chat_completion_tool_choice_named_function_serialization_deserialization() { + let tool_choice = ChatCompletionToolChoice::ChatCompletionToolChoiceFunction( + ChatCompletionToolChoiceFunction { + r#type: Some(ChatCompletionToolType::Function), + function: ChatCompletionToolChoiceFunctionName { + name: "get_current_weather".to_string(), + }, + }, + ); + + let serialized = serde_json::to_string(&tool_choice).unwrap(); + assert_eq!( + serialized, + "{\"type\":\"function\",\"function\":{\"name\":\"get_current_weather\"}}" + ); + + let deserialized: ChatCompletionToolChoice = + serde_json::from_str(serialized.as_str()).unwrap(); + assert_eq!(deserialized, tool_choice) + } +} diff --git a/openai_dive/src/v1/resources/image.rs b/openai_dive/src/v1/resources/image.rs index 01cac166..1b5860d4 100644 --- a/openai_dive/src/v1/resources/image.rs +++ b/openai_dive/src/v1/resources/image.rs @@ -53,10 +53,20 @@ pub struct EditImageParameters { pub image: FileUpload, /// A text description of the desired image(s). The maximum length is 1000 characters. pub prompt: String, + /// Allows to set transparency for the background of the generated image(s). + #[serde(skip_serializing_if = "Option::is_none")] + pub background: Option, /// An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where image should be edited. /// Must be a valid PNG file, less than 4MB, and have the same dimensions as image. #[serde(skip_serializing_if = "Option::is_none")] pub mask: Option, + /// The quality of the image that will be generated. hd creates images with finer details and greater consistency across the image. + #[serde(skip_serializing_if = "Option::is_none")] + pub quality: Option, + /// The mime type of the image. If not provided, the mime type will be set to application/octet-stream. + /// gpt-image-1 expects `image/png`, `image/jpeg` or `image/webp`. + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, /// The model to use for image generation. Only dall-e-2 is supported at this time. #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, @@ -114,17 +124,45 @@ pub enum ImageSize { Size512X512, #[serde(rename = "1024x1024")] Size1024X1024, + #[serde(rename = "1024x1536")] + Size1024X1536, + #[serde(rename = "1536x1024")] + Size1536X1024, #[serde(rename = "1792x1024")] Size1792X1024, #[serde(rename = "1024x1792")] Size1024X1792, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum BackgroundStyle { + Transparent, + Opaque, + Auto, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ImageQuality { Standard, Hd, + High, + Medium, + Low, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum MimeType { + #[serde(rename = "image/png")] + Png, + #[serde(rename = "image/jpeg")] + Jpeg, + #[serde(rename = "image/webp")] + Webp, + #[serde(rename = "application/octet-stream")] + OctetStream, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] @@ -234,6 +272,36 @@ impl ImageData { } } +impl Display for BackgroundStyle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + BackgroundStyle::Transparent => "transparent", + BackgroundStyle::Opaque => "opaque", + BackgroundStyle::Auto => "auto", + } + ) + } +} + +impl Display for ImageQuality { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + ImageQuality::Standard => "standard", + ImageQuality::Hd => "hd", + ImageQuality::High => "high", + ImageQuality::Medium => "medium", + ImageQuality::Low => "low", + } + ) + } +} + impl Display for ImageSize { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( @@ -243,6 +311,8 @@ impl Display for ImageSize { ImageSize::Size256X256 => "256x256", ImageSize::Size512X512 => "512x512", ImageSize::Size1024X1024 => "1024x1024", + ImageSize::Size1536X1024 => "1536x1024", + ImageSize::Size1024X1536 => "1024x1536", ImageSize::Size1792X1024 => "1792x1024", ImageSize::Size1024X1792 => "1024x1792", } @@ -250,6 +320,21 @@ impl Display for ImageSize { } } +impl Display for MimeType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + MimeType::Png => "image/png", + MimeType::Jpeg => "image/jpeg", + MimeType::Webp => "image/webp", + MimeType::OctetStream => "application/octet-stream", + } + ) + } +} + impl Display for ResponseFormat { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( diff --git a/openai_dive/src/v1/resources/response/items.rs b/openai_dive/src/v1/resources/response/items.rs index 87aca396..63efbc4d 100644 --- a/openai_dive/src/v1/resources/response/items.rs +++ b/openai_dive/src/v1/resources/response/items.rs @@ -46,7 +46,7 @@ pub struct FunctionToolCall { #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct FunctionToolCallOutput { - pub id: String, + pub id: Option, pub call_id: String, pub output: String, pub status: InputItemStatus, diff --git a/openai_dive/src/v1/resources/shared.rs b/openai_dive/src/v1/resources/shared.rs index 0df8ce53..7f4247c2 100644 --- a/openai_dive/src/v1/resources/shared.rs +++ b/openai_dive/src/v1/resources/shared.rs @@ -295,14 +295,53 @@ impl FileUploadBytes { #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] pub enum FileUpload { Bytes(FileUploadBytes), + BytesArray(Vec), #[cfg(all(feature = "tokio", feature = "tokio-util"))] File(String), + #[cfg(all(feature = "tokio", feature = "tokio-util"))] + FileArray(Vec), } impl FileUpload { #[cfg(feature = "reqwest")] pub(crate) async fn into_part(self) -> Result { match self { FileUpload::Bytes(bytes) => bytes.into_part(), + FileUpload::BytesArray(_) => { + unimplemented!("BytesArray is not supported for this route") + } + #[cfg(all(feature = "tokio", feature = "tokio-util"))] + FileUpload::File(path) => { + use tokio::fs::File; + use tokio_util::codec::{BytesCodec, FramedRead}; + + let file = File::open(&path) + .await + .map_err(|error| APIError::FileError(error.to_string()))?; + + let stream = FramedRead::new(file, BytesCodec::new()); + let file_body = reqwest::Body::wrap_stream(stream); + + let file_part = reqwest::multipart::Part::stream(file_body).file_name(path); + // .mime_str("application/octet-stream") + // .unwrap(); + + Ok(file_part) + } + #[cfg(all(feature = "tokio", feature = "tokio-util"))] + FileUpload::FileArray(_) => { + unimplemented!("FileArray is not supported for this route") + } + } + } + + #[cfg(feature = "reqwest")] + pub(crate) async fn into_parts(self) -> Result, APIError> { + match self { + FileUpload::Bytes(bytes) => bytes.into_part().map(|part| vec![part]), + FileUpload::BytesArray(bytes) => bytes + .into_iter() + .map(|bytes| bytes.into_part()) + .collect::, APIError>>(), #[cfg(all(feature = "tokio", feature = "tokio-util"))] FileUpload::File(path) => { use tokio::fs::File; @@ -320,7 +359,31 @@ impl FileUpload { .mime_str("application/octet-stream") .unwrap(); - Ok(file_part) + Ok(vec![file_part]) + } + #[cfg(all(feature = "tokio", feature = "tokio-util"))] + FileUpload::FileArray(paths) => { + use tokio::fs::File; + use tokio_util::codec::{BytesCodec, FramedRead}; + + let mut file_parts = vec![]; + for path in paths { + let file = File::open(&path) + .await + .map_err(|error| APIError::FileError(error.to_string()))?; + + let stream = FramedRead::new(file, BytesCodec::new()); + let file_body = reqwest::Body::wrap_stream(stream); + + let file_part = reqwest::multipart::Part::stream(file_body) + .file_name(path) + .mime_str("application/octet-stream") + .unwrap(); + + file_parts.push(file_part); + } + + Ok(file_parts) } } }