From a29c7a6a156835ac9dba9ae027fce8456360a541 Mon Sep 17 00:00:00 2001 From: Felix Eckhofer Date: Mon, 7 Oct 2024 15:24:49 +0200 Subject: [PATCH] Implement SSH host key verification --- cmd/litestream/main.go | 2 + cmd/litestream/main_test.go | 22 ++++++++ etc/litestream.yml | 8 +++ internal/testingutil/testingutil.go | 63 +++++++++++++++++++++++ replica_client_test.go | 78 +++++++++++++++++++++++++++++ sftp/replica_client.go | 15 +++++- 6 files changed, 187 insertions(+), 1 deletion(-) diff --git a/cmd/litestream/main.go b/cmd/litestream/main.go index 3bf00b194..654db8189 100644 --- a/cmd/litestream/main.go +++ b/cmd/litestream/main.go @@ -561,6 +561,7 @@ type ReplicaConfig struct { Password string `yaml:"password"` KeyPath string `yaml:"key-path"` ConcurrentWrites *bool `yaml:"concurrent-writes"` + HostKey string `yaml:"host-key"` // NATS settings JWT string `yaml:"jwt"` @@ -869,6 +870,7 @@ func newSFTPReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_ client.Password = password client.Path = path client.KeyPath = c.KeyPath + client.HostKey = c.HostKey // Set concurrent writes if specified, otherwise use default (true) if c.ConcurrentWrites != nil { diff --git a/cmd/litestream/main_test.go b/cmd/litestream/main_test.go index dadbc3236..21a538d92 100644 --- a/cmd/litestream/main_test.go +++ b/cmd/litestream/main_test.go @@ -13,6 +13,7 @@ import ( "github.com/benbjohnson/litestream/file" "github.com/benbjohnson/litestream/gs" "github.com/benbjohnson/litestream/s3" + "github.com/benbjohnson/litestream/sftp" ) func TestOpenConfigFile(t *testing.T) { @@ -217,6 +218,27 @@ func TestNewGSReplicaFromConfig(t *testing.T) { } } +func TestNewSFTPReplicaFromConfig(t *testing.T) { + hostKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIAnK0+GdwOelXlAXdqLx/qvS7WHMr3rH7zW2+0DtmK5r" + r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{ + URL: "sftp://user@example.com:2222/foo", + HostKey: hostKey, + }, nil) + if err != nil { + t.Fatal(err) + } else if client, ok := r.Client.(*sftp.ReplicaClient); !ok { + t.Fatal("unexpected replica type") + } else if got, want := client.HostKey, hostKey; got != want { + t.Fatalf("HostKey=%s, want %s", got, want) + } else if got, want := client.Host, "example.com:2222"; got != want { + t.Fatalf("Host=%s, want %s", got, want) + } else if got, want := client.User, "user"; got != want { + t.Fatalf("User=%s, want %s", got, want) + } else if got, want := client.Path, "/foo"; got != want { + t.Fatalf("Path=%s, want %s", got, want) + } +} + // TestConfig_Validate_SnapshotIntervals tests validation of snapshot intervals func TestConfig_Validate_SnapshotIntervals(t *testing.T) { t.Run("ValidInterval", func(t *testing.T) { diff --git a/etc/litestream.yml b/etc/litestream.yml index 6ead81cba..203beaa81 100644 --- a/etc/litestream.yml +++ b/etc/litestream.yml @@ -14,3 +14,11 @@ # # client-cert: /path/to/client.pem # # client-key: /path/to/client.key # # root-cas: [/path/to/ca.pem] +# - url: sftp://user@host:22/path # SFTP-based replication +# key-path: /etc/litestream/sftp_key +# # Strongly recommended: SSH host key for verification +# # Get this from the server's /etc/ssh/ssh_host_*.pub file +# # or use `ssh-keyscan hostname` +# # Example formats: +# # host-key: ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMvvypUkBrS9RCyV//p+UFCLg8yKNtTu/ew/cV6XXAAP +# # host-key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ... diff --git a/internal/testingutil/testingutil.go b/internal/testingutil/testingutil.go index 897fd20d4..52ea211f4 100644 --- a/internal/testingutil/testingutil.go +++ b/internal/testingutil/testingutil.go @@ -5,14 +5,19 @@ import ( "database/sql" "flag" "fmt" + "io" "log/slog" "math/rand/v2" + "net" "os" "path" "path/filepath" "strings" "testing" + sftpserver "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" + "github.com/benbjohnson/litestream" "github.com/benbjohnson/litestream/abs" "github.com/benbjohnson/litestream/file" @@ -295,3 +300,61 @@ func MustDeleteAll(tb testing.TB, c litestream.ReplicaClient) { } } } + +func MockSFTPServer(t *testing.T, hostKey ssh.Signer) string { + config := &ssh.ServerConfig{NoClientAuth: true} + config.AddHostKey(hostKey) + + listener, err := net.Listen("tcp", "127.0.0.1:0") // random available port + if err != nil { + t.Fatal(err) + } + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + + go func() { + _, chans, reqs, err := ssh.NewServerConn(conn, config) + if err != nil { + return + } + go ssh.DiscardRequests(reqs) + + for ch := range chans { + if ch.ChannelType() != "session" { + ch.Reject(ssh.UnknownChannelType, "unsupported") + continue + } + channel, requests, err := ch.Accept() + if err != nil { + return + } + + go func(in <-chan *ssh.Request) { + for req := range in { + if req.Type == "subsystem" && string(req.Payload[4:]) == "sftp" { + req.Reply(true, nil) + + server, err := sftpserver.NewServer(channel) + if err != nil { + return + } + if err := server.Serve(); err != nil && err != io.EOF { + t.Logf("SFTP server error: %v", err) + } + return + } + req.Reply(false, nil) + } + }(requests) + } + }() + } + }() + + return listener.Addr().String() +} diff --git a/replica_client_test.go b/replica_client_test.go index 0c9fa37ac..181b5706b 100644 --- a/replica_client_test.go +++ b/replica_client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "log/slog" "os" "slices" "strings" @@ -11,6 +12,7 @@ import ( "time" "github.com/superfly/ltx" + "golang.org/x/crypto/ssh" "github.com/benbjohnson/litestream" "github.com/benbjohnson/litestream/internal/testingutil" @@ -308,3 +310,79 @@ func TestReplicaClient_S3_BucketValidation(t *testing.T) { t.Errorf("expected bucket validation error, got: %v", err) } } + +func TestReplicaClient_SFTP_HostKeyValidation(t *testing.T) { + testHostKeyPEM := `-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACAJytPhncDnpV5QF3ai8f6r0u1hzK96x+81tvtA7ZiuawAAAJAIcGGVCHBh +lQAAAAtzc2gtZWQyNTUxOQAAACAJytPhncDnpV5QF3ai8f6r0u1hzK96x+81tvtA7Ziuaw +AAAEDzV1D6COyvFGhSiZa6ll9aXZ2IMWED3KGrvCNjEEtYHwnK0+GdwOelXlAXdqLx/qvS +7WHMr3rH7zW2+0DtmK5rAAAADGZlbGl4QGJvcmVhcwE= +-----END OPENSSH PRIVATE KEY-----` + privateKey, err := ssh.ParsePrivateKey([]byte(testHostKeyPEM)) + if err != nil { + t.Fatal(err) + } + + t.Run("ValidHostKey", func(t *testing.T) { + addr := testingutil.MockSFTPServer(t, privateKey) + expectedHostKey := string(ssh.MarshalAuthorizedKey(privateKey.PublicKey())) + + c := testingutil.NewSFTPReplicaClient(t) + c.User = "foo" + c.Host = addr + c.HostKey = expectedHostKey + + _, err = c.Init(context.Background()) + if err != nil { + t.Fatalf("SFTP connection failed: %v", err) + } + }) + t.Run("InvalidHostKey", func(t *testing.T) { + addr := testingutil.MockSFTPServer(t, privateKey) + invalidHostKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIEqM2NkGvKKhR1oiKO0E72L3tOsYk+aX7H8Xn4bbZKsa" + + c := testingutil.NewSFTPReplicaClient(t) + c.User = "foo" + c.Host = addr + c.HostKey = invalidHostKey + + _, err = c.Init(context.Background()) + if err == nil { + t.Fatalf("SFTP connection established despite invalid host key") + } + if !strings.Contains(err.Error(), "ssh: host key mismatch") { + t.Errorf("expected host key validation error, got: %v", err) + } + }) + t.Run("IgnoreHostKey", func(t *testing.T) { + var captured []string + slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{ + Level: slog.LevelWarn, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == slog.MessageKey { + captured = append(captured, a.Value.String()) + } + return a + }, + }))) + + addr := testingutil.MockSFTPServer(t, privateKey) + + c := testingutil.NewSFTPReplicaClient(t) + c.User = "foo" + c.Host = addr + + _, err = c.Init(context.Background()) + if err != nil { + t.Fatalf("SFTP connection failed: %v", err) + } + + if !slices.ContainsFunc(captured, func(msg string) bool { + return strings.Contains(msg, "sftp host key not verified") + }) { + t.Errorf("Expected warning not found") + } + + }) +} diff --git a/sftp/replica_client.go b/sftp/replica_client.go index f77cf8005..3a1532a89 100644 --- a/sftp/replica_client.go +++ b/sftp/replica_client.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "os" "path" @@ -41,6 +42,7 @@ type ReplicaClient struct { Password string Path string KeyPath string + HostKey string DialTimeout time.Duration // ConcurrentWrites enables concurrent writes for better performance. @@ -75,9 +77,20 @@ func (c *ReplicaClient) Init(ctx context.Context) (_ *sftp.Client, err error) { } // Build SSH configuration & auth methods + var hostkey ssh.HostKeyCallback + if c.HostKey != "" { + var pubkey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(c.HostKey)) + if err != nil { + return nil, fmt.Errorf("cannot parse sftp host key: %w", err) + } + hostkey = ssh.FixedHostKey(pubkey) + } else { + slog.Warn("sftp host key not verified", "host", c.Host) + hostkey = ssh.InsecureIgnoreHostKey() + } config := &ssh.ClientConfig{ User: c.User, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyCallback: hostkey, BannerCallback: ssh.BannerDisplayStderr(), } if c.Password != "" {