Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cmd/run/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/pingidentity/pingone-mcp-server/internal/sdk"
"github.com/pingidentity/pingone-mcp-server/internal/sdk/legacy"
"github.com/pingidentity/pingone-mcp-server/internal/testutils"
authtestutils "github.com/pingidentity/pingone-mcp-server/internal/testutils/auth"
mcptestutils "github.com/pingidentity/pingone-mcp-server/internal/testutils/mcp"
"github.com/pingidentity/pingone-mcp-server/internal/tokenstore"
"github.com/pingidentity/pingone-mcp-server/internal/tools/environments"
Expand Down Expand Up @@ -87,7 +88,7 @@ func TestRunCommand_FromSubcommand_RunServer(t *testing.T) {
// Run the server in a goroutine so the test doesn't block.
var wg sync.WaitGroup
wg.Go(func() {
err := testutils.ExecuteCliRunCommand(t, ctx, tokenStoreFactory, sdk.NewEmptyClientFactory(), legacy.NewEmptyClientFactory(), testutils.NewEmptyMockAuthClientFactory(), &mcp.StdioTransport{})
err := testutils.ExecuteCliRunCommand(t, ctx, tokenStoreFactory, sdk.NewEmptyClientFactory(), legacy.NewEmptyClientFactory(), authtestutils.NewEmptyMockAuthClientFactory(), &mcp.StdioTransport{})
assert.ErrorIs(t, err, context.Canceled, "server should stop due to context cancellation")
})

Expand All @@ -114,7 +115,7 @@ func TestRunCommand_FromSubcommand_NoValidSession(t *testing.T) {
// Run the server in a goroutine so the test doesn't block.
var wg sync.WaitGroup
wg.Go(func() {
err := testutils.ExecuteCliRunCommand(t, ctx, tokenStoreFactory, sdk.NewEmptyClientFactory(), legacy.NewEmptyClientFactory(), testutils.NewEmptyMockAuthClientFactory(), &mcp.StdioTransport{})
err := testutils.ExecuteCliRunCommand(t, ctx, tokenStoreFactory, sdk.NewEmptyClientFactory(), legacy.NewEmptyClientFactory(), authtestutils.NewEmptyMockAuthClientFactory(), &mcp.StdioTransport{})
assert.ErrorIs(t, err, context.Canceled, "server should stop due to context cancellation")
})

Expand All @@ -134,7 +135,7 @@ func TestRunCommand_FromSubcommand_TokenStoreFactoryError(t *testing.T) {
expectedError := assert.AnError
tokenStoreFactory := testutils.NewMockTokenStoreFactoryWithError(expectedError)

err := testutils.ExecuteCliRunCommand(t, ctx, tokenStoreFactory, sdk.NewEmptyClientFactory(), legacy.NewEmptyClientFactory(), testutils.NewEmptyMockAuthClientFactory(), &mcp.StdioTransport{})
err := testutils.ExecuteCliRunCommand(t, ctx, tokenStoreFactory, sdk.NewEmptyClientFactory(), legacy.NewEmptyClientFactory(), authtestutils.NewEmptyMockAuthClientFactory(), &mcp.StdioTransport{})
require.Error(t, err, "Run should fail when token store factory returns error")
assert.Contains(t, err.Error(), expectedError.Error(), "Error should contain the factory error")
tokenStoreFactory.AssertExpectations(t)
Expand Down Expand Up @@ -183,7 +184,7 @@ func TestRunCommand_FromSubcommand_StoreTypeSelection(t *testing.T) {
// Run the server in a goroutine so the test doesn't block.
var wg sync.WaitGroup
wg.Go(func() {
err := testutils.ExecuteCliRunCommand(t, ctx, tokenStoreFactory, sdk.NewEmptyClientFactory(), legacy.NewEmptyClientFactory(), testutils.NewEmptyMockAuthClientFactory(), &mcp.StdioTransport{}, tt.args...)
err := testutils.ExecuteCliRunCommand(t, ctx, tokenStoreFactory, sdk.NewEmptyClientFactory(), legacy.NewEmptyClientFactory(), authtestutils.NewEmptyMockAuthClientFactory(), &mcp.StdioTransport{}, tt.args...)
assert.ErrorIs(t, err, context.Canceled, "server should stop due to context cancellation")
})

Expand All @@ -205,7 +206,7 @@ func TestRunCommand_FromSubcommand_InvalidStoreType(t *testing.T) {

tokenStoreFactory := testutils.NewMockTokenStoreFactory()

err := testutils.ExecuteCliRunCommand(t, ctx, tokenStoreFactory, sdk.NewEmptyClientFactory(), legacy.NewEmptyClientFactory(), testutils.NewEmptyMockAuthClientFactory(), &mcp.StdioTransport{}, "--store-type", "invalid")
err := testutils.ExecuteCliRunCommand(t, ctx, tokenStoreFactory, sdk.NewEmptyClientFactory(), legacy.NewEmptyClientFactory(), authtestutils.NewEmptyMockAuthClientFactory(), &mcp.StdioTransport{}, "--store-type", "invalid")
require.Error(t, err, "Run should fail with invalid store type")
assert.Contains(t, err.Error(), "unable to parse store type from string: invalid", "Error should indicate invalid store type")
}
Expand Down Expand Up @@ -336,7 +337,7 @@ func TestRunCommand_FromSubcommand_ToolFiltering(t *testing.T) {

var wg sync.WaitGroup
wg.Go(func() {
err := testutils.ExecuteCliRunCommand(t, ctx, tokenStoreFactory, sdk.NewEmptyClientFactory(), legacy.NewEmptyClientFactory(), testutils.NewEmptyMockAuthClientFactory(), serverTransport, tt.args...)
err := testutils.ExecuteCliRunCommand(t, ctx, tokenStoreFactory, sdk.NewEmptyClientFactory(), legacy.NewEmptyClientFactory(), authtestutils.NewEmptyMockAuthClientFactory(), serverTransport, tt.args...)
assert.ErrorIs(t, err, context.Canceled, "server should stop due to context cancellation")
})

Expand Down
85 changes: 85 additions & 0 deletions internal/auth/middleware/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright © 2025 Ping Identity Corporation

package middleware

import (
"context"
"fmt"
"log/slog"

"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/pingidentity/pingone-mcp-server/internal/auth"
"github.com/pingidentity/pingone-mcp-server/internal/auth/client"
"github.com/pingidentity/pingone-mcp-server/internal/logger"
"github.com/pingidentity/pingone-mcp-server/internal/tokenstore"
"github.com/pingidentity/pingone-mcp-server/internal/tools/initialize"
)

// AuthMiddleware ensures all tool calls have proper authentication context.
// It intercepts tool call requests and initializes the auth context before the tool handler executes.
//
// This middleware should be added to the MCP server via AddReceivingMiddleware.
// It runs the initializeAuthContext function to establish authentication, which may:
// 1. Check for an existing session
// 2. Trigger browser-based login if necessary
// 3. Add session information to the context
type AuthMiddleware struct {
authClientFactory client.AuthClientFactory
tokenStore tokenstore.TokenStore
grantType auth.GrantType
}

// NewAuthMiddleware creates middleware with auth dependencies.
// The authClientFactory is used to create auth clients for login flows.
// The tokenStore manages session persistence.
// The grantType determines the authentication method (authorization_code or device_code).
func NewAuthMiddleware(
authClientFactory client.AuthClientFactory,
tokenStore tokenstore.TokenStore,
grantType auth.GrantType,
) *AuthMiddleware {
return &AuthMiddleware{
authClientFactory: authClientFactory,
tokenStore: tokenStore,
grantType: grantType,
}
}

// Handler implements the middleware pattern by returning a MethodHandler that wraps the next handler.
// This handler intercepts all MCP method calls and ensures tool calls have proper authentication context.
func (m *AuthMiddleware) Handler(next mcp.MethodHandler) mcp.MethodHandler {
return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) {
// Only authenticate tool calls, not other MCP methods (initialize, list_tools, etc.)
if method != "tools/call" {
return next(ctx, method, req)
}

// Extract tool call details for logging
callToolReq, ok := req.(*mcp.CallToolRequest)
if !ok {
// Should never happen for tools/call method, but fail safe
return nil, fmt.Errorf("authentication failed: invalid tool call request")
}

toolName := callToolReq.Params.Name

logger.FromContext(ctx).Debug("Initializing authentication for tool",
slog.String("tool", toolName))

// Initialize auth context using the same logic as individual tool handlers
initializeAuthContext := initialize.AuthContextInitializer(m.authClientFactory, m.tokenStore, m.grantType)
authenticatedCtx, err := initializeAuthContext(ctx)
if err != nil {
logger.FromContext(ctx).Error("Authentication initialization failed",
slog.String("tool", toolName),
slog.String("error", err.Error()))
return nil, fmt.Errorf("authentication failed: %w", err)
}

logger.FromContext(ctx).Debug("Authentication initialized successfully",
slog.String("tool", toolName))

// Authentication successful, continue to next handler with authenticated context
return next(authenticatedCtx, method, req)
}
}
Loading