Skip to content

Commit b8a111c

Browse files
committed
Make Tool APIs more type safe and require fewer temporaries
1 parent 083779b commit b8a111c

File tree

6 files changed

+77
-81
lines changed

6 files changed

+77
-81
lines changed

tools.go

+36-3
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,48 @@ import (
1111
"github.com/mark3labs/mcp-go/server"
1212
)
1313

14-
func MustTool(name, description string, toolHandler any) (mcp.Tool, server.ToolHandlerFunc) {
14+
// Tool is a struct that represents a tool definition and the function used
15+
// to handle tool calls.
16+
//
17+
// The simplest way to create a Tool is to use `MustTool`, or `ConvertTool`
18+
// if you wish to create tools at runtime and need to handle errors without
19+
// panicking.
20+
type Tool struct {
21+
Tool mcp.Tool
22+
Handler server.ToolHandlerFunc
23+
}
24+
25+
// Register adds the Tool to the given MCPServer.
26+
//
27+
// It is a convenience method that calls `server.MCPServer.Register` with the
28+
// Tool's Tool and Handler fields, allowing you to add the tool in a single
29+
// statement:
30+
//
31+
// mcpgrafana.MustTool(name, description, toolHandler).Register(server)
32+
func (t *Tool) Register(mcp *server.MCPServer) {
33+
mcp.AddTool(t.Tool, t.Handler)
34+
}
35+
36+
// MustTool creates a new Tool from the given name, description, and toolHandler.
37+
// It panics if the tool cannot be created.
38+
func MustTool[T any](name, description string, toolHandler ToolHandlerFunc[T]) Tool {
1539
tool, handler, err := ConvertTool(name, description, toolHandler)
1640
if err != nil {
1741
panic(err)
1842
}
19-
return tool, handler
43+
return Tool{Tool: tool, Handler: handler}
2044
}
2145

22-
func ConvertTool(name, description string, toolHandler any) (mcp.Tool, server.ToolHandlerFunc, error) {
46+
// ToolHandlerFunc is the type of a handler function for a tool.
47+
type ToolHandlerFunc[T any] = func(ctx context.Context, request T) (*mcp.CallToolResult, error)
48+
49+
// ConvertTool converts a toolHandler function to a Tool and ToolHandlerFunc.
50+
//
51+
// The toolHandler function must have two arguments: a context.Context and a struct
52+
// to be used as the parameters for the tool. The second argument must not be a pointer,
53+
// should be marshalable to JSON, and the fields should have a `jsonschema` tag with the
54+
// description of the parameter.
55+
func ConvertTool[T any](name, description string, toolHandler ToolHandlerFunc[T]) (mcp.Tool, server.ToolHandlerFunc, error) {
2356
zero := mcp.Tool{}
2457
handlerValue := reflect.ValueOf(toolHandler)
2558
handlerType := handlerValue.Type()

tools/datasources.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func listDatasources(ctx context.Context, args ListDatasourcesParams) (*mcp.Call
2626
return mcp.NewToolResultText(string(b)), nil
2727
}
2828

29-
var ListDatasourcesTool, ListDatasourcesHandler = mcpgrafana.MustTool(
29+
var ListDatasources = mcpgrafana.MustTool(
3030
"list_datasources",
3131
"List datasources",
3232
listDatasources,
@@ -49,7 +49,7 @@ func getDatasourceByUID(ctx context.Context, args GetDatasourceByUIDParams) (*mc
4949
return mcp.NewToolResultText(string(b)), nil
5050
}
5151

52-
var GetDatasourceByUIDTool, GetDatasourceByUIDHandler = mcpgrafana.MustTool(
52+
var GetDatasourceByUID = mcpgrafana.MustTool(
5353
"get_datasource_by_uid",
5454
"Get datasource by uid",
5555
getDatasourceByUID,
@@ -72,14 +72,14 @@ func getDatasourceByName(ctx context.Context, args GetDatasourceByNameParams) (*
7272
return mcp.NewToolResultText(string(b)), nil
7373
}
7474

75-
var GetDatasourceByNameTool, GetDatasourceByNameHandler = mcpgrafana.MustTool(
75+
var GetDatasourceByName = mcpgrafana.MustTool(
7676
"get_datasource_by_name",
7777
"Get datasource by name",
7878
getDatasourceByName,
7979
)
8080

8181
func AddDatasourceTools(mcp *server.MCPServer) {
82-
mcp.AddTool(ListDatasourcesTool, ListDatasourcesHandler)
83-
mcp.AddTool(GetDatasourceByUIDTool, GetDatasourceByUIDHandler)
84-
mcp.AddTool(GetDatasourceByNameTool, GetDatasourceByNameHandler)
82+
ListDatasources.Register(mcp)
83+
GetDatasourceByUID.Register(mcp)
84+
GetDatasourceByName.Register(mcp)
8585
}

tools/incident.go

+12-12
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ type ListIncidentsParams struct {
1717
Status string `json:"status" jsonschema:"description=The status of the incidents to include"`
1818
}
1919

20-
func ListIncidents(ctx context.Context, args ListIncidentsParams) (*mcp.CallToolResult, error) {
20+
func listIncidents(ctx context.Context, args ListIncidentsParams) (*mcp.CallToolResult, error) {
2121
c := mcpgrafana.IncidentClientFromContext(ctx)
2222
is := incident.NewIncidentsService(c)
2323
query := ""
@@ -44,10 +44,10 @@ func ListIncidents(ctx context.Context, args ListIncidentsParams) (*mcp.CallTool
4444
return mcp.NewToolResultText(string(b)), nil
4545
}
4646

47-
var ListIncidentsTool, ListIncidentsHandler = mcpgrafana.MustTool(
47+
var ListIncidents = mcpgrafana.MustTool(
4848
"list_incidents",
4949
"List incidents",
50-
ListIncidents,
50+
listIncidents,
5151
)
5252

5353
type CreateIncidentParams struct {
@@ -61,7 +61,7 @@ type CreateIncidentParams struct {
6161
Labels []incident.IncidentLabel `json:"labels" jsonschema:"description=The labels to add to the incident"`
6262
}
6363

64-
func CreateIncident(ctx context.Context, args CreateIncidentParams) (*mcp.CallToolResult, error) {
64+
func createIncident(ctx context.Context, args CreateIncidentParams) (*mcp.CallToolResult, error) {
6565
c := mcpgrafana.IncidentClientFromContext(ctx)
6666
is := incident.NewIncidentsService(c)
6767
incident, err := is.CreateIncident(ctx, incident.CreateIncidentRequest{
@@ -84,10 +84,10 @@ func CreateIncident(ctx context.Context, args CreateIncidentParams) (*mcp.CallTo
8484
return mcp.NewToolResultText(string(b)), nil
8585
}
8686

87-
var CreateIncidentTool, CreateIncidentHandler = mcpgrafana.MustTool(
87+
var CreateIncident = mcpgrafana.MustTool(
8888
"create_incident",
8989
"Create an incident",
90-
CreateIncident,
90+
createIncident,
9191
)
9292

9393
type AddActivityToIncidentParams struct {
@@ -96,7 +96,7 @@ type AddActivityToIncidentParams struct {
9696
EventTime string `json:"eventTime" jsonschema:"description=The time that the activity occurred. If not provided, the current time will be used"`
9797
}
9898

99-
func AddActivityToIncident(ctx context.Context, args AddActivityToIncidentParams) (*mcp.CallToolResult, error) {
99+
func addActivityToIncident(ctx context.Context, args AddActivityToIncidentParams) (*mcp.CallToolResult, error) {
100100
c := mcpgrafana.IncidentClientFromContext(ctx)
101101
as := incident.NewActivityService(c)
102102
activity, err := as.AddActivity(ctx, incident.AddActivityRequest{
@@ -115,14 +115,14 @@ func AddActivityToIncident(ctx context.Context, args AddActivityToIncidentParams
115115
return mcp.NewToolResultText(string(b)), nil
116116
}
117117

118-
var AddActivityToIncidentTool, AddActivityToIncidentHandler = mcpgrafana.MustTool(
118+
var AddActivityToIncident = mcpgrafana.MustTool(
119119
"add_activity_to_incident",
120120
"Add an activity to an incident",
121-
AddActivityToIncident,
121+
addActivityToIncident,
122122
)
123123

124124
func AddIncidentTools(mcp *server.MCPServer) {
125-
mcp.AddTool(ListIncidentsTool, ListIncidentsHandler)
126-
mcp.AddTool(CreateIncidentTool, CreateIncidentHandler)
127-
mcp.AddTool(AddActivityToIncidentTool, AddActivityToIncidentHandler)
125+
ListIncidents.Register(mcp)
126+
CreateIncident.Register(mcp)
127+
AddActivityToIncident.Register(mcp)
128128
}

tools/prometheus.go

+20-20
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ type ListPrometheusMetricMetadataParams struct {
4242
Metric string `json:"metric" jsonschema:"description=The metric to query"`
4343
}
4444

45-
func ListPrometheusMetricMetadata(ctx context.Context, args ListPrometheusMetricMetadataParams) (*mcp.CallToolResult, error) {
45+
func listPrometheusMetricMetadata(ctx context.Context, args ListPrometheusMetricMetadataParams) (*mcp.CallToolResult, error) {
4646
promClient, err := promClientFromContext(ctx, args.DatasourceUID)
4747
if err != nil {
4848
return nil, fmt.Errorf("getting Prometheus client: %w", err)
@@ -64,10 +64,10 @@ func ListPrometheusMetricMetadata(ctx context.Context, args ListPrometheusMetric
6464
return mcp.NewToolResultText(string(b)), nil
6565
}
6666

67-
var ListPrometheusMetricMetadataTool, ListPrometheusMetricMetadataHandler = mcpgrafana.MustTool(
67+
var ListPrometheusMetricMetadata = mcpgrafana.MustTool(
6868
"list_prometheus_metric_metadata",
6969
"List Prometheus metric metadata",
70-
ListPrometheusMetricMetadata,
70+
listPrometheusMetricMetadata,
7171
)
7272

7373
type QueryPrometheusParams struct {
@@ -79,7 +79,7 @@ type QueryPrometheusParams struct {
7979
QueryType string `json:"queryType,omitempty" jsonschema:"description=The type of query to use. Either 'range' or 'instant'"`
8080
}
8181

82-
func QueryPrometheus(ctx context.Context, args QueryPrometheusParams) (*mcp.CallToolResult, error) {
82+
func queryPrometheus(ctx context.Context, args QueryPrometheusParams) (*mcp.CallToolResult, error) {
8383
promClient, err := promClientFromContext(ctx, args.DatasourceUID)
8484
if err != nil {
8585
return nil, fmt.Errorf("getting Prometheus client: %w", err)
@@ -136,10 +136,10 @@ func QueryPrometheus(ctx context.Context, args QueryPrometheusParams) (*mcp.Call
136136
return nil, fmt.Errorf("invalid query type: %s", queryType)
137137
}
138138

139-
var QueryPrometheusTool, QueryPrometheusHandler = mcpgrafana.MustTool(
139+
var QueryPrometheus = mcpgrafana.MustTool(
140140
"query_prometheus",
141141
"Query Prometheus using a range or instant request",
142-
QueryPrometheus,
142+
queryPrometheus,
143143
)
144144

145145
type ListPrometheusMetricNamesParams struct {
@@ -149,7 +149,7 @@ type ListPrometheusMetricNamesParams struct {
149149
Page int `json:"page,omitempty" jsonschema:"description=The page number to return"`
150150
}
151151

152-
func ListPrometheusMetricNames(ctx context.Context, args ListPrometheusMetricNamesParams) (*mcp.CallToolResult, error) {
152+
func listPrometheusMetricNames(ctx context.Context, args ListPrometheusMetricNamesParams) (*mcp.CallToolResult, error) {
153153
promClient, err := promClientFromContext(ctx, args.DatasourceUID)
154154
if err != nil {
155155
return nil, fmt.Errorf("getting Prometheus client: %w", err)
@@ -207,10 +207,10 @@ func ListPrometheusMetricNames(ctx context.Context, args ListPrometheusMetricNam
207207
return mcp.NewToolResultText(string(b)), nil
208208
}
209209

210-
var ListPrometheusMetricNamesTool, ListPrometheusMetricNamesHandler = mcpgrafana.MustTool(
210+
var ListPrometheusMetricNames = mcpgrafana.MustTool(
211211
"list_prometheus_metric_names",
212212
"List metric names in a Prometheus datasource that match the given regex",
213-
ListPrometheusMetricNames,
213+
listPrometheusMetricNames,
214214
)
215215

216216
type LabelMatcher struct {
@@ -244,7 +244,7 @@ type ListPrometheusLabelNamesParams struct {
244244
Limit int `json:"limit,omitempty" jsonschema:"description=Optionally, the maximum number of results to return"`
245245
}
246246

247-
func ListPrometheusLabelNames(ctx context.Context, args ListPrometheusLabelNamesParams) (*mcp.CallToolResult, error) {
247+
func listPrometheusLabelNames(ctx context.Context, args ListPrometheusLabelNamesParams) (*mcp.CallToolResult, error) {
248248
promClient, err := promClientFromContext(ctx, args.DatasourceUID)
249249
if err != nil {
250250
return nil, fmt.Errorf("getting Prometheus client: %w", err)
@@ -289,10 +289,10 @@ func ListPrometheusLabelNames(ctx context.Context, args ListPrometheusLabelNames
289289
return mcp.NewToolResultText(string(b)), nil
290290
}
291291

292-
var ListPrometheusLabelNamesTool, ListPrometheusLabelNamesHandler = mcpgrafana.MustTool(
292+
var ListPrometheusLabelNames = mcpgrafana.MustTool(
293293
"list_prometheus_label_names",
294294
"List the label names in a Prometheus datasource",
295-
ListPrometheusLabelNames,
295+
listPrometheusLabelNames,
296296
)
297297

298298
type ListPrometheusLabelValuesParams struct {
@@ -304,7 +304,7 @@ type ListPrometheusLabelValuesParams struct {
304304
Limit int `json:"limit,omitempty" jsonschema:"description=Optionally, the maximum number of results to return"`
305305
}
306306

307-
func ListPrometheusLabelValues(ctx context.Context, args ListPrometheusLabelValuesParams) (*mcp.CallToolResult, error) {
307+
func listPrometheusLabelValues(ctx context.Context, args ListPrometheusLabelValuesParams) (*mcp.CallToolResult, error) {
308308
promClient, err := promClientFromContext(ctx, args.DatasourceUID)
309309
if err != nil {
310310
return nil, fmt.Errorf("getting Prometheus client: %w", err)
@@ -349,16 +349,16 @@ func ListPrometheusLabelValues(ctx context.Context, args ListPrometheusLabelValu
349349
return mcp.NewToolResultText(string(b)), nil
350350
}
351351

352-
var ListPrometheusLabelValuesTool, ListPrometheusLabelValuesHandler = mcpgrafana.MustTool(
352+
var ListPrometheusLabelValues = mcpgrafana.MustTool(
353353
"list_prometheus_label_values",
354354
"Get the values of a label in Prometheus",
355-
ListPrometheusLabelValues,
355+
listPrometheusLabelValues,
356356
)
357357

358358
func AddPrometheusTools(mcp *server.MCPServer) {
359-
mcp.AddTool(ListPrometheusMetricMetadataTool, ListPrometheusMetricMetadataHandler)
360-
mcp.AddTool(QueryPrometheusTool, QueryPrometheusHandler)
361-
mcp.AddTool(ListPrometheusMetricNamesTool, ListPrometheusMetricNamesHandler)
362-
mcp.AddTool(ListPrometheusLabelNamesTool, ListPrometheusLabelNamesHandler)
363-
mcp.AddTool(ListPrometheusLabelValuesTool, ListPrometheusLabelValuesHandler)
359+
ListPrometheusMetricMetadata.Register(mcp)
360+
QueryPrometheus.Register(mcp)
361+
ListPrometheusMetricNames.Register(mcp)
362+
ListPrometheusLabelNames.Register(mcp)
363+
ListPrometheusLabelValues.Register(mcp)
364364
}

tools/search.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ func searchDashboards(ctx context.Context, args SearchDashboardsParams) (*mcp.Ca
3333
return mcp.NewToolResultText(string(b)), nil
3434
}
3535

36-
var SearchDashboardsTool, SearchDashboardsHandler = mcpgrafana.MustTool(
36+
var SearchDashboards = mcpgrafana.MustTool(
3737
"search_dashboards",
3838
"Search for dashboards",
3939
searchDashboards,
4040
)
4141

4242
func AddSearchTools(mcp *server.MCPServer) {
43-
mcp.AddTool(SearchDashboardsTool, SearchDashboardsHandler)
43+
SearchDashboards.Register(mcp)
4444
}

tools_test.go

+1-38
Original file line numberDiff line numberDiff line change
@@ -131,48 +131,11 @@ func TestConvertTool(t *testing.T) {
131131
})
132132

133133
t.Run("invalid handler types", func(t *testing.T) {
134-
// Test non-function handler
135-
_, _, err := ConvertTool("invalid", "description", "not a function")
136-
assert.Error(t, err)
137-
assert.Contains(t, err.Error(), "must be a function")
138-
139-
// Test wrong number of arguments
140-
wrongArgsFunc := func(ctx context.Context) (*mcp.CallToolResult, error) {
141-
return nil, nil
142-
}
143-
_, _, err = ConvertTool("invalid", "description", wrongArgsFunc)
144-
assert.Error(t, err)
145-
assert.Contains(t, err.Error(), "must have 2 arguments")
146-
147-
// Test wrong number of return values
148-
wrongReturnFunc := func(ctx context.Context, params testToolParams) *mcp.CallToolResult {
149-
return nil
150-
}
151-
_, _, err = ConvertTool("invalid", "description", wrongReturnFunc)
152-
assert.Error(t, err)
153-
assert.Contains(t, err.Error(), "must return 2 values")
154-
155-
// Test wrong first argument type
156-
wrongFirstArgFunc := func(s string, params testToolParams) (*mcp.CallToolResult, error) {
157-
return nil, nil
158-
}
159-
_, _, err = ConvertTool("invalid", "description", wrongFirstArgFunc)
160-
assert.Error(t, err)
161-
assert.Contains(t, err.Error(), "first argument must be context.Context")
162-
163-
// Test wrong first return value type
164-
wrongFirstReturnFunc := func(ctx context.Context, params testToolParams) (string, error) {
165-
return "", nil
166-
}
167-
_, _, err = ConvertTool("invalid", "description", wrongFirstReturnFunc)
168-
assert.Error(t, err)
169-
assert.Contains(t, err.Error(), "first return value must be mcp.CallToolResult")
170-
171134
// Test wrong second argument type (not a struct)
172135
wrongSecondArgFunc := func(ctx context.Context, s string) (*mcp.CallToolResult, error) {
173136
return nil, nil
174137
}
175-
_, _, err = ConvertTool("invalid", "description", wrongSecondArgFunc)
138+
_, _, err := ConvertTool("invalid", "description", wrongSecondArgFunc)
176139
assert.Error(t, err)
177140
assert.Contains(t, err.Error(), "second argument must be a struct")
178141
})

0 commit comments

Comments
 (0)