Skip to content

Commit 99f1fb8

Browse files
qhenkartgrulex
authored and
grulex
committed
Fix broken implementation AssistantModify implementation (sashabaranov#685)
* add custom marshaller, documentation and isolate tests * fix linter
1 parent 14db069 commit 99f1fb8

File tree

2 files changed

+109
-30
lines changed

2 files changed

+109
-30
lines changed

Diff for: assistant.go

+28-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package openai
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"net/http"
78
"net/url"
@@ -21,7 +22,7 @@ type Assistant struct {
2122
Description *string `json:"description,omitempty"`
2223
Model string `json:"model"`
2324
Instructions *string `json:"instructions,omitempty"`
24-
Tools []AssistantTool `json:"tools,omitempty"`
25+
Tools []AssistantTool `json:"tools"`
2526
FileIDs []string `json:"file_ids,omitempty"`
2627
Metadata map[string]any `json:"metadata,omitempty"`
2728

@@ -41,16 +42,41 @@ type AssistantTool struct {
4142
Function *FunctionDefinition `json:"function,omitempty"`
4243
}
4344

45+
// AssistantRequest provides the assistant request parameters.
46+
// When modifying the tools the API functions as the following:
47+
// If Tools is undefined, no changes are made to the Assistant's tools.
48+
// If Tools is empty slice it will effectively delete all of the Assistant's tools.
49+
// If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools.
4450
type AssistantRequest struct {
4551
Model string `json:"model"`
4652
Name *string `json:"name,omitempty"`
4753
Description *string `json:"description,omitempty"`
4854
Instructions *string `json:"instructions,omitempty"`
49-
Tools []AssistantTool `json:"tools"`
55+
Tools []AssistantTool `json:"-"`
5056
FileIDs []string `json:"file_ids,omitempty"`
5157
Metadata map[string]any `json:"metadata,omitempty"`
5258
}
5359

60+
// MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases
61+
// If Tools is nil, the field is omitted from the JSON.
62+
// If Tools is an empty slice, it's included in the JSON as an empty array ([]).
63+
// If Tools is populated, it's included in the JSON with the elements.
64+
func (a AssistantRequest) MarshalJSON() ([]byte, error) {
65+
type Alias AssistantRequest
66+
assistantAlias := &struct {
67+
Tools *[]AssistantTool `json:"tools,omitempty"`
68+
*Alias
69+
}{
70+
Alias: (*Alias)(&a),
71+
}
72+
73+
if a.Tools != nil {
74+
assistantAlias.Tools = &a.Tools
75+
}
76+
77+
return json.Marshal(assistantAlias)
78+
}
79+
5480
// AssistantsList is a list of assistants.
5581
type AssistantsList struct {
5682
Assistants []Assistant `json:"data"`

Diff for: assistant_test.go

+81-28
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ When asked a question, write and run Python code to answer the question.`
9696
})
9797
fmt.Fprintln(w, string(resBytes))
9898
case http.MethodPost:
99-
var request openai.AssistantRequest
99+
var request openai.Assistant
100100
err := json.NewDecoder(r.Body).Decode(&request)
101101
checks.NoError(t, err, "Decode error")
102102

@@ -163,44 +163,97 @@ When asked a question, write and run Python code to answer the question.`
163163

164164
ctx := context.Background()
165165

166-
_, err := client.CreateAssistant(ctx, openai.AssistantRequest{
167-
Name: &assistantName,
168-
Description: &assistantDescription,
169-
Model: openai.GPT4TurboPreview,
170-
Instructions: &assistantInstructions,
166+
t.Run("create_assistant", func(t *testing.T) {
167+
_, err := client.CreateAssistant(ctx, openai.AssistantRequest{
168+
Name: &assistantName,
169+
Description: &assistantDescription,
170+
Model: openai.GPT4TurboPreview,
171+
Instructions: &assistantInstructions,
172+
})
173+
checks.NoError(t, err, "CreateAssistant error")
171174
})
172-
checks.NoError(t, err, "CreateAssistant error")
173175

174-
_, err = client.RetrieveAssistant(ctx, assistantID)
175-
checks.NoError(t, err, "RetrieveAssistant error")
176+
t.Run("retrieve_assistant", func(t *testing.T) {
177+
_, err := client.RetrieveAssistant(ctx, assistantID)
178+
checks.NoError(t, err, "RetrieveAssistant error")
179+
})
176180

177-
_, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{
178-
Name: &assistantName,
179-
Description: &assistantDescription,
180-
Model: openai.GPT4TurboPreview,
181-
Instructions: &assistantInstructions,
181+
t.Run("delete_assistant", func(t *testing.T) {
182+
_, err := client.DeleteAssistant(ctx, assistantID)
183+
checks.NoError(t, err, "DeleteAssistant error")
182184
})
183-
checks.NoError(t, err, "ModifyAssistant error")
184185

185-
_, err = client.DeleteAssistant(ctx, assistantID)
186-
checks.NoError(t, err, "DeleteAssistant error")
186+
t.Run("list_assistant", func(t *testing.T) {
187+
_, err := client.ListAssistants(ctx, &limit, &order, &after, &before)
188+
checks.NoError(t, err, "ListAssistants error")
189+
})
187190

188-
_, err = client.ListAssistants(ctx, &limit, &order, &after, &before)
189-
checks.NoError(t, err, "ListAssistants error")
191+
t.Run("create_assistant_file", func(t *testing.T) {
192+
_, err := client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{
193+
FileID: assistantFileID,
194+
})
195+
checks.NoError(t, err, "CreateAssistantFile error")
196+
})
190197

191-
_, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{
192-
FileID: assistantFileID,
198+
t.Run("list_assistant_files", func(t *testing.T) {
199+
_, err := client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before)
200+
checks.NoError(t, err, "ListAssistantFiles error")
193201
})
194-
checks.NoError(t, err, "CreateAssistantFile error")
195202

196-
_, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before)
197-
checks.NoError(t, err, "ListAssistantFiles error")
203+
t.Run("retrieve_assistant_file", func(t *testing.T) {
204+
_, err := client.RetrieveAssistantFile(ctx, assistantID, assistantFileID)
205+
checks.NoError(t, err, "RetrieveAssistantFile error")
206+
})
198207

199-
_, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID)
200-
checks.NoError(t, err, "RetrieveAssistantFile error")
208+
t.Run("delete_assistant_file", func(t *testing.T) {
209+
err := client.DeleteAssistantFile(ctx, assistantID, assistantFileID)
210+
checks.NoError(t, err, "DeleteAssistantFile error")
211+
})
201212

202-
err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID)
203-
checks.NoError(t, err, "DeleteAssistantFile error")
213+
t.Run("modify_assistant_no_tools", func(t *testing.T) {
214+
assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{
215+
Name: &assistantName,
216+
Description: &assistantDescription,
217+
Model: openai.GPT4TurboPreview,
218+
Instructions: &assistantInstructions,
219+
})
220+
checks.NoError(t, err, "ModifyAssistant error")
221+
222+
if assistant.Tools != nil {
223+
t.Errorf("expected nil got %v", assistant.Tools)
224+
}
225+
})
226+
227+
t.Run("modify_assistant_with_tools", func(t *testing.T) {
228+
assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{
229+
Name: &assistantName,
230+
Description: &assistantDescription,
231+
Model: openai.GPT4TurboPreview,
232+
Instructions: &assistantInstructions,
233+
Tools: []openai.AssistantTool{{Type: openai.AssistantToolTypeFunction}},
234+
})
235+
checks.NoError(t, err, "ModifyAssistant error")
236+
237+
if assistant.Tools == nil || len(assistant.Tools) != 1 {
238+
t.Errorf("expected a slice got %v", assistant.Tools)
239+
}
240+
})
241+
242+
t.Run("modify_assistant_empty_tools", func(t *testing.T) {
243+
assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{
244+
Name: &assistantName,
245+
Description: &assistantDescription,
246+
Model: openai.GPT4TurboPreview,
247+
Instructions: &assistantInstructions,
248+
Tools: make([]openai.AssistantTool, 0),
249+
})
250+
251+
checks.NoError(t, err, "ModifyAssistant error")
252+
253+
if assistant.Tools == nil {
254+
t.Errorf("expected a slice got %v", assistant.Tools)
255+
}
256+
})
204257
}
205258

206259
func TestAzureAssistant(t *testing.T) {

0 commit comments

Comments
 (0)