@@ -17,15 +17,15 @@ package crdb
17
17
import (
18
18
"context"
19
19
"database/sql"
20
+ "errors"
20
21
"fmt"
21
22
"testing"
22
23
23
24
"github.com/cockroachdb/cockroach-go/v2/testserver"
24
25
)
25
26
26
27
// TestExecuteCtx verifies that ExecuteCtx correctly handles different retry limits
27
- // when executing database operations. It tests both successful operations and
28
- // retry behavior.
28
+ // and context cancellation when executing database operations.
29
29
//
30
30
// TODO(seanc@): Add test cases that force retryable errors by simulating
31
31
// transaction conflicts or network failures. Consider using the same write skew
@@ -44,20 +44,43 @@ func TestExecuteCtx(t *testing.T) {
44
44
name string
45
45
maxRetries int
46
46
id int
47
+ withCancel bool
48
+ wantErr error
47
49
}{
48
- {"no retries" , 0 , 0 },
49
- {"single retry" , 1 , 1 },
50
+ {"no retries" , 0 , 0 , false , nil },
51
+ {"single retry" , 1 , 1 , false , nil },
52
+ {"cancelled context" , 1 , 2 , true , context .Canceled },
53
+ {"no args" , 1 , 3 , false , nil },
54
+ }
55
+
56
+ fn := func (ctx context.Context , args ... interface {}) error {
57
+ if len (args ) == 0 {
58
+ _ , err := db .ExecContext (ctx , `INSERT INTO test_retry VALUES (3)` )
59
+ return err
60
+ }
61
+ id := args [0 ].(int )
62
+ _ , err := db .ExecContext (ctx , `INSERT INTO test_retry VALUES ($1)` , id )
63
+ return err
50
64
}
51
65
52
66
for _ , tc := range testCases {
53
67
t .Run (tc .name , func (t * testing.T ) {
54
68
limitedCtx := WithMaxRetries (ctx , tc .maxRetries )
55
- err := ExecuteCtx (limitedCtx , func (ctx context.Context ) error {
56
- _ , err := db .ExecContext (ctx , `INSERT INTO test_retry VALUES ($1)` , tc .id )
57
- return err
58
- })
59
- if err != nil {
60
- t .Errorf ("expected success with retry limit %d, got: %v" , tc .maxRetries , err )
69
+ if tc .withCancel {
70
+ var cancel context.CancelFunc
71
+ limitedCtx , cancel = context .WithCancel (limitedCtx )
72
+ cancel ()
73
+ }
74
+
75
+ var err error
76
+ if tc .name == "no args" {
77
+ err = ExecuteCtx (limitedCtx , fn )
78
+ } else {
79
+ err = ExecuteCtx (limitedCtx , fn , tc .id )
80
+ }
81
+
82
+ if ! errors .Is (err , tc .wantErr ) {
83
+ t .Errorf ("got error %v, want %v" , err , tc .wantErr )
61
84
}
62
85
})
63
86
}
0 commit comments