Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions src/llm/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,24 @@ impl OpenAiProvider {
}

const OPENAI_API_KEY_ENV: &str = "OPENAI_API_KEY";
const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
const OPENAI_BASE_URL_ENV: &str = "OPENAI_BASE_URL";
const DEFAULT_OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";

/// Get the API endpoint URL for chat completions
/// Uses OPENAI_BASE_URL environment variable if set, otherwise uses default OpenAI URL
fn get_api_endpoint() -> String {
match env::var(OPENAI_BASE_URL_ENV) {
Ok(base) if !base.trim().is_empty() => {
// User provided custom base URL
let base = base.trim().trim_end_matches('/');
format!("{}/chat/completions", base)
}
_ => {
// Use default OpenAI URL
DEFAULT_OPENAI_API_URL.to_string()
}
}
}

#[async_trait::async_trait]
impl AiProvider for OpenAiProvider {
Expand Down Expand Up @@ -597,15 +614,17 @@ async fn execute_openai_request(
) -> Result<ProviderResponse> {
let client = Client::new();
let start_time = std::time::Instant::now();
let api_url = get_api_endpoint();

let response = retry::retry_with_exponential_backoff(
|| {
let client = client.clone();
let api_key = api_key.clone();
let request_body = request_body.clone();
let api_url = api_url.clone();
Box::pin(async move {
client
.post(OPENAI_API_URL)
.post(&api_url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.json(&request_body)
Expand Down Expand Up @@ -969,4 +988,79 @@ mod tests {
assert!(!provider.supports_vision("o1-mini"));
assert!(!provider.supports_vision("text-davinci-003"));
}

#[test]
fn test_get_api_endpoint_default() {
// When OPENAI_BASE_URL is not set, should return default URL
let _lock = std::env::var("_TEST_LOCK").ok(); // Ensure tests run serially
std::env::remove_var(OPENAI_BASE_URL_ENV);
let result = get_api_endpoint();
assert_eq!(result, DEFAULT_OPENAI_API_URL);
}

#[test]
fn test_get_api_endpoint_custom_base_url() {
// When OPENAI_BASE_URL is set, should append /chat/completions
let _lock = std::env::var("_TEST_LOCK").ok();
std::env::remove_var(OPENAI_BASE_URL_ENV); // Clean first
std::env::set_var(OPENAI_BASE_URL_ENV, "http://localhost:8080/v1");
let result = get_api_endpoint();
assert_eq!(result, "http://localhost:8080/v1/chat/completions");
std::env::remove_var(OPENAI_BASE_URL_ENV);
}

#[test]
fn test_get_api_endpoint_trailing_slash() {
// Should normalize trailing slash
let _lock = std::env::var("_TEST_LOCK").ok();
std::env::remove_var(OPENAI_BASE_URL_ENV); // Clean first
std::env::set_var(OPENAI_BASE_URL_ENV, "http://localhost:8080/v1/");
let result = get_api_endpoint();
assert_eq!(result, "http://localhost:8080/v1/chat/completions");
std::env::remove_var(OPENAI_BASE_URL_ENV);
}

#[test]
fn test_get_api_endpoint_empty_string() {
// Empty string should fallback to default
let _lock = std::env::var("_TEST_LOCK").ok();
std::env::remove_var(OPENAI_BASE_URL_ENV); // Clean first
std::env::set_var(OPENAI_BASE_URL_ENV, "");
let result = get_api_endpoint();
assert_eq!(result, DEFAULT_OPENAI_API_URL);
std::env::remove_var(OPENAI_BASE_URL_ENV);
}

#[test]
fn test_get_api_endpoint_whitespace() {
// Whitespace-only should fallback to default
let _lock = std::env::var("_TEST_LOCK").ok();
std::env::remove_var(OPENAI_BASE_URL_ENV); // Clean first
std::env::set_var(OPENAI_BASE_URL_ENV, " ");
let result = get_api_endpoint();
assert_eq!(result, DEFAULT_OPENAI_API_URL);
std::env::remove_var(OPENAI_BASE_URL_ENV);
}

#[test]
fn test_get_api_endpoint_https() {
// Should support HTTPS
let _lock = std::env::var("_TEST_LOCK").ok();
std::env::remove_var(OPENAI_BASE_URL_ENV); // Clean first
std::env::set_var(OPENAI_BASE_URL_ENV, "https://proxy.example.com/v1");
let result = get_api_endpoint();
assert_eq!(result, "https://proxy.example.com/v1/chat/completions");
std::env::remove_var(OPENAI_BASE_URL_ENV);
}

#[test]
fn test_get_api_endpoint_http() {
// Should support HTTP (for local development)
let _lock = std::env::var("_TEST_LOCK").ok();
std::env::remove_var(OPENAI_BASE_URL_ENV); // Clean first
std::env::set_var(OPENAI_BASE_URL_ENV, "http://localhost:11434/v1");
let result = get_api_endpoint();
assert_eq!(result, "http://localhost:11434/v1/chat/completions");
std::env::remove_var(OPENAI_BASE_URL_ENV);
}
}