Skip to content

Commit 59bd281

Browse files
author
Jason Abbott
committed
Incorporate original PR 271 from https://github.com/brokensandals
1 parent 8a4c825 commit 59bd281

File tree

4 files changed

+139
-0
lines changed

4 files changed

+139
-0
lines changed

Diff for: _example/hook/hook.go

+6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ func main() {
1414
&sqlite3.SQLiteDriver{
1515
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
1616
sqlite3conn = append(sqlite3conn, conn)
17+
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
18+
switch op {
19+
case sqlite3.SQLITE_INSERT:
20+
log.Println("Notified of insert on db", db, "table", table, "rowid", rowid)
21+
}
22+
})
1723
return nil
1824
},
1925
})

Diff for: callback.go

+18
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,24 @@ func doneTrampoline(ctx *C.sqlite3_context) {
5353
ai.Done(ctx)
5454
}
5555

56+
//export commitHookTrampoline
57+
func commitHookTrampoline(handle uintptr) int {
58+
callback := lookupHandle(handle).(func() int)
59+
return callback()
60+
}
61+
62+
//export rollbackHookTrampoline
63+
func rollbackHookTrampoline(handle uintptr) {
64+
callback := lookupHandle(handle).(func())
65+
callback()
66+
}
67+
68+
//export updateHookTrampoline
69+
func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
70+
callback := lookupHandle(handle).(func(int, string, string, int64))
71+
callback(op, C.GoString(db), C.GoString(table), rowid)
72+
}
73+
5674
// Use handles to avoid passing Go pointers to C.
5775

5876
type handleVal struct {

Diff for: sqlite3.go

+54
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ int _sqlite3_create_function(
100100
}
101101
102102
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
103+
int commitHookTrampoline(void*);
104+
void rollbackHookTrampoline(void*);
105+
void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
103106
*/
104107
import "C"
105108
import (
@@ -150,6 +153,12 @@ func Version() (libVersion string, libVersionNumber int, sourceID string) {
150153
return libVersion, libVersionNumber, sourceID
151154
}
152155

156+
const (
157+
SQLITE_DELETE = C.SQLITE_DELETE
158+
SQLITE_INSERT = C.SQLITE_INSERT
159+
SQLITE_UPDATE = C.SQLITE_UPDATE
160+
)
161+
153162
// SQLiteDriver implement sql.Driver.
154163
type SQLiteDriver struct {
155164
Extensions []string
@@ -315,6 +324,51 @@ func (tx *SQLiteTx) Rollback() error {
315324
return err
316325
}
317326

327+
// RegisterCommitHook sets the commit hook for a connection.
328+
//
329+
// If the callback returns non-zero the transaction will become a rollback.
330+
//
331+
// If there is an existing commit hook for this connection, it will be
332+
// removed. If callback is nil the existing hook (if any) will be removed
333+
// without creating a new one.
334+
func (c *SQLiteConn) RegisterCommitHook(callback func() int) {
335+
if callback == nil {
336+
C.sqlite3_commit_hook(c.db, nil, nil)
337+
} else {
338+
C.sqlite3_commit_hook(c.db, (*[0]byte)(unsafe.Pointer(C.commitHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
339+
}
340+
}
341+
342+
// RegisterRollbackHook sets the rollback hook for a connection.
343+
//
344+
// If there is an existing rollback hook for this connection, it will be
345+
// removed. If callback is nil the existing hook (if any) will be removed
346+
// without creating a new one.
347+
func (c *SQLiteConn) RegisterRollbackHook(callback func()) {
348+
if callback == nil {
349+
C.sqlite3_rollback_hook(c.db, nil, nil)
350+
} else {
351+
C.sqlite3_rollback_hook(c.db, (*[0]byte)(unsafe.Pointer(C.rollbackHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
352+
}
353+
}
354+
355+
// RegisterUpdateHook sets the update hook for a connection.
356+
//
357+
// The parameters to the callback are the operation (one of the constants
358+
// SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), the database name, the
359+
// table name, and the rowid.
360+
//
361+
// If there is an existing update hook for this connection, it will be
362+
// removed. If callback is nil the existing hook (if any) will be removed
363+
// without creating a new one.
364+
func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64)) {
365+
if callback == nil {
366+
C.sqlite3_update_hook(c.db, nil, nil)
367+
} else {
368+
C.sqlite3_update_hook(c.db, (*[0]byte)(unsafe.Pointer(C.updateHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
369+
}
370+
}
371+
318372
// RegisterFunc makes a Go function available as a SQLite function.
319373
//
320374
// The Go function can have arguments of the following types: any

Diff for: sqlite3_test.go

+61
Original file line numberDiff line numberDiff line change
@@ -1265,6 +1265,67 @@ func TestPinger(t *testing.T) {
12651265
}
12661266
}
12671267

1268+
func TestUpdateAndTransactionHooks(t *testing.T) {
1269+
var events []string
1270+
var commitHookReturn = 0
1271+
1272+
sql.Register("sqlite3_UpdateHook", &SQLiteDriver{
1273+
ConnectHook: func(conn *SQLiteConn) error {
1274+
conn.RegisterCommitHook(func() int {
1275+
events = append(events, "commit")
1276+
return commitHookReturn
1277+
})
1278+
conn.RegisterRollbackHook(func() {
1279+
events = append(events, "rollback")
1280+
})
1281+
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
1282+
events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid))
1283+
})
1284+
return nil
1285+
},
1286+
})
1287+
db, err := sql.Open("sqlite3_UpdateHook", ":memory:")
1288+
if err != nil {
1289+
t.Fatal("Failed to open database:", err)
1290+
}
1291+
defer db.Close()
1292+
1293+
statements := []string{
1294+
"create table foo (id integer primary key)",
1295+
"insert into foo values (9)",
1296+
"update foo set id = 99 where id = 9",
1297+
"delete from foo where id = 99",
1298+
}
1299+
for _, statement := range statements {
1300+
_, err = db.Exec(statement)
1301+
if err != nil {
1302+
t.Fatalf("Unable to prepare test data [%v]: %v", statement, err)
1303+
}
1304+
}
1305+
1306+
commitHookReturn = 1
1307+
_, err = db.Exec("insert into foo values (5)")
1308+
if err == nil {
1309+
t.Error("Commit hook failed to rollback transaction")
1310+
}
1311+
1312+
var expected = []string{
1313+
"commit",
1314+
fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT),
1315+
"commit",
1316+
fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE),
1317+
"commit",
1318+
fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE),
1319+
"commit",
1320+
fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT),
1321+
"commit",
1322+
"rollback",
1323+
}
1324+
if !reflect.DeepEqual(events, expected) {
1325+
t.Errorf("Expected notifications %v but got %v", expected, events)
1326+
}
1327+
}
1328+
12681329
var customFunctionOnce sync.Once
12691330

12701331
func BenchmarkCustomFunctions(b *testing.B) {

0 commit comments

Comments
 (0)