Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/maximhq/bifrost/core/providers/sgl"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/providers/vertex"
"github.com/maximhq/bifrost/core/providers/xai"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
Expand Down Expand Up @@ -1927,6 +1928,8 @@ func (bifrost *Bifrost) createBaseProvider(providerKey schemas.ModelProvider, co
return nebius.NewNebiusProvider(config, bifrost.logger)
case schemas.HuggingFace:
return huggingface.NewHuggingFaceProvider(config, bifrost.logger), nil
case schemas.XAI:
return xai.NewXAIProvider(config, bifrost.logger)
default:
return nil, fmt.Errorf("unsupported provider: %s", targetProviderKey)
}
Expand Down
48 changes: 48 additions & 0 deletions core/internal/testutil/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ func (account *ComprehensiveTestAccount) GetConfiguredProviders() ([]schemas.Mod
schemas.OpenRouter,
schemas.HuggingFace,
schemas.Nebius,
schemas.XAI,
ProviderOpenAICustom,
}, nil
}
Expand Down Expand Up @@ -346,6 +347,15 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context
UseForBatchAPI: bifrost.Ptr(true),
},
}, nil
case schemas.XAI:
return []schemas.Key{
{
Value: os.Getenv("XAI_API_KEY"),
Models: []string{},
Weight: 1.0,
UseForBatchAPI: bifrost.Ptr(true),
},
}, nil
default:
return nil, fmt.Errorf("unsupported provider: %s", providerKey)
}
Expand Down Expand Up @@ -617,6 +627,19 @@ func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schema
BufferSize: 10,
},
}, nil
case schemas.XAI:
return &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 120,
MaxRetries: 10,
RetryBackoffInitial: 1 * time.Second,
RetryBackoffMax: 12 * time.Second,
},
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
Concurrency: Concurrency,
BufferSize: 10,
},
}, nil
default:
return nil, fmt.Errorf("unsupported provider: %s", providerKey)
}
Expand Down Expand Up @@ -1056,4 +1079,29 @@ var AllProviderConfigs = []ComprehensiveTestConfig{
{Provider: schemas.OpenAI, Model: "gpt-4o-mini"},
},
},
{
Provider: schemas.XAI,
ChatModel: "grok-4-0709",
TextModel: "", // XAI focuses on chat
Scenarios: TestScenarios{
TextCompletion: false, // Not typical
SimpleChat: true,
CompletionStream: true,
MultiTurnConversation: true,
ToolCalls: true,
MultipleToolCalls: true,
End2EndToolCalling: true,
AutomaticFunctionCall: true,
ImageURL: true,
ImageBase64: true,
MultipleImages: true,
CompleteEnd2End: true,
SpeechSynthesis: false, // Not supported
SpeechSynthesisStream: false, // Not supported
Transcription: false, // Not supported
TranscriptionStream: false, // Not supported
Embedding: false, // Not supported
ListModels: true,
},
},
}
4 changes: 3 additions & 1 deletion core/internal/testutil/text_completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"os"
"testing"


bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
Expand All @@ -29,6 +28,9 @@ func RunTextCompletionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co
Input: &schemas.TextCompletionInput{
PromptStr: &prompt,
},
Params: &schemas.TextCompletionParameters{
MaxTokens: bifrost.Ptr(100),
},
Fallbacks: testConfig.TextCompletionFallbacks,
}

Expand Down
277 changes: 277 additions & 0 deletions core/providers/xai/xai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
// Package providers implements various LLM providers and their utility functions.
// This file contains the xAI provider implementation.
package xai

import (
"context"
"strings"
"time"

"github.com/maximhq/bifrost/core/providers/openai"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)

// xAIProvider implements the Provider interface for xAI's API.
type XAIProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
}

// NewXAIProvider creates a new xAI provider instance.
// It initializes the HTTP client with the provided configuration and sets up response pools.
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
func NewXAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*XAIProvider, error) {
config.CheckAndSetDefaults()

client := &fasthttp.Client{
ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
MaxConnsPerHost: 5000,
MaxIdleConnDuration: 60 * time.Second,
MaxConnWaitTimeout: 10 * time.Second,
}

// Configure proxy if provided
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)

