diff --git a/client/v3/concurrency/session.go b/client/v3/concurrency/session.go index 2275e96c972c..d0127170057b 100644 --- a/client/v3/concurrency/session.go +++ b/client/v3/concurrency/session.go @@ -40,14 +40,25 @@ type Session struct { // NewSession gets the leased session for a client. func NewSession(client *v3.Client, opts ...SessionOption) (*Session, error) { lg := client.GetLogger() - ops := &sessionOptions{ttl: defaultSessionTTL, ctx: client.Ctx()} + ops := &sessionOptions{ + ttl: defaultSessionTTL, + ctx: client.Ctx(), + sessionCreationTimeout: 0, + } for _, opt := range opts { opt(ops, lg) } + var cancel context.CancelFunc + sessionCreationCtx := ops.ctx + if ops.sessionCreationTimeout > 0 { + clientDeadline := time.Now().Add(time.Duration(ops.sessionCreationTimeout) * time.Millisecond) + sessionCreationCtx, cancel = context.WithDeadline(ops.ctx, clientDeadline) + } + id := ops.leaseID if id == v3.NoLease { - resp, err := client.Grant(ops.ctx, int64(ops.ttl)) + resp, err := client.Grant(sessionCreationCtx, int64(ops.ttl)) if err != nil { return nil, err } @@ -115,9 +126,10 @@ func (s *Session) Close() error { } type sessionOptions struct { - ttl int - leaseID v3.LeaseID - ctx context.Context + ttl int + leaseID v3.LeaseID + ctx context.Context + sessionCreationTimeout int } // SessionOption configures Session. @@ -135,6 +147,18 @@ func WithTTL(ttl int) SessionOption { } } +// WithSessionCreationTimeout configures the timeout for creating a new session +// in milliseconds. If timeout is <= 0, no timeout will be used. +func WithSessionCreationTimeout(timeout int) SessionOption { + return func(so *sessionOptions, lg *zap.Logger) { + if timeout > 0 { + so.sessionCreationTimeout = timeout + } else { + lg.Warn("WithSessionCreationTimeout(): timeout should be > 0, preserving current timeout", zap.Int64("current-session-timeout", int64(so.sessionCreationTimeout))) + } + } +} + // WithLease specifies the existing leaseID to be used for the session. // This is useful in process restart scenario, for example, to reclaim // leadership from an election prior to restart. diff --git a/tests/integration/clientv3/concurrency/session_test.go b/tests/integration/clientv3/concurrency/session_test.go index b17991179751..0161cb916f52 100644 --- a/tests/integration/clientv3/concurrency/session_test.go +++ b/tests/integration/clientv3/concurrency/session_test.go @@ -23,6 +23,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/client/v3/concurrency" integration2 "go.etcd.io/etcd/tests/v3/framework/integration" + clientv3test "go.etcd.io/etcd/tests/v3/integration/clientv3" ) func TestSessionOptions(t *testing.T) { @@ -111,3 +112,82 @@ func TestSessionCtx(t *testing.T) { } assert.Equal(t, childCtx.Err(), context.Canceled) } + +// TestSessionCreationTimeout checks that the option WithSessionCreationTimeout +// sets a timeout for the creation of new sessions +func TestSessionCreationTimeout(t *testing.T) { + integration2.BeforeTest(t) + + // create new cluster + clus := integration2.NewCluster(t, &integration2.ClusterConfig{ + Size: 3, + }) + + eps := []string{clus.Members[0].GRPCURL(), clus.Members[1].GRPCURL(), clus.Members[2].GRPCURL()} + lead := clus.WaitLeader(t) + + // create new client + cli, err := integration2.NewClient(t, clientv3.Config{Endpoints: []string{eps[lead]}}) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + + // wait for eps[lead] to be pinned + clientv3test.MustWaitPinReady(t, cli) + + // add all eps to list, so that when the original pined one fails + // the client can switch to other available eps + cli.SetEndpoints(eps...) + + clus.Terminate(t) + + _, err = concurrency.NewSession(cli, concurrency.WithSessionCreationTimeout(50)) + assert.Equal(t, err, context.DeadlineExceeded) +} + +// TestTimeoutDoesntAffectSubsequentConnections checks that the option WithSessionCreationTimeout +// does not set the timeout of subsequent connections to the server. This means after successful +// session creation, if servers are unavailable, requests would be blocked waiting for the +// cluster to recover +func TestTimeoutDoesntAffectSubsequentConnections(t *testing.T) { + integration2.BeforeTest(t) + + // create new cluster + clus := integration2.NewCluster(t, &integration2.ClusterConfig{ + Size: 3, + }) + + eps := []string{clus.Members[0].GRPCURL(), clus.Members[1].GRPCURL(), clus.Members[2].GRPCURL()} + lead := clus.WaitLeader(t) + + // create new client + cli, err := integration2.NewClient(t, clientv3.Config{Endpoints: []string{eps[lead]}}) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + + // wait for eps[lead] to be pinned + clientv3test.MustWaitPinReady(t, cli) + + // add all eps to list, so that when the original pined one fails + // the client can switch to other available eps + cli.SetEndpoints(eps...) + + s, err := concurrency.NewSession(cli, concurrency.WithSessionCreationTimeout(50)) + if err != nil { + t.Fatal(err) + } + defer s.Close() + + clus.Terminate(t) + key := "sample_key" + value := "sample_value" + _, err = cli.Put(s.Ctx(), key, value, clientv3.WithLease(s.Lease())) + + // assert that the request is cancelled because the test case timed out + // from the infinite hanging of the request, instead of the request + // itself timing out + assert.Equal(t, err, context.Canceled) +}