Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions mcp/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@ import (
"time"
)

const (
defaultTerminateDuration = 5 * time.Second
)

// A CommandTransport is a [Transport] that runs a command and communicates
// with it over stdin/stdout, using newline-delimited JSON.
type CommandTransport struct {
Command *exec.Cmd
// TerminateDuration controls how long Close waits after closing stdin
// for the process to exit before sending SIGTERM.
// If zero or negative, the default of 5s is used.
TerminateDuration time.Duration
}

// NewCommandTransport returns a [CommandTransport] that runs the given command
Expand Down Expand Up @@ -46,15 +54,20 @@ func (t *CommandTransport) Connect(ctx context.Context) (Connection, error) {
if err := t.Command.Start(); err != nil {
return nil, err
}
return newIOConn(&pipeRWC{t.Command, stdout, stdin}), nil
td := t.TerminateDuration
if td <= 0 {
td = defaultTerminateDuration
}
return newIOConn(&pipeRWC{t.Command, stdout, stdin, td}), nil
}

// A pipeRWC is an io.ReadWriteCloser that communicates with a subprocess over
// stdin/stdout pipes.
type pipeRWC struct {
cmd *exec.Cmd
stdout io.ReadCloser
stdin io.WriteCloser
cmd *exec.Cmd
stdout io.ReadCloser
stdin io.WriteCloser
terminateDuration time.Duration
}

func (s *pipeRWC) Read(p []byte) (n int, err error) {
Expand Down Expand Up @@ -85,7 +98,7 @@ func (s *pipeRWC) Close() error {
select {
case err := <-resChan:
return err, true
case <-time.After(5 * time.Second):
case <-time.After(s.terminateDuration):
}
return nil, false
}
Expand Down
68 changes: 68 additions & 0 deletions mcp/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,74 @@ func createServerCommand(t *testing.T, serverName string) *exec.Cmd {
return cmd
}

func TestCommandTransportTerminateDuration(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("requires POSIX signals")
}
requireExec(t)

tests := []struct {
name string
duration time.Duration
wantMaxDuration time.Duration
}{
{
name: "default duration (zero)",
duration: 0,
wantMaxDuration: 6 * time.Second, // default 5s + buffer
},
{
name: "below minimum duration",
duration: 500 * time.Millisecond,
wantMaxDuration: 6 * time.Second, // should use default 5s + buffer
},
{
name: "custom valid duration",
duration: 2 * time.Second,
wantMaxDuration: 3 * time.Second, // custom 2s + buffer
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Use a command that won't exit when stdin is closed
cmd := exec.Command("sleep", "20")
transport := &mcp.CommandTransport{
Command: cmd,
TerminateDuration: tt.duration,
}

conn, err := transport.Connect(ctx)
if err != nil {
t.Fatal(err)
}

start := time.Now()
err = conn.Close()
elapsed := time.Since(start)

if err != nil {
var exitErr *exec.ExitError
if !errors.As(err, &exitErr) {
t.Fatalf("Close() failed with unexpected error: %v", err)
}
}

if elapsed > tt.wantMaxDuration {
t.Errorf("Close() took %v, expected at most %v", elapsed, tt.wantMaxDuration)
}

// Ensure the process was actually terminated
if cmd.Process != nil {
cmd.Process.Kill()
}
})
}
}

func requireExec(t *testing.T) {
t.Helper()

Expand Down