From 43d86097d64911188e7482d8f3c71c34b6ab9c33 Mon Sep 17 00:00:00 2001 From: ljluestc Date: Wed, 26 Nov 2025 08:33:56 -0800 Subject: [PATCH] Support multiple hosts in provider configuration --- postgresql/proxy_driver.go | 80 +++++++++++++++++++++++++++++-- postgresql/proxy_driver_test.go | 83 +++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 5 deletions(-) create mode 100644 postgresql/proxy_driver_test.go diff --git a/postgresql/proxy_driver.go b/postgresql/proxy_driver.go index f08c2b1a..c28fcfbc 100644 --- a/postgresql/proxy_driver.go +++ b/postgresql/proxy_driver.go @@ -4,7 +4,9 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" "net" + "strings" "time" "github.com/lib/pq" @@ -20,14 +22,82 @@ func (d proxyDriver) Open(name string) (driver.Conn, error) { } func (d proxyDriver) Dial(network, address string) (net.Conn, error) { - dialer := proxy.FromEnvironment() - return dialer.Dial(network, address) + return d.DialTimeout(network, address, 0) } func (d proxyDriver) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { - ctx, cancel := context.WithTimeout(context.TODO(), timeout) - defer cancel() - return proxy.Dial(ctx, network, address) + var ctx context.Context + var cancel context.CancelFunc + if timeout > 0 { + ctx, cancel = context.WithTimeout(context.Background(), timeout) + defer cancel() + } else { + ctx = context.Background() + } + + // Only handle TCP networks for multi-host splitting + if !strings.HasPrefix(network, "tcp") { + return proxy.Dial(ctx, network, address) + } + + hosts, port, err := parseAddress(address) + if err != nil { + // If parsing fails, fall back to trying the original address + return proxy.Dial(ctx, network, address) + } + + var lastErr error + for _, host := range hosts { + addr := net.JoinHostPort(host, port) + conn, err := proxy.Dial(ctx, network, addr) + if err == nil { + return conn, nil + } + lastErr = err + + // Check if context expired + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + } + if lastErr != nil { + return nil, lastErr + } + return nil, errors.New("no hosts available") +} + +func parseAddress(address string) ([]string, string, error) { + host, port, err := net.SplitHostPort(address) + if err == nil { + if strings.Contains(host, ",") { + return strings.Split(host, ","), port, nil + } + return []string{host}, port, nil + } + + // Fallback for when net.SplitHostPort fails (e.g. mixed bracketed and unbracketed hosts) + lastColon := strings.LastIndex(address, ":") + if lastColon == -1 { + return nil, "", err + } + + port = address[lastColon+1:] + hostPart := address[:lastColon] + + if strings.Contains(hostPart, ",") { + hosts := strings.Split(hostPart, ",") + // Clean up brackets if present so net.JoinHostPort doesn't double them + for i, h := range hosts { + if len(h) > 2 && h[0] == '[' && h[len(h)-1] == ']' { + hosts[i] = h[1 : len(h)-1] + } + } + return hosts, port, nil + } + + return nil, "", err } func init() { diff --git a/postgresql/proxy_driver_test.go b/postgresql/proxy_driver_test.go new file mode 100644 index 00000000..b0b3f5c3 --- /dev/null +++ b/postgresql/proxy_driver_test.go @@ -0,0 +1,83 @@ +package postgresql + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseAddress(t *testing.T) { + tests := []struct { + input string + expectHosts []string + expectPort string + expectErr bool + }{ + { + input: "host1:5432", + expectHosts: []string{"host1"}, + expectPort: "5432", + expectErr: false, + }, + { + input: "host1,host2:5432", + expectHosts: []string{"host1", "host2"}, + expectPort: "5432", + expectErr: false, + }, + { + input: "[::1]:5432", + expectHosts: []string{"::1"}, // net.SplitHostPort strips brackets + expectPort: "5432", + expectErr: false, + }, + { + input: "[::1],localhost:5432", + expectHosts: []string{"::1", "localhost"}, // manual split strips brackets + expectPort: "5432", + expectErr: false, + }, + { + input: "host1,[::1]:5432", + expectHosts: []string{"host1", "::1"}, + expectPort: "5432", + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + hosts, port, err := parseAddress(tt.input) + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectHosts, hosts) + assert.Equal(t, tt.expectPort, port) + } + }) + } +} + +func TestReconstruction(t *testing.T) { + // Verify that net.JoinHostPort reconstructs correctly from what parseAddress returns + tests := []string{ + "host1:5432", + "host1,host2:5432", + "[::1]:5432", + "[::1],localhost:5432", + } + + for _, input := range tests { + hosts, port, err := parseAddress(input) + assert.NoError(t, err) + for _, h := range hosts { + addr := net.JoinHostPort(h, port) + // Sanity check on address format + _, _, err := net.SplitHostPort(addr) + assert.NoError(t, err, "JoinHostPort produced invalid address: %s", addr) + } + } +} +