@@ -22,14 +22,20 @@ import (
2222)
2323
2424const (
25- defaultMaxIdleConns = 2 // copied from database/sql
25+ defaultMaxIdleConns = 2 // copied from database/sql
26+ tableNameMaxLength = 32 // set to 32 to avoid table name and index name too long
2627)
2728
2829// explicit interface check
2930var _ server.Dialect = (* Generic )(nil )
3031
3132var (
32- columns = "kv.id AS theid, kv.name AS thename, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value"
33+ columns = "kv.id AS theid, kv.name AS thename, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value"
34+ revSQL string
35+ compactRevSQL string
36+ listSQL string
37+ tableName string
38+ quotedTableName string
3339)
3440
3541type ErrRetry func (error ) bool
@@ -45,7 +51,6 @@ type ConnectionPoolConfig struct {
4551type Generic struct {
4652 sync.Mutex
4753
48- TableName string
4954 LockWrites bool
5055 LastInsertID bool
5156 DB * sql.DB
@@ -72,45 +77,6 @@ type Generic struct {
7277 FillRetryDuration time.Duration
7378}
7479
75- // RevSQL use d.TableName to format the origin revSQL
76- func (d * Generic ) RevSQL () string {
77- return fmt .Sprintf (`
78- SELECT MAX(rkv.id) AS id
79- FROM %s AS rkv` , d .TableName )
80- }
81-
82- // CompactRevSQL use d.TableName to format the origin compactRevSQL
83- func (d * Generic ) CompactRevSQL () string {
84- return fmt .Sprintf (`
85- SELECT MAX(crkv.prev_revision) AS prev_revision
86- FROM %s AS crkv
87- WHERE crkv.name = 'compact_rev_key'` , d .TableName )
88- }
89-
90- // ListSQL use d.TableName to format the origin listSQL
91- func (d * Generic ) ListSQL () string {
92- return fmt .Sprintf (`
93- SELECT *
94- FROM (
95- SELECT (%s), (%s), %s
96- FROM %s AS kv
97- JOIN (
98- SELECT MAX(mkv.id) AS id
99- FROM %s AS mkv
100- WHERE
101- mkv.name LIKE ?
102- %%s
103- GROUP BY mkv.name) AS maxkv
104- ON maxkv.id = kv.id
105- WHERE
106- kv.deleted = 0 OR
107- ?
108- ) AS lkv
109- ORDER BY lkv.thename ASC
110- ` , d .RevSQL (), d .CompactRevSQL (), columns , d .TableName ,
111- d .TableName )
112- }
113-
11480func q (sql , param string , numbered bool ) string {
11581 if param == "?" && ! numbered {
11682 return sql
@@ -131,7 +97,7 @@ func (d *Generic) Migrate(ctx context.Context) {
13197 var (
13298 count = 0
13399 countKV = d .queryRow (ctx , "SELECT COUNT(*) FROM key_value" )
134- countKine = d .queryRow (ctx , fmt .Sprintf ("SELECT COUNT(*) FROM %s" , d . TableName ))
100+ countKine = d .queryRow (ctx , fmt .Sprintf ("SELECT COUNT(*) FROM %s" , quotedTableName ))
135101 )
136102
137103 if err := countKV .Scan (& count ); err != nil || count == 0 {
@@ -147,7 +113,7 @@ func (d *Generic) Migrate(ctx context.Context) {
147113 fmt .Sprintf (`INSERT INTO %s(deleted, create_revision, prev_revision, name, value, created, lease)
148114 SELECT 0, 0, 0, kv.name, kv.value, 1, CASE WHEN kv.ttl > 0 THEN 15 ELSE 0 END
149115 FROM key_value kv
150- WHERE kv.id IN (SELECT MAX(kvd.id) FROM key_value kvd GROUP BY kvd.name)` , d . TableName ))
116+ WHERE kv.id IN (SELECT MAX(kvd.id) FROM key_value kvd GROUP BY kvd.name)` , quotedTableName ))
151117 if err != nil {
152118 logrus .Errorf ("Migration failed: %v" , err )
153119 }
@@ -167,6 +133,54 @@ func configureConnectionPooling(connPoolConfig ConnectionPoolConfig, db *sql.DB,
167133 db .SetConnMaxLifetime (connPoolConfig .MaxLifetime )
168134}
169135
136+ func validateTableName (customTableName string ) error {
137+ if len (customTableName ) > tableNameMaxLength {
138+ return fmt .Errorf ("invalid table name '%s': must be less than %d characters" , customTableName , tableNameMaxLength )
139+ }
140+
141+ matched , err := regexp .MatchString (`^[a-zA-Z][a-zA-Z0-9_]*$` , customTableName )
142+ if err != nil {
143+ return fmt .Errorf ("failed to validate table name: %w" , err )
144+ }
145+ if ! matched {
146+ return fmt .Errorf ("invalid table name '%s': must contain only letters, numbers, underscores and start with letter" , customTableName )
147+ }
148+ return nil
149+ }
150+
151+ func buildSQLStatements () (rev , compactRev , list string ) {
152+ rev = fmt .Sprintf (`
153+ SELECT MAX(rkv.id) AS id
154+ FROM %s AS rkv` , quotedTableName )
155+
156+ compactRev = fmt .Sprintf (`
157+ SELECT MAX(crkv.prev_revision) AS prev_revision
158+ FROM %s AS crkv
159+ WHERE crkv.name = 'compact_rev_key'` , quotedTableName )
160+
161+ list = fmt .Sprintf (`
162+ SELECT *
163+ FROM (
164+ SELECT (%s), (%s), %s
165+ FROM %s AS kv
166+ JOIN (
167+ SELECT MAX(mkv.id) AS id
168+ FROM %s AS mkv
169+ WHERE
170+ mkv.name LIKE ?
171+ %%s
172+ GROUP BY mkv.name) AS maxkv
173+ ON maxkv.id = kv.id
174+ WHERE
175+ kv.deleted = 0 OR
176+ ?
177+ ) AS lkv
178+ ORDER BY lkv.thename ASC
179+ ` , rev , compactRev , columns , quotedTableName , quotedTableName )
180+
181+ return rev , compactRev , list
182+ }
183+
170184func openAndTest (driverName , dataSourceName string ) (* sql.DB , error ) {
171185 db , err := sql .Open (driverName , dataSourceName )
172186 if err != nil {
@@ -183,16 +197,29 @@ func openAndTest(driverName, dataSourceName string) (*sql.DB, error) {
183197 return db , nil
184198}
185199
186- func Open (ctx context.Context , driverName , dataSourceName string , connPoolConfig ConnectionPoolConfig , paramCharacter string , numbered bool , metricsRegisterer prometheus.Registerer , tableName string ) (* Generic , error ) {
200+ func Open (ctx context.Context , driverName , dataSourceName string , connPoolConfig ConnectionPoolConfig , paramCharacter string , numbered bool , metricsRegisterer prometheus.Registerer , customTableName string ) (* Generic , error ) {
187201 var (
188202 db * sql.DB
189203 err error
190204 )
191205
192- if tableName == "" {
193- tableName = "kine"
206+ if err := validateTableName ( customTableName ); err != nil {
207+ return nil , err
194208 }
195209
210+ tableName = customTableName
211+
212+ // In case of MySQL, we need to quote the table name using ` `
213+ // In case of SQLite and Postgres, we need to quote the table name using " "
214+ switch driverName {
215+ case "mysql" :
216+ quotedTableName = "`" + tableName + "`"
217+ default :
218+ quotedTableName = `"` + tableName + `"`
219+ }
220+
221+ revSQL , compactRevSQL , listSQL = buildSQLStatements ()
222+
196223 for i := 0 ; i < 300 ; i ++ {
197224 db , err = openAndTest (driverName , dataSourceName )
198225 if err == nil {
@@ -213,64 +240,57 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig
213240 metricsRegisterer .MustRegister (collectors .NewDBStatsCollector (db , "kine" ))
214241 }
215242
216- d := & Generic {
217- DB : db ,
218- TableName : tableName ,
219- }
243+ return & Generic {
244+ DB : db ,
220245
221- revSQL := d .RevSQL ()
222- compactRevSQL := d .CompactRevSQL ()
223- listSQL := d .ListSQL ()
246+ GetRevisionSQL : q (fmt .Sprintf (`
247+ SELECT
248+ 0, 0, %s
249+ FROM %s AS kv
250+ WHERE kv.id = ?` , columns , quotedTableName ), paramCharacter , numbered ),
224251
225- d .GetRevisionSQL = q (fmt .Sprintf (`
226- SELECT
227- 0, 0, %s
228- FROM %s AS kv
229- WHERE kv.id = ?` , columns , tableName ), paramCharacter , numbered )
252+ GetCurrentSQL : q (fmt .Sprintf (listSQL , "AND mkv.name > ?" ), paramCharacter , numbered ),
253+ ListRevisionStartSQL : q (fmt .Sprintf (listSQL , "AND mkv.id <= ?" ), paramCharacter , numbered ),
254+ GetRevisionAfterSQL : q (fmt .Sprintf (listSQL , "AND mkv.name > ? AND mkv.id <= ?" ), paramCharacter , numbered ),
230255
231- d .GetCurrentSQL = q (fmt .Sprintf (listSQL , "AND mkv.name > ?" ), paramCharacter , numbered )
232- d .ListRevisionStartSQL = q (fmt .Sprintf (listSQL , "AND mkv.id <= ?" ), paramCharacter , numbered )
233- d .GetRevisionAfterSQL = q (fmt .Sprintf (listSQL , "AND mkv.name > ? AND mkv.id <= ?" ), paramCharacter , numbered )
256+ CountCurrentSQL : q (fmt .Sprintf (`
257+ SELECT (%s), COUNT(c.theid)
258+ FROM (
259+ %s
260+ ) c` , revSQL , fmt .Sprintf (listSQL , "AND mkv.name > ?" )), paramCharacter , numbered ),
234261
235- d .CountCurrentSQL = q (fmt .Sprintf (`
236- SELECT (%s), COUNT(c.theid)
237- FROM (
238- %s
239- ) c` , revSQL , fmt .Sprintf (listSQL , "AND mkv.name > ?" )), paramCharacter , numbered )
240-
241- d .CountRevisionSQL = q (fmt .Sprintf (`
242- SELECT (%s), COUNT(c.theid)
243- FROM (
244- %s
245- ) c` , revSQL , fmt .Sprintf (listSQL , "AND mkv.name > ? AND mkv.id <= ?" )), paramCharacter , numbered )
262+ CountRevisionSQL : q (fmt .Sprintf (`
263+ SELECT (%s), COUNT(c.theid)
264+ FROM (
265+ %s
266+ ) c` , revSQL , fmt .Sprintf (listSQL , "AND mkv.name > ? AND mkv.id <= ?" )), paramCharacter , numbered ),
246267
247- d .AfterSQL = q (fmt .Sprintf (`
248- SELECT (%s), (%s), %s
249- FROM %s AS kv
250- WHERE
251- kv.name LIKE ? AND
252- kv.id > ?
253- ORDER BY kv.id ASC` , revSQL , compactRevSQL , columns , tableName ), paramCharacter , numbered )
254-
255- d .DeleteSQL = q (fmt .Sprintf (`
256- DELETE FROM %s AS kv
257- WHERE kv.id = ?` , tableName ), paramCharacter , numbered )
268+ AfterSQL : q (fmt .Sprintf (`
269+ SELECT (%s), (%s), %s
270+ FROM %s AS kv
271+ WHERE
272+ kv.name LIKE ? AND
273+ kv.id > ?
274+ ORDER BY kv.id ASC` , revSQL , compactRevSQL , columns , quotedTableName ), paramCharacter , numbered ),
258275
259- d .UpdateCompactSQL = q (fmt .Sprintf (`
260- UPDATE %s
261- SET prev_revision = ?
262- WHERE name = 'compact_rev_key'` , tableName ), paramCharacter , numbered )
276+ DeleteSQL : q (fmt .Sprintf (`
277+ DELETE FROM %s AS kv
278+ WHERE kv.id = ?` , quotedTableName ), paramCharacter , numbered ),
263279
264- d .InsertLastInsertIDSQL = q (fmt .Sprintf (`INSERT INTO %s(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
265- values(?, ?, ?, ?, ?, ?, ?, ?)` , tableName ), paramCharacter , numbered )
280+ UpdateCompactSQL : q (fmt .Sprintf (`
281+ UPDATE %s
282+ SET prev_revision = ?
283+ WHERE name = 'compact_rev_key'` , quotedTableName ), paramCharacter , numbered ),
266284
267- d . InsertSQL = q (fmt .Sprintf (`INSERT INTO %s(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
268- values(?, ?, ?, ?, ?, ?, ?, ?) RETURNING id ` , tableName ), paramCharacter , numbered )
285+ InsertLastInsertIDSQL : q (fmt .Sprintf (`INSERT INTO %s(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
286+ values(?, ?, ?, ?, ?, ?, ?, ?)` , quotedTableName ), paramCharacter , numbered ),
269287
270- d . FillSQL = q (fmt .Sprintf (`INSERT INTO %s(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value)
271- values(?, ?, ?, ?, ?, ?, ?, ?, ?) ` , tableName ), paramCharacter , numbered )
288+ InsertSQL : q (fmt .Sprintf (`INSERT INTO %s(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
289+ values(?, ?, ?, ?, ?, ?, ?, ?) RETURNING id ` , quotedTableName ), paramCharacter , numbered ),
272290
273- return d , err
291+ FillSQL : q (fmt .Sprintf (`INSERT INTO %s(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value)
292+ values(?, ?, ?, ?, ?, ?, ?, ?, ?)` , quotedTableName ), paramCharacter , numbered ),
293+ }, err
274294}
275295
276296func (d * Generic ) query (ctx context.Context , sql string , args ... interface {}) (result * sql.Rows , err error ) {
@@ -314,7 +334,7 @@ func (d *Generic) execute(ctx context.Context, sql string, args ...interface{})
314334
315335func (d * Generic ) GetCompactRevision (ctx context.Context ) (int64 , error ) {
316336 var id int64
317- row := d .queryRow (ctx , d . CompactRevSQL () )
337+ row := d .queryRow (ctx , compactRevSQL )
318338 err := row .Scan (& id )
319339 if err == sql .ErrNoRows {
320340 return 0 , nil
@@ -404,7 +424,7 @@ func (d *Generic) Count(ctx context.Context, prefix, startKey string, revision i
404424
405425func (d * Generic ) CurrentRevision (ctx context.Context ) (int64 , error ) {
406426 var id int64
407- row := d .queryRow (ctx , d . RevSQL () )
427+ row := d .queryRow (ctx , revSQL )
408428 err := row .Scan (& id )
409429 if err == sql .ErrNoRows {
410430 return 0 , nil
0 commit comments