Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 43 additions & 20 deletions api/handler.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package api

import (
"context"
// "fmt"
"net/http"
// "os"

"github.com/gin-gonic/gin"
"github.com/modelcontextprotocol/go-sdk/mcp"

"pansou/config"
"pansou/model"
"pansou/service"
jsonutil "pansou/util/json"
"pansou/util"
jsonutil "pansou/util/json"
"strings"
)

Expand All @@ -32,7 +35,7 @@ func SearchHandler(c *gin.Context) {
// GET方式:从URL参数获取
// 获取keyword,必填参数
keyword := c.Query("kw")

// 处理channels参数,支持逗号分隔
channelsStr := c.Query("channels")
var channels []string
Expand All @@ -46,32 +49,32 @@ func SearchHandler(c *gin.Context) {
}
}
}

// 处理并发数
concurrency := 0
concStr := c.Query("conc")
if concStr != "" && concStr != " " {
concurrency = util.StringToInt(concStr)
}

// 处理强制刷新
forceRefresh := false
refreshStr := c.Query("refresh")
if refreshStr != "" && refreshStr != " " && refreshStr == "true" {
forceRefresh = true
}

// 处理结果类型和来源类型
resultType := c.Query("res")
if resultType == "" || resultType == " " {
resultType = "merge" // 直接设置为默认值merge
}

sourceType := c.Query("src")
if sourceType == "" || sourceType == " " {
sourceType = "all" // 直接设置为默认值all
}

// 处理plugins参数,支持逗号分隔
var plugins []string
// 检查请求中是否存在plugins参数
Expand All @@ -91,7 +94,7 @@ func SearchHandler(c *gin.Context) {
// 如果请求中不存在plugins参数,设置为nil
plugins = nil
}

// 处理cloud_types参数,支持逗号分隔
var cloudTypes []string
// 检查请求中是否存在cloud_types参数
Expand All @@ -111,7 +114,7 @@ func SearchHandler(c *gin.Context) {
// 如果请求中不存在cloud_types参数,设置为nil
cloudTypes = nil
}

// 处理ext参数,JSON格式
var ext map[string]interface{}
extStr := c.Query("ext")
Expand All @@ -130,7 +133,7 @@ func SearchHandler(c *gin.Context) {
if ext == nil {
ext = make(map[string]interface{})
}

// 处理filter参数,JSON格式
var filter *model.FilterConfig
filterStr := c.Query("filter")
Expand Down Expand Up @@ -167,25 +170,25 @@ func SearchHandler(c *gin.Context) {
return
}
}

// 检查并设置默认值
if len(req.Channels) == 0 {
req.Channels = config.AppConfig.DefaultChannels
}

// 如果未指定结果类型,默认返回merge并转换为merged_by_type
if req.ResultType == "" {
req.ResultType = "merged_by_type"
} else if req.ResultType == "merge" {
// 将merge转换为merged_by_type,以兼容内部处理
req.ResultType = "merged_by_type"
}

// 如果未指定数据来源类型,默认为全部
if req.SourceType == "" {
req.SourceType = "all"
}

// 参数互斥逻辑:当src=tg时忽略plugins参数,当src=plugin时忽略channels参数
if req.SourceType == "tg" {
req.Plugins = nil // 忽略plugins参数
Expand All @@ -197,14 +200,14 @@ func SearchHandler(c *gin.Context) {
req.Plugins = nil
}
}

// 可选:启用调试输出(生产环境建议注释掉)
// fmt.Printf("🔧 [调试] 搜索参数: keyword=%s, channels=%v, concurrency=%d, refresh=%v, resultType=%s, sourceType=%s, plugins=%v, cloudTypes=%v, ext=%v\n",
// fmt.Printf("🔧 [调试] 搜索参数: keyword=%s, channels=%v, concurrency=%d, refresh=%v, resultType=%s, sourceType=%s, plugins=%v, cloudTypes=%v, ext=%v\n",
// req.Keyword, req.Channels, req.Concurrency, req.ForceRefresh, req.ResultType, req.SourceType, req.Plugins, req.CloudTypes, req.Ext)

// 执行搜索
result, err := searchService.Search(req.Keyword, req.Channels, req.Concurrency, req.ForceRefresh, req.ResultType, req.SourceType, req.Plugins, req.CloudTypes, req.Ext)

if err != nil {
response := model.NewErrorResponse(500, "搜索失败: "+err.Error())
jsonData, _ := jsonutil.Marshal(response)
Expand All @@ -221,4 +224,24 @@ func SearchHandler(c *gin.Context) {
response := model.NewSuccessResponse(result)
jsonData, _ := jsonutil.Marshal(response)
c.Data(http.StatusOK, "application/json", jsonData)
}
}

func SearchMcpHandler(_ context.Context, request *mcp.CallToolRequest, mcpReq model.McpSearchRequest) (toolCallResult *mcp.CallToolResult, result model.SearchResponse, err error) {
searchReq := model.SearchRequest{
Keyword: mcpReq.Keyword,
ForceRefresh: mcpReq.ForceRefresh,
Filter: mcpReq.Filter,
CloudTypes: mcpReq.CloudTypes,
Channels: config.AppConfig.DefaultChannels, // 设置默认值
SourceType: "all", // 默认为全部
ResultType: "merged_by_type",
Plugins: nil,
}
// 执行搜索
result, err = searchService.Search(searchReq.Keyword, searchReq.Channels, searchReq.Concurrency, searchReq.ForceRefresh, searchReq.ResultType, searchReq.SourceType, searchReq.Plugins, searchReq.CloudTypes, searchReq.Ext)
if err != nil {
return
}

return
}
49 changes: 49 additions & 0 deletions api/health.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package api

