Skip to content

Commit 45a005f

Browse files
committed
feat: add openai function call
1 parent fb2bf02 commit 45a005f

File tree

6 files changed

+125
-12
lines changed

6 files changed

+125
-12
lines changed

internal/cron/v2ex.go

+10-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package cron
22

33
import (
44
"fmt"
5-
"github.com/tidwall/gjson"
65
"log/slog"
76
"net/http"
87
"strings"
@@ -13,10 +12,12 @@ import (
1312
"webot/internal/types"
1413
"webot/pkg/client"
1514
"webot/pkg/openai"
15+
16+
"github.com/tidwall/gjson"
1617
)
1718

1819
const (
19-
v2exSpec = "30 8-18/3 * * *"
20+
v2exSpec = "30 10-18/3 * * *"
2021
v2exHotAPI = "https://www.v2ex.com/api/topics/hot.json"
2122
v2exRepliesAPI = "https://www.v2ex.com/api/replies/show.json?topic_id="
2223
)
@@ -62,7 +63,7 @@ func (ctx v2ex) run() {
6263
Title: title,
6364
})
6465
}
65-
if err = PushToAll("V2EX 热帖推送"+v2exSplit+strings.Join(summarizeList, v2exSplit), types.PushV2ex); err != nil {
66+
if err = PushToAll(v2exHeader+strings.Join(summarizeList, v2exSplit)+v2exSplit+v2exFooter, types.PushV2ex); err != nil {
6667
slog.Error("push v2ex post failed", slog.Any("err", err))
6768
}
6869
}
@@ -106,5 +107,10 @@ AI 点评:%s
106107

107108
const (
108109
v2exSystem = "你擅长点评 V2EX 论坛热门帖子,从主题内容和回复中提取关键信息,并使用 60 个汉字以内的简洁描述。"
109-
v2exSplit = "\n---------------\n"
110+
v2exSplit = "\n-------------------\n"
111+
112+
v2exHeader = `V2EX 热帖推送
113+
---------------
114+
`
115+
v2exFooter = "输入 v2ex 关闭推送"
110116
)

