1
1
package openai
2
2
3
3
import (
4
+ "encoding/json"
4
5
"errors"
5
6
"fmt"
6
7
"github.com/tidwall/gjson"
7
8
"net/http"
9
+ "strings"
8
10
"webot/pkg/client"
9
11
)
10
12
@@ -23,13 +25,23 @@ func New(opts Options) *OpenAI {
23
25
}
24
26
}
25
27
26
- func (oai * OpenAI ) Chat (model string , messages []Message ) (* Message , error ) {
28
+ func (oai * OpenAI ) Chat (model string , messages []Message , tools ... Tool ) (* Message , error ) {
27
29
const api = "/v1/chat/completions"
28
30
31
+ var toolsBody json.RawMessage
32
+ if len (tools ) != 0 {
33
+ var funcList = make ([]string , 0 , len (tools ))
34
+ for _ , t := range tools {
35
+ funcList = append (funcList , t .Func )
36
+ }
37
+ toolsBody = json .RawMessage (fmt .Sprintf ("[%s]" , strings .Join (funcList , "," )))
38
+ }
39
+
29
40
reqBody := RequestBody {
30
41
Model : model ,
31
42
Messages : messages ,
32
43
Temperature : 0.5 ,
44
+ Tools : toolsBody ,
33
45
}
34
46
35
47
req , _ := http .NewRequest (http .MethodPost , oai .opts .BaseURL + api , client .MarshalBody (reqBody ))
@@ -39,7 +51,37 @@ func (oai *OpenAI) Chat(model string, messages []Message) (*Message, error) {
39
51
if err != nil {
40
52
return nil , err
41
53
}
42
- message := NewMessage (gjson .GetBytes (resp , "choices.0.message" ))
54
+
55
+ result := gjson .GetBytes (resp , "choices.0.message" )
56
+ var toolCalls = result .Get ("tool_calls" )
57
+ var toolCallsResult []Message
58
+ for _ , call := range toolCalls .Array () {
59
+ callId := call .Get ("id" ).String ()
60
+ call = call .Get ("function" )
61
+ name := call .Get ("name" ).String ()
62
+ args := call .Get ("arguments" ).String ()
63
+
64
+ for _ , t := range tools {
65
+ if t .Name != name {
66
+ continue
67
+ }
68
+ toolCallsResult = append (toolCallsResult , Message {
69
+ Role : RTool ,
70
+ Content : t .Call (args ),
71
+ ToolCallId : callId ,
72
+ })
73
+ }
74
+ }
75
+ if len (toolCallsResult ) != 0 {
76
+ messages = append (messages , Message {
77
+ Role : RAssistant ,
78
+ ToolCalls : json .RawMessage (toolCalls .Raw ),
79
+ })
80
+ messages = append (messages , toolCallsResult ... )
81
+ return oai .Chat (model , messages , tools ... )
82
+ }
83
+
84
+ message := NewMessage (result )
43
85
if message .Role == "" {
44
86
return nil , errors .New (string (resp ))
45
87
}
0 commit comments