import (
"pansou/config"
)

type HealthResponse struct {
Status string `json:"status" jsonschema:"状态,如果是ok则表示服务正常"`
AuthEnabled bool `json:"auth_enabled" jsonschema:"是否启用认证,如果启用则需要先通过登录获取token"`
PluginsEnabled bool `json:"plugins_enabled" jsonschema:"是否启用异步插件"`
ChannelsCount int `json:"channels_count" jsonschema:"是否启用异步插件"`
Channels []string `json:"channels" jsonschema:"异步插件列表"`
PluginCount int `json:"plugin_count" jsonschema:"插件数量"`
Plugins []string `json:"plugins" jsonschema:"插件列表"`
}

func Health() HealthResponse {
// 根据配置决定是否返回插件信息
pluginCount := 0
pluginNames := []string{}
pluginsEnabled := config.AppConfig.AsyncPluginEnabled

if pluginsEnabled && searchService != nil && searchService.GetPluginManager() != nil {
plugins := searchService.GetPluginManager().GetPlugins()
pluginCount = len(plugins)
for _, p := range plugins {
pluginNames = append(pluginNames, p.Name())
}
}
// 获取频道信息
channels := config.AppConfig.DefaultChannels
channelsCount := len(channels)

response := HealthResponse{
Status: "ok",
AuthEnabled: config.AppConfig.AuthEnabled, // 添加认证状态
PluginsEnabled: pluginsEnabled,
Channels: channels,
ChannelsCount: channelsCount,
}

// 只有当插件启用时才返回插件相关信息
if pluginsEnabled {
response.PluginCount = pluginCount
response.Plugins = pluginNames
}

return response
}
25 changes: 25 additions & 0 deletions api/mcp_tools.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package api

import (
"context"

"github.com/modelcontextprotocol/go-sdk/mcp"
)

const (
McpMethodNameSearch = "search"
McpMethodNameHealth = "health"
)

// SetupMcpTool 设置MCP工具
func SetupMcpTool() *mcp.Server {
server := mcp.NewServer(&mcp.Implementation{Name: "pansou", Version: "v2.0"}, nil)
server.AddReceivingMiddleware(McpAuthMiddleware)
mcp.AddTool(server, &mcp.Tool{Name: McpMethodNameHealth, Description: "获取服务器状态"}, func(_ context.Context, request *mcp.CallToolRequest, input map[string]any) (toolCallResult *mcp.CallToolResult, result HealthResponse, _ error) {
result = Health()
return
})
mcp.AddTool(server, &mcp.Tool{Name: McpMethodNameSearch, Description: "搜索网盘资源"}, SearchMcpHandler)

return server
}
45 changes: 32 additions & 13 deletions api/middleware.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package api

import (
"context"
"errors"
"fmt"
"net/url"
"strings"
"time"

"github.com/gin-gonic/gin"
"github.com/modelcontextprotocol/go-sdk/mcp"

"pansou/config"
"pansou/util"
)
Expand All @@ -17,12 +21,12 @@ func CORSMiddleware() gin.HandlerFunc {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")

if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}

c.Next()
}
}
Expand All @@ -32,22 +36,22 @@ func LoggerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 开始时间
startTime := time.Now()

// 处理请求
c.Next()

// 结束时间
endTime := time.Now()

// 执行时间
latencyTime := endTime.Sub(startTime)

// 请求方式
reqMethod := c.Request.Method

// 请求路由
reqURI := c.Request.RequestURI

// 对于搜索API,尝试解码关键词以便更好地显示
displayURI := reqURI
if strings.Contains(reqURI, "/api/search") && strings.Contains(reqURI, "kw=") {
Expand All @@ -60,16 +64,16 @@ func LoggerMiddleware() gin.HandlerFunc {
}
}
}

// 状态码
statusCode := c.Writer.Status()

// 请求IP
clientIP := c.ClientIP()

// 日志格式
gin.DefaultWriter.Write([]byte(
fmt.Sprintf("| %s | %s | %s | %d | %s\n",
fmt.Sprintf("| %s | %s | %s | %d | %s\n",
clientIP, reqMethod, displayURI, statusCode, latencyTime.String())))
}
}
Expand Down Expand Up @@ -138,4 +142,19 @@ func AuthMiddleware() gin.HandlerFunc {
c.Set("username", claims.Username)
c.Next()
}
}
}

func McpAuthMiddleware(h mcp.MethodHandler) mcp.MethodHandler {
return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) {
// mcp token为空则,直接返回有效
if config.AppConfig.AuthMcpToken == "" {
return h(ctx, method, req)
}
extra := req.GetExtra()
if extra == nil || extra.Header.Get("Authorization") != config.AppConfig.AuthMcpToken {
return nil, errors.New("未授权:确认McpToken")
}

return h(ctx, method, req)
}
}
Loading