Skip to content

Commit 609aefe

Browse files
committedAug 23, 2023
add Queryable interface
1 parent 06d07f4 commit 609aefe

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed
 

‎core.go

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package sqlx
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
)
7+
8+
var (
9+
_ Queryable = (*DB)(nil)
10+
_ Queryable = (*Tx)(nil)
11+
)
12+
13+
// Queryable includes all methods shared by sqlx.DB and sqlx.Tx, allowing
14+
// either type to be used interchangeably.
15+
type Queryable interface {
16+
Ext
17+
ExecIn
18+
QueryIn
19+
ExecerContext
20+
PreparerContext
21+
QueryerContext
22+
Preparer
23+
24+
GetContext(context.Context, interface{}, string, ...interface{}) error
25+
SelectContext(context.Context, interface{}, string, ...interface{}) error
26+
Get(interface{}, string, ...interface{}) error
27+
MustExecContext(context.Context, string, ...interface{}) sql.Result
28+
PreparexContext(context.Context, string) (*Stmt, error)
29+
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
30+
Select(interface{}, string, ...interface{}) error
31+
QueryRow(string, ...interface{}) *sql.Row
32+
PrepareNamedContext(context.Context, string) (*NamedStmt, error)
33+
PrepareNamed(string) (*NamedStmt, error)
34+
Preparex(string) (*Stmt, error)
35+
NamedExec(string, interface{}) (sql.Result, error)
36+
NamedExecContext(context.Context, string, interface{}) (sql.Result, error)
37+
MustExec(string, ...interface{}) sql.Result
38+
NamedQuery(string, interface{}) (*Rows, error)
39+
InGet(any, string, ...any) error
40+
InSelect(any, string, ...any) error
41+
InExec(query string, args ...any) (sql.Result, error)
42+
MustInExec(string, ...any) sql.Result
43+
}

‎core_test.go

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package sqlx
2+
3+
import (
4+
"database/sql"
5+
"reflect"
6+
"testing"
7+
)
8+
9+
func TestQueryable(t *testing.T) {
10+
sqlDBType := reflect.TypeOf(&sql.DB{})
11+
dbType := reflect.TypeOf(&DB{})
12+
sqlTxType := reflect.TypeOf(&sql.Tx{})
13+
txType := reflect.TypeOf(&Tx{})
14+
15+
dbMethods := exportableMethods(sqlDBType)
16+
for k, v := range exportableMethods(dbType) {
17+
dbMethods[k] = v
18+
}
19+
20+
txMethods := exportableMethods(sqlTxType)
21+
for k, v := range exportableMethods(txType) {
22+
txMethods[k] = v
23+
}
24+
25+
sharedMethods := make([]string, 0)
26+
27+
for name, dbMethod := range dbMethods {
28+
if txMethod, ok := txMethods[name]; ok {
29+
if methodsEqual(dbMethod.Type, txMethod.Type) {
30+
sharedMethods = append(sharedMethods, name)
31+
}
32+
}
33+
}
34+
35+
queryableType := reflect.TypeOf((*Queryable)(nil)).Elem()
36+
queryableMethods := exportableMethods(queryableType)
37+
38+
for _, sharedMethodName := range sharedMethods {
39+
if _, ok := queryableMethods[sharedMethodName]; !ok {
40+
t.Errorf("Queryable does not include shared DB/Tx method: %s", sharedMethodName)
41+
}
42+
}
43+
}
44+
45+
func exportableMethods(t reflect.Type) map[string]reflect.Method {
46+
methods := make(map[string]reflect.Method)
47+
48+
for i := 0; i < t.NumMethod(); i++ {
49+
method := t.Method(i)
50+
51+
if method.IsExported() {
52+
methods[method.Name] = method
53+
}
54+
}
55+
56+
return methods
57+
}
58+
59+
func methodsEqual(t reflect.Type, ot reflect.Type) bool {
60+
if t.NumIn() != ot.NumIn() || t.NumOut() != ot.NumOut() || t.IsVariadic() != ot.IsVariadic() {
61+
return false
62+
}
63+
64+
// Start at 1 to avoid comparing receiver argument
65+
for i := 1; i < t.NumIn(); i++ {
66+
if t.In(i) != ot.In(i) {
67+
return false
68+
}
69+
}
70+
71+
for i := 0; i < t.NumOut(); i++ {
72+
if t.Out(i) != ot.Out(i) {
73+
return false
74+
}
75+
}
76+
77+
return true
78+
}

0 commit comments

Comments
 (0)