Skip to content

Commit 020f38a

Browse files
committed
Refactor to make queries more like database/sql
1 parent 4c57a06 commit 020f38a

File tree

7 files changed

+187
-116
lines changed

7 files changed

+187
-116
lines changed

README.md

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,14 @@ func main() {
5353
Name: "example",
5454
}
5555

56-
_, err = db.Exec(ctx, postgres.NewQueryParams(
57-
`INSERT INTO test_table (id, name) VALUES ($1, $2)`, data.Id, data.Name),
58-
)
56+
_, err = db.Exec(ctx, `INSERT INTO test_table (id, name) VALUES ($1, $2)`, data.Id, data.Name)
5957
if err != nil {
6058
// ....
6159
}
6260

6361
// Read it
6462
var name string
65-
err = db.QueryRow(ctx, postgres.NewQueryParams(
66-
`SELECT name FROM test_table WHERE id = $1)`, 1),
67-
&name,
68-
)
63+
err = db.QueryRow(ctx, `SELECT name FROM test_table WHERE id = $1)`, 1).Scan(&name)
6964
if err != nil {
7065
// ....
7166
if postgres.IsNoRowsError(err) {
@@ -77,7 +72,6 @@ func main() {
7772
}
7873
```
7974

80-
8175
## LICENSE
8276

8377
See `LICENSE` file for details.

postgres.go

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
// -----------------------------------------------------------------------------
1515

1616
type WithinTxCallback = func(ctx context.Context, tx Tx) error
17-
type QueryRowsCallback = func(ctx context.Context, rows RowGetter) (bool, error)
1817
type CopyCallback func(ctx context.Context, idx int) ([]interface{}, error)
1918

2019
// -----------------------------------------------------------------------------
@@ -122,8 +121,8 @@ func (db *Database) SetEventHandler(handler ErrorHandler) {
122121
}
123122

124123
// Exec executes an SQL statement on a new connection
125-
func (db *Database) Exec(ctx context.Context, params QueryParams) (int64, error) {
126-
ct, err := db.pool.Exec(ctx, params.sql, params.args...)
124+
func (db *Database) Exec(ctx context.Context, sql string, args ...interface{}) (int64, error) {
125+
ct, err := db.pool.Exec(ctx, sql, args...)
127126
return ct.RowsAffected(), db.processError(err)
128127
}
129128

@@ -138,34 +137,22 @@ func (db *Database) Exec(ctx context.Context, params QueryParams) (int64, error)
138137
// the field in the query.
139138
// 3. To avoid overflows on high uint64 values, store them in NUMERIC(24,0) fields.
140139
// 4. For time-only fields, date is set to Jan 1, 2000 by PGX in time.Time variables.
141-
func (db *Database) QueryRow(ctx context.Context, params QueryParams, dest ...interface{}) error {
142-
row := db.pool.QueryRow(ctx, params.sql, params.args...)
143-
err := row.Scan(dest...)
144-
return db.processError(err)
140+
func (db *Database) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
141+
return rowGetter{
142+
db: db,
143+
row: db.pool.QueryRow(ctx, sql, args...),
144+
}
145145
}
146146

147147
// QueryRows executes a SQL query on a new connection
148-
func (db *Database) QueryRows(ctx context.Context, params QueryParams, callback QueryRowsCallback) error {
149-
rows, err := db.pool.Query(ctx, params.sql, params.args...)
150-
if err == nil {
151-
// Scan returned rows
152-
rg := rowsGetter{
153-
db: db,
154-
rows: rows,
155-
}
156-
for rows.Next() {
157-
var cont bool
158-
159-
cont, err = callback(ctx, rg)
160-
if err != nil || (!cont) {
161-
break
162-
}
163-
}
164-
rows.Close()
148+
func (db *Database) QueryRows(ctx context.Context, sql string, args ...interface{}) Rows {
149+
rows, err := db.pool.Query(ctx, sql, args...)
150+
return rowsGetter{
151+
db: db,
152+
ctx: ctx,
153+
rows: rows,
154+
err: err,
165155
}
166-
167-
// Done
168-
return db.processError(err)
169156
}
170157

171158
// Copy executes a SQL copy query within the transaction.

postgres_test.go

Lines changed: 105 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,13 @@ func TestPostgres(t *testing.T) {
143143
t.Fatalf("%v", err.Error())
144144
}
145145

146+
t.Log("Reading test data (multi-row)")
147+
err = readMultiTestData(ctx, db)
148+
if err != nil {
149+
db.Close()
150+
t.Fatalf("%v", err.Error())
151+
}
152+
146153
db.Close()
147154
}
148155

@@ -168,13 +175,13 @@ func checkSettings(t *testing.T) {
168175

169176
func createTestTable(ctx context.Context, db *postgres.Database) error {
170177
// Destroy old test table if exists
171-
_, err := db.Exec(ctx, postgres.NewQueryParams(`DROP TABLE IF EXISTS go_postgres_test_table CASCADE`))
178+
_, err := db.Exec(ctx, `DROP TABLE IF EXISTS go_postgres_test_table CASCADE`)
172179
if err != nil {
173180
return fmt.Errorf("Unable to drop tables [err=%v]", err.Error())
174181
}
175182

176183
// Create the test table
177-
_, err = db.Exec(ctx, postgres.NewQueryParams(`CREATE TABLE go_postgres_test_table (
184+
_, err = db.Exec(ctx, `CREATE TABLE go_postgres_test_table (
178185
id INT NOT NULL,
179186
num NUMERIC(24, 0) NULL,
180187
sm SMALLINT NULL,
@@ -192,7 +199,7 @@ func createTestTable(ctx context.Context, db *postgres.Database) error {
192199
js JSONB NULL,
193200
194201
PRIMARY KEY (id)
195-
)`))
202+
)`)
196203
if err != nil {
197204
return fmt.Errorf("Unable to create test table [err=%v]", err.Error())
198205
}
@@ -202,7 +209,6 @@ func createTestTable(ctx context.Context, db *postgres.Database) error {
202209
}
203210

204211
func insertTestData(ctx context.Context, db *postgres.Database) error {
205-
206212
return db.WithinTx(ctx, func(ctx context.Context, tx postgres.Tx) error {
207213
for idx := 1; idx <= 2; idx++ {
208214
rd := genTestRowDef(idx, true)
@@ -220,21 +226,54 @@ func insertTestData(ctx context.Context, db *postgres.Database) error {
220226
// Done
221227
return nil
222228
})
223-
224229
}
225230

226231
func readTestData(ctx context.Context, db *postgres.Database) error {
227232
for idx := 1; idx <= 2; idx++ {
228-
rd := genTestRowDef(idx, false)
229-
err := readTestRowDef(ctx, db, rd)
233+
compareRd := genTestRowDef(idx, false)
234+
rd, err := readTestRowDef(ctx, db, compareRd.id)
230235
if err != nil {
231-
return fmt.Errorf("Unable to verify test data [id=%v/err=%v]", rd.id, err.Error())
236+
return fmt.Errorf("Unable to verify test data [id=%v/err=%v]", compareRd.id, err.Error())
237+
}
238+
// Do deep comparison
239+
if !reflect.DeepEqual(compareRd, rd) {
240+
return errors.New("data mismatch")
232241
}
233242

234-
nrd := genTestNullableRowDef(idx, false)
235-
err = readTestNullableRowDef(ctx, db, nrd)
243+
compareNrd := genTestNullableRowDef(idx, false)
244+
nrd, err := readTestNullableRowDef(ctx, db, compareNrd.id)
236245
if err != nil {
237-
return fmt.Errorf("Unable to verify test data [id=%v/err=%v]", nrd.id, err.Error())
246+
return fmt.Errorf("Unable to verify test data [id=%v/err=%v]", compareNrd.id, err.Error())
247+
}
248+
249+
// Do deep comparison
250+
if !reflect.DeepEqual(compareNrd, nrd) {
251+
return fmt.Errorf("Data mismatch while comparing test data [id=%v]", compareNrd.id)
252+
}
253+
}
254+
255+
// Done
256+
return nil
257+
}
258+
259+
func readMultiTestData(ctx context.Context, db *postgres.Database) error {
260+
compareRd := make([]TestRowDef, 0)
261+
for idx := 1; idx <= 2; idx++ {
262+
compareRd = append(compareRd, genTestRowDef(idx, false))
263+
}
264+
rd, err := readMultiTestRowDef(ctx, db, compareRd)
265+
if err != nil {
266+
return fmt.Errorf("Unable to verify test data [err=%v]", err.Error())
267+
}
268+
269+
// Do deep comparison
270+
if len(compareRd) != len(rd) {
271+
return fmt.Errorf("Data mismatch while comparing test data [len1=%d/len2=%d]", len(compareRd), len(rd))
272+
}
273+
274+
for idx := 0; idx < len(rd); idx++ {
275+
if !reflect.DeepEqual(compareRd[idx], rd[idx]) {
276+
return fmt.Errorf("Data mismatch while comparing test data [id=%v]", compareRd[idx].id)
238277
}
239278
}
240279

@@ -348,53 +387,89 @@ func genTestNullableRowDef(index int, write bool) TestNullableRowDef {
348387
}
349388

350389
func insertTestRowDef(ctx context.Context, tx postgres.Tx, rd TestRowDef) error {
351-
_, err := tx.Exec(ctx, postgres.NewQueryParams(`
390+
_, err := tx.Exec(ctx, `
352391
INSERT INTO go_postgres_test_table (
353392
id, num, sm, bi, bi2, dbl, va, chr, txt, blob, ts, dt, tim, b, js
354393
) VALUES (
355394
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
356395
)
357396
`,
358397
rd.id, rd.num, rd.sm, rd.bi, rd.bi2, rd.dbl, rd.va, rd.chr, rd.txt, rd.blob, rd.ts, rd.dt, rd.tim, rd.b, rd.js,
359-
))
398+
)
360399
return err
361400
}
362401