config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")

if config.NetworkConfig.BaseURL == "" {
config.NetworkConfig.BaseURL = "https://api.x.ai"
}

return &XAIProvider{
logger: logger,
client: client,
networkConfig: config.NetworkConfig,
sendBackRawRequest: config.SendBackRawRequest,
sendBackRawResponse: config.SendBackRawResponse,
}, nil
}

// GetProviderKey returns the provider identifier for xAI.
func (provider *XAIProvider) GetProviderKey() schemas.ModelProvider {
return schemas.XAI
}

// ListModels performs a list models request to xAI's API.
func (provider *XAIProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
if provider.networkConfig.BaseURL == "" {
return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey())
}
return openai.HandleOpenAIListModelsRequest(
ctx,
provider.client,
request,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"),
keys,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.logger,
)
}

// TextCompletion performs a text completion request to the xAI API.
func (provider *XAIProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
return openai.HandleOpenAITextCompletionRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
request,
key,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.logger,
)
}

// TextCompletionStream performs a streaming text completion request to xAI's API.
// It formats the request, sends it to xAI, and processes the response.
// Returns a channel of BifrostStream objects or an error if the request fails.
func (provider *XAIProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
return openai.HandleOpenAITextCompletionStreaming(
ctx,
provider.client,
provider.networkConfig.BaseURL+"/v1/completions",
request,
nil,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
nil,
provider.logger,
)
}

// ChatCompletion performs a chat completion request to the xAI API.
func (provider *XAIProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
return openai.HandleOpenAIChatCompletionRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
request,
key,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
provider.logger,
)
}

// ChatCompletionStream performs a streaming chat completion request to the xAI API.
// It supports real-time streaming of responses using Server-Sent Events (SSE).
// Uses xAI's OpenAI-compatible streaming format.
// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails.
func (provider *XAIProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
var authHeader map[string]string
if key.Value != "" {
authHeader = map[string]string{"Authorization": "Bearer " + key.Value}
}
// Use shared OpenAI-compatible streaming logic
return openai.HandleOpenAIChatCompletionStreaming(
ctx,
provider.client,
provider.networkConfig.BaseURL+"/v1/chat/completions",
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
schemas.XAI,
postHookRunner,
nil,
nil,
nil,
provider.logger,
)
}

// Responses performs a responses request to the xAI API.
func (provider *XAIProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
return openai.HandleOpenAIResponsesRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/responses"),
request,
key,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
provider.logger,
)
}

// ResponsesStream performs a streaming responses request to the xAI API.
func (provider *XAIProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
var authHeader map[string]string
if key.Value != "" {
authHeader = map[string]string{"Authorization": "Bearer " + key.Value}
}
return openai.HandleOpenAIResponsesStreaming(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/responses"),
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
nil,
nil,
provider.logger,
)
}

// Embedding is not supported by the xAI provider.
func (provider *XAIProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey())
}

// Speech is not supported by the xAI provider.
func (provider *XAIProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
}

// SpeechStream is not supported by the xAI provider.
func (provider *XAIProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
}

// Transcription is not supported by the xAI provider.
func (provider *XAIProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey())
}

// TranscriptionStream is not supported by the xAI provider.
func (provider *XAIProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
}

// BatchCreate is not supported by xAI provider.
func (provider *XAIProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
}

// BatchList is not supported by xAI provider.
func (provider *XAIProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
}

// BatchRetrieve is not supported by xAI provider.
func (provider *XAIProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
}

// BatchCancel is not supported by xAI provider.
func (provider *XAIProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
}

// BatchResults is not supported by xAI provider.
func (provider *XAIProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
}

// FileUpload is not supported by xAI provider.
func (provider *XAIProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
}

// FileList is not supported by xAI provider.
func (provider *XAIProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
}

// FileRetrieve is not supported by xAI provider.
func (provider *XAIProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
}

// FileDelete is not supported by xAI provider.
func (provider *XAIProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
}

// FileContent is not supported by xAI provider.
func (provider *XAIProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
}

func (provider *XAIProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
}
Loading