diff --git a/.github/workflows/lint-pr.yml b/.github/workflows/lint-pr.yml index 460ed9330..9c311a2c5 100644 --- a/.github/workflows/lint-pr.yml +++ b/.github/workflows/lint-pr.yml @@ -15,7 +15,7 @@ jobs: name: Validate PR title runs-on: ubuntu-latest steps: - - uses: amannn/action-semantic-pull-request@069817c298f23fab00a8f29a2e556a5eac0f6390 + - uses: amannn/action-semantic-pull-request@71b07ef490c9e8ef772f64a62d41545ae5b9ef22 with: types: | build diff --git a/.github/workflows/zizmor.yml b/.github/workflows/zizmor.yml index b2b3e5622..4b6729751 100644 --- a/.github/workflows/zizmor.yml +++ b/.github/workflows/zizmor.yml @@ -22,7 +22,7 @@ jobs: persist-credentials: false - name: Install the latest version of uv - uses: astral-sh/setup-uv@681c641aba71e4a1c380be3ab5e12ad51f415867 + uses: astral-sh/setup-uv@702b425af1c366e68b4a9449a8de6dd98b63e979 - name: Run zizmor ๐ŸŒˆ run: uvx zizmor --format sarif . > results.sarif diff --git a/go.mod b/go.mod index 1f9cb8f8e..433d73f8f 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/netapp/harvest/v2 go 1.24.0 require ( - github.com/goccy/go-yaml v1.19.0 + github.com/goccy/go-yaml v1.19.1 github.com/google/go-cmp v0.7.0 github.com/rivo/uniseg v0.4.7 github.com/spf13/cobra v1.10.2 diff --git a/go.sum b/go.sum index e52b2a262..da5f7d730 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/goccy/go-yaml v1.19.0 h1:EmkZ9RIsX+Uq4DYFowegAuJo8+xdX3T/2dwNPXbxEYE= github.com/goccy/go-yaml v1.19.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/goccy/go-yaml v1.19.1 h1:3rG3+v8pkhRqoQ/88NYNMHYVGYztCOCIZ7UQhu7H+NE= +github.com/goccy/go-yaml v1.19.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= diff --git a/mcp/go.mod b/mcp/go.mod index d44f2b16c..09e9c07c1 100644 --- a/mcp/go.mod +++ b/mcp/go.mod @@ -5,9 +5,9 @@ go 1.25 replace github.com/netapp/harvest/v2 => ../ require ( - github.com/goccy/go-yaml v1.19.0 - github.com/modelcontextprotocol/go-sdk v1.1.0 - github.com/netapp/harvest/v2 v2.0.0-20251212120439-ea75a8047ce8 + github.com/goccy/go-yaml v1.19.1 + github.com/modelcontextprotocol/go-sdk v1.2.0 + github.com/netapp/harvest/v2 v2.0.0-20251215084222-367b927e5360 github.com/spf13/cobra v1.10.2 ) diff --git a/mcp/go.sum b/mcp/go.sum index 7995c2143..044df673e 100644 --- a/mcp/go.sum +++ b/mcp/go.sum @@ -3,6 +3,8 @@ github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/goccy/go-yaml v1.19.0 h1:EmkZ9RIsX+Uq4DYFowegAuJo8+xdX3T/2dwNPXbxEYE= github.com/goccy/go-yaml v1.19.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/goccy/go-yaml v1.19.1 h1:3rG3+v8pkhRqoQ/88NYNMHYVGYztCOCIZ7UQhu7H+NE= +github.com/goccy/go-yaml v1.19.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= @@ -13,6 +15,8 @@ github.com/modelcontextprotocol/go-sdk v1.0.0 h1:Z4MSjLi38bTgLrd/LjSmofqRqyBiVKR github.com/modelcontextprotocol/go-sdk v1.0.0/go.mod h1:nYtYQroQ2KQiM0/SbyEPUWQ6xs4B95gJjEalc9AQyOs= github.com/modelcontextprotocol/go-sdk v1.1.0 h1:Qjayg53dnKC4UZ+792W21e4BpwEZBzwgRW6LrjLWSwA= github.com/modelcontextprotocol/go-sdk v1.1.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= +github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= +github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= diff --git a/mcp/vendor/github.com/goccy/go-yaml/ast/ast.go b/mcp/vendor/github.com/goccy/go-yaml/ast/ast.go index ca1505381..a8078a5f5 100644 --- a/mcp/vendor/github.com/goccy/go-yaml/ast/ast.go +++ b/mcp/vendor/github.com/goccy/go-yaml/ast/ast.go @@ -1623,7 +1623,11 @@ func (n *SequenceNode) flowStyleString() string { for _, value := range n.Values { values = append(values, value.String()) } - return fmt.Sprintf("[%s]", strings.Join(values, ", ")) + seqText := fmt.Sprintf("[%s]", strings.Join(values, ", ")) + if n.Comment != nil { + return addCommentString(seqText, n.Comment) + } + return seqText } func (n *SequenceNode) blockStyleString() string { diff --git a/mcp/vendor/github.com/goccy/go-yaml/decode.go b/mcp/vendor/github.com/goccy/go-yaml/decode.go index 43c317f8f..d490add63 100644 --- a/mcp/vendor/github.com/goccy/go-yaml/decode.go +++ b/mcp/vendor/github.com/goccy/go-yaml/decode.go @@ -288,7 +288,9 @@ func (d *Decoder) addSequenceNodeCommentToMap(node *ast.SequenceNode) { texts = append(texts, comment.Token.Value) } if len(texts) != 0 { - d.addCommentToMap(node.Values[0].GetPath(), HeadComment(texts...)) + if len(node.Values) != 0 { + d.addCommentToMap(node.Values[0].GetPath(), HeadComment(texts...)) + } } } } @@ -1750,14 +1752,11 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node return err } } else { - keyVal, err := d.nodeToValue(ctx, key) + keyVal, err := d.createDecodedNewValue(ctx, keyType, reflect.Value{}, key) if err != nil { return err } - k = reflect.ValueOf(keyVal) - if k.IsValid() && k.Type().ConvertibleTo(keyType) { - k = k.Convert(keyType) - } + k = keyVal } if k.IsValid() { diff --git a/mcp/vendor/github.com/goccy/go-yaml/internal/format/format.go b/mcp/vendor/github.com/goccy/go-yaml/internal/format/format.go index 2d55652ff..461dc36d2 100644 --- a/mcp/vendor/github.com/goccy/go-yaml/internal/format/format.go +++ b/mcp/vendor/github.com/goccy/go-yaml/internal/format/format.go @@ -351,8 +351,9 @@ func (f *Formatter) formatMapping(n *ast.MappingNode) string { var ret string if n.IsFlowStyle { ret = f.origin(n.Start) + } else { + ret += f.formatCommentGroup(n.Comment) } - ret += f.formatCommentGroup(n.Comment) for _, value := range n.Values { if value.CollectEntry != nil { ret += f.origin(value.CollectEntry) @@ -361,6 +362,7 @@ func (f *Formatter) formatMapping(n *ast.MappingNode) string { } if n.IsFlowStyle { ret += f.origin(n.End) + ret += f.formatCommentGroup(n.Comment) } return ret } @@ -377,8 +379,7 @@ func (f *Formatter) formatSequence(n *ast.SequenceNode) string { var ret string if n.IsFlowStyle { ret = f.origin(n.Start) - } - if n.Comment != nil { + } else { // add head comment. ret += f.formatCommentGroup(n.Comment) } @@ -387,6 +388,7 @@ func (f *Formatter) formatSequence(n *ast.SequenceNode) string { } if n.IsFlowStyle { ret += f.origin(n.End) + ret += f.formatCommentGroup(n.Comment) } ret += f.formatCommentGroup(n.FootComment) return ret diff --git a/mcp/vendor/github.com/goccy/go-yaml/parser/parser.go b/mcp/vendor/github.com/goccy/go-yaml/parser/parser.go index 2c79d3690..f5bfd1a96 100644 --- a/mcp/vendor/github.com/goccy/go-yaml/parser/parser.go +++ b/mcp/vendor/github.com/goccy/go-yaml/parser/parser.go @@ -426,6 +426,11 @@ func (p *parser) parseFlowMap(ctx *context) (*ast.MappingNode, error) { if node.End == nil { return nil, errors.ErrSyntax("could not find flow mapping end token '}'", node.Start) } + + // set line comment if exists. e.g.) } # comment + if err := setLineComment(ctx, node, ctx.currentToken()); err != nil { + return nil, err + } ctx.goNext() // skip mapping end token. return node, nil } @@ -1066,6 +1071,11 @@ func (p *parser) parseFlowSequence(ctx *context) (*ast.SequenceNode, error) { if node.End == nil { return nil, errors.ErrSyntax("sequence end token ']' not found", node.Start) } + + // set line comment if exists. e.g.) ] # comment + if err := setLineComment(ctx, node, ctx.currentToken()); err != nil { + return nil, err + } ctx.goNext() // skip sequence end token. return node, nil } diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go index 0eea1d873..87665121c 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/auth/auth.go @@ -6,17 +6,25 @@ package auth import ( "context" + "encoding/json" "errors" "net/http" "slices" "strings" "time" + + "github.com/modelcontextprotocol/go-sdk/oauthex" ) // TokenInfo holds information from a bearer token. type TokenInfo struct { Scopes []string Expiration time.Time + // UserID is an optional identifier for the authenticated user. + // If set by a TokenVerifier, it can be used by transports to prevent + // session hijacking by ensuring that all requests for a given session + // come from the same user. + UserID string // TODO: add standard JWT fields Extra map[string]any } @@ -118,3 +126,43 @@ func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenO } return tokenInfo, "", 0 } + +// ProtectedResourceMetadataHandler returns an http.Handler that serves OAuth 2.0 +// protected resource metadata (RFC 9728) with CORS support. +// +// This handler allows cross-origin requests from any origin (Access-Control-Allow-Origin: *) +// because OAuth metadata is public information intended for client discovery (RFC 9728 ยง3.1). +// The metadata contains only non-sensitive configuration data about authorization servers +// and supported scopes. +// +// No validation of metadata fields is performed; ensure metadata accuracy at configuration time. +// +// For more sophisticated CORS policies or to restrict origins, wrap this handler with a +// CORS middleware like github.com/rs/cors or github.com/jub0bs/cors. +func ProtectedResourceMetadataHandler(metadata *oauthex.ProtectedResourceMetadata) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Set CORS headers for cross-origin client discovery. + // OAuth metadata is public information, so allowing any origin is safe. + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + + // Handle CORS preflight requests + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + // Only GET allowed for metadata retrieval + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(metadata); err != nil { + http.Error(w, "Failed to encode metadata", http.StatusInternalServerError) + return + } + }) +} diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go index 5549ee1c9..627ffe7b6 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/conn.go @@ -55,9 +55,16 @@ type ConnectionOptions struct { } // Connection manages the jsonrpc2 protocol, connecting responses back to their -// calls. -// Connection is bidirectional; it does not have a designated server or client -// end. +// calls. Connection is bidirectional; it does not have a designated server or +// client end. +// +// Note that the word 'Connection' is overloaded: the mcp.Connection represents +// the bidirectional stream of messages between client an server. The +// jsonrpc2.Connection layers RPC logic on top of that stream, dispatching RPC +// handlers, and correlating requests with responses from the peer. +// +// Some of the complexity of the Connection type is grown out of its usage in +// gopls: it could probably be simplified based on our usage in MCP. type Connection struct { seq int64 // must only be accessed using atomic operations @@ -361,19 +368,26 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async if err := c.write(ctx, call); err != nil { // Sending failed. We will never get a response, so deliver a fake one if it // wasn't already retired by the connection breaking. - c.updateInFlight(func(s *inFlightState) { - if s.outgoingCalls[ac.id] == ac { - delete(s.outgoingCalls, ac.id) - ac.retire(&Response{ID: id, Error: err}) - } else { - // ac was already retired by the readIncoming goroutine: - // perhaps our write raced with the Read side of the connection breaking. - } - }) + c.Retire(ac, err) } return ac } +// Retire stops tracking the call, and reports err as its terminal error. +// +// Retire is safe to call multiple times: if the call is already no longer +// tracked, Retire is a no op. +func (c *Connection) Retire(ac *AsyncCall, err error) { + c.updateInFlight(func(s *inFlightState) { + if s.outgoingCalls[ac.id] == ac { + delete(s.outgoingCalls, ac.id) + ac.retire(&Response{ID: ac.id, Error: err}) + } else { + // ac was already retired elsewhere. + } + }) +} + // Async, signals that the current jsonrpc2 request may be handled // asynchronously to subsequent requests, when ctx is the request context. // @@ -437,6 +451,9 @@ func (ac *AsyncCall) IsReady() bool { } // retire processes the response to the call. +// +// It is an error to call retire more than once: retire is guarded by the +// connection's outgoingCalls map. func (ac *AsyncCall) retire(response *Response) { select { case <-ac.ready: @@ -450,6 +467,9 @@ func (ac *AsyncCall) retire(response *Response) { // Await waits for (and decodes) the results of a Call. // The response will be unmarshaled from JSON into the result. +// +// If the call is cancelled due to context cancellation, the result is +// ctx.Err(). func (ac *AsyncCall) Await(ctx context.Context, result any) error { select { case <-ctx.Done(): @@ -772,13 +792,9 @@ func (c *Connection) write(ctx context.Context, msg Message) error { err = c.writer.Write(ctx, msg) } - // For rejected requests, we don't set the writeErr (which would break the - // connection). They can just be returned to the caller. - if errors.Is(err, ErrRejected) { - return err - } - - if err != nil && ctx.Err() == nil { + // For cancelled or rejected requests, we don't set the writeErr (which would + // break the connection). They can just be returned to the caller. + if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) { // The call to Write failed, and since ctx.Err() is nil we can't attribute // the failure (even indirectly) to Context cancellation. The writer appears // to be broken, and future writes are likely to also fail. diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go index 8be2872e4..c0a41bffb 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2/wire.go @@ -47,7 +47,7 @@ var ( // Such failures do not indicate that the connection is broken, but rather // should be returned to the caller to indicate that the specific request is // invalid in the current context. - ErrRejected = NewError(-32004, "rejected by transport") + ErrRejected = NewError(-32005, "rejected by transport") ) const wireVersion = "2.0" diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go index 1633d4e3c..a9ea78fa8 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/jsonrpc/jsonrpc.go @@ -17,6 +17,8 @@ type ( Request = jsonrpc2.Request // Response is a JSON-RPC response. Response = jsonrpc2.Response + // Error is a structured error in a JSON-RPC response. + Error = jsonrpc2.WireError ) // MakeID coerces the given Go value to an ID. The value should be the @@ -37,3 +39,18 @@ func EncodeMessage(msg Message) ([]byte, error) { func DecodeMessage(data []byte) (Message, error) { return jsonrpc2.DecodeMessage(data) } + +// Standard JSON-RPC 2.0 error codes. +// See https://www.jsonrpc.org/specification#error_object +const ( + // CodeParseError indicates invalid JSON was received by the server. + CodeParseError = -32700 + // CodeInvalidRequest indicates the JSON sent is not a valid Request object. + CodeInvalidRequest = -32600 + // CodeMethodNotFound indicates the method does not exist or is not available. + CodeMethodNotFound = -32601 + // CodeInvalidParams indicates invalid method parameter(s). + CodeInvalidParams = -32602 + // CodeInternalError indicates an internal JSON-RPC error. + CodeInternalError = -32603 +) diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go index d7e3ae5a6..2dc1a86c0 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/client.go @@ -7,9 +7,12 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" "iter" + "log/slog" "slices" + "strings" "sync" "sync/atomic" "time" @@ -24,6 +27,7 @@ import ( type Client struct { impl *Implementation opts ClientOptions + logger *slog.Logger // TODO: file proposal to export this mu sync.Mutex roots *featureSet[*Root] sessions []*ClientSession @@ -44,8 +48,9 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client { } c := &Client{ impl: impl, + logger: ensureLogger(nil), // ensure we have a logger roots: newFeatureSet(func(r *Root) string { return r.URI }), - sendingMethodHandler_: defaultSendingMethodHandler[*ClientSession], + sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], } if opts != nil { @@ -58,14 +63,59 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client { type ClientOptions struct { // CreateMessageHandler handles incoming requests for sampling/createMessage. // - // Setting CreateMessageHandler to a non-nil value causes the client to - // advertise the sampling capability. + // Setting CreateMessageHandler to a non-nil value automatically causes the + // client to advertise the sampling capability, with default value + // &SamplingCapabilities{}. If [ClientOptions.Capabilities] is set and has a + // non nil value for [ClientCapabilities.Sampling], that value overrides the + // inferred capability. CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) // ElicitationHandler handles incoming requests for elicitation/create. // - // Setting ElicitationHandler to a non-nil value causes the client to - // advertise the elicitation capability. + // Setting ElicitationHandler to a non-nil value automatically causes the + // client to advertise the elicitation capability, with default value + // &ElicitationCapabilities{}. If [ClientOptions.Capabilities] is set and has + // a non nil value for [ClientCapabilities.ELicitattion], that value + // overrides the inferred capability. ElicitationHandler func(context.Context, *ElicitRequest) (*ElicitResult, error) + // Capabilities optionally configures the client's default capabilities, + // before any capabilities are inferred from other configuration. + // + // If Capabilities is nil, the default client capabilities are + // {"roots":{"listChanged":true}}, for historical reasons. Setting + // Capabilities to a non-nil value overrides this default. As a special case, + // to work around #607, Capabilities.Roots is ignored: set + // Capabilities.RootsV2 to configure the roots capability. This allows the + // "roots" capability to be disabled entirely. + // + // For example: + // - To disable the "roots" capability, use &ClientCapabilities{} + // - To configure "roots", but disable "listChanged" notifications, use + // &ClientCapabilities{RootsV2:&RootCapabilities{}}. + // + // # Interaction with capability inference + // + // Sampling and elicitation capabilities are automatically added when their + // corresponding handlers are set, with the default value described at + // [ClientOptions.CreateMessageHandler] and + // [ClientOptions.ElicitationHandler]. If the Sampling or Elicitation fields + // are set in the Capabilities field, their values override the inferred + // value. + // + // For example, to to configure elicitation modes: + // + // Capabilities: &ClientCapabilities{ + // Elicitation: &ElicitationCapabilities{ + // Form: &FormElicitationCapabilities{}, + // URL: &URLElicitationCapabilities{}, + // }, + // } + // + // Conversely, if Capabilities does not set a field (for example, if the + // Elicitation field is nil), the inferred elicitation capability will be + // used. + Capabilities *ClientCapabilities + // ElicitationCompleteHandler handles incoming notifications for notifications/elicitation/complete. + ElicitationCompleteHandler func(context.Context, *ElicitationCompleteNotificationRequest) // Handlers for notifications from the server. ToolListChangedHandler func(context.Context, *ToolListChangedRequest) PromptListChangedHandler func(context.Context, *PromptListChangedRequest) @@ -113,16 +163,50 @@ func (e unsupportedProtocolVersionError) Error() string { } // ClientSessionOptions is reserved for future use. -type ClientSessionOptions struct{} +type ClientSessionOptions struct { + // protocolVersion overrides the protocol version sent in the initialize + // request, for testing. If empty, latestProtocolVersion is used. + protocolVersion string +} + +func (c *Client) capabilities(protocolVersion string) *ClientCapabilities { + // Start with user-provided capabilities as defaults, or use SDK defaults. + var caps *ClientCapabilities + if c.opts.Capabilities != nil { + // Deep copy the user-provided capabilities to avoid mutation. + caps = c.opts.Capabilities.clone() + } else { + // SDK defaults: roots with listChanged. + // (this was the default behavior at v1.0.0, and so cannot be changed) + caps = &ClientCapabilities{ + RootsV2: &RootCapabilities{ + ListChanged: true, + }, + } + } -func (c *Client) capabilities() *ClientCapabilities { - caps := &ClientCapabilities{} - caps.Roots.ListChanged = true + // Sync Roots from RootsV2 for backward compatibility (#607). + if caps.RootsV2 != nil { + caps.Roots = *caps.RootsV2 + } + + // Augment with sampling capability if handler is set. if c.opts.CreateMessageHandler != nil { - caps.Sampling = &SamplingCapabilities{} + if caps.Sampling == nil { + caps.Sampling = &SamplingCapabilities{} + } } + + // Augment with elicitation capability if handler is set. if c.opts.ElicitationHandler != nil { - caps.Elicitation = &ElicitationCapabilities{} + if caps.Elicitation == nil { + caps.Elicitation = &ElicitationCapabilities{} + // Form elicitation was added in 2025-11-25; for older versions, + // {} is treated the same as {"form":{}}. + if protocolVersion >= protocolVersion20251125 { + caps.Elicitation.Form = &FormElicitationCapabilities{} + } + } } return caps } @@ -134,16 +218,20 @@ func (c *Client) capabilities() *ClientCapabilities { // when it is no longer needed. However, if the connection is closed by the // server, calls or notifications will return an error wrapping // [ErrConnectionClosed]. -func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptions) (cs *ClientSession, err error) { +func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOptions) (cs *ClientSession, err error) { cs, err = connect(ctx, t, c, (*clientSessionState)(nil), nil) if err != nil { return nil, err } + protocolVersion := latestProtocolVersion + if opts != nil && opts.protocolVersion != "" { + protocolVersion = opts.protocolVersion + } params := &InitializeParams{ - ProtocolVersion: latestProtocolVersion, + ProtocolVersion: protocolVersion, ClientInfo: c.impl, - Capabilities: c.capabilities(), + Capabilities: c.capabilities(protocolVersion), } req := &InitializeRequest{Session: cs, Params: params} res, err := handleSend[*InitializeResult](ctx, methodInitialize, req) @@ -192,6 +280,10 @@ type ClientSession struct { // No mutex is (currently) required to guard the session state, because it is // only set synchronously during Client.Connect. state clientSessionState + + // Pending URL elicitations waiting for completion notifications. + pendingElicitationsMu sync.Mutex + pendingElicitations map[string]chan struct{} } type clientSessionState struct { @@ -236,6 +328,46 @@ func (cs *ClientSession) Wait() error { return cs.conn.Wait() } +// registerElicitationWaiter registers a waiter for an elicitation complete +// notification with the given elicitation ID. It returns two functions: an await +// function that waits for the notification or context cancellation, and a cleanup +// function that must be called to unregister the waiter. This must be called before +// triggering the elicitation to avoid a race condition where the notification +// arrives before the waiter is registered. +// +// The cleanup function must be called even if the await function is never called, +// to prevent leaking the registration. +func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await func(context.Context) error, cleanup func()) { + // Create a channel for this elicitation. + ch := make(chan struct{}, 1) + + // Register the channel. + cs.pendingElicitationsMu.Lock() + if cs.pendingElicitations == nil { + cs.pendingElicitations = make(map[string]chan struct{}) + } + cs.pendingElicitations[elicitationID] = ch + cs.pendingElicitationsMu.Unlock() + + // Return await and cleanup functions. + await = func(ctx context.Context) error { + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled while waiting for elicitation completion: %w", ctx.Err()) + case <-ch: + return nil + } + } + + cleanup = func() { + cs.pendingElicitationsMu.Lock() + delete(cs.pendingElicitations, elicitationID) + cs.pendingElicitationsMu.Unlock() + } + + return await, cleanup +} + // startKeepalive starts the keepalive mechanism for this client session. func (cs *ClientSession) startKeepalive(interval time.Duration) { startKeepalive(cs, interval, &cs.keepaliveCancel) @@ -269,10 +401,36 @@ func changeAndNotify[P Params](c *Client, notification string, params P, change // Lock for the change, but not for the notification. c.mu.Lock() if change() { - sessions = slices.Clone(c.sessions) + // Check if listChanged is enabled for this notification type. + if c.shouldSendListChangedNotification(notification) { + sessions = slices.Clone(c.sessions) + } } c.mu.Unlock() - notifySessions(sessions, notification, params) + notifySessions(sessions, notification, params, c.logger) +} + +// shouldSendListChangedNotification checks if the client's capabilities allow +// sending the given list-changed notification. +func (c *Client) shouldSendListChangedNotification(notification string) bool { + // Get effective capabilities (considering user-provided defaults). + caps := c.opts.Capabilities + + switch notification { + case notificationRootsListChanged: + // If user didn't specify capabilities, default behavior sends notifications. + if caps == nil { + return true + } + // Check RootsV2 first (preferred), then fall back to Roots. + if caps.RootsV2 != nil { + return caps.RootsV2.ListChanged + } + return caps.Roots.ListChanged + default: + // Unknown notification, allow by default. + return true + } } func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRootsResult, error) { @@ -290,44 +448,166 @@ func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRoots func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { if c.opts.CreateMessageHandler == nil { // TODO: wrap or annotate this error? Pick a standard code? - return nil, jsonrpc2.NewError(codeUnsupportedMethod, "client does not support CreateMessage") + return nil, &jsonrpc.Error{Code: codeUnsupportedMethod, Message: "client does not support CreateMessage"} } return c.opts.CreateMessageHandler(ctx, req) } -func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { - if c.opts.ElicitationHandler == nil { - // TODO: wrap or annotate this error? Pick a standard code? - return nil, jsonrpc2.NewError(codeUnsupportedMethod, "client does not support elicitation") +// urlElicitationMiddleware returns middleware that automatically handles URL elicitation +// required errors by executing the elicitation handler, waiting for completion notifications, +// and retrying the operation. +// +// This middleware should be added to clients that want automatic URL elicitation handling: +// +// client := mcp.NewClient(impl, opts) +// client.AddSendingMiddleware(mcp.urlElicitationMiddleware()) +// +// TODO(rfindley): this isn't strictly necessary for the SEP, but may be +// useful. Propose exporting it. +func urlElicitationMiddleware() Middleware { + return func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + // Call the underlying handler. + res, err := next(ctx, method, req) + if err == nil { + return res, nil + } + + // Check if this is a URL elicitation required error. + var rpcErr *jsonrpc.Error + if !errors.As(err, &rpcErr) || rpcErr.Code != CodeURLElicitationRequired { + return res, err + } + + // Notifications don't support retries. + if strings.HasPrefix(method, "notifications/") { + return res, err + } + + // Extract the client session. + cs, ok := req.GetSession().(*ClientSession) + if !ok { + return res, err + } + + // Check if the client has an elicitation handler. + if cs.client.opts.ElicitationHandler == nil { + return res, err + } + + // Parse the elicitations from the error data. + var errorData struct { + Elicitations []*ElicitParams `json:"elicitations"` + } + if rpcErr.Data != nil { + if err := json.Unmarshal(rpcErr.Data, &errorData); err != nil { + return nil, fmt.Errorf("failed to parse URL elicitation error data: %w", err) + } + } + + // Validate that all elicitations are URL mode. + for _, elicit := range errorData.Elicitations { + mode := elicit.Mode + if mode == "" { + mode = "form" // Default mode. + } + if mode != "url" { + return nil, fmt.Errorf("URLElicitationRequired error must only contain URL mode elicitations, got %q", mode) + } + } + + // Register waiters for all elicitations before executing handlers + // to avoid race condition where notification arrives before waiter is registered. + type waiter struct { + await func(context.Context) error + cleanup func() + } + waiters := make([]waiter, 0, len(errorData.Elicitations)) + for _, elicitParams := range errorData.Elicitations { + await, cleanup := cs.registerElicitationWaiter(elicitParams.ElicitationID) + waiters = append(waiters, waiter{await: await, cleanup: cleanup}) + } + + // Ensure cleanup happens even if we return early. + defer func() { + for _, w := range waiters { + w.cleanup() + } + }() + + // Execute the elicitation handler for each elicitation. + for _, elicitParams := range errorData.Elicitations { + elicitReq := newClientRequest(cs, elicitParams) + _, elicitErr := cs.client.elicit(ctx, elicitReq) + if elicitErr != nil { + return nil, fmt.Errorf("URL elicitation failed: %w", elicitErr) + } + } + + // Wait for all elicitations to complete. + for _, w := range waiters { + if err := w.await(ctx); err != nil { + return nil, err + } + } + + // All elicitations complete, retry the original operation. + return next(ctx, method, req) + } } +} - // Validate that the requested schema only contains top-level properties without nesting - schema, err := validateElicitSchema(req.Params.RequestedSchema) - if err != nil { - return nil, jsonrpc2.NewError(codeInvalidParams, err.Error()) +func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + if c.opts.ElicitationHandler == nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "client does not support elicitation"} } - res, err := c.opts.ElicitationHandler(ctx, req) - if err != nil { - return nil, err + // Validate the elicitation parameters based on the mode. + mode := req.Params.Mode + if mode == "" { + mode = "form" } - // Validate elicitation result content against requested schema - if schema != nil && res.Content != nil { - // TODO: is this the correct behavior if validation fails? - // It isn't the *server's* params that are invalid, so why would we return - // this code to the server? - resolved, err := schema.Resolve(nil) + switch mode { + case "form": + if req.Params.URL != "" { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "URL must not be set for form elicitation"} + } + schema, err := validateElicitSchema(req.Params.RequestedSchema) if err != nil { - return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err)) + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: err.Error()} } - - if err := resolved.Validate(res.Content); err != nil { - return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err)) + res, err := c.opts.ElicitationHandler(ctx, req) + if err != nil { + return nil, err } + // Validate elicitation result content against requested schema. + if schema != nil && res.Content != nil { + resolved, err := schema.Resolve(nil) + if err != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to resolve requested schema: %v", err)} + } + if err := resolved.Validate(res.Content); err != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("elicitation result content does not match requested schema: %v", err)} + } + err = resolved.ApplyDefaults(&res.Content) + if err != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err)} + } + } + return res, nil + case "url": + if req.Params.RequestedSchema != nil { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "requestedSchema must not be set for URL elicitation"} + } + if req.Params.URL == "" { + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "URL must be set for URL elicitation"} + } + // No schema validation for URL mode, just pass through to handler. + return c.opts.ElicitationHandler(ctx, req) + default: + return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("unsupported elicitation mode: %q", mode)} } - - return res, nil } // validateElicitSchema validates that the schema conforms to MCP elicitation schema requirements. @@ -341,6 +621,9 @@ func validateElicitSchema(wireSchema any) (*jsonschema.Schema, error) { if err := remarshal(wireSchema, &schema); err != nil { return nil, err } + if schema == nil { + return nil, nil + } // The root schema must be of type "object" if specified if schema.Type != "" && schema.Type != "object" { @@ -369,7 +652,6 @@ func validateElicitProperty(propName string, propSchema *jsonschema.Schema) erro if len(propSchema.Properties) > 0 { return fmt.Errorf("elicit schema property %q contains nested properties, only primitive properties are allowed", propName) } - // Validate based on the property type - only primitives are supported switch propSchema.Type { case "string": @@ -439,7 +721,7 @@ func validateElicitStringProperty(propName string, propSchema *jsonschema.Schema } } - return nil + return validateDefaultProperty[string](propName, propSchema) } // validateElicitNumberProperty validates number and integer-type properties. @@ -450,19 +732,28 @@ func validateElicitNumberProperty(propName string, propSchema *jsonschema.Schema } } + intDefaultError := validateDefaultProperty[int](propName, propSchema) + floatDefaultError := validateDefaultProperty[float64](propName, propSchema) + if intDefaultError != nil && floatDefaultError != nil { + return fmt.Errorf("elicit schema property %q has default value that cannot be interpreted as an int or float", propName) + } + return nil } // validateElicitBooleanProperty validates boolean-type properties. func validateElicitBooleanProperty(propName string, propSchema *jsonschema.Schema) error { - // Validate default value if specified - must be a valid boolean + return validateDefaultProperty[bool](propName, propSchema) +} + +func validateDefaultProperty[T any](propName string, propSchema *jsonschema.Schema) error { + // Validate default value if specified - must be a valid T if propSchema.Default != nil { - var defaultValue bool + var defaultValue T if err := json.Unmarshal(propSchema.Default, &defaultValue); err != nil { - return fmt.Errorf("elicit schema property %q has invalid default value, must be a boolean: %v", propName, err) + return fmt.Errorf("elicit schema property %q has invalid default value, must be a %T: %v", propName, defaultValue, err) } } - return nil } @@ -514,6 +805,7 @@ var clientMethodInfos = map[string]methodInfo{ notificationResourceUpdated: newClientMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), notification|missingParamsOK), notificationLoggingMessage: newClientMethodInfo(clientMethod((*Client).callLoggingHandler), notification), notificationProgress: newClientMethodInfo(clientSessionMethod((*ClientSession).callProgressNotificationHandler), notification), + notificationElicitationComplete: newClientMethodInfo(clientMethod((*Client).callElicitationCompleteHandler), notification|missingParamsOK), } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { @@ -678,6 +970,27 @@ func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, pa return nil, nil } +func (c *Client) callElicitationCompleteHandler(ctx context.Context, req *ElicitationCompleteNotificationRequest) (Result, error) { + // Check if there's a pending elicitation waiting for this notification. + if cs, ok := req.GetSession().(*ClientSession); ok { + cs.pendingElicitationsMu.Lock() + if ch, exists := cs.pendingElicitations[req.Params.ElicitationID]; exists { + select { + case ch <- struct{}{}: + default: + // Channel already signaled. + } + } + cs.pendingElicitationsMu.Unlock() + } + + // Call the user's handler if provided. + if h := c.opts.ElicitationCompleteHandler; h != nil { + h(ctx, req) + } + return nil, nil +} + // NotifyProgress sends a progress notification from the client to the server // associated with this session. // This can be used if the client is performing a long-running task that was diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go index e53cad14b..fb1a0d1e5 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/content.go @@ -130,6 +130,8 @@ type ResourceLink struct { Size *int64 Meta Meta Annotations *Annotations + // Icons for the resource link, if any. + Icons []Icon `json:"icons,omitempty"` } func (c *ResourceLink) MarshalJSON() ([]byte, error) { @@ -143,6 +145,7 @@ func (c *ResourceLink) MarshalJSON() ([]byte, error) { Size: c.Size, Meta: c.Meta, Annotations: c.Annotations, + Icons: c.Icons, }) } @@ -155,6 +158,7 @@ func (c *ResourceLink) fromWire(wire *wireContent) { c.Size = wire.Size c.Meta = wire.Meta c.Annotations = wire.Annotations + c.Icons = wire.Icons } // EmbeddedResource contains embedded resources. @@ -237,6 +241,7 @@ type wireContent struct { Size *int64 `json:"size,omitempty"` Meta Meta `json:"_meta,omitempty"` Annotations *Annotations `json:"annotations,omitempty"` + Icons []Icon `json:"icons,omitempty"` } func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, error) { diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go index 281f5925a..5c322c4a3 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/event.go @@ -29,14 +29,15 @@ const validateMemoryEventStore = false // An Event is a server-sent event. // See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#fields. type Event struct { - Name string // the "event" field - ID string // the "id" field - Data []byte // the "data" field + Name string // the "event" field + ID string // the "id" field + Data []byte // the "data" field + Retry string // the "retry" field } // Empty reports whether the Event is empty. func (e Event) Empty() bool { - return e.Name == "" && e.ID == "" && len(e.Data) == 0 + return e.Name == "" && e.ID == "" && len(e.Data) == 0 && e.Retry == "" } // writeEvent writes the event to w, and flushes. @@ -48,6 +49,9 @@ func writeEvent(w io.Writer, evt Event) (int, error) { if evt.ID != "" { fmt.Fprintf(&b, "id: %s\n", evt.ID) } + if evt.Retry != "" { + fmt.Fprintf(&b, "retry: %s\n", evt.Retry) + } fmt.Fprintf(&b, "data: %s\n\n", string(evt.Data)) n, err := w.Write(b.Bytes()) if f, ok := w.(http.Flusher); ok { @@ -73,6 +77,7 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { eventKey = []byte("event") idKey = []byte("id") dataKey = []byte("data") + retryKey = []byte("retry") ) return func(yield func(Event, error) bool) { @@ -119,6 +124,8 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { evt.Name = strings.TrimSpace(string(after)) case bytes.Equal(before, idKey): evt.ID = strings.TrimSpace(string(after)) + case bytes.Equal(before, retryKey): + evt.Retry = strings.TrimSpace(string(after)) case bytes.Equal(before, dataKey): data := bytes.TrimSpace(after) if dataBuf != nil { @@ -191,12 +198,8 @@ type dataList struct { } func (dl *dataList) appendData(d []byte) { - // If we allowed empty data, we would consume memory without incrementing the size. - // We could of course account for that, but we keep it simple and assume there is no - // empty data. - if len(d) == 0 { - panic("empty data item") - } + // Empty data consumes memory but doesn't increment size. However, it should + // be rare. dl.data = append(dl.data, d) dl.size += len(d) } diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go index 1312dfbdc..26c8982f8 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/protocol.go @@ -177,23 +177,89 @@ func (x *CancelledParams) isParams() {} func (x *CancelledParams) GetProgressToken() any { return getProgressToken(x) } func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } +// RootCapabilities describes a client's support for roots. +type RootCapabilities struct { + // ListChanged reports whether the client supports notifications for + // changes to the roots list. + ListChanged bool `json:"listChanged,omitempty"` +} + // Capabilities a client may support. Known capabilities are defined here, in // this schema, but this is not a closed set: any client can define its own, // additional capabilities. type ClientCapabilities struct { - // Experimental, non-standard capabilities that the client supports. + + // NOTE: any addition to ClientCapabilities must also be reflected in + // [ClientCapabilities.clone]. + + // Experimental reports non-standard capabilities that the client supports. Experimental map[string]any `json:"experimental,omitempty"` - // Present if the client supports listing roots. + // Roots describes the client's support for roots. + // + // Deprecated: use RootsV2. As described in #607, Roots should have been a + // pointer to a RootCapabilities value. Roots will be continue to be + // populated, but any new fields will only be added in the RootsV2 field. Roots struct { - // Whether the client supports notifications for changes to the roots list. + // ListChanged reports whether the client supports notifications for + // changes to the roots list. ListChanged bool `json:"listChanged,omitempty"` } `json:"roots,omitempty"` - // Present if the client supports sampling from an LLM. + // RootsV2 is present if the client supports roots. When capabilities are explicitly configured via [ClientOptions.Capabilities] + RootsV2 *RootCapabilities `json:"-"` + // Sampling is present if the client supports sampling from an LLM. Sampling *SamplingCapabilities `json:"sampling,omitempty"` - // Present if the client supports elicitation from the server. + // Elicitation is present if the client supports elicitation from the server. Elicitation *ElicitationCapabilities `json:"elicitation,omitempty"` } +// clone returns a deep copy of the ClientCapabilities. +func (c *ClientCapabilities) clone() *ClientCapabilities { + cp := *c + cp.RootsV2 = shallowClone(c.RootsV2) + cp.Sampling = shallowClone(c.Sampling) + if c.Elicitation != nil { + x := *c.Elicitation + x.Form = shallowClone(c.Elicitation.Form) + x.URL = shallowClone(c.Elicitation.URL) + cp.Elicitation = &x + } + return &cp +} + +// shallowClone returns a shallow clone of *p, or nil if p is nil. +func shallowClone[T any](p *T) *T { + if p == nil { + return nil + } + x := *p + return &x +} + +func (c *ClientCapabilities) toV2() *clientCapabilitiesV2 { + return &clientCapabilitiesV2{ + ClientCapabilities: *c, + Roots: c.RootsV2, + } +} + +// clientCapabilitiesV2 is a version of ClientCapabilities that fixes the bug +// described in #607: Roots should have been a pointer to value type +// RootCapabilities. +type clientCapabilitiesV2 struct { + ClientCapabilities + Roots *RootCapabilities `json:"roots,omitempty"` +} + +func (c *clientCapabilitiesV2) toV1() *ClientCapabilities { + caps := c.ClientCapabilities + caps.RootsV2 = c.Roots + // Sync Roots from RootsV2 for backward compatibility (#607). + if caps.RootsV2 != nil { + caps.Roots = *caps.RootsV2 + } + return &caps +} + type CompleteParamsArgument struct { // The name of the argument Name string `json:"name"` @@ -373,27 +439,53 @@ type GetPromptResult struct { func (*GetPromptResult) isResult() {} +// InitializeParams is sent by the client to initialize the session. type InitializeParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. - Meta `json:"_meta,omitempty"` + Meta `json:"_meta,omitempty"` + // Capabilities describes the client's capabilities. Capabilities *ClientCapabilities `json:"capabilities"` - ClientInfo *Implementation `json:"clientInfo"` - // The latest version of the Model Context Protocol that the client supports. - // The client may decide to support older versions as well. + // ClientInfo provides information about the client. + ClientInfo *Implementation `json:"clientInfo"` + // ProtocolVersion is the latest version of the Model Context Protocol that + // the client supports. ProtocolVersion string `json:"protocolVersion"` } +func (p *InitializeParams) toV2() *initializeParamsV2 { + return &initializeParamsV2{ + InitializeParams: *p, + Capabilities: p.Capabilities.toV2(), + } +} + +// initializeParamsV2 works around the mistake in #607: Capabilities.Roots +// should have been a pointer. +type initializeParamsV2 struct { + InitializeParams + Capabilities *clientCapabilitiesV2 `json:"capabilities"` +} + +func (p *initializeParamsV2) toV1() *InitializeParams { + p1 := p.InitializeParams + if p.Capabilities != nil { + p1.Capabilities = p.Capabilities.toV1() + } + return &p1 +} + func (x *InitializeParams) isParams() {} func (x *InitializeParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializeParams) SetProgressToken(t any) { setProgressToken(x, t) } -// After receiving an initialize request from the client, the server sends this -// response. +// InitializeResult is sent by the server in response to an initialize request +// from the client. type InitializeResult struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. - Meta `json:"_meta,omitempty"` + Meta `json:"_meta,omitempty"` + // Capabilities describes the server's capabilities. Capabilities *ServerCapabilities `json:"capabilities"` // Instructions describing how to use the server and its features. // @@ -411,8 +503,8 @@ type InitializeResult struct { func (*InitializeResult) isResult() {} type InitializedParams struct { - // This property is reserved by the protocol to allow clients and servers to - // attach additional metadata to their responses. + // Meta is reserved by the protocol to allow clients and servers to attach + // additional metadata to their responses. Meta `json:"_meta,omitempty"` } @@ -658,6 +750,34 @@ type ProgressNotificationParams struct { func (*ProgressNotificationParams) isParams() {} +// IconTheme specifies the theme an icon is designed for. +type IconTheme string + +const ( + // IconThemeLight indicates the icon is designed for a light background. + IconThemeLight IconTheme = "light" + // IconThemeDark indicates the icon is designed for a dark background. + IconThemeDark IconTheme = "dark" +) + +// Icon provides visual identifiers for their resources, tools, prompts, and implementations +// See [/specification/draft/basic/index#icons] for notes on icons +// +// TODO(iamsurajbobade): update specification url from draft. +type Icon struct { + // Source is A URI pointing to the icon resource (required). This can be: + // - An HTTP/HTTPS URL pointing to an image file + // - A data URI with base64-encoded image data + Source string `json:"src"` + // Optional MIME type if the server's type is missing or generic + MIMEType string `json:"mimeType,omitempty"` + // Optional size specification (e.g., ["48x48"], ["any"] for scalable formats like SVG, or ["48x48", "96x96"] for multiple sizes) + Sizes []string `json:"sizes,omitempty"` + // Optional theme specifier. "light" indicates the icon is designed for a light + // background, "dark" indicates the icon is designed for a dark background. + Theme IconTheme `json:"theme,omitempty"` +} + // A prompt or prompt template that the server offers. type Prompt struct { // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta @@ -673,6 +793,8 @@ type Prompt struct { // Intended for UI and end-user contexts โ€” optimized to be human-readable and // easily understood, even by those unfamiliar with domain-specific terminology. Title string `json:"title,omitempty"` + // Icons for the prompt, if any. + Icons []Icon `json:"icons,omitempty"` } // Describes an argument that a prompt can accept. @@ -782,6 +904,8 @@ type Resource struct { Title string `json:"title,omitempty"` // The URI of this resource. URI string `json:"uri"` + // Icons for the resource, if any. + Icons []Icon `json:"icons,omitempty"` } type ResourceListChangedParams struct { @@ -822,6 +946,8 @@ type ResourceTemplate struct { // A URI template (according to RFC 6570) that can be used to construct resource // URIs. URITemplate string `json:"uriTemplate"` + // Icons for the resource template, if any. + Icons []Icon `json:"icons,omitempty"` } // The sender or recipient of messages and data in a conversation. @@ -852,11 +978,27 @@ func (x *RootsListChangedParams) isParams() {} func (x *RootsListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } -// SamplingCapabilities describes the capabilities for sampling. +// TODO: to be consistent with ServerCapabilities, move the capability types +// below directly above ClientCapabilities. + +// SamplingCapabilities describes the client's support for sampling. type SamplingCapabilities struct{} // ElicitationCapabilities describes the capabilities for elicitation. -type ElicitationCapabilities struct{} +// +// If neither Form nor URL is set, the 'Form' capabilitiy is assumed. +type ElicitationCapabilities struct { + Form *FormElicitationCapabilities + URL *URLElicitationCapabilities +} + +// FormElicitationCapabilities describes capabilities for form elicitation. +type FormElicitationCapabilities struct { +} + +// URLElicitationCapabilities describes capabilities for url elicitation. +type URLElicitationCapabilities struct { +} // Describes a message issued to or received from an LLM API. type SamplingMessage struct { @@ -948,6 +1090,8 @@ type Tool struct { // If not provided, Annotations.Title should be used for display if present, // otherwise Name. Title string `json:"title,omitempty"` + // Icons for the tool, if any. + Icons []Icon `json:"icons,omitempty"` } // Additional properties describing a Tool to clients. @@ -1042,6 +1186,10 @@ type ElicitParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` + // The mode of elicitation to use. + // + // If unset, will be inferred from the other fields. + Mode string `json:"mode"` // The message to present to the user. Message string `json:"message"` // A JSON schema object defining the requested elicitation schema. @@ -1055,7 +1203,17 @@ type ElicitParams struct { // map[string]any). // // Only top-level properties are allowed, without nesting. - RequestedSchema any `json:"requestedSchema"` + // + // This is only used for "form" elicitation. + RequestedSchema any `json:"requestedSchema,omitempty"` + // The URL to present to the user. + // + // This is only used for "url" elicitation. + URL string `json:"url,omitempty"` + // The ID of the elicitation. + // + // This is only used for "url" elicitation. + ElicitationID string `json:"elicitationId,omitempty"` } func (x *ElicitParams) isParams() {} @@ -1080,6 +1238,18 @@ type ElicitResult struct { func (*ElicitResult) isResult() {} +// ElicitationCompleteParams is sent from the server to the client, informing it that an out-of-band elicitation interaction has completed. +type ElicitationCompleteParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The ID of the elicitation that has completed. This must correspond to the + // elicitationId from the original elicitation/create request. + ElicitationID string `json:"elicitationId"` +} + +func (*ElicitationCompleteParams) isParams() {} + // An Implementation describes the name and version of an MCP implementation, with an optional // title for UI representation. type Implementation struct { @@ -1090,50 +1260,71 @@ type Implementation struct { // easily understood, even by those unfamiliar with domain-specific terminology. Title string `json:"title,omitempty"` Version string `json:"version"` + // WebsiteURL for the server, if any. + WebsiteURL string `json:"websiteUrl,omitempty"` + // Icons for the Server, if any. + Icons []Icon `json:"icons,omitempty"` } -// Present if the server supports argument autocompletion suggestions. +// CompletionCapabilities describes the server's support for argument autocompletion. type CompletionCapabilities struct{} -// Present if the server supports sending log messages to the client. +// LoggingCapabilities describes the server's support for sending log messages to the client. type LoggingCapabilities struct{} -// Present if the server offers any prompt templates. +// PromptCapabilities describes the server's support for prompts. type PromptCapabilities struct { // Whether this server supports notifications for changes to the prompt list. ListChanged bool `json:"listChanged,omitempty"` } -// Present if the server offers any resources to read. +// ResourceCapabilities describes the server's support for resources. type ResourceCapabilities struct { - // Whether this server supports notifications for changes to the resource list. + // ListChanged reports whether the client supports notifications for + // changes to the resource list. ListChanged bool `json:"listChanged,omitempty"` - // Whether this server supports subscribing to resource updates. + // Subscribe reports whether this server supports subscribing to resource + // updates. Subscribe bool `json:"subscribe,omitempty"` } -// Capabilities that a server may support. Known capabilities are defined here, -// in this schema, but this is not a closed set: any server can define its own, -// additional capabilities. +// ToolCapabilities describes the server's support for tools. +type ToolCapabilities struct { + // ListChanged reports whether the client supports notifications for + // changes to the tool list. + ListChanged bool `json:"listChanged,omitempty"` +} + +// ServerCapabilities describes capabilities that a server supports. type ServerCapabilities struct { - // Present if the server supports argument autocompletion suggestions. - Completions *CompletionCapabilities `json:"completions,omitempty"` - // Experimental, non-standard capabilities that the server supports. + + // NOTE: any addition to ServerCapabilities must also be reflected in + // [ServerCapabilities.clone]. + + // Experimental reports non-standard capabilities that the server supports. Experimental map[string]any `json:"experimental,omitempty"` - // Present if the server supports sending log messages to the client. + // Completions is present if the server supports argument autocompletion + // suggestions. + Completions *CompletionCapabilities `json:"completions,omitempty"` + // Logging is present if the server supports log messages. Logging *LoggingCapabilities `json:"logging,omitempty"` - // Present if the server offers any prompt templates. + // Prompts is present if the server supports prompts. Prompts *PromptCapabilities `json:"prompts,omitempty"` - // Present if the server offers any resources to read. + // Resources is present if the server supports resourcs. Resources *ResourceCapabilities `json:"resources,omitempty"` - // Present if the server offers any tools to call. + // Tools is present if the supports tools. Tools *ToolCapabilities `json:"tools,omitempty"` } -// Present if the server offers any tools to call. -type ToolCapabilities struct { - // Whether this server supports notifications for changes to the tool list. - ListChanged bool `json:"listChanged,omitempty"` +// clone returns a deep copy of the ServerCapabilities. +func (c *ServerCapabilities) clone() *ServerCapabilities { + cp := *c + cp.Completions = shallowClone(c.Completions) + cp.Logging = shallowClone(c.Logging) + cp.Prompts = shallowClone(c.Prompts) + cp.Resources = shallowClone(c.Resources) + cp.Tools = shallowClone(c.Tools) + return &cp } const ( @@ -1142,6 +1333,7 @@ const ( methodComplete = "completion/complete" methodCreateMessage = "sampling/createMessage" methodElicit = "elicitation/create" + notificationElicitationComplete = "notifications/elicitation/complete" methodGetPrompt = "prompts/get" methodInitialize = "initialize" notificationInitialized = "notifications/initialized" diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go index 82b700f56..f64d6fb62 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/requests.go @@ -23,15 +23,16 @@ type ( ) type ( - CreateMessageRequest = ClientRequest[*CreateMessageParams] - ElicitRequest = ClientRequest[*ElicitParams] - initializedClientRequest = ClientRequest[*InitializedParams] - InitializeRequest = ClientRequest[*InitializeParams] - ListRootsRequest = ClientRequest[*ListRootsParams] - LoggingMessageRequest = ClientRequest[*LoggingMessageParams] - ProgressNotificationClientRequest = ClientRequest[*ProgressNotificationParams] - PromptListChangedRequest = ClientRequest[*PromptListChangedParams] - ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams] - ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams] - ToolListChangedRequest = ClientRequest[*ToolListChangedParams] + CreateMessageRequest = ClientRequest[*CreateMessageParams] + ElicitRequest = ClientRequest[*ElicitParams] + initializedClientRequest = ClientRequest[*InitializedParams] + InitializeRequest = ClientRequest[*InitializeParams] + ListRootsRequest = ClientRequest[*ListRootsParams] + LoggingMessageRequest = ClientRequest[*LoggingMessageParams] + ProgressNotificationClientRequest = ClientRequest[*ProgressNotificationParams] + PromptListChangedRequest = ClientRequest[*PromptListChangedParams] + ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams] + ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams] + ToolListChangedRequest = ClientRequest[*ToolListChangedParams] + ElicitationCompleteNotificationRequest = ClientRequest[*ElicitationCompleteParams] ) diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go index 8746edaed..dc657f5dd 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/resource.go @@ -15,8 +15,8 @@ import ( "path/filepath" "strings" - "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/util" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" "github.com/yosida95/uritemplate/v3" ) @@ -40,8 +40,8 @@ type ResourceHandler func(context.Context, *ReadResourceRequest) (*ReadResourceR // ResourceNotFoundError returns an error indicating that a resource being read could // not be found. func ResourceNotFoundError(uri string) error { - return &jsonrpc2.WireError{ - Code: codeResourceNotFound, + return &jsonrpc.Error{ + Code: CodeResourceNotFound, Message: "Resource not found", Data: json.RawMessage(fmt.Sprintf(`{"uri":%q}`, uri)), } diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go index 4a7bc89a8..1f7edf9c5 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/server.go @@ -10,6 +10,7 @@ import ( "encoding/base64" "encoding/gob" "encoding/json" + "errors" "fmt" "iter" "log/slog" @@ -50,6 +51,7 @@ type Server struct { sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool + pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } // ServerOptions is used to configure behavior of the server. @@ -79,14 +81,51 @@ type ServerOptions struct { SubscribeHandler func(context.Context, *SubscribeRequest) error // Function called when a client session unsubscribes from a resource. UnsubscribeHandler func(context.Context, *UnsubscribeRequest) error + + // Capabilities optionally configures the server's default capabilities, + // before any capabilities are inferred from other configuration or server + // features. + // + // If Capabilities is nil, the default server capabilities are {"logging":{}}, + // for historical reasons. Setting Capabilities to a non-nil value overrides + // this default. For example, setting Capabilities to `&ServerCapabilities{}` + // disables the logging capability. + // + // # Interaction with capability inference + // + // "tools", "prompts", and "resources" capabilities are automatically added when + // tools, prompts, or resources are added to the server (for example, via + // [Server.AddPrompt]), with default value `{"listChanged":true}`. Similarly, + // if the [ClientOptions.SubscribeHandler] or + // [ClientOptions.CompletionHandler] are set, the inferred capabilities are + // adjusted accordingly. + // + // Any non-nil field in Capabilities overrides the inferred value. + // For example: + // + // - To advertise the "tools" capability, even if no tools are added, set + // Capabilities.Tools to &ToolCapabilities{ListChanged:true}. + // - To disable tool list notifications, set Capabilities.Tools to + // &ToolCapabilities{}. + // + // Conversely, if Capabilities does not set a field (for example, if the + // Prompts field is nil), the inferred capability will be used. + Capabilities *ServerCapabilities + // If true, advertises the prompts capability during initialization, // even if no prompts have been registered. + // + // Deprecated: Use Capabilities instead. HasPrompts bool // If true, advertises the resources capability during initialization, // even if no resources have been registered. + // + // Deprecated: Use Capabilities instead. HasResources bool // If true, advertises the tools capability during initialization, // even if no tools have been registered. + // + // Deprecated: Use Capabilities instead. HasTools bool // GetSessionID provides the next session ID to use for an incoming request. @@ -145,9 +184,10 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), - sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession], + sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), + pendingNotifications: make(map[string]*time.Timer), } } @@ -157,15 +197,13 @@ func (s *Server) AddPrompt(p *Prompt, h PromptHandler) { // (It's possible an item was replaced with an identical one, but not worth checking.) s.changeAndNotify( notificationPromptListChanged, - &PromptListChangedParams{}, func() bool { s.prompts.add(&serverPrompt{p, h}); return true }) } // RemovePrompts removes the prompts with the given names. // It is not an error to remove a nonexistent prompt. func (s *Server) RemovePrompts(names ...string) { - s.changeAndNotify(notificationPromptListChanged, &PromptListChangedParams{}, - func() bool { return s.prompts.remove(names...) }) + s.changeAndNotify(notificationPromptListChanged, func() bool { return s.prompts.remove(names...) }) } // AddTool adds a [Tool] to the server, or replaces one with the same name. @@ -191,6 +229,9 @@ func (s *Server) RemovePrompts(names ...string) { // Most users should use the top-level function [AddTool], which handles all these // responsibilities. func (s *Server) AddTool(t *Tool, h ToolHandler) { + if err := validateToolName(t.Name); err != nil { + s.opts.Logger.Error(fmt.Sprintf("AddTool: invalid tool name %q: %v", t.Name, err)) + } if t.InputSchema == nil { // This prevents the tool author from forgetting to write a schema where // one should be provided. If we papered over this by supplying the empty @@ -231,8 +272,7 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { // (It's possible a tool was replaced with an identical one, but not worth checking.) // TODO: Batch these changes by size and time? The typescript SDK doesn't. // TODO: Surface notify error here? best not, in case we need to batch. - s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, - func() bool { s.tools.add(st); return true }) + s.changeAndNotify(notificationToolListChanged, func() bool { s.tools.add(st); return true }) } func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { @@ -289,12 +329,12 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan // Call typed handler. res, out, err := h(ctx, req, in) // Handle server errors appropriately: - // - If the handler returns a structured error (like jsonrpc2.WireError), return it directly + // - If the handler returns a structured error (like jsonrpc.Error), return it directly // - If the handler returns a regular error, wrap it in a CallToolResult with IsError=true // - This allows tools to distinguish between protocol errors and tool execution errors if err != nil { // Check if this is already a structured JSON-RPC error - if wireErr, ok := err.(*jsonrpc2.WireError); ok { + if wireErr, ok := err.(*jsonrpc.Error); ok { return nil, wireErr } // For regular errors, embed them in the tool result as per MCP spec @@ -415,14 +455,13 @@ func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { // RemoveTools removes the tools with the given names. // It is not an error to remove a nonexistent tool. func (s *Server) RemoveTools(names ...string) { - s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, - func() bool { return s.tools.remove(names...) }) + s.changeAndNotify(notificationToolListChanged, func() bool { return s.tools.remove(names...) }) } // AddResource adds a [Resource] to the server, or replaces one with the same URI. // AddResource panics if the resource URI is invalid or not absolute (has an empty scheme). func (s *Server) AddResource(r *Resource, h ResourceHandler) { - s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, + s.changeAndNotify(notificationResourceListChanged, func() bool { if _, err := url.Parse(r.URI); err != nil { panic(err) // url.Parse includes the URI in the error @@ -435,14 +474,13 @@ func (s *Server) AddResource(r *Resource, h ResourceHandler) { // RemoveResources removes the resources with the given URIs. // It is not an error to remove a nonexistent resource. func (s *Server) RemoveResources(uris ...string) { - s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, - func() bool { return s.resources.remove(uris...) }) + s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resources.remove(uris...) }) } // AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces one with the same URI. // AddResourceTemplate panics if a URI template is invalid or not absolute (has an empty scheme). func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { - s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, + s.changeAndNotify(notificationResourceListChanged, func() bool { // Validate the URI template syntax _, err := uritemplate.New(t.URITemplate) @@ -457,32 +495,56 @@ func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { // RemoveResourceTemplates removes the resource templates with the given URI templates. // It is not an error to remove a nonexistent resource. func (s *Server) RemoveResourceTemplates(uriTemplates ...string) { - s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, - func() bool { return s.resourceTemplates.remove(uriTemplates...) }) + s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resourceTemplates.remove(uriTemplates...) }) } func (s *Server) capabilities() *ServerCapabilities { s.mu.Lock() defer s.mu.Unlock() - caps := &ServerCapabilities{ - Logging: &LoggingCapabilities{}, + // Start with user-provided capabilities as defaults, or use SDK defaults. + var caps *ServerCapabilities + if s.opts.Capabilities != nil { + // Deep copy the user-provided capabilities to avoid mutation. + caps = s.opts.Capabilities.clone() + } else { + // SDK defaults: only logging capability. + caps = &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + } } + + // Augment with tools capability if tools exist or legacy HasTools is set. if s.opts.HasTools || s.tools.len() > 0 { - caps.Tools = &ToolCapabilities{ListChanged: true} + if caps.Tools == nil { + caps.Tools = &ToolCapabilities{ListChanged: true} + } } + + // Augment with prompts capability if prompts exist or legacy HasPrompts is set. if s.opts.HasPrompts || s.prompts.len() > 0 { - caps.Prompts = &PromptCapabilities{ListChanged: true} + if caps.Prompts == nil { + caps.Prompts = &PromptCapabilities{ListChanged: true} + } } + + // Augment with resources capability if resources/templates exist or legacy HasResources is set. if s.opts.HasResources || s.resources.len() > 0 || s.resourceTemplates.len() > 0 { - caps.Resources = &ResourceCapabilities{ListChanged: true} + if caps.Resources == nil { + caps.Resources = &ResourceCapabilities{ListChanged: true} + } if s.opts.SubscribeHandler != nil { caps.Resources.Subscribe = true } } + + // Augment with completions capability if handler is set. if s.opts.CompletionHandler != nil { - caps.Completions = &CompletionCapabilities{} + if caps.Completions == nil { + caps.Completions = &CompletionCapabilities{} + } } + return caps } @@ -493,18 +555,72 @@ func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteR return s.opts.CompletionHandler(ctx, req) } +// Map from notification name to its corresponding params. The params have no fields, +// so a single struct can be reused. +var changeNotificationParams = map[string]Params{ + notificationToolListChanged: &ToolListChangedParams{}, + notificationPromptListChanged: &PromptListChangedParams{}, + notificationResourceListChanged: &ResourceListChangedParams{}, +} + +// How long to wait before sending a change notification. +const notificationDelay = 10 * time.Millisecond + // changeAndNotify is called when a feature is added or removed. // It calls change, which should do the work and report whether a change actually occurred. -// If there was a change, it notifies a snapshot of the sessions. -func (s *Server) changeAndNotify(notification string, params Params, change func() bool) { - var sessions []*ServerSession - // Lock for the change, but not for the notification. +// If there was a change, it sets a timer to send a notification. +// This debounces change notifications: a single notification is sent after +// multiple changes occur in close proximity. +func (s *Server) changeAndNotify(notification string, change func() bool) { s.mu.Lock() - if change() { - sessions = slices.Clone(s.sessions) + defer s.mu.Unlock() + if change() && s.shouldSendListChangedNotification(notification) { + // Reset the outstanding delayed call, if any. + if t := s.pendingNotifications[notification]; t == nil { + s.pendingNotifications[notification] = time.AfterFunc(notificationDelay, func() { s.notifySessions(notification) }) + } else { + t.Reset(notificationDelay) + } + } +} + +// notifySessions sends the notification n to all existing sessions. +// It is called asynchronously by changeAndNotify. +func (s *Server) notifySessions(n string) { + s.mu.Lock() + sessions := slices.Clone(s.sessions) + s.pendingNotifications[n] = nil + s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock. + notifySessions(sessions, n, changeNotificationParams[n], s.opts.Logger) +} + +// shouldSendListChangedNotification checks if the server's capabilities allow +// sending the given list-changed notification. +func (s *Server) shouldSendListChangedNotification(notification string) bool { + // Get effective capabilities (considering user-provided defaults). + caps := s.opts.Capabilities + + switch notification { + case notificationToolListChanged: + // If user didn't specify capabilities, default behavior sends notifications. + if caps == nil || caps.Tools == nil { + return true + } + return caps.Tools.ListChanged + case notificationPromptListChanged: + if caps == nil || caps.Prompts == nil { + return true + } + return caps.Prompts.ListChanged + case notificationResourceListChanged: + if caps == nil || caps.Resources == nil { + return true + } + return caps.Resources.ListChanged + default: + // Unknown notification, allow by default. + return true } - s.mu.Unlock() - notifySessions(sessions, notification, params) } // Sessions returns an iterator that yields the current set of server sessions. @@ -538,8 +654,8 @@ func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetProm s.mu.Unlock() if !ok { // Return a proper JSON-RPC error with the correct error code - return nil, &jsonrpc2.WireError{ - Code: codeInvalidParams, + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), } } @@ -565,8 +681,8 @@ func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolR st, ok := s.tools.get(req.Params.Name) s.mu.Unlock() if !ok { - return nil, &jsonrpc2.WireError{ - Code: codeInvalidParams, + return nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("unknown tool %q", req.Params.Name), } } @@ -704,7 +820,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot subscribedSessions := s.resourceSubscriptions[params.URI] sessions := slices.Collect(maps.Keys(subscribedSessions)) s.mu.Unlock() - notifySessions(sessions, notificationResourceUpdated, params) + notifySessions(sessions, notificationResourceUpdated, params, s.opts.Logger) s.opts.Logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions)) return nil } @@ -1015,7 +1131,66 @@ func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*Eli if err := ss.checkInitialized(methodElicit); err != nil { return nil, err } - return handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params))) + if params == nil { + return nil, fmt.Errorf("%w: params cannot be nil", jsonrpc2.ErrInvalidParams) + } + + if params.Mode == "" { + params2 := *params + if params.URL != "" || params.ElicitationID != "" { + params2.Mode = "url" + } else { + params2.Mode = "form" + } + params = ¶ms2 + } + + if iparams := ss.InitializeParams(); iparams == nil || iparams.Capabilities == nil || iparams.Capabilities.Elicitation == nil { + return nil, fmt.Errorf("client does not support elicitation") + } + caps := ss.InitializeParams().Capabilities.Elicitation + switch params.Mode { + case "form": + if caps.Form == nil && caps.URL != nil { + // Note: if both 'Form' and 'URL' are nil, we assume the client supports + // form elicitation for backward compatibility. + return nil, errors.New(`client does not support "form" elicitation`) + } + case "url": + if caps.URL == nil { + return nil, errors.New(`client does not support "url" elicitation`) + } + } + + res, err := handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params))) + if err != nil { + return nil, err + } + + if params.RequestedSchema == nil { + return res, nil + } + schema, err := validateElicitSchema(params.RequestedSchema) + if err != nil { + return nil, err + } + if schema == nil { + return res, nil + } + + resolved, err := schema.Resolve(nil) + if err != nil { + return nil, err + } + if err := resolved.Validate(res.Content); err != nil { + return nil, fmt.Errorf("elicitation result content does not match requested schema: %v", err) + } + err = resolved.ApplyDefaults(&res.Content) + if err != nil { + return nil, fmt.Errorf("failed to apply schema defalts to elicitation result: %v", err) + } + + return res, nil } // Log sends a log message to the client. @@ -1074,7 +1249,7 @@ func (s *Server) AddReceivingMiddleware(middleware ...Middleware) { // curating these method flags. var serverMethodInfos = map[string]methodInfo{ methodComplete: newServerMethodInfo(serverMethod((*Server).complete), 0), - methodInitialize: newServerMethodInfo(serverSessionMethod((*ServerSession).initialize), 0), + methodInitialize: initializeMethodInfo(), methodPing: newServerMethodInfo(serverSessionMethod((*ServerSession).ping), missingParamsOK), methodListPrompts: newServerMethodInfo(serverMethod((*Server).listPrompts), missingParamsOK), methodGetPrompt: newServerMethodInfo(serverMethod((*Server).getPrompt), 0), @@ -1092,6 +1267,25 @@ var serverMethodInfos = map[string]methodInfo{ notificationProgress: newServerMethodInfo(serverSessionMethod((*ServerSession).callProgressNotificationHandler), notification), } +// initializeMethodInfo handles the workaround for #607: we must set +// params.Capabilities.RootsV2. +func initializeMethodInfo() methodInfo { + info := newServerMethodInfo(serverSessionMethod((*ServerSession).initialize), 0) + info.unmarshalParams = func(m json.RawMessage) (Params, error) { + var params *initializeParamsV2 + if m != nil { + if err := json.Unmarshal(m, ¶ms); err != nil { + return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, params, err) + } + } + if params == nil { + return nil, fmt.Errorf(`missing required "params"`) + } + return params.toV1(), nil + } + return info +} + func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go index e90bcbd8d..d83eae7da 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/shared.go @@ -15,7 +15,7 @@ import ( "context" "encoding/json" "fmt" - "log" + "log/slog" "net/http" "reflect" "slices" @@ -34,12 +34,14 @@ const ( // It is the version that the client sends in the initialization request, and // the default version used by the server. latestProtocolVersion = protocolVersion20250618 + protocolVersion20251125 = "2025-11-25" // not yet released protocolVersion20250618 = "2025-06-18" protocolVersion20250326 = "2025-03-26" protocolVersion20241105 = "2024-11-05" ) var supportedProtocolVersions = []string{ + protocolVersion20251125, protocolVersion20250618, protocolVersion20250326, protocolVersion20241105, @@ -86,20 +88,28 @@ func addMiddleware(handlerp *MethodHandler, middleware []Middleware) { } } -func defaultSendingMethodHandler[S Session](ctx context.Context, method string, req Request) (Result, error) { +func defaultSendingMethodHandler(ctx context.Context, method string, req Request) (Result, error) { info, ok := req.GetSession().sendingMethodInfos()[method] if !ok { // This can be called from user code, with an arbitrary value for method. return nil, jsonrpc2.ErrNotHandled } + params := req.GetParams() + if initParams, ok := params.(*InitializeParams); ok { + // Fix the marshaling of initialize params, to work around #607. + // + // The initialize params we produce should never be nil, nor have nil + // capabilities, so any panic here is a bug. + params = initParams.toV2() + } // Notifications don't have results. if strings.HasPrefix(method, "notifications/") { - return nil, req.GetSession().getConn().Notify(ctx, method, req.GetParams()) + return nil, req.GetSession().getConn().Notify(ctx, method, params) } // Create the result to unmarshal into. // The concrete type of the result is the return type of the receiving function. res := info.newResult() - if err := call(ctx, req.GetSession().getConn(), method, req.GetParams(), res); err != nil { + if err := call(ctx, req.GetSession().getConn(), method, params, res); err != nil { return nil, err } return res, nil @@ -329,21 +339,63 @@ func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Cont } } -// Error codes +// MCP-specific error codes. +const ( + // CodeResourceNotFound indicates that a requested resource could not be found. + CodeResourceNotFound = -32002 + // CodeURLElicitationRequired indicates that the server requires URL elicitation + // before processing the request. The client should execute the elicitation handler + // with the elicitations provided in the error data. + CodeURLElicitationRequired = -32042 +) + +// URLElicitationRequiredError returns an error indicating that URL elicitation is required +// before the request can be processed. The elicitations parameter should contain the +// elicitation requests that must be completed. +func URLElicitationRequiredError(elicitations []*ElicitParams) error { + // Validate that all elicitations are URL mode + for _, elicit := range elicitations { + mode := elicit.Mode + if mode == "" { + mode = "form" // default mode + } + if mode != "url" { + panic(fmt.Sprintf("URLElicitationRequiredError requires all elicitations to be URL mode, got %q", mode)) + } + } + + data, err := json.Marshal(map[string]any{ + "elicitations": elicitations, + }) + if err != nil { + // This should never happen with valid ElicitParams + panic(fmt.Sprintf("failed to marshal elicitations: %v", err)) + } + return &jsonrpc.Error{ + Code: CodeURLElicitationRequired, + Message: "URL elicitation required", + Data: json.RawMessage(data), + } +} + +// Internal error codes const ( - codeResourceNotFound = -32002 // The error code if the method exists and was called properly, but the peer does not support it. + // + // TODO(rfindley): this code is wrong, and we should fix it to be + // consistent with other SDKs. codeUnsupportedMethod = -31001 - // The error code for invalid parameters - codeInvalidParams = -32602 ) // notifySessions calls Notify on all the sessions. // Should be called on a copy of the peer sessions. -func notifySessions[S Session, P Params](sessions []S, method string, params P) { +// The logger must be non-nil. +func notifySessions[S Session, P Params](sessions []S, method string, params P, logger *slog.Logger) { if sessions == nil { return } + // Notify with the background context, so the messages are sent on the + // standalone stream. // TODO: make this timeout configurable, or call handleNotify asynchronously. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -353,8 +405,7 @@ func notifySessions[S Session, P Params](sessions []S, method string, params P) for _, s := range sessions { req := newRequest(s, params) if err := handleNotify(ctx, method, req); err != nil { - // TODO(jba): surface this error better - log.Printf("calling %s: %v", method, err) + logger.Warn(fmt.Sprintf("calling %s: %v", method, err)) } } } @@ -427,6 +478,24 @@ type ServerRequest[P Params] struct { type RequestExtra struct { TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any Header http.Header // header from HTTP request, if any + + // If set, CloseSSEStream explicitly closes the current SSE request stream. + // + // [SEP-1699] introduced server-side SSE stream disconnection: for + // long-running requests, servers may opt to close the SSE stream and + // ask the client to retry at a later time. CloseSSEStream implements this + // feature; if RetryAfter is set, an event is sent with a `retry:` field + // to configure the reconnection delay. + // + // [SEP-1699]: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1699 + CloseSSEStream func(CloseSSEStreamArgs) +} + +// CloseSSEStreamArgs are arguments for [RequestExtra.CloseSSEStream]. +type CloseSSEStreamArgs struct { + // RetryAfter configures the reconnection delay sent to the client via the + // SSE retry field. If zero, no retry field is sent. + RetryAfter time.Duration } func (*ClientRequest[P]) isRequest() {} diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go index 178b24662..b4b2fa310 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable.go @@ -2,6 +2,10 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. +// NOTE: see streamable_server.go and streamable_client.go for detailed +// documentation of the streamable server design. +// TODO: move the client and server logic into those files. + package mcp import ( @@ -25,12 +29,14 @@ import ( "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/xcontext" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) const ( protocolVersionHeader = "Mcp-Protocol-Version" sessionIDHeader = "Mcp-Session-Id" + lastEventIDHeader = "Last-Event-ID" ) // A StreamableHTTPHandler is an http.Handler that serves streamable MCP @@ -50,6 +56,10 @@ type StreamableHTTPHandler struct { type sessionInfo struct { session *ServerSession transport *StreamableServerTransport + // userID is the user ID from the TokenInfo when the session was created. + // If non-empty, subsequent requests must have the same user ID to prevent + // session hijacking. + userID string // If timeout is set, automatically close the session after an idle period. timeout time.Duration @@ -237,6 +247,15 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "session not found", http.StatusNotFound) return } + // Prevent session hijacking: if the session was created with a user ID, + // verify that subsequent requests come from the same user. + if sessInfo != nil && sessInfo.userID != "" { + tokenInfo := auth.TokenInfoFromContext(req.Context()) + if tokenInfo == nil || tokenInfo.UserID != sessInfo.userID { + http.Error(w, "session user mismatch", http.StatusForbidden) + return + } + } } if req.Method == http.MethodDelete { @@ -403,9 +422,16 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "failed connection", http.StatusInternalServerError) return } + // Capture the user ID from the token info to enable session hijacking + // prevention on subsequent requests. + var userID string + if tokenInfo := auth.TokenInfoFromContext(req.Context()); tokenInfo != nil { + userID = tokenInfo.UserID + } sessInfo = &sessionInfo{ session: session, transport: transport, + userID: userID, } if stateless { @@ -592,24 +618,147 @@ type stream struct { // The standalone SSE stream has id "". id string + // logger is used for logging errors during stream operations. + logger *slog.Logger + // mu guards the fields below, as well as storage of new messages in the // connection's event store (if any). mu sync.Mutex - // If non-nil, deliver writes data directly to the HTTP response. + // If pendingJSONMessages is non-nil, this is a JSON stream and messages are + // collected here until the stream is complete, at which point they are + // flushed as a single JSON response. Note that the non-nilness of this field + // is significant, as it signals the expected content type. // - // Only one HTTP response may receive messages at a given time. An active - // HTTP connection acquires ownership of the stream by setting this field. - deliver func(data []byte, final bool) error + // Note: if we remove support for batching, this could just be a bool. + pendingJSONMessages []json.RawMessage + + // w is the HTTP response writer for this stream. A non-nil w indicates + // that the stream is claimed by an HTTP request (the hanging POST or GET); + // it is set to nil when the request completes. + w http.ResponseWriter + + // done is closed to release the hanging HTTP request. + // + // Invariant: a non-nil done implies w is also non-nil, though the converse + // is not necessarily true: done is set to nil when it is closed, to avoid + // duplicate closure. + done chan struct{} + + // lastIdx is the index of the last written SSE event, for event ID generation. + // It starts at -1 since indices start at 0. + lastIdx int - // streamRequests is the set of unanswered incoming requests for the stream. + // protocolVersion is the protocol version for this stream. + protocolVersion string + + // requests is the set of unanswered incoming requests for the stream. // // Requests are removed when their response has been received. + // In practice, there is only one request, but in the 2025-03-26 version of + // the spec and earlier there was a concept of batching, in which POST + // payloads could hold multiple requests or responses. requests map[jsonrpc.ID]struct{} } +// close sends a 'close' event to the client (if protocolVersion >= 2025-11-25 +// and reconnectAfter > 0) and closes the done channel. +// +// The done channel is set to nil after closing, so that done != nil implies +// the stream is active and done is open. This simplifies checks elsewhere. +func (s *stream) close(reconnectAfter time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done == nil { + return // stream not connected or already closed + } + if s.protocolVersion >= protocolVersion20251125 && reconnectAfter > 0 { + reconnectStr := strconv.FormatInt(reconnectAfter.Milliseconds(), 10) + if _, err := writeEvent(s.w, Event{ + Name: "close", + Retry: reconnectStr, + }); err != nil { + s.logger.Warn(fmt.Sprintf("Writing close event: %v", err)) + } + } + close(s.done) + s.done = nil +} + +// release releases the stream from its HTTP request, allowing it to be +// claimed by another request (e.g., for resumption). +func (s *stream) release() { + s.mu.Lock() + defer s.mu.Unlock() + s.w = nil + s.done = nil // may already be nil, if the stream is done or closed +} + +// deliverLocked writes data to the stream (for SSE) or stores it in +// pendingJSONMessages (for JSON mode). The eventID is used for SSE event ID; +// pass "" to omit. +// +// If responseTo is valid, it is removed from the requests map. When all +// requests have been responded to, the done channel is closed and set to nil. +// +// Returns true if the stream is now done (all requests have been responded to). +// The done value is always accurate, even if an error is returned. +// +// s.mu must be held when calling this method. +func (s *stream) deliverLocked(data []byte, eventID string, responseTo jsonrpc.ID) (done bool, err error) { + // First, record the response. We must do this *before* returning an error + // below, as even if the stream is disconnected we want to update our + // accounting. + if responseTo.IsValid() { + delete(s.requests, responseTo) + } + // Now, try to deliver the message to the client. + done = len(s.requests) == 0 && s.id != "" + if s.done == nil { + return done, fmt.Errorf("stream not connected or already closed") + } + if done { + defer func() { close(s.done); s.done = nil }() + } + // Try to write to the response. + // + // If we get here, the request is still hanging (because s.done != nil + // implies s.w != nil), but may have been cancelled by the client/http layer: + // there's a brief race between request cancellation and releasing the + // stream. + if s.pendingJSONMessages != nil { + s.pendingJSONMessages = append(s.pendingJSONMessages, data) + if done { + // Flush all pending messages as JSON response. + var toWrite []byte + if len(s.pendingJSONMessages) == 1 { + toWrite = s.pendingJSONMessages[0] + } else { + toWrite, err = json.Marshal(s.pendingJSONMessages) + if err != nil { + return done, err + } + } + if _, err := s.w.Write(toWrite); err != nil { + return done, err + } + } + } else { + // SSE mode: write event to response writer. + s.lastIdx++ + if _, err := writeEvent(s.w, Event{Name: "message", Data: data, ID: eventID}); err != nil { + return done, err + } + } + return done, nil +} + // doneLocked reports whether the stream is logically complete. // +// s.requests was populated when reading the POST body, requests are deleted as +// they are responded to. Once all requests have been responded to, the stream +// is done. +// // s.mu must be held while calling this function. func (s *stream) doneLocked() bool { return len(s.requests) == 0 && s.id != "" @@ -624,6 +773,8 @@ func (c *streamableServerConn) newStream(ctx context.Context, requests map[jsonr return &stream{ id: id, requests: requests, + lastIdx: -1, // indices start at 0, incremented before each write + logger: c.logger, }, nil } @@ -679,8 +830,8 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request // By default, we haven't seen a last index. Since indices start at 0, we represent // that by -1. This is incremented just before each event is written. lastIdx := -1 - if len(req.Header.Values("Last-Event-ID")) > 0 { - eid := req.Header.Get("Last-Event-ID") + if len(req.Header.Values(lastEventIDHeader)) > 0 { + eid := req.Header.Get(lastEventIDHeader) var ok bool streamID, lastIdx, ok = parseEventID(eid) if !ok { @@ -693,52 +844,42 @@ func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request } } - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() + ctx := req.Context() - stream, done := c.acquireStream(ctx, w, streamID, &lastIdx) + // Read the protocol version from the header. For GET requests, this should + // always be present since GET only happens after initialization. + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + + stream, done := c.acquireStream(ctx, w, streamID, lastIdx, protocolVersion) if stream == nil { return } - // Release the stream when we're done. - defer func() { - stream.mu.Lock() - stream.deliver = nil - stream.mu.Unlock() - }() + defer stream.release() + c.hangResponse(ctx, done) +} +// hangResponse blocks the HTTP response until one of three conditions is met: +// - ctx is cancelled (the client disconnected or the request timed out) +// - done is closed (all responses have been sent, or the stream was explicitly closed) +// - the session is closed +// +// This keeps the HTTP connection open so that server-sent events can be +// written to the response. +func (c *streamableServerConn) hangResponse(ctx context.Context, done <-chan struct{}) { select { case <-ctx.Done(): - // request cancelled case <-done: - // request complete case <-c.done: - // session closed } } -// writeEvent writes an SSE event to w corresponding to the given stream, data, and index. -// lastIdx is incremented before writing, so that it continues to point to the index of the -// last event written to the stream. -func (c *streamableServerConn) writeEvent(w http.ResponseWriter, stream *stream, data []byte, lastIdx *int) error { - *lastIdx++ - e := Event{ - Name: "message", - Data: data, - } - if c.eventStore != nil { - e.ID = formatEventID(stream.id, *lastIdx) - } - if _, err := writeEvent(w, e); err != nil { - return err - } - return nil -} - -// acquireStream acquires the stream and replays all events since lastIdx, if -// any, updating lastIdx accordingly. If non-nil, the resulting stream will be -// registered for receiving new messages, and the resulting done channel will -// be closed when all related messages have been delivered. +// acquireStream replays all events since lastIdx, and acquires the ongoing +// stream, if any. If non-nil, the resulting stream will be registered for +// receiving new messages, and the stream's done channel will be closed when +// all related messages have been delivered. // // If any errors occur, they will be written to w and the resulting stream will // be nil. The resulting stream may also be nil if the stream is complete. @@ -746,10 +887,15 @@ func (c *streamableServerConn) writeEvent(w http.ResponseWriter, stream *stream, // Importantly, this function must hold the stream mutex until done replaying // all messages, so that no delivery or storage of new messages occurs while // the stream is still replaying. -func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx *int) (*stream, chan struct{}) { +// +// protocolVersion is the protocol version for this stream, used to determine +// feature support (e.g. prime and close events were added in 2025-11-25). +func (c *streamableServerConn) acquireStream(ctx context.Context, w http.ResponseWriter, streamID string, lastIdx int, protocolVersion string) (*stream, chan struct{}) { // if tempStream is set, the stream is done and we're just replaying messages. // - // We record a temporary stream to claim exclusive replay rights. + // We record a temporary stream to claim exclusive replay rights. The spec + // (https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#resumability-and-redelivery) + // does not explicitly require exclusive replay, but we enforce it defensively. tempStream := false c.mu.Lock() s, ok := c.streams[streamID] @@ -757,12 +903,12 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons // The stream is logically done, but claim exclusive rights to replay it by // adding a temporary entry in the streams map. // - // We create this entry with a non-nil deliver function, to ensure it isn't - // claimed by another request before we lock it below. + // We create this entry with a non-nil w, to ensure it isn't claimed by + // another request before we lock it below. tempStream = true s = &stream{ - id: streamID, - deliver: func([]byte, bool) error { return nil }, + id: streamID, + w: w, } c.streams[streamID] = s @@ -779,7 +925,7 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons defer s.mu.Unlock() // Check that this stream wasn't claimed by another request. - if !tempStream && s.deliver != nil { + if !tempStream && s.w != nil { http.Error(w, "stream ID conflicts with ongoing stream", http.StatusConflict) return nil, nil } @@ -792,7 +938,7 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons // messages, and registered our delivery function. var toReplay [][]byte if c.eventStore != nil { - for data, err := range c.eventStore.After(ctx, c.SessionID(), s.id, *lastIdx) { + for data, err := range c.eventStore.After(ctx, c.SessionID(), s.id, lastIdx) { if err != nil { // We can't replay events, perhaps because the underlying event store // has garbage collected its storage. @@ -805,7 +951,9 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons http.Error(w, "failed to replay events", http.StatusBadRequest) return nil, nil } - toReplay = append(toReplay, data) + if len(data) > 0 { + toReplay = append(toReplay, data) + } } } @@ -823,7 +971,12 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons } for _, data := range toReplay { - if err := c.writeEvent(w, s, data, lastIdx); err != nil { + lastIdx++ + e := Event{Name: "message", Data: data} + if c.eventStore != nil { + e.ID = formatEventID(s.id, lastIdx) + } + if _, err := writeEvent(w, e); err != nil { return nil, nil } } @@ -833,20 +986,13 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons return nil, nil } - // The stream is not done: register a delivery function before the stream is + // The stream is not done: set up delivery state before the stream is // unlocked, allowing the connection to write new events. - done := make(chan struct{}) - s.deliver = func(data []byte, final bool) error { - if err := ctx.Err(); err != nil { - return err - } - err := c.writeEvent(w, s, data, lastIdx) - if final { - close(done) - } - return err - } - return s, done + s.w = w + s.done = make(chan struct{}) + s.lastIdx = lastIdx + s.protocolVersion = protocolVersion + return s, s.done } // servePOST handles an incoming message, and replies with either an outgoing @@ -855,7 +1001,7 @@ func (c *streamableServerConn) acquireStream(ctx context.Context, w http.Respons // // It returns an HTTP status code and error message. func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Request) { - if len(req.Header.Values("Last-Event-ID")) > 0 { + if len(req.Header.Values(lastEventIDHeader)) > 0 { http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest) return } @@ -870,6 +1016,9 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, "POST requires a non-empty body", http.StatusBadRequest) return } + // TODO(#674): once we've documented the support matrix for 2025-03-26 and + // earlier, drop support for matching entirely; that will simplify this + // logic. incoming, isBatch, err := readBatch(body) if err != nil { http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) @@ -896,6 +1045,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques calls := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) isInitialize := false + var initializeProtocolVersion string for _, msg := range incoming { if jreq, ok := msg.(*jsonrpc.Request); ok { // Preemptively check that this is a valid request, so that we can fail @@ -907,19 +1057,53 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } if jreq.Method == methodInitialize { isInitialize = true + // Extract the protocol version from InitializeParams. + var params InitializeParams + if err := json.Unmarshal(jreq.Params, ¶ms); err == nil { + initializeProtocolVersion = params.ProtocolVersion + } } + // Include metadata for all requests (including notifications). jreq.Extra = &RequestExtra{ TokenInfo: tokenInfo, Header: req.Header, } if jreq.IsCall() { calls[jreq.ID] = struct{}{} + // See the doc for CloseSSEStream: allow the request handler to + // explicitly close the ongoing stream. + jreq.Extra.(*RequestExtra).CloseSSEStream = func(args CloseSSEStreamArgs) { + c.mu.Lock() + streamID, ok := c.requestStreams[jreq.ID] + var stream *stream + if ok { + stream = c.streams[streamID] + } + c.mu.Unlock() + + if stream != nil { + stream.close(args.RetryAfter) + } + } } } } + // The prime and close events were added in protocol version 2025-11-25 (SEP-1699). + // Use the version from InitializeParams if this is an initialize request, + // otherwise use the protocol version header. + effectiveVersion := protocolVersion + if isInitialize && initializeProtocolVersion != "" { + effectiveVersion = initializeProtocolVersion + } + // If we don't have any calls, we can just publish the incoming messages and return. // No need to track a logical stream. + // + // See section [ยง2.1.4] of the spec: "If the server accepts the input, the + // server MUST return HTTP status code 202 Accepted with no body." + // + // [ยง2.1.4]: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#sending-messages-to-the-server if len(calls) == 0 { for _, msg := range incoming { select { @@ -959,56 +1143,39 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques w.Header().Set(sessionIDHeader, c.sessionID) } - // Message delivery has two paths, depending on whether we're responding with JSON or - // event stream. - done := make(chan struct{}) // closed after the final response is written + // Set up stream delivery state. + stream.w = w + done := make(chan struct{}) + stream.done = done + stream.protocolVersion = effectiveVersion if c.jsonResponse { - var msgs []json.RawMessage - stream.deliver = func(data []byte, final bool) error { - // Collect messages until we've received the final response. + // JSON mode: collect messages in pendingJSONMessages until done. + // Set pendingJSONMessages to a non-nil value to signal that this is an + // application/json stream. + stream.pendingJSONMessages = []json.RawMessage{} + } else { + // SSE mode: write a priming event if supported. + if c.eventStore != nil && effectiveVersion >= protocolVersion20251125 { + // Write a priming event, as defined by [ยง2.1.6] of the spec. // - // In recent protocol versions, there should only be one message as - // batching is disabled, as checked above. - msgs = append(msgs, data) - if !final { - return nil - } - defer close(done) // final response - - // Write either the JSON object corresponding to the one response, or a - // JSON array corresponding to the batch response. - var toWrite []byte - if len(msgs) == 1 { - toWrite = []byte(msgs[0]) - } else { - var err error - toWrite, err = json.Marshal(msgs) - if err != nil { - return err - } + // [ยง2.1.6]: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#sending-messages-to-the-server + // + // We must also write it to the event store in order for indexes to + // align. + if err := c.eventStore.Append(req.Context(), c.sessionID, stream.id, nil); err != nil { + c.logger.Warn(fmt.Sprintf("Storing priming event: %v", err)) } - _, err = w.Write(toWrite) - return err - } - } else { - // Write events in the order we receive them. - lastIndex := -1 - stream.deliver = func(data []byte, final bool) error { - if final { - defer close(done) + stream.lastIdx++ + e := Event{Name: "prime", ID: formatEventID(stream.id, stream.lastIdx)} + if _, err := writeEvent(w, e); err != nil { + c.logger.Warn(fmt.Sprintf("Writing priming event: %v", err)) } - return c.writeEvent(w, stream, data, &lastIndex) } } - // Release ownership of the stream by unsetting deliver. - defer func() { - stream.mu.Lock() - // TODO(rfindley): if we have no event store, we should really cancel all - // remaining requests here, since the client will never get the results. - stream.deliver = nil - stream.mu.Unlock() - }() + // TODO(rfindley): if we have no event store, we should really cancel all + // remaining requests here, since the client will never get the results. + defer stream.release() // The stream is now set up to deliver messages. // @@ -1027,6 +1194,8 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // Note: don't select on req.Context().Done() here, since we've already // received the requests and may have already published a response message // or notification. The client could resume the stream. + // + // In fact, this send could be in a separate goroutine. case <-c.done: // Session closed: we don't know if any data has been written, so it's // too late to write a status code here. @@ -1034,14 +1203,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } - select { - case <-req.Context().Done(): - // request cancelled - case <-done: - // request complete - case <-c.done: - // session is closed - } + c.hangResponse(req.Context(), done) } // Event IDs: encode both the logical connection ID and the index, as @@ -1100,7 +1262,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e return err } - if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() && (c.stateless || c.sessionID == "") { + if req, ok := msg.(*jsonrpc.Request); ok && req.IsCall() && (c.stateless || c.sessionID == "") { // Requests aren't possible with stateless servers, or when there's no session ID. return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected) } @@ -1161,48 +1323,42 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e s.mu.Lock() defer s.mu.Unlock() - if s.doneLocked() { - // It's possible that the stream was completed in between getting s above, - // and acquiring the stream lock. In order to avoid acquiring s.mu while - // holding c.mu, we check the terminal condition again. - return fmt.Errorf("%w: write to closed stream", jsonrpc2.ErrRejected) - } - // Perform accounting on responses. - if responseTo.IsValid() { - if _, ok := s.requests[responseTo]; !ok { - panic(fmt.Sprintf("internal error: stream %v: response to untracked request %v", s.id, responseTo)) - } - if s.id == "" { - // This should be guaranteed not to happen by the stream resolution logic - // above, but be defensive: we don't ever want to delete the standalone - // stream. - panic("internal error: response on standalone stream") - } - delete(s.requests, responseTo) - if len(s.requests) == 0 { - c.mu.Lock() - delete(c.streams, s.id) - c.mu.Unlock() - } - } + // Store in eventStore before delivering. + // TODO(rfindley): we should only append if the response is SSE, not JSON, by + // pushing down into the delivery layer. delivered := false + var errs []error if c.eventStore != nil { if err := c.eventStore.Append(ctx, c.sessionID, s.id, data); err != nil { - // TODO: report a side-channel error. + errs = append(errs, err) } else { delivered = true } } - if s.deliver != nil { - if err := s.deliver(data, s.doneLocked()); err != nil { - // TODO: report a side-channel error. - } else { - delivered = true - } + + // Compute eventID for SSE streams with event store. + // Use s.lastIdx + 1 because deliverLocked increments before writing. + var eventID string + if c.eventStore != nil { + eventID = formatEventID(s.id, s.lastIdx+1) } + + done, err := s.deliverLocked(data, eventID, responseTo) + if err != nil { + errs = append(errs, err) + } else { + delivered = true + } + + if done { + c.mu.Lock() + delete(c.streams, s.id) + c.mu.Unlock() + } + if !delivered { - return fmt.Errorf("%w: undelivered message", jsonrpc2.ErrRejected) + return fmt.Errorf("%w: undelivered message: %v", jsonrpc2.ErrRejected, errors.Join(errs...)) } return nil } @@ -1249,12 +1405,17 @@ const ( // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. // It must be 1.0 or greater if MaxRetries is greater than 0. reconnectGrowFactor = 1.5 - // reconnectInitialDelay is the base delay for the first reconnect attempt. - reconnectInitialDelay = 1 * time.Second // reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely. reconnectMaxDelay = 30 * time.Second ) +var ( + // reconnectInitialDelay is the base delay for the first reconnect attempt. + // + // Mutable for testing. + reconnectInitialDelay = 1 * time.Second +) + // Connect implements the [Transport] interface. // // The resulting [Connection] writes messages via POST requests to the @@ -1277,7 +1438,20 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er // Create a new cancellable context that will manage the connection's lifecycle. // This is crucial for cleanly shutting down the background SSE listener by // cancelling its blocking network operations, which prevents hangs on exit. - connCtx, cancel := context.WithCancel(ctx) + // + // This context should be detached from the incoming context: the standalone + // SSE request should not break when the connection context is done. + // + // For example, consider that the user may want to wait at most 5s to connect + // to the server, and therefore uses a context with a 5s timeout when calling + // client.Connect. Let's suppose that Connect returns after 1s, and the user + // starts using the resulting session. If we didn't detach here, the session + // would break after 4s, when the background SSE stream is terminated. + // + // Instead, creating a cancellable context detached from the incoming context + // allows us to preserve context values (which may be necessary for auth + // middleware), yet only cancel the standalone stream when the connection is closed. + connCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) conn := &streamableClientConn{ url: t.Endpoint, client: client, @@ -1285,7 +1459,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er done: make(chan struct{}), maxRetries: maxRetries, strict: t.strict, - logger: t.logger, + logger: ensureLogger(t.logger), // must be non-nil for safe logging ctx: connCtx, cancel: cancel, failed: make(chan struct{}), @@ -1296,8 +1470,8 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er type streamableClientConn struct { url string client *http.Client - ctx context.Context - cancel context.CancelFunc + ctx context.Context // connection context, detached from Connect + cancel context.CancelFunc // cancels ctx incoming chan jsonrpc.Message maxRetries int strict bool // from [StreamableClientTransport.strict] @@ -1360,9 +1534,13 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { } func (c *streamableClientConn) connectStandaloneSSE() { - resp, err := c.connectSSE("") + resp, err := c.connectSSE(c.ctx, "", 0, true) if err != nil { - c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err)) + // If the client didn't cancel the request, and failure breaks the logical + // session. + if c.ctx.Err() == nil { + c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err)) + } return } @@ -1377,14 +1555,13 @@ func (c *streamableClientConn) connectStandaloneSSE() { resp.Body.Close() return } - if resp.StatusCode == http.StatusNotFound && !c.strict { - // modelcontextprotocol/gosdk#393: some servers return NotFound instead - // of MethodNotAllowed for the standalone SSE stream. + if resp.StatusCode >= 400 && resp.StatusCode < 500 && !c.strict { + // modelcontextprotocol/go-sdk#393,#610: some servers return NotFound or + // other status codes instead of MethodNotAllowed for the standalone SSE + // stream. // // Treat this like MethodNotAllowed in non-strict mode. - if c.logger != nil { - c.logger.Warn("got 404 instead of 405 for standalone SSE stream") - } + c.logger.Warn(fmt.Sprintf("got %d instead of 405 for standalone SSE stream", resp.StatusCode)) resp.Body.Close() return } @@ -1393,7 +1570,7 @@ func (c *streamableClientConn) connectStandaloneSSE() { c.fail(err) return } - go c.handleSSE(summary, resp, true, nil) + go c.handleSSE(c.ctx, summary, resp, nil) } // fail handles an asynchronous error while reading. @@ -1452,11 +1629,13 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } var requestSummary string - var isCall bool + var forCall *jsonrpc.Request switch msg := msg.(type) { case *jsonrpc.Request: requestSummary = fmt.Sprintf("sending %q", msg.Method) - isCall = msg.IsCall() + if msg.IsCall() { + forCall = msg + } case *jsonrpc.Response: requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) default: @@ -1478,11 +1657,18 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e resp, err := c.client.Do(req) if err != nil { - return fmt.Errorf("%s: %v", requestSummary, err) + // Any error from client.Do means the request didn't reach the server. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err) } if err := c.checkResponse(requestSummary, resp); err != nil { - c.fail(err) + // Only fail the connection for non-transient errors. + // Transient errors (wrapped with ErrRejected) should not break the connection. + if !errors.Is(err, jsonrpc2.ErrRejected) { + c.fail(err) + } return err } @@ -1498,22 +1684,24 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return fmt.Errorf("mismatching session IDs %q and %q", hadSessionID, sessionID) } } - // TODO(rfindley): this logic isn't quite right. - // We should keep going even if the server returns 202, if we have a call. - if resp.StatusCode == http.StatusNoContent || resp.StatusCode == http.StatusAccepted { + + if forCall == nil { + resp.Body.Close() + // [ยง2.1.4]: "If the input is a JSON-RPC response or notification: // If the server accepts the input, the server MUST return HTTP status code 202 Accepted with no body." // // [ยง2.1.4]: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server - resp.Body.Close() - return nil - } else if !isCall && !c.strict { - // Some servers return 200, even with an empty json body. - // Ignore this response in non-strict mode. - if c.logger != nil { - c.logger.Warn(fmt.Sprintf("unexpected status code %d from non-call", resp.StatusCode)) + if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusAccepted { + errMsg := fmt.Sprintf("unexpected status code %d from non-call", resp.StatusCode) + // Some servers return 200, even with an empty json body. + // + // In strict mode, return an error to the caller. + c.logger.Warn(errMsg) + if c.strict { + return errors.New(errMsg) + } } - resp.Body.Close() return nil } @@ -1527,8 +1715,11 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e if jsonReq, ok := msg.(*jsonrpc.Request); ok && jsonReq.IsCall() { forCall = jsonReq } - // TODO: should we cancel this logical SSE request if/when jsonReq is canceled? - go c.handleSSE(requestSummary, resp, false, forCall) + // Handle the resulting stream. Note that ctx comes from the call, and + // therefore is already cancelled when the JSON-RPC request is cancelled + // (or rather, context cancellation is what *triggers* JSON-RPC + // cancellation) + go c.handleSSE(ctx, requestSummary, resp, forCall) default: resp.Body.Close() @@ -1579,34 +1770,43 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp // persistent (for the main GET listener) or temporary (for a POST response). // // If forCall is set, it is the call that initiated the stream, and the -// stream is complete when we receive its response. -func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) { +// stream is complete when we receive its response. Otherwise, this is the +// standalone stream. +func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc2.Request) { for { // Connection was successful. Continue the loop with the new response. - // TODO: we should set a reasonable limit on the number of times we'll try - // getting a response for a given request. + // + // TODO(#679): we should set a reasonable limit on the number of times + // we'll try getting a response for a given request, or enforce that we + // actually make progress. // // Eventually, if we don't get the response, we should stop trying and // fail the request. - lastEventID, clientClosed := c.processStream(requestSummary, resp, forCall) + lastEventID, reconnectDelay, clientClosed := c.processStream(ctx, requestSummary, resp, forCall) // If the connection was closed by the client, we're done. if clientClosed { return } - // If the stream has ended, then do not reconnect if the stream is - // temporary (POST initiated SSE). - if lastEventID == "" && !persistent { + // If we don't have a last event ID, we can never get the call response, so + // there's nothing to resume. For the standalone stream, we can reconnect, + // but we may just miss messages. + if lastEventID == "" && forCall != nil { return } // The stream was interrupted or ended by the server. Attempt to reconnect. - newResp, err := c.connectSSE(lastEventID) + newResp, err := c.connectSSE(ctx, lastEventID, reconnectDelay, false) if err != nil { - // All reconnection attempts failed: fail the connection. - c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) + // If the client didn't cancel this request, any failure to execute it + // breaks the logical MCP session. + if ctx.Err() == nil { + // All reconnection attempts failed: fail the connection. + c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) + } return } + resp = newResp if err := c.checkResponse(requestSummary, resp); err != nil { c.fail(err) @@ -1633,8 +1833,13 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R // session is already gone. return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing) } + // Transient server errors (502, 503, 504, 429) should not break the connection. + // Wrap them with ErrRejected so the jsonrpc2 layer doesn't set writeErr. + if isTransientHTTPStatus(resp.StatusCode) { + return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, http.StatusText(resp.StatusCode)) + } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("%s: failed to connect: %v", requestSummary, http.StatusText(resp.StatusCode)) + return fmt.Errorf("%s: %v", requestSummary, http.StatusText(resp.StatusCode)) } return nil } @@ -1643,11 +1848,17 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R // incoming channel. It returns the ID of the last processed event and a flag // indicating if the connection was closed by the client. If resp is nil, it // returns "", false. -func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, clientClosed bool) { - defer resp.Body.Close() +func (c *streamableClientConn) processStream(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) { + defer func() { + // Drain any remaining unprocessed body. This allows the connection to be re-used after closing. + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() for evt, err := range scanEvents(resp.Body) { if err != nil { - // TODO: we should differentiate EOF from other errors here. + if ctx.Err() != nil { + return "", 0, true // don't reconnect: client cancelled + } break } @@ -1655,30 +1866,46 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R lastEventID = evt.ID } + if evt.Retry != "" { + if n, err := strconv.ParseInt(evt.Retry, 10, 64); err == nil { + reconnectDelay = time.Duration(n) * time.Millisecond + } + } + // According to SSE spec, events with no name default to "message" + if evt.Name != "" && evt.Name != "message" { + continue + } + msg, err := jsonrpc.DecodeMessage(evt.Data) if err != nil { c.fail(fmt.Errorf("%s: failed to decode event: %v", requestSummary, err)) - return "", true + return "", 0, true } select { case c.incoming <- msg: + // Check if this is the response to our call, which terminates the request. + // (it could also be a server->client request or notification). if jsonResp, ok := msg.(*jsonrpc.Response); ok && forCall != nil { // TODO: we should never get a response when forReq is nil (the standalone SSE request). // We should detect this case. if jsonResp.ID == forCall.ID { - return "", true + return "", 0, true } } + case <-c.done: // The connection was closed by the client; exit gracefully. - return "", true + return "", 0, true } } // The loop finished without an error, indicating the server closed the stream. // // If the lastEventID is "", the stream is not retryable and we should // report a synthetic error for the call. + // + // Note that this is different from the cancellation case above, since the + // caller is still waiting for a response that will never come. if lastEventID == "" && forCall != nil { errmsg := &jsonrpc2.Response{ ID: forCall.ID, @@ -1689,7 +1916,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R case <-c.done: } } - return lastEventID, false + return lastEventID, reconnectDelay, false } // connectSSE handles the logic of connecting a text/event-stream connection. @@ -1699,31 +1926,53 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R // If connection fails, connectSSE retries with an exponential backoff // strategy. It returns a new, valid HTTP response if successful, or an error // if all retries are exhausted. -func (c *streamableClientConn) connectSSE(lastEventID string) (*http.Response, error) { +// +// reconnectDelay is the delay set by the server using the SSE retry field, or +// 0. +// +// If initial is set, this is the initial attempt. +// +// If connectSSE exits due to context cancellation, the result is (nil, ctx.Err()). +func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID string, reconnectDelay time.Duration, initial bool) (*http.Response, error) { var finalErr error - // If lastEventID is set, we've already connected successfully once, so - // consider that to be the first attempt. attempt := 0 - if lastEventID != "" { + if !initial { + // We've already connected successfully once, so delay subsequent + // reconnections. Otherwise, if the server returns 200 but terminates the + // connection, we'll reconnect as fast as we can, ad infinitum. + // + // TODO: we should consider also setting a limit on total attempts for one + // logical request. attempt = 1 } + delay := calculateReconnectDelay(attempt) + if reconnectDelay > 0 { + delay = reconnectDelay // honor the server's requested initial delay + } for ; attempt <= c.maxRetries; attempt++ { select { case <-c.done: return nil, fmt.Errorf("connection closed by client during reconnect") - case <-time.After(calculateReconnectDelay(attempt)): - req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil) + + case <-ctx.Done(): + // If the connection context is canceled, the request below will not + // succeed anyway. + return nil, ctx.Err() + + case <-time.After(delay): + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) if err != nil { return nil, err } c.setMCPHeaders(req) if lastEventID != "" { - req.Header.Set("Last-Event-ID", lastEventID) + req.Header.Set(lastEventIDHeader, lastEventID) } req.Header.Set("Accept", "text/event-stream") resp, err := c.client.Do(req) if err != nil { finalErr = err // Store the error and try again. + delay = calculateReconnectDelay(attempt + 1) continue } return resp, nil @@ -1775,3 +2024,17 @@ func calculateReconnectDelay(attempt int) time.Duration { return backoffDuration + jitter } + +// isTransientHTTPStatus reports whether the HTTP status code indicates a +// transient server error that should not permanently break the connection. +func isTransientHTTPStatus(statusCode int) bool { + switch statusCode { + case http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout, // 504 + http.StatusTooManyRequests: // 429 + return true + } + return false +} diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go new file mode 100644 index 000000000..41a100461 --- /dev/null +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_client.go @@ -0,0 +1,226 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// TODO: move client-side streamable HTTP logic from streamable.go to this file. + +package mcp + +/* +Streamable HTTP Client Design + +This document describes the client-side implementation of the MCP streamable +HTTP transport, as defined by the MCP spec: +https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http + +# Overview + +The client-side streamable transport allows an MCP client to communicate with a +server over HTTP, sending messages via POST and receiving responses via either +JSON or server-sent events (SSE). The implementation consists of two main +components: + + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ [StreamableClientTransport] โ”‚ + โ”‚ Transport configuration; creates connections via Connect() โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ + โ–ผ + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ [streamableClientConn] โ”‚ + โ”‚ Connection implementation; handles HTTP request/response โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ + โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ–ผ โ–ผ + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ POST request handlers โ”‚ โ”‚ Standalone SSE stream โ”‚ + โ”‚ (one per outgoing message/call) โ”‚ โ”‚ (server-initiated messages) โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + +# Sessions + +The client maintains a session with the server, identified by a session ID +(Mcp-Session-Id header): + + - Session ID is received from the server after initialization + - Client includes the session ID in all subsequent requests + - Session ends when the client calls Close() (sends DELETE) or server returns 404 + +[streamableClientConn] stores the session state: + - [streamableClientConn.sessionID]: Server-assigned session identifier + - [streamableClientConn.initializedResult]: Protocol version and server capabilities + +# Connection Lifecycle + +1. Connect: [StreamableClientTransport.Connect] creates a [streamableClientConn] + with a detached context for the connection's lifetime. The context is detached + to prevent the standalone SSE stream from being cancelled when the original + Connect context times out. + +2. Initialize: The MCP client sends initialize/initialized messages. Upon + receiving [InitializeResult], the connection: + - Stores the negotiated protocol version for the Mcp-Protocol-Version header + - Captures the session ID from the Mcp-Session-Id response header + - Starts the standalone SSE stream via [streamableClientConn.connectStandaloneSSE] + +3. Operation: Messages are sent via POST, responses received via JSON or SSE. + +4. Close: [streamableClientConn.Close] sends a DELETE request to terminate + the session (unless the session is already gone), then cancels the connection + context to clean up the standalone SSE stream. + +# Sending Messages (Write) + +[streamableClientConn.Write] sends all outgoing messages via HTTP POST: + + POST /endpoint + Content-Type: application/json + Accept: application/json, text/event-stream + Mcp-Protocol-Version: + Mcp-Session-Id: + + + +The server may respond with: + - 202 Accepted: Message received, no response body (notifications/responses) + - 200 OK with application/json: Single JSON-RPC response + - 200 OK with text/event-stream: SSE stream of responses + +# Receiving Messages (Read) + +[streamableClientConn.Read] returns messages from the [streamableClientConn.incoming] +channel, which is populated by multiple concurrent goroutines: + +1. POST response handlers ([streamableClientConn.handleJSON] and + [streamableClientConn.handleSSE]): Process responses from POST requests + +2. Standalone SSE stream: Receives server-initiated requests and notifications + +The client handles both response formats: + - JSON: [streamableClientConn.handleJSON] reads body, decodes message + - SSE: [streamableClientConn.handleSSE] scans events, decodes each message + +# Standalone SSE Stream + +After initialization, [streamableClientConn.sessionUpdated] triggers +[streamableClientConn.connectStandaloneSSE] to open a GET request for +server-initiated messages: + + GET /endpoint + Accept: text/event-stream + Mcp-Session-Id: + +Stream behavior: + - Optional: Server may return 405 Method Not Allowed (spec-compliant) or + other 4xx errors (tolerated in non-strict mode for compatibility) + - Persistent: Runs for the connection lifetime in a background goroutine + - Resumable: Uses Last-Event-ID header on reconnection if server provides event IDs + - Reconnects: Automatic reconnection with exponential backoff on interruption + +# Stream Resumption + +When an SSE stream (standalone or POST response) is interrupted, the client +attempts to reconnect using [streamableClientConn.connectSSE]: + +Event ID tracking: + - [streamableClientConn.processStream] tracks the last received event ID + - On reconnection, the Last-Event-ID header is set to resume from that point + - Server replays missed events if it has an [EventStore] configured + +See [calculateReconnectDelay] for the reconnect delay details. + +Server-initiated reconnection (SEP-1699) + - SSE retry field: Sets the delay for the next reconnect attempt + - If server doesn't provide event IDs, non-standalone streams don't reconnect + +# Response Formats + +The client must handle two response formats from POST requests: + +1. application/json: Single JSON-RPC response + - Body contains one JSON-RPC message + - Handled by [streamableClientConn.handleJSON] + - Simpler but doesn't support streaming or server-initiated messages + +2. text/event-stream: SSE stream of messages + - Body contains SSE events with JSON-RPC messages + - Handled by [streamableClientConn.handleSSE] + - Supports multiple messages and server-initiated communication + - Stream completes when the response to the originating call is received + +# HTTP Methods + + - POST: Send JSON-RPC messages (requests, responses, notifications) + - Used by [streamableClientConn.Write] + - Response may be JSON or SSE + + - GET: Open or resume SSE stream for server-initiated messages + - Used by [streamableClientConn.connectSSE] + - Always expects text/event-stream response (or 405) + + - DELETE: Terminate the session + - Used by [streamableClientConn.Close] + - Skipped if session is already known to be gone ([errSessionMissing]) + +# Error Handling + +Errors are categorized and handled differently: + +1. Transient (recoverable via reconnection): + - Network interruption during SSE streaming + - Connection reset or timeout + - Triggers reconnection in [streamableClientConn.handleSSE] + +2. Terminal (breaks the connection): + - 404 Not Found: Session terminated by server ([errSessionMissing]) + - Message decode errors: Protocol violation + - Context cancellation: Client closed connection + - Mismatched session IDs: Protocol error + - See issue #683: our terminal errors are too strict. + +Terminal errors are stored via [streamableClientConn.fail] and returned by +subsequent [streamableClientConn.Read] calls. The [streamableClientConn.failed] +channel signals that the connection is broken. + +Special case: [errSessionMissing] indicates the server has terminated the session, +so [streamableClientConn.Close] skips the DELETE request. + +# Protocol Version Header + +After initialization, all requests include: + + Mcp-Protocol-Version: + +This header (set by [streamableClientConn.setMCPHeaders]): + - Allows the server to handle requests per the negotiated protocol + - Is omitted before initialization completes + - Uses the version from [streamableClientConn.initializedResult] + +# Key Implementation Details + +[StreamableClientTransport] configuration: + - [StreamableClientTransport.Endpoint]: URL of the MCP server + - [StreamableClientTransport.HTTPClient]: Custom HTTP client (optional) + - [StreamableClientTransport.MaxRetries]: Reconnection attempts (default 5) + +[streamableClientConn] handles the [Connection] interface: + - [streamableClientConn.Read]: Returns messages from incoming channel + - [streamableClientConn.Write]: Sends messages via POST, starts response handlers + - [streamableClientConn.Close]: Sends DELETE, cancels context, closes done channel + +State management: + - [streamableClientConn.incoming]: Buffered channel for received messages + - [streamableClientConn.sessionID]: Server-assigned session identifier + - [streamableClientConn.initializedResult]: Cached for protocol version header + - [streamableClientConn.failed]: Channel closed on terminal error + - [streamableClientConn.done]: Channel closed on graceful shutdown + - [streamableClientConn.ctx]: Detached context for connection lifetime + - [streamableClientConn.cancel]: Cancels ctx to terminate SSE streams + +Context handling: + - Connection context is detached from [StreamableClientTransport.Connect] context + using [xcontext.Detach] to preserve context values (for auth middleware) while + preventing premature cancellation of the standalone SSE stream + - Individual POST requests use caller-provided contexts for cancellation +*/ diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_server.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_server.go new file mode 100644 index 000000000..8a573e56a --- /dev/null +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/streamable_server.go @@ -0,0 +1,160 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// TODO: move server-side streamable HTTP logic from streamable.go to this file. + +package mcp + +/* +Streamable HTTP Server Design + +This document describes the server-side implementation of the MCP streamable +HTTP transport, as defined by the MCP spec: +https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http + +# Overview + +The streamable HTTP transport enables MCP communication over HTTP, with +server-sent events (SSE) for server-to-client messages. The implementation +consists of several layered components: + + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ [StreamableHTTPHandler] โ”‚ + โ”‚ http.Handler that manages sessions and routes HTTP requests โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ + โ–ผ + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ [StreamableServerTransport] โ”‚ + โ”‚ transport implementation, one per session; exposes ServeHTTP โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ + โ–ผ + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ [streamableServerConn] โ”‚ + โ”‚ Connection implementation, handles message routing โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ + โ–ผ + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + โ”‚ [stream] โ”‚ + โ”‚ Logical message channel within a session, may be resumed โ”‚ + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + +# Sessions + +As with other transports, a session represents a logical MCP connection between +a client and server. In the streamable transport, sessions are identified by a +unique session ID (Mcp-Session-Id header) and persist across multiple HTTP +requests. + +[StreamableHTTPHandler] maintains a map of active sessions ([sessionInfo]), +each containing: + - The [ServerSession] (MCP-level session state) + - The [StreamableServerTransport] (for message I/O) + - Optional timeout management for idle session cleanup + +Sessions are created on the first POST request (typically containing the +initialize request) and destroyed either by: + - Client sending a DELETE request + - Session timeout due to inactivity + - Server explicitly closing the session + +# Streams + +Within a session, there can be multiple concurrent "streams" - logical channels +for message delivery. This is distinct from HTTP streams; a single [stream] may +span multiple HTTP request/response cycles (via resumption). + +There are two types of streams: + +1. Optional standalone SSE stream (id = ""): + - Created when client sends a GET request to the endpoint + - Used for server-initiated messages (requests/notifications to client) + - Persists for the lifetime of the session + - Only one standalone stream per session + +2. Request streams (id = random string): + - Created for each POST request containing JSON-RPC calls + - Used to route responses back to the originating HTTP request + - Completed when all responses have been sent + - Can be resumed via GET with Last-Event-ID if interrupted + +# Message Routing + +When the server writes a message, it must be routed to the correct [stream]: + + - Responses: Routed to the stream that originated the request + - Requests/Notifications made during request handling: Routed to the same + stream as the triggering request (via context) + - Requests/Notifications made outside request handling: Routed to the + standalone SSE stream + +This routing is implemented using: + - [streamableServerConn.requestStreams] maps request IDs to stream IDs + - [idContextKey] is used to store the originating request ID in Context + - [streamableServerConn.streams] maps stream IDs to [stream] objects + +# Stream Resumption + +If an HTTP connection is interrupted (network issues, etc.), clients can +resume a stream by sending a GET request with the Last-Event-ID header. +This requires an [EventStore] to be configured on the server. + + - [EventStore.Open] is called when a new stream is created + - [EventStore.Append] is called for each message written to the stream + - [EventStore.After] is called to replay messages after a given index + - [EventStore.SessionClosed] is called when the session ends + +Event IDs are formatted as "_" to identify both the +stream and position within that stream (see [formatEventID] and [parseEventID]). + +# Stateless Mode + +For simpler deployments, the handler supports "stateless" mode +([StreamableHTTPOptions.Stateless]) where: + - No session ID validation is performed + - Each request creates a temporary session that's closed after the request + - Server-to-client requests are not supported (no way to receive response) + +This mode is useful for simple tool servers that don't need bidirectional +communication. + +# Response Formats + +The server can respond to POST requests in two formats: + +1. text/event-stream (default): Messages sent as SSE events, supports + streaming multiple messages and server-initiated communication during + request handling. + +2. application/json ([StreamableHTTPOptions.JSONResponse]): Single JSON + response, simpler but doesn't support streaming. Server-initiated messages + during request handling go to the standalone SSE stream instead. + +# HTTP Methods + + - POST: Send JSON-RPC messages (requests, responses, notifications) + - GET: Open standalone SSE stream or resume an interrupted stream + - DELETE: Terminate the session + +# Key Implementation Details + +The [stream] struct manages delivery of messages to HTTP responses. + +Fields: + - [stream.w] is the ResponseWriter for the current HTTP response (non-nil indicates claimed) + - [stream.done] is closed to release the hanging HTTP request + - [stream.requests] tracks pending request IDs (stream completes when empty) + +Methods: + - [stream.deliverLocked] delivers a message to the stream + - [stream.close] sends a close event and releases the stream + - [stream.release] releases the stream from the HTTP request, allowing resumption + +[streamableServerConn] handles the [Connection] interface: + - [streamableServerConn.Read] receives messages from the incoming channel (fed by POST handlers) + - [streamableServerConn.Write] routes messages to appropriate streams + - [streamableServerConn.Close] terminates the session and notifies the [EventStore] +*/ diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go index 12b02b7bb..8aa7c3c0d 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/tool.go @@ -8,6 +8,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "github.com/google/jsonschema-go/jsonschema" ) @@ -101,3 +102,38 @@ func applySchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawM } return data, nil } + +// validateToolName checks whether name is a valid tool name, reporting a +// non-nil error if not. +func validateToolName(name string) error { + if name == "" { + return fmt.Errorf("tool name cannot be empty") + } + if len(name) > 128 { + return fmt.Errorf("tool name exceeds maximum length of 128 characters (current: %d)", len(name)) + } + // For consistency with other SDKs, report characters in the order the appear + // in the name. + var invalidChars []string + seen := make(map[rune]bool) + for _, r := range name { + if !validToolNameRune(r) { + if !seen[r] { + invalidChars = append(invalidChars, fmt.Sprintf("%q", string(r))) + seen[r] = true + } + } + } + if len(invalidChars) > 0 { + return fmt.Errorf("tool name contains invalid characters: %s", strings.Join(invalidChars, ", ")) + } + return nil +} + +// validToolNameRune reports whether r is valid within tool names. +func validToolNameRune(r rune) bool { + return (r >= 'a' && r <= 'z') || + (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || + r == '_' || r == '-' || r == '.' +} diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go index cacd65fd5..25f1d5d05 100644 --- a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/mcp/transport.go @@ -69,7 +69,7 @@ type Connection interface { type clientConnection interface { Connection - // SessionUpdated is called whenever the client session state changes. + // sessionUpdated is called whenever the client session state changes. sessionUpdated(clientSessionState) } @@ -204,8 +204,7 @@ func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result a // call executes and awaits a jsonrpc2 call on the given connection, // translating errors into the mcp domain. func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params Params, result Result) error { - // TODO: the "%w"s in this function effectively make jsonrpc2.WireError part of the API. - // Consider alternatives. + // The "%w"s in this function expose jsonrpc.Error as part of the API. call := conn.Call(ctx, method, params) err := call.Await(ctx, result) switch { @@ -217,6 +216,18 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params Reason: ctx.Err().Error(), RequestID: call.ID().Raw(), }) + // By default, the jsonrpc2 library waits for graceful shutdown when the + // connection is closed, meaning it expects all outgoing and incoming + // requests to complete. However, for MCP this expectation is unrealistic, + // and can lead to hanging shutdown. For example, if a streamable client is + // killed, the server will not be able to detect this event, except via + // keepalive pings (if they are configured), and so outgoing calls may hang + // indefinitely. + // + // Therefore, we choose to eagerly retire calls, removing them from the + // outgoingCalls map, when the caller context is cancelled: if the caller + // will never receive the response, there's no need to track it. + conn.Retire(call, ctx.Err()) return errors.Join(ctx.Err(), err) case err != nil: return fmt.Errorf("calling %q: %w", method, err) @@ -381,7 +392,8 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn { var tr [1]byte if n, readErr := dec.Buffered().Read(tr[:]); n > 0 { // If read byte is not a newline, it is an error. - if tr[0] != '\n' { + // Support both Unix (\n) and Windows (\r\n) line endings. + if tr[0] != '\n' && tr[0] != '\r' { err = fmt.Errorf("invalid trailing data at the end of stream") } } else if readErr != nil && readErr != io.EOF { diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go new file mode 100644 index 000000000..9aa0c8d7d --- /dev/null +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/auth_meta.go @@ -0,0 +1,187 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Authorization Server Metadata. +// See https://www.rfc-editor.org/rfc/rfc8414.html. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "errors" + "fmt" + "net/http" +) + +// AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, +// as defined in [RFC 8414]. +// +// Not supported: +// - signed metadata +// +// Note: URL fields in this struct are validated by validateAuthServerMetaURLs to +// prevent XSS attacks. If you add a new URL field, you must also add it to that +// function. +// +// [RFC 8414]: https://tools.ietf.org/html/rfc8414) +type AuthServerMeta struct { + // GENERATED BY GEMINI 2.5. + + // Issuer is the REQUIRED URL identifying the authorization server. + Issuer string `json:"issuer"` + + // AuthorizationEndpoint is the REQUIRED URL of the server's OAuth 2.0 authorization endpoint. + AuthorizationEndpoint string `json:"authorization_endpoint"` + + // TokenEndpoint is the REQUIRED URL of the server's OAuth 2.0 token endpoint. + TokenEndpoint string `json:"token_endpoint"` + + // JWKSURI is the REQUIRED URL of the server's JSON Web Key Set [JWK] document. + JWKSURI string `json:"jwks_uri"` + + // RegistrationEndpoint is the RECOMMENDED URL of the server's OAuth 2.0 Dynamic Client Registration endpoint. + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + + // ScopesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // "scope" values that this server supports. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // ResponseTypesSupported is a REQUIRED JSON array of strings containing a list of the OAuth 2.0 + // "response_type" values that this server supports. + ResponseTypesSupported []string `json:"response_types_supported"` + + // ResponseModesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // "response_mode" values that this server supports. + ResponseModesSupported []string `json:"response_modes_supported,omitempty"` + + // GrantTypesSupported is a RECOMMENDED JSON array of strings containing a list of the OAuth 2.0 + // grant type values that this server supports. + GrantTypesSupported []string `json:"grant_types_supported,omitempty"` + + // TokenEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing a list of + // client authentication methods supported by this token endpoint. + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"` + + // TokenEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings containing + // a list of the JWS signing algorithms ("alg" values) supported by the token endpoint for + // the signature on the JWT used to authenticate the client. + TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` + + // ServiceDocumentation is a RECOMMENDED URL of a page containing human-readable documentation + // for the service. + ServiceDocumentation string `json:"service_documentation,omitempty"` + + // UILocalesSupported is a RECOMMENDED JSON array of strings representing supported + // BCP47 [RFC5646] language tag values for display in the user interface. + UILocalesSupported []string `json:"ui_locales_supported,omitempty"` + + // OpPolicyURI is a RECOMMENDED URL that the server provides to the person registering + // the client to read about the server's operator policies. + OpPolicyURI string `json:"op_policy_uri,omitempty"` + + // OpTOSURI is a RECOMMENDED URL that the server provides to the person registering the + // client to read about the server's terms of service. + OpTOSURI string `json:"op_tos_uri,omitempty"` + + // RevocationEndpoint is a RECOMMENDED URL of the server's OAuth 2.0 revocation endpoint. + RevocationEndpoint string `json:"revocation_endpoint,omitempty"` + + // RevocationEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing + // a list of client authentication methods supported by this revocation endpoint. + RevocationEndpointAuthMethodsSupported []string `json:"revocation_endpoint_auth_methods_supported,omitempty"` + + // RevocationEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings + // containing a list of the JWS signing algorithms ("alg" values) supported by the revocation + // endpoint for the signature on the JWT used to authenticate the client. + RevocationEndpointAuthSigningAlgValuesSupported []string `json:"revocation_endpoint_auth_signing_alg_values_supported,omitempty"` + + // IntrospectionEndpoint is a RECOMMENDED URL of the server's OAuth 2.0 introspection endpoint. + IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"` + + // IntrospectionEndpointAuthMethodsSupported is a RECOMMENDED JSON array of strings containing + // a list of client authentication methods supported by this introspection endpoint. + IntrospectionEndpointAuthMethodsSupported []string `json:"introspection_endpoint_auth_methods_supported,omitempty"` + + // IntrospectionEndpointAuthSigningAlgValuesSupported is a RECOMMENDED JSON array of strings + // containing a list of the JWS signing algorithms ("alg" values) supported by the introspection + // endpoint for the signature on the JWT used to authenticate the client. + IntrospectionEndpointAuthSigningAlgValuesSupported []string `json:"introspection_endpoint_auth_signing_alg_values_supported,omitempty"` + + // CodeChallengeMethodsSupported is a RECOMMENDED JSON array of strings containing a list of + // PKCE code challenge methods supported by this authorization server. + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` +} + +var wellKnownPaths = []string{ + "/.well-known/oauth-authorization-server", + "/.well-known/openid-configuration", +} + +// GetAuthServerMeta issues a GET request to retrieve authorization server metadata +// from an OAuth authorization server with the given issuerURL. +// +// It follows [RFC 8414]: +// - The well-known paths specified there are inserted into the URL's path, one at time. +// The first to succeed is used. +// - The Issuer field is checked against issuerURL. +// +// [RFC 8414]: https://tools.ietf.org/html/rfc8414 +func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (*AuthServerMeta, error) { + var errs []error + for _, p := range wellKnownPaths { + u, err := prependToPath(issuerURL, p) + if err != nil { + // issuerURL is bad; no point in continuing. + return nil, err + } + asm, err := getJSON[AuthServerMeta](ctx, c, u, 1<<20) + if err == nil { + if asm.Issuer != issuerURL { // section 3.3 + // Security violation; don't keep trying. + return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) + } + + if len(asm.CodeChallengeMethodsSupported) == 0 { + return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuerURL) + } + + // Validate endpoint URLs to prevent XSS attacks (see #526). + if err := validateAuthServerMetaURLs(asm); err != nil { + return nil, err + } + + return asm, nil + } + errs = append(errs, err) + } + return nil, fmt.Errorf("failed to get auth server metadata from %q: %w", issuerURL, errors.Join(errs...)) +} + +// validateAuthServerMetaURLs validates all URL fields in AuthServerMeta +// to ensure they don't use dangerous schemes that could enable XSS attacks. +func validateAuthServerMetaURLs(asm *AuthServerMeta) error { + urls := []struct { + name string + value string + }{ + {"authorization_endpoint", asm.AuthorizationEndpoint}, + {"token_endpoint", asm.TokenEndpoint}, + {"jwks_uri", asm.JWKSURI}, + {"registration_endpoint", asm.RegistrationEndpoint}, + {"service_documentation", asm.ServiceDocumentation}, + {"op_policy_uri", asm.OpPolicyURI}, + {"op_tos_uri", asm.OpTOSURI}, + {"revocation_endpoint", asm.RevocationEndpoint}, + {"introspection_endpoint", asm.IntrospectionEndpoint}, + } + + for _, u := range urls { + if err := checkURLScheme(u.value); err != nil { + return fmt.Errorf("%s: %w", u.name, err) + } + } + return nil +} diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/dcr.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/dcr.go new file mode 100644 index 000000000..c64cb8cd4 --- /dev/null +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/dcr.go @@ -0,0 +1,261 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Authorization Server Metadata. +// See https://www.rfc-editor.org/rfc/rfc8414.html. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// ClientRegistrationMetadata represents the client metadata fields for the DCR POST request (RFC 7591). +// +// Note: URL fields in this struct are validated by validateClientRegistrationURLs +// to prevent XSS attacks. If you add a new URL field, you must also add it to +// that function. +type ClientRegistrationMetadata struct { + // RedirectURIs is a REQUIRED JSON array of redirection URI strings for use in + // redirect-based flows (such as the authorization code grant). + RedirectURIs []string `json:"redirect_uris"` + + // TokenEndpointAuthMethod is an OPTIONAL string indicator of the requested + // authentication method for the token endpoint. + // If omitted, the default is "client_secret_basic". + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + + // GrantTypes is an OPTIONAL JSON array of OAuth 2.0 grant type strings + // that the client will restrict itself to using. + // If omitted, the default is ["authorization_code"]. + GrantTypes []string `json:"grant_types,omitempty"` + + // ResponseTypes is an OPTIONAL JSON array of OAuth 2.0 response type strings + // that the client will restrict itself to using. + // If omitted, the default is ["code"]. + ResponseTypes []string `json:"response_types,omitempty"` + + // ClientName is a RECOMMENDED human-readable name of the client to be presented + // to the end-user. + ClientName string `json:"client_name,omitempty"` + + // ClientURI is a RECOMMENDED URL of a web page providing information about the client. + ClientURI string `json:"client_uri,omitempty"` + + // LogoURI is an OPTIONAL URL of a logo for the client, which may be displayed + // to the end-user. + LogoURI string `json:"logo_uri,omitempty"` + + // Scope is an OPTIONAL string containing a space-separated list of scope values + // that the client will restrict itself to using. + Scope string `json:"scope,omitempty"` + + // Contacts is an OPTIONAL JSON array of strings representing ways to contact + // people responsible for this client (e.g., email addresses). + Contacts []string `json:"contacts,omitempty"` + + // TOSURI is an OPTIONAL URL that the client provides to the end-user + // to read about the client's terms of service. + TOSURI string `json:"tos_uri,omitempty"` + + // PolicyURI is an OPTIONAL URL that the client provides to the end-user + // to read about the client's privacy policy. + PolicyURI string `json:"policy_uri,omitempty"` + + // JWKSURI is an OPTIONAL URL for the client's JSON Web Key Set [JWK] document. + // This is preferred over the 'jwks' parameter. + JWKSURI string `json:"jwks_uri,omitempty"` + + // JWKS is an OPTIONAL client's JSON Web Key Set [JWK] document, passed by value. + // This is an alternative to providing a JWKSURI. + JWKS string `json:"jwks,omitempty"` + + // SoftwareID is an OPTIONAL unique identifier string for the client software, + // constant across all instances and versions. + SoftwareID string `json:"software_id,omitempty"` + + // SoftwareVersion is an OPTIONAL version identifier string for the client software. + SoftwareVersion string `json:"software_version,omitempty"` + + // SoftwareStatement is an OPTIONAL JWT that asserts client metadata values. + // Values in the software statement take precedence over other metadata values. + SoftwareStatement string `json:"software_statement,omitempty"` +} + +// ClientRegistrationResponse represents the fields returned by the Authorization Server +// (RFC 7591, Section 3.2.1 and 3.2.2). +type ClientRegistrationResponse struct { + // ClientRegistrationMetadata contains all registered client metadata, returned by the + // server on success, potentially with modified or defaulted values. + ClientRegistrationMetadata + + // ClientID is the REQUIRED newly issued OAuth 2.0 client identifier. + ClientID string `json:"client_id"` + + // ClientSecret is an OPTIONAL client secret string. + ClientSecret string `json:"client_secret,omitempty"` + + // ClientIDIssuedAt is an OPTIONAL Unix timestamp when the ClientID was issued. + ClientIDIssuedAt time.Time `json:"client_id_issued_at,omitempty"` + + // ClientSecretExpiresAt is the REQUIRED (if client_secret is issued) Unix + // timestamp when the secret expires, or 0 if it never expires. + ClientSecretExpiresAt time.Time `json:"client_secret_expires_at,omitempty"` +} + +func (r *ClientRegistrationResponse) MarshalJSON() ([]byte, error) { + type alias ClientRegistrationResponse + var clientIDIssuedAt int64 + var clientSecretExpiresAt int64 + + if !r.ClientIDIssuedAt.IsZero() { + clientIDIssuedAt = r.ClientIDIssuedAt.Unix() + } + if !r.ClientSecretExpiresAt.IsZero() { + clientSecretExpiresAt = r.ClientSecretExpiresAt.Unix() + } + + return json.Marshal(&struct { + ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"` + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + *alias + }{ + ClientIDIssuedAt: clientIDIssuedAt, + ClientSecretExpiresAt: clientSecretExpiresAt, + alias: (*alias)(r), + }) +} + +func (r *ClientRegistrationResponse) UnmarshalJSON(data []byte) error { + type alias ClientRegistrationResponse + aux := &struct { + ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"` + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + *alias + }{ + alias: (*alias)(r), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + if aux.ClientIDIssuedAt != 0 { + r.ClientIDIssuedAt = time.Unix(aux.ClientIDIssuedAt, 0) + } + if aux.ClientSecretExpiresAt != 0 { + r.ClientSecretExpiresAt = time.Unix(aux.ClientSecretExpiresAt, 0) + } + return nil +} + +// ClientRegistrationError is the error response from the Authorization Server +// for a failed registration attempt (RFC 7591, Section 3.2.2). +type ClientRegistrationError struct { + // ErrorCode is the REQUIRED error code if registration failed (RFC 7591, 3.2.2). + ErrorCode string `json:"error"` + + // ErrorDescription is an OPTIONAL human-readable error message. + ErrorDescription string `json:"error_description,omitempty"` +} + +func (e *ClientRegistrationError) Error() string { + return fmt.Sprintf("registration failed: %s (%s)", e.ErrorCode, e.ErrorDescription) +} + +// RegisterClient performs Dynamic Client Registration according to RFC 7591. +func RegisterClient(ctx context.Context, registrationEndpoint string, clientMeta *ClientRegistrationMetadata, c *http.Client) (*ClientRegistrationResponse, error) { + if registrationEndpoint == "" { + return nil, fmt.Errorf("registration_endpoint is required") + } + + if c == nil { + c = http.DefaultClient + } + + payload, err := json.Marshal(clientMeta) + if err != nil { + return nil, fmt.Errorf("failed to marshal client metadata: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", registrationEndpoint, bytes.NewBuffer(payload)) + if err != nil { + return nil, fmt.Errorf("failed to create registration request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := c.Do(req) + if err != nil { + return nil, fmt.Errorf("registration request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read registration response body: %w", err) + } + + if resp.StatusCode == http.StatusCreated { + var regResponse ClientRegistrationResponse + if err := json.Unmarshal(body, ®Response); err != nil { + return nil, fmt.Errorf("failed to decode successful registration response: %w (%s)", err, string(body)) + } + if regResponse.ClientID == "" { + return nil, fmt.Errorf("registration response is missing required 'client_id' field") + } + // Validate URL fields to prevent XSS attacks (see #526). + if err := validateClientRegistrationURLs(®Response.ClientRegistrationMetadata); err != nil { + return nil, err + } + return ®Response, nil + } + + if resp.StatusCode == http.StatusBadRequest { + var regError ClientRegistrationError + if err := json.Unmarshal(body, ®Error); err != nil { + return nil, fmt.Errorf("failed to decode registration error response: %w (%s)", err, string(body)) + } + return nil, ®Error + } + + return nil, fmt.Errorf("registration failed with status %s: %s", resp.Status, string(body)) +} + +// validateClientRegistrationURLs validates all URL fields in ClientRegistrationMetadata +// to ensure they don't use dangerous schemes that could enable XSS attacks. +func validateClientRegistrationURLs(meta *ClientRegistrationMetadata) error { + // Validate redirect URIs + for i, uri := range meta.RedirectURIs { + if err := checkURLScheme(uri); err != nil { + return fmt.Errorf("redirect_uris[%d]: %w", i, err) + } + } + + // Validate other URL fields + urls := []struct { + name string + value string + }{ + {"client_uri", meta.ClientURI}, + {"logo_uri", meta.LogoURI}, + {"tos_uri", meta.TOSURI}, + {"policy_uri", meta.PolicyURI}, + {"jwks_uri", meta.JWKSURI}, + } + + for _, u := range urls { + if err := checkURLScheme(u.value); err != nil { + return fmt.Errorf("%s: %w", u.name, err) + } + } + return nil +} diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go new file mode 100644 index 000000000..cdda695b7 --- /dev/null +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauth2.go @@ -0,0 +1,91 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package oauthex implements extensions to OAuth2. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "strings" +) + +// prependToPath prepends pre to the path of urlStr. +// When pre is the well-known path, this is the algorithm specified in both RFC 9728 +// section 3.1 and RFC 8414 section 3.1. +func prependToPath(urlStr, pre string) (string, error) { + u, err := url.Parse(urlStr) + if err != nil { + return "", err + } + p := "/" + strings.Trim(pre, "/") + if u.Path != "" { + p += "/" + } + + u.Path = p + strings.TrimLeft(u.Path, "/") + return u.String(), nil +} + +// getJSON retrieves JSON and unmarshals JSON from the URL, as specified in both +// RFC 9728 and RFC 8414. +// It will not read more than limit bytes from the body. +func getJSON[T any](ctx context.Context, c *http.Client, url string, limit int64) (*T, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + if c == nil { + c = http.DefaultClient + } + res, err := c.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // Specs require a 200. + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bad status %s", res.Status) + } + // Specs require application/json. + ct := res.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(ct) + if err != nil || mediaType != "application/json" { + return nil, fmt.Errorf("bad content type %q", ct) + } + + var t T + dec := json.NewDecoder(io.LimitReader(res.Body, limit)) + if err := dec.Decode(&t); err != nil { + return nil, err + } + return &t, nil +} + +// checkURLScheme ensures that its argument is a valid URL with a scheme +// that prevents XSS attacks. +// See #526. +func checkURLScheme(u string) error { + if u == "" { + return nil + } + uu, err := url.Parse(u) + if err != nil { + return err + } + scheme := strings.ToLower(uu.Scheme) + if scheme == "javascript" || scheme == "data" || scheme == "vbscript" { + return fmt.Errorf("URL has disallowed scheme %q", scheme) + } + return nil +} diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go new file mode 100644 index 000000000..34ed55b59 --- /dev/null +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/oauthex.go @@ -0,0 +1,92 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package oauthex implements extensions to OAuth2. +package oauthex + +// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, +// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. +// +// The following features are not supported: +// - additional keys (ยง2, last sentence) +// - human-readable metadata (ยง2.1) +// - signed metadata (ยง2.2) +type ProtectedResourceMetadata struct { + // GENERATED BY GEMINI 2.5. + + // Resource (resource) is the protected resource's resource identifier. + // Required. + Resource string `json:"resource"` + + // AuthorizationServers (authorization_servers) is an optional slice containing a list of + // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be + // used with this protected resource. + AuthorizationServers []string `json:"authorization_servers,omitempty"` + + // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set + // document. This contains public keys belonging to the protected resource, such as + // signing key(s) that the resource server uses to sign resource responses. + JWKSURI string `json:"jwks_uri,omitempty"` + + // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope + // values (as defined in RFC 6749) used in authorization requests to request access + // to this protected resource. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing + // a list of the supported methods of sending an OAuth 2.0 bearer token to the + // protected resource. Defined values are "header", "body", and "query". + BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` + + // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms (alg values) supported by the protected + // resource for signing resource responses. + ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` + + // ResourceName (resource_name) is a human-readable name of the protected resource + // intended for display to the end user. It is RECOMMENDED that this field be included. + // This value may be internationalized. + ResourceName string `json:"resource_name,omitempty"` + + // ResourceDocumentation (resource_documentation) is an optional URL of a page containing + // human-readable information for developers using the protected resource. + // This value may be internationalized. + ResourceDocumentation string `json:"resource_documentation,omitempty"` + + // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing + // human-readable policy information on how a client can use the data provided. + // This value may be internationalized. + ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` + + // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected + // resource's human-readable terms of service. This value may be internationalized. + ResourceTOSURI string `json:"resource_tos_uri,omitempty"` + + // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an + // optional boolean indicating support for mutual-TLS client certificate-bound + // access tokens (RFC 8705). Defaults to false if omitted. + TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` + + // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional + // slice of 'type' values supported by the resource server for the + // 'authorization_details' parameter (RFC 9396). + AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` + + // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms supported by the resource server for validating + // DPoP proof JWTs (RFC 9449). + DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` + + // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean + // specifying whether the protected resource always requires the use of DPoP-bound + // access tokens (RFC 9449). Defaults to false if omitted. + DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` + + // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters + // about the protected resource as claims. If present, these values take precedence + // over values conveyed in plain JSON. + // TODO:implement. + // Note that ยง2.2 says it's okay to ignore this. + // SignedMetadata string `json:"signed_metadata,omitempty"` +} diff --git a/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go new file mode 100644 index 000000000..bb61f7974 --- /dev/null +++ b/mcp/vendor/github.com/modelcontextprotocol/go-sdk/oauthex/resource_meta.go @@ -0,0 +1,281 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Protected Resource Metadata. +// See https://www.rfc-editor.org/rfc/rfc9728.html. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "path" + "strings" + "unicode" + + "github.com/modelcontextprotocol/go-sdk/internal/util" +) + +const defaultProtectedResourceMetadataURI = "/.well-known/oauth-protected-resource" + +// GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource +// metadata from a resource server by its ID. +// The resource ID is an HTTPS URL, typically with a host:port and possibly a path. +// For example: +// +// https://example.com/server +// +// This function, following the spec (ยง3), inserts the default well-known path into the +// URL. In our example, the result would be +// +// https://example.com/.well-known/oauth-protected-resource/server +// +// It then retrieves the metadata at that location using the given client (or the +// default client if nil) and validates its resource field against resourceID. +func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadataFromID(%q)", resourceID) + + u, err := url.Parse(resourceID) + if err != nil { + return nil, err + } + // Insert well-known URI into URL. + u.Path = path.Join(defaultProtectedResourceMetadataURI, u.Path) + return getPRM(ctx, u.String(), c, resourceID) +} + +// GetProtectedResourceMetadataFromHeader retrieves protected resource metadata +// using information in the given header, using the given client (or the default +// client if nil). +// It issues a GET request to a URL discovered by parsing the WWW-Authenticate headers in the given request. +// Per RFC 9728 section 3.3, it validates that the resource field of the resulting metadata +// matches the serverURL (the URL that the client used to make the original request to the resource server). +// If there is no metadata URL in the header, it returns nil, nil. +func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL string, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader") + headers := header[http.CanonicalHeaderKey("WWW-Authenticate")] + if len(headers) == 0 { + return nil, nil + } + cs, err := ParseWWWAuthenticate(headers) + if err != nil { + return nil, err + } + metadataURL := ResourceMetadataURL(cs) + if metadataURL == "" { + return nil, nil + } + return getPRM(ctx, metadataURL, c, serverURL) +} + +// getPRM makes a GET request to the given URL, and validates the response. +// As part of the validation, it compares the returned resource field to wantResource. +func getPRM(ctx context.Context, purl string, c *http.Client, wantResource string) (*ProtectedResourceMetadata, error) { + if !strings.HasPrefix(strings.ToUpper(purl), "HTTPS://") { + return nil, fmt.Errorf("resource URL %q does not use HTTPS", purl) + } + prm, err := getJSON[ProtectedResourceMetadata](ctx, c, purl, 1<<20) + if err != nil { + return nil, err + } + // Validate the Resource field (see RFC 9728, section 3.3). + if prm.Resource != wantResource { + return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, wantResource) + } + // Validate the authorization server URLs to prevent XSS attacks (see #526). + for _, u := range prm.AuthorizationServers { + if err := checkURLScheme(u); err != nil { + return nil, err + } + } + return prm, nil +} + +// challenge represents a single authentication challenge from a WWW-Authenticate header. +// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. +type challenge struct { + // GENERATED BY GEMINI 2.5. + // + // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). + // It is case-insensitive. A parsed value will always be lower-case. + Scheme string + // Params is a map of authentication parameters. + // Keys are case-insensitive. Parsed keys are always lower-case. + Params map[string]string +} + +// ResourceMetadataURL returns a resource metadata URL from the given challenges, +// or the empty string if there is none. +func ResourceMetadataURL(cs []challenge) string { + for _, c := range cs { + if u := c.Params["resource_metadata"]; u != "" { + return u + } + } + return "" +} + +// ParseWWWAuthenticate parses a WWW-Authenticate header string. +// The header format is defined in RFC 9110, Section 11.6.1, and can contain +// one or more challenges, separated by commas. +// It returns a slice of challenges or an error if one of the headers is malformed. +func ParseWWWAuthenticate(headers []string) ([]challenge, error) { + // GENERATED BY GEMINI 2.5 (human-tweaked) + var challenges []challenge + for _, h := range headers { + challengeStrings, err := splitChallenges(h) + if err != nil { + return nil, err + } + for _, cs := range challengeStrings { + if strings.TrimSpace(cs) == "" { + continue + } + challenge, err := parseSingleChallenge(cs) + if err != nil { + return nil, fmt.Errorf("failed to parse challenge %q: %w", cs, err) + } + challenges = append(challenges, challenge) + } + } + return challenges, nil +} + +// splitChallenges splits a header value containing one or more challenges. +// It correctly handles commas within quoted strings and distinguishes between +// commas separating auth-params and commas separating challenges. +func splitChallenges(header string) ([]string, error) { + // GENERATED BY GEMINI 2.5. + var challenges []string + inQuotes := false + start := 0 + for i, r := range header { + if r == '"' { + if i > 0 && header[i-1] != '\\' { + inQuotes = !inQuotes + } else if i == 0 { + // A challenge begins with an auth-scheme, which is a token, which cannot contain + // a quote. + return nil, errors.New(`challenge begins with '"'`) + } + } else if r == ',' && !inQuotes { + // This is a potential challenge separator. + // A new challenge does not start with `key=value`. + // We check if the part after the comma looks like a parameter. + lookahead := strings.TrimSpace(header[i+1:]) + eqPos := strings.Index(lookahead, "=") + + isParam := false + if eqPos > 0 { + // Check if the part before '=' is a single token (no spaces). + token := lookahead[:eqPos] + if strings.IndexFunc(token, unicode.IsSpace) == -1 { + isParam = true + } + } + + if !isParam { + // The part after the comma does not look like a parameter, + // so this comma separates challenges. + challenges = append(challenges, header[start:i]) + start = i + 1 + } + } + } + // Add the last (or only) challenge to the list. + challenges = append(challenges, header[start:]) + return challenges, nil +} + +// parseSingleChallenge parses a string containing exactly one challenge. +// challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] +func parseSingleChallenge(s string) (challenge, error) { + // GENERATED BY GEMINI 2.5, human-tweaked. + s = strings.TrimSpace(s) + if s == "" { + return challenge{}, errors.New("empty challenge string") + } + + scheme, paramsStr, found := strings.Cut(s, " ") + c := challenge{Scheme: strings.ToLower(scheme)} + if !found { + return c, nil + } + + params := make(map[string]string) + + // Parse the key-value parameters. + for paramsStr != "" { + // Find the end of the parameter key. + keyEnd := strings.Index(paramsStr, "=") + if keyEnd <= 0 { + return challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) + } + key := strings.TrimSpace(paramsStr[:keyEnd]) + + // Move the string past the key and the '='. + paramsStr = strings.TrimSpace(paramsStr[keyEnd+1:]) + + var value string + if strings.HasPrefix(paramsStr, "\"") { + // The value is a quoted string. + paramsStr = paramsStr[1:] // Consume the opening quote. + var valBuilder strings.Builder + i := 0 + for ; i < len(paramsStr); i++ { + // Handle escaped characters. + if paramsStr[i] == '\\' && i+1 < len(paramsStr) { + valBuilder.WriteByte(paramsStr[i+1]) + i++ // We've consumed two characters. + } else if paramsStr[i] == '"' { + // End of the quoted string. + break + } else { + valBuilder.WriteByte(paramsStr[i]) + } + } + + // A quoted string must be terminated. + if i == len(paramsStr) { + return challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") + } + + value = valBuilder.String() + // Move the string past the value and the closing quote. + paramsStr = strings.TrimSpace(paramsStr[i+1:]) + } else { + // The value is a token. It ends at the next comma or the end of the string. + commaPos := strings.Index(paramsStr, ",") + if commaPos == -1 { + value = paramsStr + paramsStr = "" + } else { + value = strings.TrimSpace(paramsStr[:commaPos]) + paramsStr = strings.TrimSpace(paramsStr[commaPos:]) // Keep comma for next check + } + } + if value == "" { + return challenge{}, fmt.Errorf("no value for auth param %q", key) + } + + // Per RFC 9110, parameter keys are case-insensitive. + params[strings.ToLower(key)] = value + + // If there is a comma, consume it and continue to the next parameter. + if strings.HasPrefix(paramsStr, ",") { + paramsStr = strings.TrimSpace(paramsStr[1:]) + } else if paramsStr != "" { + // If there's content but it's not a new parameter, the format is wrong. + return challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) + } + } + + // Per RFC 9110, the scheme is case-insensitive. + return challenge{Scheme: strings.ToLower(scheme), Params: params}, nil +} diff --git a/mcp/vendor/modules.txt b/mcp/vendor/modules.txt index 8a52fe0fd..57d6290b9 100644 --- a/mcp/vendor/modules.txt +++ b/mcp/vendor/modules.txt @@ -1,4 +1,4 @@ -# github.com/goccy/go-yaml v1.19.0 +# github.com/goccy/go-yaml v1.19.1 ## explicit; go 1.21.0 github.com/goccy/go-yaml github.com/goccy/go-yaml/ast @@ -15,7 +15,7 @@ github.com/google/jsonschema-go/jsonschema # github.com/inconshreveable/mousetrap v1.1.0 ## explicit; go 1.18 github.com/inconshreveable/mousetrap -# github.com/modelcontextprotocol/go-sdk v1.1.0 +# github.com/modelcontextprotocol/go-sdk v1.2.0 ## explicit; go 1.23.0 github.com/modelcontextprotocol/go-sdk/auth github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2 @@ -23,7 +23,8 @@ github.com/modelcontextprotocol/go-sdk/internal/util github.com/modelcontextprotocol/go-sdk/internal/xcontext github.com/modelcontextprotocol/go-sdk/jsonrpc github.com/modelcontextprotocol/go-sdk/mcp -# github.com/netapp/harvest/v2 v2.0.0-20251212120439-ea75a8047ce8 => ../ +github.com/modelcontextprotocol/go-sdk/oauthex +# github.com/netapp/harvest/v2 v2.0.0-20251215084222-367b927e5360 => ../ ## explicit; go 1.24.0 github.com/netapp/harvest/v2/pkg/slogx # github.com/spf13/cobra v1.10.2 diff --git a/vendor/github.com/goccy/go-yaml/ast/ast.go b/vendor/github.com/goccy/go-yaml/ast/ast.go index ca1505381..a8078a5f5 100644 --- a/vendor/github.com/goccy/go-yaml/ast/ast.go +++ b/vendor/github.com/goccy/go-yaml/ast/ast.go @@ -1623,7 +1623,11 @@ func (n *SequenceNode) flowStyleString() string { for _, value := range n.Values { values = append(values, value.String()) } - return fmt.Sprintf("[%s]", strings.Join(values, ", ")) + seqText := fmt.Sprintf("[%s]", strings.Join(values, ", ")) + if n.Comment != nil { + return addCommentString(seqText, n.Comment) + } + return seqText } func (n *SequenceNode) blockStyleString() string { diff --git a/vendor/github.com/goccy/go-yaml/decode.go b/vendor/github.com/goccy/go-yaml/decode.go index 43c317f8f..d490add63 100644 --- a/vendor/github.com/goccy/go-yaml/decode.go +++ b/vendor/github.com/goccy/go-yaml/decode.go @@ -288,7 +288,9 @@ func (d *Decoder) addSequenceNodeCommentToMap(node *ast.SequenceNode) { texts = append(texts, comment.Token.Value) } if len(texts) != 0 { - d.addCommentToMap(node.Values[0].GetPath(), HeadComment(texts...)) + if len(node.Values) != 0 { + d.addCommentToMap(node.Values[0].GetPath(), HeadComment(texts...)) + } } } } @@ -1750,14 +1752,11 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node return err } } else { - keyVal, err := d.nodeToValue(ctx, key) + keyVal, err := d.createDecodedNewValue(ctx, keyType, reflect.Value{}, key) if err != nil { return err } - k = reflect.ValueOf(keyVal) - if k.IsValid() && k.Type().ConvertibleTo(keyType) { - k = k.Convert(keyType) - } + k = keyVal } if k.IsValid() { diff --git a/vendor/github.com/goccy/go-yaml/internal/format/format.go b/vendor/github.com/goccy/go-yaml/internal/format/format.go index 2d55652ff..461dc36d2 100644 --- a/vendor/github.com/goccy/go-yaml/internal/format/format.go +++ b/vendor/github.com/goccy/go-yaml/internal/format/format.go @@ -351,8 +351,9 @@ func (f *Formatter) formatMapping(n *ast.MappingNode) string { var ret string if n.IsFlowStyle { ret = f.origin(n.Start) + } else { + ret += f.formatCommentGroup(n.Comment) } - ret += f.formatCommentGroup(n.Comment) for _, value := range n.Values { if value.CollectEntry != nil { ret += f.origin(value.CollectEntry) @@ -361,6 +362,7 @@ func (f *Formatter) formatMapping(n *ast.MappingNode) string { } if n.IsFlowStyle { ret += f.origin(n.End) + ret += f.formatCommentGroup(n.Comment) } return ret } @@ -377,8 +379,7 @@ func (f *Formatter) formatSequence(n *ast.SequenceNode) string { var ret string if n.IsFlowStyle { ret = f.origin(n.Start) - } - if n.Comment != nil { + } else { // add head comment. ret += f.formatCommentGroup(n.Comment) } @@ -387,6 +388,7 @@ func (f *Formatter) formatSequence(n *ast.SequenceNode) string { } if n.IsFlowStyle { ret += f.origin(n.End) + ret += f.formatCommentGroup(n.Comment) } ret += f.formatCommentGroup(n.FootComment) return ret diff --git a/vendor/github.com/goccy/go-yaml/parser/parser.go b/vendor/github.com/goccy/go-yaml/parser/parser.go index 2c79d3690..f5bfd1a96 100644 --- a/vendor/github.com/goccy/go-yaml/parser/parser.go +++ b/vendor/github.com/goccy/go-yaml/parser/parser.go @@ -426,6 +426,11 @@ func (p *parser) parseFlowMap(ctx *context) (*ast.MappingNode, error) { if node.End == nil { return nil, errors.ErrSyntax("could not find flow mapping end token '}'", node.Start) } + + // set line comment if exists. e.g.) } # comment + if err := setLineComment(ctx, node, ctx.currentToken()); err != nil { + return nil, err + } ctx.goNext() // skip mapping end token. return node, nil } @@ -1066,6 +1071,11 @@ func (p *parser) parseFlowSequence(ctx *context) (*ast.SequenceNode, error) { if node.End == nil { return nil, errors.ErrSyntax("sequence end token ']' not found", node.Start) } + + // set line comment if exists. e.g.) ] # comment + if err := setLineComment(ctx, node, ctx.currentToken()); err != nil { + return nil, err + } ctx.goNext() // skip sequence end token. return node, nil } diff --git a/vendor/modules.txt b/vendor/modules.txt index 98b440e2f..6672e5386 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,4 +1,4 @@ -# github.com/goccy/go-yaml v1.19.0 +# github.com/goccy/go-yaml v1.19.1 ## explicit; go 1.21.0 github.com/goccy/go-yaml github.com/goccy/go-yaml/ast