Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 75 additions & 5 deletions postgresql/proxy_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"net"
"strings"
"time"

"github.com/lib/pq"
Expand All @@ -20,14 +22,82 @@ func (d proxyDriver) Open(name string) (driver.Conn, error) {
}

func (d proxyDriver) Dial(network, address string) (net.Conn, error) {
dialer := proxy.FromEnvironment()
return dialer.Dial(network, address)
return d.DialTimeout(network, address, 0)
}

func (d proxyDriver) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
defer cancel()
return proxy.Dial(ctx, network, address)
var ctx context.Context
var cancel context.CancelFunc
if timeout > 0 {
ctx, cancel = context.WithTimeout(context.Background(), timeout)
defer cancel()
} else {
ctx = context.Background()
}

// Only handle TCP networks for multi-host splitting
if !strings.HasPrefix(network, "tcp") {
return proxy.Dial(ctx, network, address)
}

hosts, port, err := parseAddress(address)
if err != nil {
// If parsing fails, fall back to trying the original address
return proxy.Dial(ctx, network, address)
}

var lastErr error
for _, host := range hosts {
addr := net.JoinHostPort(host, port)
conn, err := proxy.Dial(ctx, network, addr)
if err == nil {
return conn, nil
}
lastErr = err

// Check if context expired
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
if lastErr != nil {
return nil, lastErr
}
return nil, errors.New("no hosts available")
}

func parseAddress(address string) ([]string, string, error) {
host, port, err := net.SplitHostPort(address)
if err == nil {
if strings.Contains(host, ",") {
return strings.Split(host, ","), port, nil
}
return []string{host}, port, nil
}

// Fallback for when net.SplitHostPort fails (e.g. mixed bracketed and unbracketed hosts)
lastColon := strings.LastIndex(address, ":")
if lastColon == -1 {
return nil, "", err
}

port = address[lastColon+1:]
hostPart := address[:lastColon]

if strings.Contains(hostPart, ",") {
hosts := strings.Split(hostPart, ",")
// Clean up brackets if present so net.JoinHostPort doesn't double them
for i, h := range hosts {
if len(h) > 2 && h[0] == '[' && h[len(h)-1] == ']' {
hosts[i] = h[1 : len(h)-1]
}
}
return hosts, port, nil
}

return nil, "", err
}

func init() {
Expand Down
83 changes: 83 additions & 0 deletions postgresql/proxy_driver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package postgresql

import (
"net"
"testing"

"github.com/stretchr/testify/assert"
)

func TestParseAddress(t *testing.T) {
tests := []struct {
input string
expectHosts []string
expectPort string
expectErr bool
}{
{
input: "host1:5432",
expectHosts: []string{"host1"},
expectPort: "5432",
expectErr: false,
},
{
input: "host1,host2:5432",
expectHosts: []string{"host1", "host2"},
expectPort: "5432",
expectErr: false,
},
{
input: "[::1]:5432",
expectHosts: []string{"::1"}, // net.SplitHostPort strips brackets
expectPort: "5432",
expectErr: false,
},
{
input: "[::1],localhost:5432",
expectHosts: []string{"::1", "localhost"}, // manual split strips brackets
expectPort: "5432",
expectErr: false,
},
{
input: "host1,[::1]:5432",
expectHosts: []string{"host1", "::1"},
expectPort: "5432",
expectErr: false,
},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
hosts, port, err := parseAddress(tt.input)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectHosts, hosts)
assert.Equal(t, tt.expectPort, port)
}
})
}
}

func TestReconstruction(t *testing.T) {
// Verify that net.JoinHostPort reconstructs correctly from what parseAddress returns
tests := []string{
"host1:5432",
"host1,host2:5432",
"[::1]:5432",
"[::1],localhost:5432",
}

for _, input := range tests {
hosts, port, err := parseAddress(input)
assert.NoError(t, err)
for _, h := range hosts {
addr := net.JoinHostPort(h, port)
// Sanity check on address format
_, _, err := net.SplitHostPort(addr)
assert.NoError(t, err, "JoinHostPort produced invalid address: %s", addr)
}
}
}