Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/litestream/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ type ReplicaConfig struct {
Password string `yaml:"password"`
KeyPath string `yaml:"key-path"`
ConcurrentWrites *bool `yaml:"concurrent-writes"`
HostKey string `yaml:"host-key"`

// NATS settings
JWT string `yaml:"jwt"`
Expand Down Expand Up @@ -869,6 +870,7 @@ func newSFTPReplicaClientFromConfig(c *ReplicaConfig, _ *litestream.Replica) (_
client.Password = password
client.Path = path
client.KeyPath = c.KeyPath
client.HostKey = c.HostKey

// Set concurrent writes if specified, otherwise use default (true)
if c.ConcurrentWrites != nil {
Expand Down
22 changes: 22 additions & 0 deletions cmd/litestream/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/benbjohnson/litestream/file"
"github.com/benbjohnson/litestream/gs"
"github.com/benbjohnson/litestream/s3"
"github.com/benbjohnson/litestream/sftp"
)

func TestOpenConfigFile(t *testing.T) {
Expand Down Expand Up @@ -217,6 +218,27 @@ func TestNewGSReplicaFromConfig(t *testing.T) {
}
}

func TestNewSFTPReplicaFromConfig(t *testing.T) {
hostKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIAnK0+GdwOelXlAXdqLx/qvS7WHMr3rH7zW2+0DtmK5r"
r, err := main.NewReplicaFromConfig(&main.ReplicaConfig{
URL: "sftp://user@example.com:2222/foo",
HostKey: hostKey,
}, nil)
if err != nil {
t.Fatal(err)
} else if client, ok := r.Client.(*sftp.ReplicaClient); !ok {
t.Fatal("unexpected replica type")
} else if got, want := client.HostKey, hostKey; got != want {
t.Fatalf("HostKey=%s, want %s", got, want)
} else if got, want := client.Host, "example.com:2222"; got != want {
t.Fatalf("Host=%s, want %s", got, want)
} else if got, want := client.User, "user"; got != want {
t.Fatalf("User=%s, want %s", got, want)
} else if got, want := client.Path, "/foo"; got != want {
t.Fatalf("Path=%s, want %s", got, want)
}
}

// TestConfig_Validate_SnapshotIntervals tests validation of snapshot intervals
func TestConfig_Validate_SnapshotIntervals(t *testing.T) {
t.Run("ValidInterval", func(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions etc/litestream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,11 @@
# # client-cert: /path/to/client.pem
# # client-key: /path/to/client.key
# # root-cas: [/path/to/ca.pem]
# - url: sftp://user@host:22/path # SFTP-based replication
# key-path: /etc/litestream/sftp_key
# # Strongly recommended: SSH host key for verification
# # Get this from the server's /etc/ssh/ssh_host_*.pub file
# # or use `ssh-keyscan hostname`
# # Example formats:
# # host-key: ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMvvypUkBrS9RCyV//p+UFCLg8yKNtTu/ew/cV6XXAAP
# # host-key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ...
63 changes: 63 additions & 0 deletions internal/testingutil/testingutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@ import (
"database/sql"
"flag"
"fmt"
"io"
"log/slog"
"math/rand/v2"
"net"
"os"
"path"
"path/filepath"
"strings"
"testing"

sftpserver "github.com/pkg/sftp"
"golang.org/x/crypto/ssh"

"github.com/benbjohnson/litestream"
"github.com/benbjohnson/litestream/abs"
"github.com/benbjohnson/litestream/file"
Expand Down Expand Up @@ -295,3 +300,61 @@ func MustDeleteAll(tb testing.TB, c litestream.ReplicaClient) {
}
}
}

func MockSFTPServer(t *testing.T, hostKey ssh.Signer) string {
config := &ssh.ServerConfig{NoClientAuth: true}
config.AddHostKey(hostKey)

listener, err := net.Listen("tcp", "127.0.0.1:0") // random available port
if err != nil {
t.Fatal(err)
}

go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}

go func() {
_, chans, reqs, err := ssh.NewServerConn(conn, config)
if err != nil {
return
}
go ssh.DiscardRequests(reqs)

for ch := range chans {
if ch.ChannelType() != "session" {
ch.Reject(ssh.UnknownChannelType, "unsupported")
continue
}
channel, requests, err := ch.Accept()
if err != nil {
return
}

go func(in <-chan *ssh.Request) {
for req := range in {
if req.Type == "subsystem" && string(req.Payload[4:]) == "sftp" {
req.Reply(true, nil)

server, err := sftpserver.NewServer(channel)
if err != nil {
return
}
if err := server.Serve(); err != nil && err != io.EOF {
t.Logf("SFTP server error: %v", err)
}
return
}
req.Reply(false, nil)
}
}(requests)
}
}()
}
}()

return listener.Addr().String()
}
78 changes: 78 additions & 0 deletions replica_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"bytes"
"context"
"io"
"log/slog"
"os"
"slices"
"strings"
"testing"
"time"

"github.com/superfly/ltx"
"golang.org/x/crypto/ssh"

"github.com/benbjohnson/litestream"
"github.com/benbjohnson/litestream/internal/testingutil"
Expand Down Expand Up @@ -308,3 +310,79 @@ func TestReplicaClient_S3_BucketValidation(t *testing.T) {
t.Errorf("expected bucket validation error, got: %v", err)
}
}

