diff --git a/neo4j/directrouter.go b/neo4j/directrouter.go index dbebd9b3..7f6dac90 100644 --- a/neo4j/directrouter.go +++ b/neo4j/directrouter.go @@ -28,6 +28,10 @@ type directRouter struct { address string } +func (r *directRouter) IsMultiServer() bool { + return false +} + func (r *directRouter) InvalidateWriter(string, string) {} func (r *directRouter) InvalidateReader(string, string) {} diff --git a/neo4j/driver_with_context.go b/neo4j/driver_with_context.go index 6cff2cb5..0c1734d1 100644 --- a/neo4j/driver_with_context.go +++ b/neo4j/driver_with_context.go @@ -312,6 +312,7 @@ type sessionRouter interface { InvalidateWriter(db string, server string) InvalidateReader(db string, server string) InvalidateServer(server string) + IsMultiServer() bool } type driverWithContext struct { diff --git a/neo4j/driver_with_context_testkit.go b/neo4j/driver_with_context_testkit.go index 5e645bce..7b8c6e8e 100644 --- a/neo4j/driver_with_context_testkit.go +++ b/neo4j/driver_with_context_testkit.go @@ -51,6 +51,10 @@ func ForceRoutingTableUpdate(d DriverWithContext, database string, bookmarks []s return errorutil.WrapError(err) } +func RegisterDnsResolver(d DriverWithContext, hook func(address string) []string) { + d.(*driverWithContext).connector.TestKitResolver = hook +} + func GetRoutingTable(d DriverWithContext, database string) (*RoutingTable, error) { driver := d.(*driverWithContext) router, ok := driver.router.(*router.Router) diff --git a/neo4j/internal/bolt/bolt3.go b/neo4j/internal/bolt/bolt3.go index 90874491..43990dae 100644 --- a/neo4j/internal/bolt/bolt3.go +++ b/neo4j/internal/bolt/bolt3.go @@ -151,6 +151,15 @@ func (b *bolt3) ServerName() string { return b.serverName } +func (b *bolt3) AdvertisedServerName() string { + // Advertised address not supported by this protocol version + return "" +} + +func (b *bolt3) SetServerName(serverName string) { + b.serverName = serverName +} + func (b *bolt3) ServerVersion() string { return b.serverVersion } diff --git a/neo4j/internal/bolt/bolt4.go b/neo4j/internal/bolt/bolt4.go index 5dbdf1a7..0f891a76 100644 --- a/neo4j/internal/bolt/bolt4.go +++ b/neo4j/internal/bolt/bolt4.go @@ -174,6 +174,15 @@ func (b *bolt4) ServerName() string { return b.serverName } +func (b *bolt4) AdvertisedServerName() string { + // Advertised address not supported by this protocol version + return "" +} + +func (b *bolt4) SetServerName(serverName string) { + b.serverName = serverName +} + func (b *bolt4) ServerVersion() string { return b.serverVersion } @@ -985,7 +994,7 @@ func (b *bolt4) GetCurrentAuth() (auth.TokenManager, iauth.Token) { } func (b *bolt4) Telemetry(telemetry.API, func()) { - // TELEMETRY not support by this protocol version, so we ignore it. + // TELEMETRY not supported by this protocol version, so we ignore it. } func (b *bolt4) helloResponseHandler(checkUtcPatch bool) responseHandler { diff --git a/neo4j/internal/bolt/bolt5.go b/neo4j/internal/bolt/bolt5.go index 54f57e59..8d9d1f9e 100644 --- a/neo4j/internal/bolt/bolt5.go +++ b/neo4j/internal/bolt/bolt5.go @@ -93,28 +93,29 @@ func (i *internalTx5) toMeta(logger log.Logger, logId string, version db.Protoco } type bolt5 struct { - state int - txId idb.TxHandle - streams openstreams - conn io.ReadWriteCloser - serverName string - queue messageQueue - connId string - logId string - serverVersion string - bookmark string // Last bookmark - birthDate time.Time - log log.Logger - databaseName string - err error // Last fatal error - minor int - lastQid int64 // Last seen qid - idleDate time.Time - auth map[string]any - authManager auth.TokenManager - resetAuth bool - errorListener ConnectionErrorListener - telemetryEnabled bool + state int + txId idb.TxHandle + streams openstreams + conn io.ReadWriteCloser + serverName string // Initial server name + advertisedServerName string // Preferred server name + queue messageQueue + connId string + logId string + serverVersion string + bookmark string // Last bookmark + birthDate time.Time + log log.Logger + databaseName string + err error // Last fatal error + minor int + lastQid int64 // Last seen qid + idleDate time.Time + auth map[string]any + authManager auth.TokenManager + resetAuth bool + errorListener ConnectionErrorListener + telemetryEnabled bool } func NewBolt5( @@ -178,6 +179,14 @@ func (b *bolt5) ServerName() string { return b.serverName } +func (b *bolt5) AdvertisedServerName() string { + return b.advertisedServerName +} + +func (b *bolt5) SetServerName(serverName string) { + b.serverName = serverName +} + func (b *bolt5) ServerVersion() string { return b.serverVersion } @@ -989,6 +998,9 @@ func (b *bolt5) logoffResponseHandler() responseHandler { } func (b *bolt5) logonResponseHandler() responseHandler { + if b.Version().Major >= 5 && b.Version().Minor >= 8 { + return b.expectedSuccessHandler(b.onLogonSuccess) + } return b.expectedSuccessHandler(onSuccessNoOp) } @@ -1127,6 +1139,10 @@ func (b *bolt5) onHelloSuccess(helloSuccess *success) { b.initializeTelemetryHint(helloSuccess.configurationHints) } +func (b *bolt5) onLogonSuccess(logonSuccess *success) { + b.advertisedServerName = logonSuccess.advertisedAddress +} + func (b *bolt5) onCommitSuccess(commitSuccess *success) { if len(commitSuccess.bookmark) > 0 { b.bookmark = commitSuccess.bookmark diff --git a/neo4j/internal/bolt/connect.go b/neo4j/internal/bolt/connect.go index 20d36adc..74c99cc3 100644 --- a/neo4j/internal/bolt/connect.go +++ b/neo4j/internal/bolt/connect.go @@ -37,7 +37,7 @@ type protocolVersion struct { // Supported versions in priority order var versions = [4]protocolVersion{ - {major: 5, minor: 7, back: 7}, + {major: 5, minor: 8, back: 8}, {major: 4, minor: 4, back: 2}, {major: 4, minor: 1}, {major: 3, minor: 0}, diff --git a/neo4j/internal/bolt/hydrator.go b/neo4j/internal/bolt/hydrator.go index 7f57a8c2..379be1a2 100644 --- a/neo4j/internal/bolt/hydrator.go +++ b/neo4j/internal/bolt/hydrator.go @@ -56,6 +56,7 @@ type success struct { num uint32 configurationHints map[string]any patches []string + advertisedAddress string } func (s *success) String() string { @@ -302,6 +303,8 @@ func (h *hydrator) success(n uint32) *success { case "patch_bolt": patches := h.strings() succ.patches = patches + case "advertised_address": + succ.advertisedAddress = h.unp.String() default: // Unknown key, waste it h.trash() diff --git a/neo4j/internal/connector/connector.go b/neo4j/internal/connector/connector.go index 97775c13..a355cc6c 100644 --- a/neo4j/internal/connector/connector.go +++ b/neo4j/internal/connector/connector.go @@ -41,9 +41,10 @@ type Connector struct { Network string Config *config.Config SupplyConnection func(context.Context, string) (net.Conn, error) + TestKitResolver func(string) []string } -func (c Connector) Connect( +func (c *Connector) Connect( ctx context.Context, address string, auth *db.ReAuthToken, @@ -138,7 +139,24 @@ func (c Connector) createConnection(ctx context.Context, address string) (net.Co dialer.KeepAlive = -1 * time.Second // Turns keep-alive off } - return dialer.DialContext(ctx, c.Network, address) + if c.TestKitResolver == nil { + return dialer.DialContext(ctx, c.Network, address) + } + + addresses := c.TestKitResolver(address) + + if len(addresses) == 0 { + return nil, errors.New("TestKit DNS resolver returned no address") + } + + var err error = nil + for _, address := range addresses { + con, err := dialer.DialContext(ctx, c.Network, address) + if err == nil { + return con, nil + } + } + return nil, err } func (c Connector) tlsConfig(serverName string) *tls.Config { diff --git a/neo4j/internal/db/connection.go b/neo4j/internal/db/connection.go index 5e358c9d..2d4e7943 100644 --- a/neo4j/internal/db/connection.go +++ b/neo4j/internal/db/connection.go @@ -124,6 +124,10 @@ type Connection interface { Bookmark() string // ServerName returns the name of the remote server ServerName() string + // AdvertisedServerName returns the advertised name of the remote server. + AdvertisedServerName() string + // SetServerName updates the server name to given value. + SetServerName(serverName string) // ServerVersion returns the server version on pattern Neo4j/1.2.3 ServerVersion() string // IsAlive returns true if the connection is fully functional. diff --git a/neo4j/internal/pool/pool.go b/neo4j/internal/pool/pool.go index cfe04652..65f9addb 100644 --- a/neo4j/internal/pool/pool.go +++ b/neo4j/internal/pool/pool.go @@ -47,6 +47,7 @@ type poolRouter interface { InvalidateWriter(db string, server string) InvalidateReader(db string, server string) InvalidateServer(server string) + IsMultiServer() bool } type qitem struct { @@ -303,7 +304,7 @@ func (p *Pool) tryBorrow( if healthy { return connection, nil } - p.unreg(ctx, serverName, connection, itime.Now()) + p.unreg(ctx, serverName, connection, itime.Now(), true) if err != nil { p.log.Debugf(log.Pool, p.logId, "Health check failed for %s: %s", serverName, err) return nil, err @@ -343,16 +344,18 @@ func (p *Pool) tryBorrow( return c, nil } -func (p *Pool) unreg(ctx context.Context, serverName string, c idb.Connection, now time.Time) { +func (p *Pool) unreg(ctx context.Context, serverName string, c idb.Connection, now time.Time, close bool) { p.serversMut.Lock() defer p.serversMut.Unlock() - p.unregLocked(ctx, serverName, c, now) + p.unregLocked(ctx, serverName, c, now, close) } -func (p *Pool) unregLocked(ctx context.Context, serverName string, c idb.Connection, now time.Time) { +func (p *Pool) unregLocked(ctx context.Context, serverName string, c idb.Connection, now time.Time, close bool) { defer func() { // Close connection in another thread to avoid potential long blocking operation during close. - go c.Close(ctx) + if close { + go c.Close(ctx) + } }() server := p.servers[serverName] @@ -384,16 +387,33 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) { return } - // Get the name of the server that the connection belongs to. - serverName := c.ServerName() - isAlive := c.IsAlive() - p.log.Debugf(log.Pool, p.logId, "Returning connection to %s {alive:%t}", serverName, isAlive) - // If the connection is dead, remove all other idle connections on the same server that older // or of the same age as the dead connection, otherwise perform normal cleanup of old connections maxAge := p.config.MaxConnectionLifetime now := itime.Now() age := now.Sub(c.Birthdate()) + + // Check if we have an advertised server name and if so replace connection from initial server. + // Only do this when routing is enabled. + if p.router.IsMultiServer() && c.AdvertisedServerName() != "" && c.ServerName() != c.AdvertisedServerName() { + // Remove connection from busy list of initial server. + p.unreg(ctx, c.ServerName(), c, now, false) + p.log.Debugf(log.Pool, p.logId, "Transferring connection from %s to advertised server %s", c.ServerName(), c.AdvertisedServerName()) + // Update connection server name to that of the advertised address. + c.SetServerName(c.AdvertisedServerName()) + // Create a fresh server. + p.serversMut.Lock() + if _, ok := p.servers[c.ServerName()]; !ok { + p.servers[c.ServerName()] = NewServer() + } + p.serversMut.Unlock() + } + + // Get the name of the server that the connection belongs to + serverName := c.ServerName() + isAlive := c.IsAlive() + p.log.Debugf(log.Pool, p.logId, "Returning connection to %s {alive:%t}", serverName, isAlive) + if !isAlive { // Since this connection has died all other connections that connected before this one // might also be bad, remove the idle ones. @@ -418,7 +438,7 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) { // Fix for race condition where expired connections could be reused or closed concurrently. // See: https://github.com/neo4j/neo4j-go-driver/issues/574 isAlive = false - p.unreg(ctx, serverName, c, now) + p.unreg(ctx, serverName, c, now, true) p.log.Infof(log.Pool, p.logId, "Unregistering dead or too old connection to %s", serverName) } diff --git a/neo4j/internal/pool/pool_test.go b/neo4j/internal/pool/pool_test.go index c6bc889f..325d898d 100644 --- a/neo4j/internal/pool/pool_test.go +++ b/neo4j/internal/pool/pool_test.go @@ -22,6 +22,7 @@ package pool import ( "context" "errors" + "fmt" "math/rand" "sync" "testing" @@ -42,6 +43,17 @@ var logger = log.ToVoid() var ctx = context.Background() var reAuthToken = &idb.ReAuthToken{FromSession: false, Manager: iauth.Token{Tokens: map[string]any{"scheme": "none"}}} +type connectFunc = func(ctx context.Context, name string, _ *idb.ReAuthToken, _ bolt.ConnectionErrorListener, _ log.BoltLogger) (idb.Connection, error) + +func newPool( + conf *config.Config, + connect connectFunc, +) *Pool { + pool := New(conf, connect, logger, "pool id") + pool.router = &RouterFake{} + return pool +} + func TestPoolBorrowReturn(outer *testing.T) { maxAge := 1 * time.Second birthdate := time.Now() @@ -59,7 +71,7 @@ func TestPoolBorrowReturn(outer *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) defer func() { p.Close(ctx) }() @@ -82,7 +94,7 @@ func TestPoolBorrowReturn(outer *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) defer func() { p.Close(ctx) }() @@ -114,7 +126,7 @@ func TestPoolBorrowReturn(outer *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) defer func() { p.Close(ctx) }() @@ -136,7 +148,7 @@ func TestPoolBorrowReturn(outer *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: maxConnections} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) serverNames := []string{"srv1"} numWorkers := 5 wg := sync.WaitGroup{} @@ -170,7 +182,7 @@ func TestPoolBorrowReturn(outer *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 2} - p := New(&conf, failingConnect, logger, "pool id") + p := newPool(&conf, failingConnect) p.SetRouter(&RouterFake{}) serverNames := []string{"srv1"} c, err := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) @@ -185,7 +197,7 @@ func TestPoolBorrowReturn(outer *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) c1, _ := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) cancelableCtx, cancel := context.WithCancel(ctx) wg := sync.WaitGroup{} @@ -216,7 +228,7 @@ func TestPoolBorrowReturn(outer *testing.T) { t.Errorf("y u call me?") }} conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - pool := New(&conf, nil, logger, "pool id") + pool := newPool(&conf, nil) setIdleConnections(pool, map[string][]idb.Connection{"a server": { deadAfterReset, stayingAlive, @@ -239,7 +251,7 @@ func TestPoolBorrowReturn(outer *testing.T) { t.Errorf("force reset should not be called on new connections") }} conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - pool := New(&conf, connectTo(healthyConnection), logger, "pool id") + pool := newPool(&conf, connectTo(healthyConnection)) setIdleConnections(pool, map[string][]idb.Connection{serverName: {deadAfterReset1, deadAfterReset2}}) result, err := pool.tryBorrow(ctx, serverName, nil, idlenessThreshold, reAuthToken) @@ -254,7 +266,7 @@ func TestPoolBorrowReturn(outer *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c1, err) ctx = context.Background() @@ -283,7 +295,7 @@ func TestPoolBorrowReturn(outer *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) c1, err := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, c1, err) ctx = context.Background() @@ -308,6 +320,61 @@ func TestPoolBorrowReturn(outer *testing.T) { wg.Wait() AssertTrue(t, reAuthCalled) }) + + outer.Run("Connection is transferred to advertised server on return", func(t *testing.T) { + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() + advertisedServerName := "advertised-server" + conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} + p := newPool(&conf, succeedingConnect) + p.router.(*RouterFake).IsMultiServerReturn = true + defer func() { + p.Close(ctx) + }() + serverNames := []string{"srvA"} + c, _ := p.Borrow(ctx, getServers(serverNames), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) + c.(*ConnFake).AdvertisedName = advertisedServerName + p.Return(ctx, c) + servers := p.getServers() + + if len(servers) != 1 { + t.Errorf("Expected only 1 server, but %v were found", len(servers)) + } + serv, exists := servers[advertisedServerName] + if !exists { + t.Errorf("Expected connection to be transferred to %s, but server was not found", advertisedServerName) + } else if serv.numIdle() != 1 { + t.Errorf("Expected 1 idle connection in %s, found %d", advertisedServerName, servers[advertisedServerName].numIdle()) + } + }) + + outer.Run("Connection is not transferred to advertised server on return for direct pool", func(t *testing.T) { + itime.ForceFreezeTime() + defer itime.ForceUnfreezeTime() + advertisedServerName := "advertised-server" + conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} + p := newPool(&conf, succeedingConnect) + p.router.(*RouterFake).IsMultiServerReturn = false + defer func() { + p.Close(ctx) + }() + serverName := "srvA" + c, _ := p.Borrow(ctx, getServers([]string{serverName}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) + c.(*ConnFake).AdvertisedName = advertisedServerName + p.Return(ctx, c) + servers := p.getServers() + + fmt.Printf("%v\n", servers["foo"]) + if len(servers) != 1 { + t.Errorf("Expected only 1 server, but %v were found", len(servers)) + } + serv, exists := servers[serverName] + if !exists { + t.Errorf("Expected connection not to be transferred to %s, but server was not found", serverName) + } else if serv.numIdle() != 1 { + t.Errorf("Expected 1 idle connection in %s, found %d", serverName, servers[serverName].numIdle()) + } + }) } // Resource usage scenarios @@ -323,7 +390,7 @@ func TestPoolResourceUsage(ot *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) defer func() { p.Close(ctx) }() @@ -338,7 +405,7 @@ func TestPoolResourceUsage(ot *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 2} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) defer func() { p.Close(ctx) }() @@ -356,7 +423,7 @@ func TestPoolResourceUsage(ot *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 2} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) defer func() { p.Close(ctx) }() @@ -374,7 +441,7 @@ func TestPoolResourceUsage(ot *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: 1<<63 - 1, MaxConnectionPoolSize: 3} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) // Trigger creation of three connections on the same server c1, _ := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) c2, _ := p.Borrow(ctx, getServers([]string{"A"}), true, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) @@ -402,7 +469,7 @@ func TestPoolResourceUsage(ot *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) defer func() { p.Close(ctx) }() @@ -423,7 +490,7 @@ func TestPoolResourceUsage(ot *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxAge, MaxConnectionPoolSize: 1} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) defer func() { p.Close(ctx) }() @@ -455,7 +522,7 @@ func TestPoolCleanup(ot *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 0} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) defer func() { p.Close(ctx) }() @@ -476,7 +543,7 @@ func TestPoolCleanup(ot *testing.T) { itime.ForceFreezeTime() defer itime.ForceUnfreezeTime() conf := config.Config{MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 0} - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) defer func() { p.Close(ctx) }() @@ -499,7 +566,7 @@ func TestPoolCleanup(ot *testing.T) { return nil, errors.New("an error") } conf := config.Config{MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 0} - p := New(&conf, failingConnect, logger, "pool id") + p := newPool(&conf, failingConnect) p.SetRouter(&RouterFake{}) defer func() { p.Close(ctx) @@ -529,7 +596,7 @@ func TestPoolCleanup(ot *testing.T) { MaxConnectionLifetime: maxLife, MaxConnectionPoolSize: 1, } - p := New(&conf, succeedingConnect, logger, "pool id") + p := newPool(&conf, succeedingConnect) servers := getServers([]string{"example.com"}) conn, err := p.Borrow(ctx, servers, false, nil, DefaultConnectionLivenessCheckTimeout, reAuthToken) assertConnection(t, conn, err) diff --git a/neo4j/internal/router/router.go b/neo4j/internal/router/router.go index a257221f..2dcbe1b7 100644 --- a/neo4j/internal/router/router.go +++ b/neo4j/internal/router/router.go @@ -81,6 +81,10 @@ func New(rootRouter string, getRouters func() []string, routerContext map[string return r } +func (r *Router) IsMultiServer() bool { + return true +} + func (r *Router) readTable( ctx context.Context, dbRouter *databaseRouter, diff --git a/neo4j/internal/testutil/connfake.go b/neo4j/internal/testutil/connfake.go index 249055a9..d37d6ed0 100644 --- a/neo4j/internal/testutil/connfake.go +++ b/neo4j/internal/testutil/connfake.go @@ -45,6 +45,7 @@ type RecordedTx struct { type ConnFake struct { Name string + AdvertisedName string ConnectionVersion db.ProtocolVersion Alive bool Birth time.Time @@ -90,6 +91,14 @@ func (c *ConnFake) ServerName() string { return c.Name } +func (c *ConnFake) AdvertisedServerName() string { + return c.AdvertisedName +} + +func (c *ConnFake) SetServerName(serverName string) { + c.Name = serverName +} + func (c *ConnFake) IsAlive() bool { return c.Alive } diff --git a/neo4j/internal/testutil/routerfake.go b/neo4j/internal/testutil/routerfake.go index 2a17c74b..05dfba39 100644 --- a/neo4j/internal/testutil/routerfake.go +++ b/neo4j/internal/testutil/routerfake.go @@ -35,6 +35,11 @@ type RouterFake struct { Err error CleanUpHook func() GetNameOfDefaultDbHook func(user string) (string, error) + IsMultiServerReturn bool +} + +func (r *RouterFake) IsMultiServer() bool { + return r.IsMultiServerReturn } func (r *RouterFake) InvalidateReader(database string, server string) { diff --git a/testkit-backend/backend.go b/testkit-backend/backend.go index 9106674d..ba7b0a34 100644 --- a/testkit-backend/backend.go +++ b/testkit-backend/backend.go @@ -52,6 +52,7 @@ type backend struct { explicitTransactions map[string]neo4j.ExplicitTransaction recordedErrors map[string]error resolvedAddresses map[string][]any + dnsResolutions map[string][]any authTokenManagers map[string]auth.TokenManager resolvedGetAuthTokens map[string]neo4j.AuthToken resolvedHandleSecurityException map[string]bool @@ -148,6 +149,7 @@ func newBackend(rd *bufio.Reader, wr io.Writer) *backend { explicitTransactions: make(map[string]neo4j.ExplicitTransaction), recordedErrors: make(map[string]error), resolvedAddresses: make(map[string][]any), + dnsResolutions: make(map[string][]any), authTokenManagers: make(map[string]auth.TokenManager), resolvedGetAuthTokens: make(map[string]neo4j.AuthToken), resolvedHandleSecurityException: make(map[string]bool), @@ -496,6 +498,27 @@ func (b *backend) customAddressResolverFunction() config.ServerAddressResolver { } } +func (b *backend) dnsResolverFunction() func(address string) []string { + return func(address string) []string { + id := b.nextId() + b.writeResponse("DomainNameResolutionRequired", map[string]string{ + "id": id, + "name": address, + }) + for { + b.process() + if addresses, ok := b.dnsResolutions[id]; ok { + delete(b.dnsResolutions, id) + result := make([]string, len(addresses)) + for i, address := range addresses { + result[i] = address.(string) + } + return result + } + } + } +} + type serverAddress struct { hostname string port string @@ -533,6 +556,11 @@ func (b *backend) handleRequest(req map[string]any) { fmt.Printf("REQ: %s %s\n", name, dataJson) switch name { + case "DomainNameResolutionCompleted": + requestId := data["requestId"].(string) + addresses := data["addresses"].([]any) + b.dnsResolutions[requestId] = addresses + case "ResolverResolutionCompleted": requestId := data["requestId"].(string) addresses := data["addresses"].([]any) @@ -637,6 +665,11 @@ func (b *backend) handleRequest(req map[string]any) { b.writeError(err) return } + + if data["domainNameResolverRegistered"] != nil && data["domainNameResolverRegistered"].(bool) { + neo4j.RegisterDnsResolver(driver, b.dnsResolverFunction()) + } + idKey := b.nextId() b.drivers[idKey] = driver b.writeResponse("Driver", map[string]any{"id": idKey}) @@ -1302,6 +1335,7 @@ func (b *backend) handleRequest(req map[string]any) { "Feature:Bolt:5.5", "Feature:Bolt:5.6", "Feature:Bolt:5.7", + "Feature:Bolt:5.8", "Feature:Bolt:Patch:UTC", "Feature:Impersonation", //"Feature:TLS:1.1", @@ -1329,6 +1363,7 @@ func (b *backend) handleRequest(req map[string]any) { "ConfHint:connection.recv_timeout_seconds", // === BACKEND FEATURES FOR TESTING === + "Backend:DNSResolver", "Backend:MockTime", "Backend:RTFetch", "Backend:RTForceUpdate", @@ -1691,10 +1726,6 @@ func testSkips() map[string]string { "stub.routing.test_routing_v3.RoutingV3.test_should_fail_when_writing_on_unexpectedly_interrupting_writer_on_run_using_tx_run": "Won't fix - only Bolt 3 affected (not officially supported by this driver): broken servers are not removed from routing table", "stub.routing.test_routing_v3.RoutingV3.test_should_fail_when_writing_on_unexpectedly_interrupting_writer_using_tx_run": "Won't fix - only Bolt 3 affected (not officially supported by this driver): broken servers are not removed from routing table", - // Missing message support in testkit backend - "stub.routing.*.*.test_should_request_rt_from_all_initial_routers_until_successful_on_unknown_failure": "Add DNS resolver TestKit message and connection timeout support", - "stub.routing.*.*.test_should_request_rt_from_all_initial_routers_until_successful_on_authorization_expired": "Add DNS resolver TestKit message and connection timeout support", - // To fix/to decide whether to fix "stub.routing.test_routing_v*.RoutingV*.test_should_revert_to_initial_router_if_known_router_throws_protocol_errors": "Driver always uses configured URL first and custom resolver only if that fails", "stub.routing.test_routing_v*.RoutingV*.test_should_read_successfully_from_reachable_db_after_trying_unreachable_db": "Driver retries to fetch a routing table up to 100 times if it's empty", diff --git a/testkit/testkit.json b/testkit/testkit.json index 93190035..307f1091 100644 --- a/testkit/testkit.json +++ b/testkit/testkit.json @@ -1,6 +1,6 @@ { "testkit": { "uri": "https://github.com/neo4j-drivers/testkit.git", - "ref": "5.0" + "ref": "advertised-address" } }