Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 48 additions & 19 deletions cmd/thv-operator/controllers/mcpserver_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"reflect"
"slices"
"strings"
"sync"
"time"

appsv1 "k8s.io/api/apps/v1"
Expand All @@ -22,20 +23,24 @@ import (
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/utils/ptr"
"k8s.io/client-go/rest"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
"sigs.k8s.io/controller-runtime/pkg/log"

mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1"
"github.com/stacklok/toolhive/pkg/container/kubernetes"
"github.com/stacklok/toolhive/pkg/logger"
)

// MCPServerReconciler reconciles a MCPServer object
type MCPServerReconciler struct {
client.Client
Scheme *runtime.Scheme
Scheme *runtime.Scheme
platformDetector kubernetes.PlatformDetector
detectedPlatform kubernetes.Platform
platformOnce sync.Once
}

// defaultRBACRules are the default RBAC rules that the
Expand Down Expand Up @@ -82,6 +87,35 @@ const (
authzLabelValueInline = "inline"
)

// detectPlatform detects the Kubernetes platform type (Kubernetes vs OpenShift)
// It uses sync.Once to ensure the detection is only performed once and cached
func (r *MCPServerReconciler) detectPlatform(ctx context.Context) (kubernetes.Platform, error) {
var err error
r.platformOnce.Do(func() {
// Initialize platform detector if not already done
if r.platformDetector == nil {
r.platformDetector = kubernetes.NewDefaultPlatformDetector()
}

cfg, configErr := rest.InClusterConfig()
if configErr != nil {
err = fmt.Errorf("failed to get in-cluster config for platform detection: %w", configErr)
return
}

r.detectedPlatform, err = r.platformDetector.DetectPlatform(cfg)
if err != nil {
err = fmt.Errorf("failed to detect platform: %w", err)
return
}

ctxLogger := log.FromContext(ctx)
ctxLogger.Info("Platform detected for MCPServer controller", "platform", r.detectedPlatform.String())
})

return r.detectedPlatform, err
}

// Reconcile is part of the main kubernetes reconciliation loop which aims to
// move the current state of the cluster closer to the desired state.
//
Expand Down Expand Up @@ -156,7 +190,7 @@ func (r *MCPServerReconciler) Reconcile(ctx context.Context, req ctrl.Request) (
err = r.Get(ctx, types.NamespacedName{Name: mcpServer.Name, Namespace: mcpServer.Namespace}, deployment)
if err != nil && errors.IsNotFound(err) {
// Define a new deployment
dep := r.deploymentForMCPServer(mcpServer)
dep := r.deploymentForMCPServer(ctx, mcpServer)
if dep == nil {
ctxLogger.Error(nil, "Failed to create Deployment object")
return ctrl.Result{}, fmt.Errorf("failed to create Deployment object")
Expand Down Expand Up @@ -225,7 +259,7 @@ func (r *MCPServerReconciler) Reconcile(ctx context.Context, req ctrl.Request) (
// Check if the deployment spec changed
if deploymentNeedsUpdate(deployment, mcpServer) {
// Update the deployment
newDeployment := r.deploymentForMCPServer(mcpServer)
newDeployment := r.deploymentForMCPServer(ctx, mcpServer)
deployment.Spec = newDeployment.Spec
err = r.Update(ctx, deployment)
if err != nil {
Expand Down Expand Up @@ -401,7 +435,7 @@ func (r *MCPServerReconciler) ensureRBACResources(ctx context.Context, mcpServer
// deploymentForMCPServer returns a MCPServer Deployment object
//
//nolint:gocyclo
func (r *MCPServerReconciler) deploymentForMCPServer(m *mcpv1alpha1.MCPServer) *appsv1.Deployment {
func (r *MCPServerReconciler) deploymentForMCPServer(ctx context.Context, m *mcpv1alpha1.MCPServer) *appsv1.Deployment {
ls := labelsForMCPServer(m.Name)
replicas := int32(1)

Expand Down Expand Up @@ -581,22 +615,17 @@ func (r *MCPServerReconciler) deploymentForMCPServer(m *mcpv1alpha1.MCPServer) *
}
}

// Prepare ProxyRunner's pod and container security context
proxyRunnerPodSecurityContext := &corev1.PodSecurityContext{
RunAsNonRoot: ptr.To(true),
RunAsUser: ptr.To(int64(1000)),
RunAsGroup: ptr.To(int64(1000)),
FSGroup: ptr.To(int64(1000)),
// Detect platform and prepare ProxyRunner's pod and container security context
_, err := r.detectPlatform(ctx)
if err != nil {
ctxLogger := log.FromContext(ctx)
ctxLogger.Error(err, "Failed to detect platform, defaulting to Kubernetes", "mcpserver", m.Name)
}

proxyRunnerContainerSecurityContext := &corev1.SecurityContext{
Privileged: ptr.To(false),
RunAsNonRoot: ptr.To(true),
RunAsUser: ptr.To(int64(1000)),
RunAsGroup: ptr.To(int64(1000)),
AllowPrivilegeEscalation: ptr.To(false),
ReadOnlyRootFilesystem: ptr.To(true),
}
// Use SecurityContextBuilder for platform-aware security context
securityBuilder := kubernetes.NewSecurityContextBuilder(r.detectedPlatform)
proxyRunnerPodSecurityContext := securityBuilder.BuildPodSecurityContext()
proxyRunnerContainerSecurityContext := securityBuilder.BuildContainerSecurityContext()

env = ensureRequiredEnvVars(env)

Expand Down
Loading
Loading