Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions cmd/wonder/commands/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package commands
import (
"log/slog"
"os"
"strings"

"github.com/spf13/cobra"
"github.com/spf13/viper"
Expand All @@ -24,12 +25,14 @@ func NewCoordinatorCmd() *cobra.Command {
cmd.Flags().String("db-driver", "sqlite", "Database driver (sqlite or postgres)")
cmd.Flags().String("db-dsn", "", "Database connection string")
cmd.Flags().Bool("enable-admin-api", false, "Enable admin API endpoints")
cmd.Flags().StringArray("privileged-networks", nil, "Headscale usernames with hub-spoke access to all WonderNets (repeatable)")

_ = viper.BindPFlag("coordinator.listen", cmd.Flags().Lookup("listen"))
_ = viper.BindPFlag("coordinator.public_url", cmd.Flags().Lookup("public-url"))
_ = viper.BindPFlag("coordinator.database_driver", cmd.Flags().Lookup("db-driver"))
_ = viper.BindPFlag("coordinator.database_dsn", cmd.Flags().Lookup("db-dsn"))
_ = viper.BindPFlag("coordinator.enable_admin_api", cmd.Flags().Lookup("enable-admin-api"))
_ = viper.BindPFlag("coordinator.privileged_networks", cmd.Flags().Lookup("privileged-networks"))

_ = viper.BindEnv("coordinator.listen", "LISTEN")
_ = viper.BindEnv("coordinator.public_url", "PUBLIC_URL")
Expand All @@ -44,6 +47,7 @@ func NewCoordinatorCmd() *cobra.Command {
_ = viper.BindEnv("coordinator.keycloak_client_secret", "KEYCLOAK_CLIENT_SECRET")
_ = viper.BindEnv("coordinator.enable_admin_api", "ENABLE_ADMIN_API")
_ = viper.BindEnv("coordinator.admin_api_auth_token", "ADMIN_API_AUTH_TOKEN")
_ = viper.BindEnv("coordinator.privileged_networks", "PRIVILEGED_NETWORKS")

return cmd
}
Expand All @@ -66,6 +70,8 @@ func runCoordinator(cmd *cobra.Command, args []string) {
cfg.EnableAdminAPI = viper.GetBool("coordinator.enable_admin_api")
cfg.AdminAPIAuthToken = viper.GetString("coordinator.admin_api_auth_token")

cfg.PrivilegedNetworks = parseStringSlice(viper.Get("coordinator.privileged_networks"))

if cfg.HeadscaleURL == "" {
cfg.HeadscaleURL = coordinator.DefaultHeadscaleURL
}
Expand Down Expand Up @@ -101,6 +107,10 @@ func runCoordinator(cmd *cobra.Command, args []string) {
slog.Info("admin API enabled")
}

if len(cfg.PrivilegedNetworks) > 0 {
slog.Info("privileged networks configured", "networks", cfg.PrivilegedNetworks)
}

server, err := coordinator.BootstrapNewServer(&cfg)
if err != nil {
slog.Error("create server", "error", err)
Expand All @@ -111,3 +121,23 @@ func runCoordinator(cmd *cobra.Command, args []string) {
slog.Error("shutdown error", "error", err)
}
}

// parseStringSlice converts a viper value to []string.
// Handles []string from cobra StringArray flags and comma-separated string from env vars.
func parseStringSlice(val any) []string {
switch v := val.(type) {
case []string:
return v
case string:
var result []string
for _, n := range strings.Split(v, ",") {
n = strings.TrimSpace(n)
if n != "" {
result = append(result, n)
}
}
return result
default:
return nil
}
}
4 changes: 4 additions & 0 deletions internal/app/coordinator/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ type Config struct {
// AdminAPIAuthToken is the bearer token for admin API authentication.
// Required if EnableAdminAPI is true. Must be at least 32 characters.
AdminAPIAuthToken string `mapstructure:"admin_api_auth_token"`

// PrivilegedNetworks is the list of Headscale usernames that have access to all
// WonderNets (hub-spoke ACL model). When empty, pure isolation policy is used.
PrivilegedNetworks []string
}

const (
Expand Down
2 changes: 1 addition & 1 deletion internal/app/coordinator/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func BootstrapNewServer(config *Config) (*Server, error) {
meshBackend := tailscale.NewTailscaleMesh(headscaleClient, config.PublicURL)

// Create services
wonderNetService := service.NewWonderNetService(wonderNetRepository, wonderNetManager, aclManager, config.PublicURL)
wonderNetService := service.NewWonderNetService(wonderNetRepository, wonderNetManager, aclManager, config.PublicURL, config.PrivilegedNetworks)
workerService := service.NewWorkerService(tokenGenerator, config.JWTSecret, wonderNetRepository, meshBackend)
nodesService := service.NewNodesService(meshBackend)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, wonderNetRepository)
Expand Down
10 changes: 9 additions & 1 deletion internal/app/coordinator/service/wondernet.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type WonderNetService struct {
wonderNetManager *headscale.WonderNetManager
aclManager *headscale.ACLManager
publicURL string
privilegedNetworks []string
}

// NewWonderNetService creates a new WonderNetService.
Expand All @@ -31,12 +32,14 @@ func NewWonderNetService(
wonderNetManager *headscale.WonderNetManager,
aclManager *headscale.ACLManager,
publicURL string,
privilegedNetworks []string,
) *WonderNetService {
return &WonderNetService{
wonderNetRepository: wonderNetRepository,
wonderNetManager: wonderNetManager,
aclManager: aclManager,
publicURL: publicURL,
privilegedNetworks: privilegedNetworks,
}
}

Expand Down Expand Up @@ -93,8 +96,13 @@ func (s *WonderNetService) GetPublicURL() string {
return s.publicURL
}

// InitializeACLPolicy rebuilds the ACL policy from all existing WonderNets to enforce isolation.
// InitializeACLPolicy rebuilds the full ACL policy from all existing Headscale users.
// When a privileged network is configured, a hub-spoke policy is used;
// otherwise, pure isolation policy is applied.
func (s *WonderNetService) InitializeACLPolicy(ctx context.Context) error {
if len(s.privilegedNetworks) > 0 {
return s.aclManager.SetHubSpokePolicy(ctx, s.privilegedNetworks)
}
return s.aclManager.SetWonderNetIsolationPolicy(ctx)
}

Expand Down
61 changes: 61 additions & 0 deletions pkg/headscale/acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,33 @@ func GenerateWonderNetIsolationPolicy(usernames []string) *ACLPolicy {
}
}

// GenerateHubSpokePolicy generates an ACL policy where privileged namespaces
// can initiate connections to all nodes, while normal namespaces are isolated
// from each other. Tailscale ACLs are directional and control connection
// initiation only; reply traffic flows back over established connections
// without needing a separate rule.
func GenerateHubSpokePolicy(privilegedUsers []string, normalUsers []string) *ACLPolicy {
rules := make([]ACLRule, 0, len(privilegedUsers)+len(normalUsers))

for _, user := range privilegedUsers {
rules = append(rules, ACLRule{
Action: "accept",
Sources: []string{user + "@"},
Destinations: []string{"*:*"},
})
}

for _, username := range normalUsers {
rules = append(rules, ACLRule{
Action: "accept",
Sources: []string{username + "@"},
Destinations: []string{username + "@:*"},
})
}

return &ACLPolicy{ACLs: rules}
}

// ACLManager manages ACL policies in Headscale
type ACLManager struct {
client v1.HeadscaleServiceClient
Expand Down Expand Up @@ -78,6 +105,40 @@ func (am *ACLManager) SetWonderNetIsolationPolicy(ctx context.Context) error {
return err
}

// SetHubSpokePolicy sets an ACL policy where privileged namespaces can access
// all nodes while normal namespaces are isolated from each other.
func (am *ACLManager) SetHubSpokePolicy(ctx context.Context, privilegedUsers []string) error {
am.mu.Lock()
defer am.mu.Unlock()

resp, err := am.client.ListUsers(ctx, &v1.ListUsersRequest{})
if err != nil {
return fmt.Errorf("list users: %w", err)
}

privilegedSet := make(map[string]struct{}, len(privilegedUsers))
for _, u := range privilegedUsers {
privilegedSet[u] = struct{}{}
}

var normalUsers []string
for _, u := range resp.GetUsers() {
name := u.GetName()
if _, ok := privilegedSet[name]; !ok {
normalUsers = append(normalUsers, name)
}
}

policy := GenerateHubSpokePolicy(privilegedUsers, normalUsers)
policyJSON, err := json.Marshal(policy)
if err != nil {
return fmt.Errorf("marshal policy: %w", err)
}

_, err = am.client.SetPolicy(ctx, &v1.SetPolicyRequest{Policy: string(policyJSON)})
return err
}

// AddWonderNetToPolicy adds a wonder net to the isolation policy
func (am *ACLManager) AddWonderNetToPolicy(ctx context.Context, username string) error {
am.mu.Lock()
Expand Down
77 changes: 77 additions & 0 deletions pkg/headscale/acl_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package headscale

import (
"testing"
)

func TestGenerateWonderNetIsolationPolicy(t *testing.T) {
policy := GenerateWonderNetIsolationPolicy([]string{"user1", "user2"})

if len(policy.ACLs) != 2 {
t.Fatalf("expected 2 rules, got %d", len(policy.ACLs))
}

assertRule(t, policy.ACLs[0], "accept", []string{"user1@"}, []string{"user1@:*"})
assertRule(t, policy.ACLs[1], "accept", []string{"user2@"}, []string{"user2@:*"})
}

func TestGenerateHubSpokePolicy(t *testing.T) {
policy := GenerateHubSpokePolicy([]string{"zeabur"}, []string{"uuid1", "uuid2"})

// 1 rule for privileged (outbound only) + 2 rules for normal users
if len(policy.ACLs) != 3 {
t.Fatalf("expected 3 rules, got %d", len(policy.ACLs))
}

assertRule(t, policy.ACLs[0], "accept", []string{"zeabur@"}, []string{"*:*"})
assertRule(t, policy.ACLs[1], "accept", []string{"uuid1@"}, []string{"uuid1@:*"})
assertRule(t, policy.ACLs[2], "accept", []string{"uuid2@"}, []string{"uuid2@:*"})
}

func TestGenerateHubSpokePolicy_MultiplePrivileged(t *testing.T) {
policy := GenerateHubSpokePolicy([]string{"zeabur", "admin"}, []string{"uuid1"})

// 1 rule per privileged user (2) + 1 normal user
if len(policy.ACLs) != 3 {
t.Fatalf("expected 3 rules, got %d", len(policy.ACLs))
}

assertRule(t, policy.ACLs[0], "accept", []string{"zeabur@"}, []string{"*:*"})
assertRule(t, policy.ACLs[1], "accept", []string{"admin@"}, []string{"*:*"})
assertRule(t, policy.ACLs[2], "accept", []string{"uuid1@"}, []string{"uuid1@:*"})
}

func TestGenerateHubSpokePolicy_NoNormalUsers(t *testing.T) {
policy := GenerateHubSpokePolicy([]string{"zeabur"}, nil)

if len(policy.ACLs) != 1 {
t.Fatalf("expected 1 rule, got %d", len(policy.ACLs))
}

assertRule(t, policy.ACLs[0], "accept", []string{"zeabur@"}, []string{"*:*"})
}

func assertRule(t *testing.T, rule ACLRule, action string, src, dst []string) {
t.Helper()
if rule.Action != action {
t.Errorf("expected action %q, got %q", action, rule.Action)
}
if len(rule.Sources) != len(src) {
t.Errorf("expected %d sources, got %d", len(src), len(rule.Sources))
return
}
for i := range src {
if rule.Sources[i] != src[i] {
t.Errorf("source[%d]: expected %q, got %q", i, src[i], rule.Sources[i])
}
}
if len(rule.Destinations) != len(dst) {
t.Errorf("expected %d destinations, got %d", len(dst), len(rule.Destinations))
return
}
for i := range dst {
if rule.Destinations[i] != dst[i] {
t.Errorf("destination[%d]: expected %q, got %q", i, dst[i], rule.Destinations[i])
}
}
}
Loading