Skip to content

Commit 807f9aa

Browse files
committed
feat: implement ExtraBody support for ChatCompletionRequest
1 parent c125ae2 commit 807f9aa

File tree

3 files changed

+227
-0
lines changed

3 files changed

+227
-0
lines changed

chat.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ type ChatCompletionRequest struct {
280280
// Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false}
281281
// https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes
282282
ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"`
283+
// Add additional JSON properties to the request
284+
ExtraBody map[string]any `json:"extra_body,omitempty"`
283285
}
284286

285287
type StreamOptions struct {
@@ -425,6 +427,7 @@ func (c *Client) CreateChatCompletion(
425427
http.MethodPost,
426428
c.fullURL(urlSuffix, withModel(request.Model)),
427429
withBody(request),
430+
withExtraBody(request.ExtraBody),
428431
)
429432
if err != nil {
430433
return

chat_test.go

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,188 @@ func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error
916916
return completion, nil
917917
}
918918

919+
func TestChatCompletionRequestExtraBody(t *testing.T) {
920+
t.Run("ExtraBodySerialization", func(t *testing.T) {
921+
req := openai.ChatCompletionRequest{
922+
Model: "gpt-4",
923+
Messages: []openai.ChatCompletionMessage{
924+
{
925+
Role: openai.ChatMessageRoleUser,
926+
Content: "Hello!",
927+
},
928+
},
929+
ExtraBody: map[string]any{
930+
"custom_param": "custom_value",
931+
"numeric_param": 42,
932+
"boolean_param": true,
933+
"array_param": []string{"item1", "item2"},
934+
"object_param": map[string]any{
935+
"nested_key": "nested_value",
936+
},
937+
},
938+
}
939+
940+
data, err := json.Marshal(req)
941+
checks.NoError(t, err, "Failed to marshal request with ExtraBody")
942+
943+
// Verify that ExtraBody fields are included in JSON
944+
jsonStr := string(data)
945+
if !strings.Contains(jsonStr, `"extra_body"`) {
946+
t.Error("ExtraBody should be serialized in JSON")
947+
}
948+
if !strings.Contains(jsonStr, `"custom_param":"custom_value"`) {
949+
t.Error("Custom string parameter should be serialized")
950+
}
951+
if !strings.Contains(jsonStr, `"numeric_param":42`) {
952+
t.Error("Numeric parameter should be serialized")
953+
}
954+
if !strings.Contains(jsonStr, `"boolean_param":true`) {
955+
t.Error("Boolean parameter should be serialized")
956+
}
957+
958+
// Verify that we can unmarshal it back
959+
var unmarshaled openai.ChatCompletionRequest
960+
err = json.Unmarshal(data, &unmarshaled)
961+
checks.NoError(t, err, "Failed to unmarshal request with ExtraBody")
962+
963+
if unmarshaled.ExtraBody["custom_param"] != "custom_value" {
964+
t.Error("Custom parameter not correctly unmarshaled")
965+
}
966+
if int(unmarshaled.ExtraBody["numeric_param"].(float64)) != 42 {
967+
t.Error("Numeric parameter not correctly unmarshaled")
968+
}
969+
if unmarshaled.ExtraBody["boolean_param"] != true {
970+
t.Error("Boolean parameter not correctly unmarshaled")
971+
}
972+
})
973+
974+
t.Run("EmptyExtraBody", func(t *testing.T) {
975+
req := openai.ChatCompletionRequest{
976+
Model: "gpt-4",
977+
Messages: []openai.ChatCompletionMessage{
978+
{
979+
Role: openai.ChatMessageRoleUser,
980+
Content: "Hello!",
981+
},
982+
},
983+
ExtraBody: map[string]any{},
984+
}
985+
986+
data, err := json.Marshal(req)
987+
checks.NoError(t, err, "Failed to marshal request with empty ExtraBody")
988+
989+
// Empty ExtraBody should be omitted due to omitempty tag
990+
jsonStr := string(data)
991+
if strings.Contains(jsonStr, `"extra_body"`) {
992+
t.Error("Empty ExtraBody should be omitted from JSON")
993+
}
994+
})
995+
996+
t.Run("NilExtraBody", func(t *testing.T) {
997+
req := openai.ChatCompletionRequest{
998+
Model: "gpt-4",
999+
Messages: []openai.ChatCompletionMessage{
1000+
{
1001+
Role: openai.ChatMessageRoleUser,
1002+
Content: "Hello!",
1003+
},
1004+
},
1005+
ExtraBody: nil,
1006+
}
1007+
1008+
data, err := json.Marshal(req)
1009+
checks.NoError(t, err, "Failed to marshal request with nil ExtraBody")
1010+
1011+
// Nil ExtraBody should be omitted due to omitempty tag
1012+
jsonStr := string(data)
1013+
if strings.Contains(jsonStr, `"extra_body"`) {
1014+
t.Error("Nil ExtraBody should be omitted from JSON")
1015+
}
1016+
})
1017+
}
1018+
1019+
func TestChatCompletionWithExtraBody(t *testing.T) {
1020+
client, server, teardown := setupOpenAITestServer()
1021+
defer teardown()
1022+
1023+
// Set up a handler that verifies ExtraBody fields are merged into the request body
1024+
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
1025+
var reqBody map[string]any
1026+
body, err := io.ReadAll(r.Body)
1027+
if err != nil {
1028+
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
1029+
return
1030+
}
1031+
1032+
err = json.Unmarshal(body, &reqBody)
1033+
if err != nil {
1034+
http.Error(w, "Failed to parse request body", http.StatusInternalServerError)
1035+
return
1036+
}
1037+
1038+
// Verify that ExtraBody fields are merged at the top level
1039+
if reqBody["custom_parameter"] != "test_value" {
1040+
http.Error(w, "ExtraBody custom_parameter not found in request", http.StatusBadRequest)
1041+
return
1042+
}
1043+
if reqBody["additional_config"] != true {
1044+
http.Error(w, "ExtraBody additional_config not found in request", http.StatusBadRequest)
1045+
return
1046+
}
1047+
1048+
// Verify standard fields are still present
1049+
if reqBody["model"] != "gpt-4" {
1050+
http.Error(w, "Standard model field not found", http.StatusBadRequest)
1051+
return
1052+
}
1053+
1054+
// Return a mock response
1055+
res := openai.ChatCompletionResponse{
1056+
ID: "test-id",
1057+
Object: "chat.completion",
1058+
Created: time.Now().Unix(),
1059+
Model: "gpt-4",
1060+
Choices: []openai.ChatCompletionChoice{
1061+
{
1062+
Index: 0,
1063+
Message: openai.ChatCompletionMessage{
1064+
Role: openai.ChatMessageRoleAssistant,
1065+
Content: "Hello! I received your message with extra parameters.",
1066+
},
1067+
FinishReason: openai.FinishReasonStop,
1068+
},
1069+
},
1070+
Usage: openai.Usage{
1071+
PromptTokens: 10,
1072+
CompletionTokens: 20,
1073+
TotalTokens: 30,
1074+
},
1075+
}
1076+
1077+
w.Header().Set("Content-Type", "application/json")
1078+
json.NewEncoder(w).Encode(res)
1079+
})
1080+
1081+
// Test ChatCompletion with ExtraBody
1082+
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
1083+
Model: "gpt-4",
1084+
Messages: []openai.ChatCompletionMessage{
1085+
{
1086+
Role: openai.ChatMessageRoleUser,
1087+
Content: "Hello!",
1088+
},
1089+
},
1090+
ExtraBody: map[string]any{
1091+
"custom_parameter": "test_value",
1092+
"additional_config": true,
1093+
"numeric_setting": 123,
1094+
"array_setting": []string{"option1", "option2"},
1095+
},
1096+
})
1097+
1098+
checks.NoError(t, err, "CreateChatCompletion with ExtraBody should not fail")
1099+
}
1100+
9191101
func TestFinishReason(t *testing.T) {
9201102
c := &openai.ChatCompletionChoice{
9211103
FinishReason: openai.FinishReasonNull,

client.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,48 @@ func withBody(body any) requestOption {
8484
}
8585
}
8686

87+
func withExtraBody(extraBody map[string]any) requestOption {
88+
return func(args *requestOptions) {
89+
if len(extraBody) == 0 {
90+
return // No extra body to merge
91+
}
92+
93+
// Check if args.body is already a map[string]any
94+
if bodyMap, ok := args.body.(map[string]any); ok {
95+
// If it's already a map[string]any, directly add extraBody fields
96+
for key, value := range extraBody {
97+
bodyMap[key] = value
98+
}
99+
return
100+
}
101+
102+
// If args.body is a struct, convert it to map[string]any first
103+
if args.body != nil {
104+
var err error
105+
var jsonBytes []byte
106+
// Marshal the struct to JSON bytes
107+
jsonBytes, err = json.Marshal(args.body)
108+
if err != nil {
109+
return // If marshaling fails, skip merging ExtraBody
110+
}
111+
112+
// Unmarshal JSON bytes to map[string]any
113+
var bodyMap map[string]any
114+
if err = json.Unmarshal(jsonBytes, &bodyMap); err != nil {
115+
return // If unmarshaling fails, skip merging ExtraBody
116+
}
117+
118+
// Merge ExtraBody fields into the map
119+
for key, value := range extraBody {
120+
bodyMap[key] = value
121+
}
122+
123+
// Replace args.body with the merged map
124+
args.body = bodyMap
125+
}
126+
}
127+
}
128+
87129
func withContentType(contentType string) requestOption {
88130
return func(args *requestOptions) {
89131
args.header.Set("Content-Type", contentType)

0 commit comments

Comments
 (0)