func TestReplicaClient_SFTP_HostKeyValidation(t *testing.T) {
testHostKeyPEM := `-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
QyNTUxOQAAACAJytPhncDnpV5QF3ai8f6r0u1hzK96x+81tvtA7ZiuawAAAJAIcGGVCHBh
lQAAAAtzc2gtZWQyNTUxOQAAACAJytPhncDnpV5QF3ai8f6r0u1hzK96x+81tvtA7Ziuaw
AAAEDzV1D6COyvFGhSiZa6ll9aXZ2IMWED3KGrvCNjEEtYHwnK0+GdwOelXlAXdqLx/qvS
7WHMr3rH7zW2+0DtmK5rAAAADGZlbGl4QGJvcmVhcwE=
-----END OPENSSH PRIVATE KEY-----`
privateKey, err := ssh.ParsePrivateKey([]byte(testHostKeyPEM))
if err != nil {
t.Fatal(err)
}

t.Run("ValidHostKey", func(t *testing.T) {
addr := testingutil.MockSFTPServer(t, privateKey)
expectedHostKey := string(ssh.MarshalAuthorizedKey(privateKey.PublicKey()))

c := testingutil.NewSFTPReplicaClient(t)
c.User = "foo"
c.Host = addr
c.HostKey = expectedHostKey

_, err = c.Init(context.Background())
if err != nil {
t.Fatalf("SFTP connection failed: %v", err)
}
})
t.Run("InvalidHostKey", func(t *testing.T) {
addr := testingutil.MockSFTPServer(t, privateKey)
invalidHostKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIEqM2NkGvKKhR1oiKO0E72L3tOsYk+aX7H8Xn4bbZKsa"

c := testingutil.NewSFTPReplicaClient(t)
c.User = "foo"
c.Host = addr
c.HostKey = invalidHostKey

_, err = c.Init(context.Background())
if err == nil {
t.Fatalf("SFTP connection established despite invalid host key")
}
if !strings.Contains(err.Error(), "ssh: host key mismatch") {
t.Errorf("expected host key validation error, got: %v", err)
}
})
t.Run("IgnoreHostKey", func(t *testing.T) {
var captured []string
slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{
Level: slog.LevelWarn,
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
if a.Key == slog.MessageKey {
captured = append(captured, a.Value.String())
}
return a
},
})))

addr := testingutil.MockSFTPServer(t, privateKey)

c := testingutil.NewSFTPReplicaClient(t)
c.User = "foo"
c.Host = addr

_, err = c.Init(context.Background())
if err != nil {
t.Fatalf("SFTP connection failed: %v", err)
}

if !slices.ContainsFunc(captured, func(msg string) bool {
return strings.Contains(msg, "sftp host key not verified")
}) {
t.Errorf("Expected warning not found")
}

})
}
15 changes: 14 additions & 1 deletion sftp/replica_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net"
"os"
"path"
Expand Down Expand Up @@ -41,6 +42,7 @@ type ReplicaClient struct {
Password string
Path string
KeyPath string
HostKey string
DialTimeout time.Duration

// ConcurrentWrites enables concurrent writes for better performance.
Expand Down Expand Up @@ -75,9 +77,20 @@ func (c *ReplicaClient) Init(ctx context.Context) (_ *sftp.Client, err error) {
}

// Build SSH configuration & auth methods
var hostkey ssh.HostKeyCallback
if c.HostKey != "" {
var pubkey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(c.HostKey))
if err != nil {
return nil, fmt.Errorf("cannot parse sftp host key: %w", err)
}
hostkey = ssh.FixedHostKey(pubkey)
} else {
slog.Warn("sftp host key not verified", "host", c.Host)
hostkey = ssh.InsecureIgnoreHostKey()
}
config := &ssh.ClientConfig{
User: c.User,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
HostKeyCallback: hostkey,
BannerCallback: ssh.BannerDisplayStderr(),
}
if c.Password != "" {
Expand Down