363-
func readTestRowDef(ctx context.Context, db *postgres.Database, compareRd TestRowDef) error {
402+
func readTestRowDef(ctx context.Context, db *postgres.Database, id int) (TestRowDef, error) {
364403
rd := TestRowDef{}
365-
err := db.QueryRow(ctx, postgres.NewQueryParams(`
404+
err := db.QueryRow(ctx, `
366405
SELECT
367406
id, num, sm, bi, bi2, dbl, va, chr, txt, blob, ts, dt, tim, b, js
368407
FROM
369408
go_postgres_test_table
370409
WHERE
371410
id = $1
372-
`, compareRd.id),
411+
`, id).Scan(
373412
&rd.id, &rd.num, &rd.sm, &rd.bi, &rd.bi2, &rd.dbl, &rd.va, &rd.chr, &rd.txt, &rd.blob, &rd.ts, &rd.dt, &rd.tim,
374413
&rd.b, &rd.js,
375414
)
376415
if err != nil {
377-
return err
416+
return TestRowDef{}, err
378417
}
379418

380419
// JSON data returned by Postgres can contain spaces and other encoding so re-encode the returned string
381420
// for comparison
382421
rd.js, err = jsonReEncode(rd.js)
383422
if err != nil {
384-
return err
423+
return TestRowDef{}, err
385424
}
386425

387-
// Do deep comparison
388-
if !reflect.DeepEqual(compareRd, rd) {
389-
return errors.New("data mismatch")
426+
// Done
427+
return rd, nil
428+
}
429+
430+
func readMultiTestRowDef(ctx context.Context, db *postgres.Database, compareRd []TestRowDef) ([]TestRowDef, error) {
431+
// Populate ids
432+
ids := make([]int, len(compareRd))
433+
for idx := 0; idx < len(compareRd); idx++ {
434+
ids[idx] = compareRd[idx].id
435+
}
436+
437+
rd := make([]TestRowDef, 0)
438+
err := db.QueryRows(ctx, `
439+
SELECT
440+
id, num, sm, bi, bi2, dbl, va, chr, txt, blob, ts, dt, tim, b, js
441+
FROM
442+
go_postgres_test_table
443+
WHERE
444+
id = ANY($1)
445+
`, ids).Do(func(ctx context.Context, row postgres.Row) (bool, error) {
446+
item := TestRowDef{}
447+
err := row.Scan(&item.id, &item.num, &item.sm, &item.bi, &item.bi2, &item.dbl, &item.va, &item.chr, &item.txt,
448+
&item.blob, &item.ts, &item.dt, &item.tim, &item.b, &item.js)
449+
if err == nil {
450+
rd = append(rd, item)
451+
}
452+
return true, err
453+
})
454+
if err != nil {
455+
return nil, err
456+
}
457+
458+
// JSON data returned by Postgres can contain spaces and other encoding so re-encode the returned string
459+
// for comparison
460+
for idx := range rd {
461+
rd[idx].js, err = jsonReEncode(rd[idx].js)
462+
if err != nil {
463+
return nil, err
464+
}
390465
}
391466

392467
// Done
393-
return nil
468+
return rd, nil
394469
}
395470

396471
func insertTestNullableRowDef(ctx context.Context, tx postgres.Tx, nrd TestNullableRowDef) error {
397-
_, err := tx.Exec(ctx, postgres.NewQueryParams(`
472+
_, err := tx.Exec(ctx, `
398473
INSERT INTO go_postgres_test_table (
399474
id, num, sm, bi, bi2, dbl, va, chr, txt, blob, ts, dt, tim, b, js
400475
) VALUES (
@@ -403,25 +478,25 @@ func insertTestNullableRowDef(ctx context.Context, tx postgres.Tx, nrd TestNulla
403478
`,
404479
nrd.id, nrd.num, nrd.sm, nrd.bi, nrd.bi2, nrd.dbl, nrd.va, nrd.chr, nrd.txt, nrd.blob, nrd.ts, nrd.dt, nrd.tim,
405480
nrd.b, nrd.js,
406-
))
481+
)
407482
return err
408483
}
409484

410-
func readTestNullableRowDef(ctx context.Context, db *postgres.Database, compareNrd TestNullableRowDef) error {
485+
func readTestNullableRowDef(ctx context.Context, db *postgres.Database, id int) (TestNullableRowDef, error) {
411486
nrd := TestNullableRowDef{}
412-
err := db.QueryRow(ctx, postgres.NewQueryParams(`
487+
err := db.QueryRow(ctx, `
413488
SELECT
414489
id, num, sm, bi, bi2, dbl, va, chr, txt, blob, ts, dt, tim, b, js::text
415490
FROM
416491
go_postgres_test_table
417492
WHERE
418493
id = $1
419-
`, compareNrd.id),
494+
`, id).Scan(
420495
&nrd.id, &nrd.num, &nrd.sm, &nrd.bi, &nrd.bi2, &nrd.dbl, &nrd.va, &nrd.chr, &nrd.txt, &nrd.blob, &nrd.ts,
421496
&nrd.dt, &nrd.tim, &nrd.b, &nrd.js,
422497
)
423498
if err != nil {
424-
return err
499+
return TestNullableRowDef{}, err
425500
}
426501

427502
// JSON data returned by Postgres can contain spaces and other encoding so re-encode the returned string
@@ -431,18 +506,13 @@ func readTestNullableRowDef(ctx context.Context, db *postgres.Database, compareN
431506

432507
js, err = jsonReEncode(*nrd.js)
433508
if err != nil {
434-
return err
509+
return TestNullableRowDef{}, err
435510
}
436511
nrd.js = &js
437512
}
438513

439-
// Do deep comparison
440-
if !reflect.DeepEqual(compareNrd, nrd) {
441-
return errors.New("data mismatch")
442-
}
443-
444514
// Done
445-
return nil
515+
return nrd, nil
446516
}
447517

448518
func addressOf[T any](x T) *T {

queryparams.go

Lines changed: 0 additions & 17 deletions
This file was deleted.

0 commit comments

Comments
 (0)