Skip to content

Commit e33ab8b

Browse files
committed
refactor(generic): Modify SQL statements generation
- Move SQL statement generation from struct methods to package-level variables - Validate table name - Quote table name - Reverts the pattern of setting fields after variable declaration Signed-off-by: happy-game <[email protected]>
1 parent 9d45ba7 commit e33ab8b

File tree

2 files changed

+119
-99
lines changed

2 files changed

+119
-99
lines changed

pkg/drivers/generic/generic.go

Lines changed: 117 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,20 @@ import (
2222
)
2323

2424
const (
25-
defaultMaxIdleConns = 2 // copied from database/sql
25+
defaultMaxIdleConns = 2 // copied from database/sql
26+
tableNameMaxLength = 32 // set to 32 to avoid table name and index name too long
2627
)
2728

2829
// explicit interface check
2930
var _ server.Dialect = (*Generic)(nil)
3031

3132
var (
32-
columns = "kv.id AS theid, kv.name AS thename, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value"
33+
columns = "kv.id AS theid, kv.name AS thename, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value"
34+
revSQL string
35+
compactRevSQL string
36+
listSQL string
37+
tableName string
38+
quotedTableName string
3339
)
3440

3541
type ErrRetry func(error) bool
@@ -45,7 +51,6 @@ type ConnectionPoolConfig struct {
4551
type Generic struct {
4652
sync.Mutex
4753

48-
TableName string
4954
LockWrites bool
5055
LastInsertID bool
5156
DB *sql.DB
@@ -72,45 +77,6 @@ type Generic struct {
7277
FillRetryDuration time.Duration
7378
}
7479

75-
// RevSQL use d.TableName to format the origin revSQL
76-
func (d *Generic) RevSQL() string {
77-
return fmt.Sprintf(`
78-
SELECT MAX(rkv.id) AS id
79-
FROM %s AS rkv`, d.TableName)
80-
}
81-
82-
// CompactRevSQL use d.TableName to format the origin compactRevSQL
83-
func (d *Generic) CompactRevSQL() string {
84-
return fmt.Sprintf(`
85-
SELECT MAX(crkv.prev_revision) AS prev_revision
86-
FROM %s AS crkv
87-
WHERE crkv.name = 'compact_rev_key'`, d.TableName)
88-
}
89-
90-
// ListSQL use d.TableName to format the origin listSQL
91-
func (d *Generic) ListSQL() string {
92-
return fmt.Sprintf(`
93-
SELECT *
94-
FROM (
95-
SELECT (%s), (%s), %s
96-
FROM %s AS kv
97-
JOIN (
98-
SELECT MAX(mkv.id) AS id
99-
FROM %s AS mkv
100-
WHERE
101-
mkv.name LIKE ?
102-
%%s
103-
GROUP BY mkv.name) AS maxkv
104-
ON maxkv.id = kv.id
105-
WHERE
106-
kv.deleted = 0 OR
107-
?
108-
) AS lkv
109-
ORDER BY lkv.thename ASC
110-
`, d.RevSQL(), d.CompactRevSQL(), columns, d.TableName,
111-
d.TableName)
112-
}
113-
11480
func q(sql, param string, numbered bool) string {
11581
if param == "?" && !numbered {
11682
return sql
@@ -131,7 +97,7 @@ func (d *Generic) Migrate(ctx context.Context) {
13197
var (
13298
count = 0
13399
countKV = d.queryRow(ctx, "SELECT COUNT(*) FROM key_value")
134-
countKine = d.queryRow(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", d.TableName))
100+
countKine = d.queryRow(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", quotedTableName))
135101
)
136102

137103
if err := countKV.Scan(&count); err != nil || count == 0 {
@@ -147,7 +113,7 @@ func (d *Generic) Migrate(ctx context.Context) {
147113
fmt.Sprintf(`INSERT INTO %s(deleted, create_revision, prev_revision, name, value, created, lease)
148114
SELECT 0, 0, 0, kv.name, kv.value, 1, CASE WHEN kv.ttl > 0 THEN 15 ELSE 0 END
149115
FROM key_value kv
150-
WHERE kv.id IN (SELECT MAX(kvd.id) FROM key_value kvd GROUP BY kvd.name)`, d.TableName))
116+
WHERE kv.id IN (SELECT MAX(kvd.id) FROM key_value kvd GROUP BY kvd.name)`, quotedTableName))
151117
if err != nil {
152118
logrus.Errorf("Migration failed: %v", err)
153119
}
@@ -167,6 +133,54 @@ func configureConnectionPooling(connPoolConfig ConnectionPoolConfig, db *sql.DB,
167133
db.SetConnMaxLifetime(connPoolConfig.MaxLifetime)
168134
}
169135

136+
func validateTableName(customTableName string) error {
137+
if len(customTableName) > tableNameMaxLength {
138+
return fmt.Errorf("invalid table name '%s': must be less than %d characters", customTableName, tableNameMaxLength)
139+
}
140+
141+
matched, err := regexp.MatchString(`^[a-zA-Z][a-zA-Z0-9_]*$`, customTableName)
142+
if err != nil {
143+
return fmt.Errorf("failed to validate table name: %w", err)
144+
}
145+
if !matched {
146+
return fmt.Errorf("invalid table name '%s': must contain only letters, numbers, underscores and start with letter", customTableName)
147+
}
148+
return nil
149+
}
150+
151+
func buildSQLStatements() (rev, compactRev, list string) {
152+
rev = fmt.Sprintf(`
153+
SELECT MAX(rkv.id) AS id
154+
FROM %s AS rkv`, quotedTableName)
155+
156+
compactRev = fmt.Sprintf(`
157+
SELECT MAX(crkv.prev_revision) AS prev_revision
158+
FROM %s AS crkv
159+
WHERE crkv.name = 'compact_rev_key'`, quotedTableName)
160+
161+
list = fmt.Sprintf(`
162+
SELECT *
163+
FROM (
164+
SELECT (%s), (%s), %s
165+
FROM %s AS kv
166+
JOIN (
167+
SELECT MAX(mkv.id) AS id
168+
FROM %s AS mkv
169+
WHERE
170+
mkv.name LIKE ?
171+
%%s
172+
GROUP BY mkv.name) AS maxkv
173+
ON maxkv.id = kv.id
174+
WHERE
175+
kv.deleted = 0 OR
176+
?
177+
) AS lkv
178+
ORDER BY lkv.thename ASC
179+
`, rev, compactRev, columns, quotedTableName, quotedTableName)
180+
181+
return rev, compactRev, list
182+
}
183+
170184
func openAndTest(driverName, dataSourceName string) (*sql.DB, error) {
171185
db, err := sql.Open(driverName, dataSourceName)
172186
if err != nil {
@@ -183,16 +197,29 @@ func openAndTest(driverName, dataSourceName string) (*sql.DB, error) {
183197
return db, nil
184198
}
185199

186-
func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig ConnectionPoolConfig, paramCharacter string, numbered bool, metricsRegisterer prometheus.Registerer, tableName string) (*Generic, error) {
200+
func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig ConnectionPoolConfig, paramCharacter string, numbered bool, metricsRegisterer prometheus.Registerer, customTableName string) (*Generic, error) {
187201
var (
188202
db *sql.DB
189203
err error
190204
)
191205

192-
if tableName == "" {
193-
tableName = "kine"
206+
if err := validateTableName(customTableName); err != nil {
207+
return nil, err
194208
}
195209

210+
tableName = customTableName
211+
212+
// In case of MySQL, we need to quote the table name using ` `
213+
// In case of SQLite and Postgres, we need to quote the table name using " "
214+
switch driverName {
215+
case "mysql":
216+
quotedTableName = "`" + tableName + "`"
217+
default:
218+
quotedTableName = `"` + tableName + `"`
219+
}
220+
221+
revSQL, compactRevSQL, listSQL = buildSQLStatements()
222+
196223
for i := 0; i < 300; i++ {
197224
db, err = openAndTest(driverName, dataSourceName)
198225
if err == nil {
@@ -213,64 +240,57 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig
213240
metricsRegisterer.MustRegister(collectors.NewDBStatsCollector(db, "kine"))
214241
}
215242

216-
d := &Generic{
217-
DB: db,
218-
TableName: tableName,
219-
}
243+
return &Generic{
244+
DB: db,
220245

221-
revSQL := d.RevSQL()
222-
compactRevSQL := d.CompactRevSQL()
223-
listSQL := d.ListSQL()
246+
GetRevisionSQL: q(fmt.Sprintf(`
247+
SELECT
248+
0, 0, %s
249+
FROM %s AS kv
250+
WHERE kv.id = ?`, columns, quotedTableName), paramCharacter, numbered),
224251

225-
d.GetRevisionSQL = q(fmt.Sprintf(`
226-
SELECT
227-
0, 0, %s
228-
FROM %s AS kv
229-
WHERE kv.id = ?`, columns, tableName), paramCharacter, numbered)
252+
GetCurrentSQL: q(fmt.Sprintf(listSQL, "AND mkv.name > ?"), paramCharacter, numbered),
253+
ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered),
254+
GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, "AND mkv.name > ? AND mkv.id <= ?"), paramCharacter, numbered),
230255

231-
d.GetCurrentSQL = q(fmt.Sprintf(listSQL, "AND mkv.name > ?"), paramCharacter, numbered)
232-
d.ListRevisionStartSQL = q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered)
233-
d.GetRevisionAfterSQL = q(fmt.Sprintf(listSQL, "AND mkv.name > ? AND mkv.id <= ?"), paramCharacter, numbered)
256+
CountCurrentSQL: q(fmt.Sprintf(`
257+
SELECT (%s), COUNT(c.theid)
258+
FROM (
259+
%s
260+
) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.name > ?")), paramCharacter, numbered),
234261

235-
d.CountCurrentSQL = q(fmt.Sprintf(`
236-
SELECT (%s), COUNT(c.theid)
237-
FROM (
238-
%s
239-
) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.name > ?")), paramCharacter, numbered)
240-
241-
d.CountRevisionSQL = q(fmt.Sprintf(`
242-
SELECT (%s), COUNT(c.theid)
243-
FROM (
244-
%s
245-
) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.name > ? AND mkv.id <= ?")), paramCharacter, numbered)
262+
CountRevisionSQL: q(fmt.Sprintf(`
263+
SELECT (%s), COUNT(c.theid)
264+
FROM (
265+
%s
266+
) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.name > ? AND mkv.id <= ?")), paramCharacter, numbered),
246267

247-
d.AfterSQL = q(fmt.Sprintf(`
248-
SELECT (%s), (%s), %s
249-
FROM %s AS kv
250-
WHERE
251-
kv.name LIKE ? AND
252-
kv.id > ?
253-
ORDER BY kv.id ASC`, revSQL, compactRevSQL, columns, tableName), paramCharacter, numbered)
254-
255-
d.DeleteSQL = q(fmt.Sprintf(`
256-
DELETE FROM %s AS kv
257-
WHERE kv.id = ?`, tableName), paramCharacter, numbered)
268+
AfterSQL: q(fmt.Sprintf(`
269+
SELECT (%s), (%s), %s
270+
FROM %s AS kv
271+
WHERE
272+
kv.name LIKE ? AND
273+
kv.id > ?
274+
ORDER BY kv.id ASC`, revSQL, compactRevSQL, columns, quotedTableName), paramCharacter, numbered),
258275

259-
d.UpdateCompactSQL = q(fmt.Sprintf(`
260-
UPDATE %s
261-
SET prev_revision = ?
262-
WHERE name = 'compact_rev_key'`, tableName), paramCharacter, numbered)
276+
DeleteSQL: q(fmt.Sprintf(`
277+
DELETE FROM %s AS kv
278+
WHERE kv.id = ?`, quotedTableName), paramCharacter, numbered),
263279

264-
d.InsertLastInsertIDSQL = q(fmt.Sprintf(`INSERT INTO %s(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
265-
values(?, ?, ?, ?, ?, ?, ?, ?)`, tableName), paramCharacter, numbered)
280+
UpdateCompactSQL: q(fmt.Sprintf(`
281+
UPDATE %s
282+
SET prev_revision = ?
283+
WHERE name = 'compact_rev_key'`, quotedTableName), paramCharacter, numbered),
266284

267-
d.InsertSQL = q(fmt.Sprintf(`INSERT INTO %s(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
268-
values(?, ?, ?, ?, ?, ?, ?, ?) RETURNING id`, tableName), paramCharacter, numbered)
285+
InsertLastInsertIDSQL: q(fmt.Sprintf(`INSERT INTO %s(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
286+
values(?, ?, ?, ?, ?, ?, ?, ?)`, quotedTableName), paramCharacter, numbered),
269287

270-
d.FillSQL = q(fmt.Sprintf(`INSERT INTO %s(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value)
271-
values(?, ?, ?, ?, ?, ?, ?, ?, ?)`, tableName), paramCharacter, numbered)
288+
InsertSQL: q(fmt.Sprintf(`INSERT INTO %s(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
289+
values(?, ?, ?, ?, ?, ?, ?, ?) RETURNING id`, quotedTableName), paramCharacter, numbered),
272290

273-
return d, err
291+
FillSQL: q(fmt.Sprintf(`INSERT INTO %s(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value)
292+
values(?, ?, ?, ?, ?, ?, ?, ?, ?)`, quotedTableName), paramCharacter, numbered),
293+
}, err
274294
}
275295

276296
func (d *Generic) query(ctx context.Context, sql string, args ...interface{}) (result *sql.Rows, err error) {
@@ -314,7 +334,7 @@ func (d *Generic) execute(ctx context.Context, sql string, args ...interface{})
314334

315335
func (d *Generic) GetCompactRevision(ctx context.Context) (int64, error) {
316336
var id int64
317-
row := d.queryRow(ctx, d.CompactRevSQL())
337+
row := d.queryRow(ctx, compactRevSQL)
318338
err := row.Scan(&id)
319339
if err == sql.ErrNoRows {
320340
return 0, nil
@@ -404,7 +424,7 @@ func (d *Generic) Count(ctx context.Context, prefix, startKey string, revision i
404424

405425
func (d *Generic) CurrentRevision(ctx context.Context) (int64, error) {
406426
var id int64
407-
row := d.queryRow(ctx, d.RevSQL())
427+
row := d.queryRow(ctx, revSQL)
408428
err := row.Scan(&id)
409429
if err == sql.ErrNoRows {
410430
return 0, nil

pkg/drivers/generic/tx.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func (t *Tx) MustRollback() {
5757

5858
func (t *Tx) GetCompactRevision(ctx context.Context) (int64, error) {
5959
var id int64
60-
row := t.queryRow(ctx, t.d.CompactRevSQL())
60+
row := t.queryRow(ctx, compactRevSQL)
6161
err := row.Scan(&id)
6262
if err == sql.ErrNoRows {
6363
return 0, nil
@@ -92,7 +92,7 @@ func (t *Tx) DeleteRevision(ctx context.Context, revision int64) error {
9292

9393
func (t *Tx) CurrentRevision(ctx context.Context) (int64, error) {
9494
var id int64
95-
row := t.queryRow(ctx, t.d.RevSQL())
95+
row := t.queryRow(ctx, revSQL)
9696
err := row.Scan(&id)
9797
if err == sql.ErrNoRows {
9898
return 0, nil

0 commit comments

Comments
 (0)