Skip to content

Commit f47d004

Browse files
authored
Merge pull request #7 from randlabs/improvements
Improvements and new features
2 parents fa37e31 + 4438472 commit f47d004

File tree

11 files changed

+950
-146
lines changed

11 files changed

+950
-146
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ app
2323
demo
2424

2525
vendor/*
26+
qodana.yml

common_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package postgres_test
2+
3+
import (
4+
"crypto/rand"
5+
"encoding/json"
6+
"flag"
7+
"testing"
8+
)
9+
10+
// -----------------------------------------------------------------------------
11+
12+
var (
13+
pgUrl string
14+
pgHost string
15+
pgPort uint
16+
pgUsername string
17+
pgPassword string
18+
pgDatabaseName string
19+
)
20+
21+
var (
22+
testJSON TestJSON
23+
testBLOB []byte
24+
testJSONBytes []byte
25+
)
26+
27+
// -----------------------------------------------------------------------------
28+
29+
func init() {
30+
flag.StringVar(&pgUrl, "url", "", "Specifies the Postgres URL.")
31+
flag.StringVar(&pgHost, "host", "127.0.0.1", "Specifies the Postgres server host. (Defaults to '127.0.0.1')")
32+
flag.UintVar(&pgPort, "port", 5432, "Specifies the Postgres server port. (Defaults to 5432)")
33+
flag.StringVar(&pgUsername, "user", "postgres", "Specifies the user name. (Defaults to 'postgres')")
34+
flag.StringVar(&pgPassword, "password", "", "Specifies the user password.")
35+
flag.StringVar(&pgDatabaseName, "db", "", "Specifies the database name.")
36+
37+
testJSON = TestJSON{
38+
Id: 1,
39+
Text: "demo",
40+
}
41+
42+
testBLOB = make([]byte, 1024)
43+
_, _ = rand.Read(testBLOB)
44+
45+
testJSONBytes, _ = json.Marshal(testJSON)
46+
}
47+
48+
// -----------------------------------------------------------------------------
49+
50+
func checkSettings(t *testing.T) {
51+
if len(pgHost) == 0 {
52+
t.Fatalf("Server host not specified")
53+
}
54+
if pgPort > 65535 {
55+
t.Fatalf("Server port not specified or invalid")
56+
}
57+
if len(pgUsername) == 0 {
58+
t.Fatalf("User name to access database server not specified")
59+
}
60+
if len(pgPassword) == 0 {
61+
t.Fatalf("User password to access database server not specified")
62+
}
63+
if len(pgDatabaseName) == 0 {
64+
t.Fatalf("Database name not specified")
65+
}
66+
}
67+
68+
func addressOf[T any](x T) *T {
69+
return &x
70+
}
71+
72+
func jsonReEncode(src string) (string, error) {
73+
var v interface{}
74+
75+
err := json.Unmarshal([]byte(src), &v)
76+
if err == nil {
77+
var reencoded []byte
78+
79+
reencoded, err = json.Marshal(v)
80+
if err == nil {
81+
return string(reencoded), nil
82+
}
83+
}
84+
return "", err
85+
}

connection.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package postgres
2+
3+
import (
4+
"context"
5+
6+
"github.com/jackc/pgx/v5"
7+
"github.com/jackc/pgx/v5/pgxpool"
8+
)
9+
10+
// -----------------------------------------------------------------------------
11+
12+
// Conn encloses a single connection object.
13+
type Conn struct {
14+
db *Database
15+
conn *pgxpool.Conn
16+
}
17+
18+
// -----------------------------------------------------------------------------
19+
20+
// DB returns the underlying database driver.
21+
func (c *Conn) DB() *Database {
22+
return c.db
23+
}
24+
25+
// Exec executes an SQL statement within the single connection.
26+
func (c *Conn) Exec(ctx context.Context, sql string, args ...interface{}) (int64, error) {
27+
affectedRows := int64(0)
28+
ct, err := c.conn.Exec(ctx, sql, args...)
29+
if err == nil {
30+
affectedRows = ct.RowsAffected()
31+
}
32+
return affectedRows, c.db.processError(err)
33+
}
34+
35+
// QueryRow executes a SQL query within the single connection.
36+
func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
37+
return &rowGetter{
38+
db: c.db,
39+
row: c.conn.QueryRow(ctx, sql, args...),
40+
}
41+
}
42+
43+
// QueryRows executes a SQL query within the single connection.
44+
func (c *Conn) QueryRows(ctx context.Context, sql string, args ...interface{}) Rows {
45+
rows, err := c.conn.Query(ctx, sql, args...)
46+
return &rowsGetter{
47+
db: c.db,
48+
ctx: ctx,
49+
rows: rows,
50+
err: err,
51+
}
52+
}
53+
54+
// Copy executes a SQL copy query within the single connection.
55+
func (c *Conn) Copy(ctx context.Context, tableName string, columnNames []string, callback CopyCallback) (int64, error) {
56+
n, err := c.conn.CopyFrom(
57+
ctx,
58+
pgx.Identifier{tableName},
59+
columnNames,
60+
&copyWithCallback{
61+
ctx: ctx,
62+
callback: callback,
63+
},
64+
)
65+
66+
// Done
67+
return n, c.db.processError(err)
68+
}
69+
70+
// WithinTx executes a callback function within the context of a single connection.
71+
func (c *Conn) WithinTx(ctx context.Context, cb WithinTxCallback) error {
72+
innerTx, err := c.conn.BeginTx(ctx, pgx.TxOptions{
73+
IsoLevel: pgx.ReadCommitted, //pgx.Serializable,
74+
AccessMode: pgx.ReadWrite,
75+
DeferrableMode: pgx.NotDeferrable,
76+
})
77+
if err == nil {
78+
err = cb(ctx, Tx{
79+
db: c.db,
80+
tx: innerTx,
81+
})
82+
if err == nil {
83+
err = innerTx.Commit(ctx)
84+
if err != nil {
85+
err = newError(err, "unable to commit db transaction")
86+
}
87+
}
88+
if err != nil {
89+
_ = innerTx.Rollback(context.Background()) // Using context.Background() on purpose
90+
}
91+
} else {
92+
err = newError(err, "unable to start transaction")
93+
}
94+
return c.db.processError(err)
95+
}

helpers.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,7 @@ func newError(wrappedErr error, text string) *Error {
4848
func encodeDSN(s string) string {
4949
return strings.ReplaceAll(s, "'", "\\'")
5050
}
51+
52+
func quoteIdentifier(s string) string {
53+
return "\"" + strings.ReplaceAll(s, "\"", "\"\"") + "\""
54+
}

internal.go

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package postgres
22

33
import (
4-
"context"
54
"errors"
65

76
"github.com/jackc/pgx/v5"
@@ -13,31 +12,18 @@ var errNoRows = &NoRowsError{}
1312

1413
// -----------------------------------------------------------------------------
1514

16-
// Gets a connection from the pool and initiates a transaction.
17-
func (db *Database) getTx(ctx context.Context) (pgx.Tx, error) {
18-
tx, err := db.pool.BeginTx(ctx, pgx.TxOptions{
19-
IsoLevel: pgx.ReadCommitted, //pgx.Serializable,
20-
AccessMode: pgx.ReadWrite,
21-
DeferrableMode: pgx.NotDeferrable,
22-
})
23-
if err != nil {
24-
return nil, newError(err, "unable to start transaction")
25-
}
26-
27-
//Done
28-
return tx, nil
29-
}
30-
3115
func (db *Database) processError(err error) error {
16+
isNoRows := false
3217
if errors.Is(err, pgx.ErrNoRows) {
3318
err = errNoRows
19+
isNoRows = true
3420
}
3521

3622
// Only deal with fatal database errors. Cancellation, timeouts and empty result sets are not considered fatal.
3723
db.err.mutex.Lock()
3824
defer db.err.mutex.Unlock()
3925

40-
if err != nil && IsDatabaseError(err) && err != errNoRows {
26+
if err != nil && (!isNoRows) && IsDatabaseError(err) {
4127
if db.err.last == nil {
4228
db.err.last = err
4329
if db.err.handler != nil {

0 commit comments

Comments
 (0)