internal/handler/handler.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func onThink(msg *openwechat.Message, done <-chan struct{}, atUser string) {
105105
select {
106106
case <-done:
107107
return
108-
case <-time.After(time.Second * 2):
108+
case <-time.After(time.Second * 5):
109109
_, err := msg.ReplyText(atUser + thinkText())
110110
if err != nil {
111111
slog.Error("reply think failed", slog.Any("err", err))
@@ -132,7 +132,7 @@ func (h *Handler) onText(msg *openwechat.Message, ctx params) {
132132
done := make(chan struct{}, 1)
133133
go onThink(msg, done, ctx.atUser)
134134
start := time.Now()
135-
result, err := h.AI.Chat(h.Cfg.GetModel(false), messages)
135+
result, err := h.AI.Chat(h.Cfg.GetModel(false), messages, openai.Google)
136136
if err != nil {
137137
done <- struct{}{}
138138
slog.Error("AI chat failed", attr, slog.Any("err", err))

pkg/openai/google.go

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package openai
2+
3+
import (
4+
"fmt"
5+
"github.com/tidwall/gjson"
6+
"log/slog"
7+
"net/http"
8+
"net/url"
9+
"os"
10+
"webot/pkg/client"
11+
)
12+
13+
const (
14+
nameGoogle = "GoogleSearch"
15+
)
16+
17+
var Google = Tool{
18+
Name: nameGoogle,
19+
Func: fmt.Sprintf(`{
20+
"type": "function",
21+
"function": {
22+
"name": "%s",
23+
"strict": true,
24+
"description": "Using Google to search the internet",
25+
"parameters": {
26+
"type": "object",
27+
"properties": {
28+
"keyword": { "type": "string", "description": "Search keyword" }
29+
},
30+
"required": ["keyword"]
31+
}
32+
}
33+
}`, nameGoogle),
34+
Call: callGoogle,
35+
}
36+
37+
func callGoogle(keyword string) string {
38+
keyword = gjson.Get(keyword, "keyword").String()
39+
slog.Info("call google search", slog.String("keyword", keyword))
40+
41+
var apiKey, _ = os.LookupEnv("GOOGLE_API_KEY")
42+
var engineId, _ = os.LookupEnv("GOOGLE_ENGINE_ID")
43+
if apiKey == "" || engineId == "" {
44+
slog.Warn("not found env GOOGLE_API_KEY or GOOGLE_ENGINE_ID")
45+
return "nothing"
46+
}
47+
req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("https://www.googleapis.com/customsearch/v1?&fields=items(title,link,snippet,pagemap/metatags(og:description))&key=%s&cx=%s&q=%s", apiKey, engineId, url.QueryEscape(keyword)), nil)
48+
resp, err := client.Do(req)
49+
if err != nil {
50+
slog.Error("google search failed", slog.Any("err", err))
51+
return "nothing"
52+
}
53+
54+
return gjson.GetBytes(resp, "items").String()
55+
}

pkg/openai/openai.go

+44-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package openai
22

33
import (
4+
"encoding/json"
45
"errors"
56
"fmt"
67
"github.com/tidwall/gjson"
78
"net/http"
9+
"strings"
810
"webot/pkg/client"
911
)
1012

@@ -23,13 +25,23 @@ func New(opts Options) *OpenAI {
2325
}
2426
}
2527

26-
func (oai *OpenAI) Chat(model string, messages []Message) (*Message, error) {
28+
func (oai *OpenAI) Chat(model string, messages []Message, tools ...Tool) (*Message, error) {
2729
const api = "/v1/chat/completions"
2830

31+
var toolsBody json.RawMessage
32+
if len(tools) != 0 {
33+
var funcList = make([]string, 0, len(tools))
34+
for _, t := range tools {
35+
funcList = append(funcList, t.Func)
36+
}
37+
toolsBody = json.RawMessage(fmt.Sprintf("[%s]", strings.Join(funcList, ",")))
38+
}
39+
2940
reqBody := RequestBody{
3041
Model: model,
3142
Messages: messages,
3243
Temperature: 0.5,
44+
Tools: toolsBody,
3345
}
3446

3547
req, _ := http.NewRequest(http.MethodPost, oai.opts.BaseURL+api, client.MarshalBody(reqBody))
@@ -39,7 +51,37 @@ func (oai *OpenAI) Chat(model string, messages []Message) (*Message, error) {
3951
if err != nil {
4052
return nil, err
4153
}
42-
message := NewMessage(gjson.GetBytes(resp, "choices.0.message"))
54+
55+
result := gjson.GetBytes(resp, "choices.0.message")
56+
var toolCalls = result.Get("tool_calls")
57+
var toolCallsResult []Message
58+
for _, call := range toolCalls.Array() {
59+
callId := call.Get("id").String()
60+
call = call.Get("function")
61+
name := call.Get("name").String()
62+
args := call.Get("arguments").String()
63+
64+
for _, t := range tools {
65+
if t.Name != name {
66+
continue
67+
}
68+
toolCallsResult = append(toolCallsResult, Message{
69+
Role: RTool,
70+
Content: t.Call(args),
71+
ToolCallId: callId,
72+
})
73+
}
74+
}
75+
if len(toolCallsResult) != 0 {
76+
messages = append(messages, Message{
77+
Role: RAssistant,
78+
ToolCalls: json.RawMessage(toolCalls.Raw),
79+
})
80+
messages = append(messages, toolCallsResult...)
81+
return oai.Chat(model, messages, tools...)
82+
}
83+
84+
message := NewMessage(result)
4385
if message.Role == "" {
4486
return nil, errors.New(string(resp))
4587
}

pkg/openai/openai_test.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package openai
22

33
import (
4-
"fmt"
54
"os"
65
"testing"
76
)
@@ -24,11 +23,11 @@ func TestOpenAI_Chat(t *testing.T) {
2423
msg, err := openai.Chat(model, []Message{
2524
{
2625
Role: RUser,
27-
Content: "help me calculate 1+2*3",
26+
Content: "help me search linux.do",
2827
},
29-
})
28+
}, Google)
3029
if err != nil {
3130
t.Fatal(err)
3231
}
33-
fmt.Println(msg.Role, msg.Content)
32+
t.Logf("%s: %s\n", msg.Role, msg.Content)
3433
}

pkg/openai/types.go

+11
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ const (
1717
type Message struct {
1818
Role Role `json:"role"`
1919
Content string `json:"content"`
20+
21+
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
22+
ToolCallId string `json:"tool_call_id,omitempty"`
2023
}
2124

2225
func NewMessage(msg gjson.Result) *Message {
@@ -32,3 +35,11 @@ type RequestBody struct {
3235
Temperature float64 `json:"temperature"`
3336
Tools json.RawMessage `json:"tools,omitempty"`
3437
}
38+
39+
type ToolCall func(string) string
40+
41+
type Tool struct {
42+
Name string
43+
Func string
44+
Call ToolCall
45+
}

0 commit comments

Comments
 (0)