Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pkg/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ func New() *cli.App {
Usage: "Storage endpoint (default is sqlite)",
Destination: &config.Endpoint,
},
&cli.StringFlag{
Name: "table-name",
Usage: "The table name for the selected backend. Defaults to 'kine'.",
Destination: &config.TableName,
Value: "kine",
},
&cli.StringFlag{
Name: "ca-file",
Usage: "CA cert for DB connection",
Expand Down
1 change: 1 addition & 0 deletions pkg/drivers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
type Config struct {
MetricsRegisterer prometheus.Registerer
Endpoint string
TableName string
Scheme string
DataSourceName string
ConnectionPoolConfig generic.ConnectionPoolConfig
Expand Down
132 changes: 82 additions & 50 deletions pkg/drivers/generic/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,42 +22,19 @@ import (
)

const (
defaultMaxIdleConns = 2 // copied from database/sql
defaultMaxIdleConns = 2 // copied from database/sql
tableNameMaxLength = 32 // set to 32 to avoid table name and index name too long
)

// explicit interface check
var _ server.Dialect = (*Generic)(nil)

var (
columns = "kv.id AS theid, kv.name AS thename, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value"
revSQL = `
SELECT MAX(rkv.id) AS id
FROM kine AS rkv`

compactRevSQL = `
SELECT MAX(crkv.prev_revision) AS prev_revision
FROM kine AS crkv
WHERE crkv.name = 'compact_rev_key'`

listSQL = fmt.Sprintf(`
SELECT *
FROM (
SELECT (%s), (%s), %s
FROM kine AS kv
JOIN (
SELECT MAX(mkv.id) AS id
FROM kine AS mkv
WHERE
mkv.name LIKE ?
%%s
GROUP BY mkv.name) AS maxkv
ON maxkv.id = kv.id
WHERE
kv.deleted = 0 OR
?
) AS lkv
ORDER BY lkv.thename ASC
`, revSQL, compactRevSQL, columns)
columns = `kv.id AS theid, kv.name AS thename, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value`
revSQL string
compactRevSQL string
listSQL string
tableName string
)

type ErrRetry func(error) bool
Expand Down Expand Up @@ -118,8 +95,8 @@ func q(sql, param string, numbered bool) string {
func (d *Generic) Migrate(ctx context.Context) {
var (
count = 0
countKV = d.queryRow(ctx, "SELECT COUNT(*) FROM key_value")
countKine = d.queryRow(ctx, "SELECT COUNT(*) FROM kine")
countKV = d.queryRow(ctx, `SELECT COUNT(*) FROM key_value`)
countKine = d.queryRow(ctx, `SELECT COUNT(*) FROM "`+tableName+`"`)
)

if err := countKV.Scan(&count); err != nil || count == 0 {
Expand All @@ -132,7 +109,7 @@ func (d *Generic) Migrate(ctx context.Context) {

logrus.Infof("Migrating content from old table")
_, err := d.execute(ctx,
`INSERT INTO kine(deleted, create_revision, prev_revision, name, value, created, lease)
`INSERT INTO "`+tableName+`"(deleted, create_revision, prev_revision, name, value, created, lease)
SELECT 0, 0, 0, kv.name, kv.value, 1, CASE WHEN kv.ttl > 0 THEN 15 ELSE 0 END
FROM key_value kv
WHERE kv.id IN (SELECT MAX(kvd.id) FROM key_value kvd GROUP BY kvd.name)`)
Expand All @@ -155,6 +132,54 @@ func configureConnectionPooling(connPoolConfig ConnectionPoolConfig, db *sql.DB,
db.SetConnMaxLifetime(connPoolConfig.MaxLifetime)
}

func validateTableName(customTableName string) error {
if len(customTableName) > tableNameMaxLength {
return fmt.Errorf("invalid table name '%s': must be less than %d characters", customTableName, tableNameMaxLength)
}

matched, err := regexp.MatchString(`^[a-zA-Z][a-zA-Z0-9_\$]*$`, customTableName)
if err != nil {
return fmt.Errorf("failed to validate table name: %w", err)
}
if !matched {
return fmt.Errorf("invalid table name '%s': must contain only letters, numbers, underscores, dollar signs and start with letter", customTableName)
}
return nil
}

func buildSQLStatements() (rev, compactRev, list string) {
rev = fmt.Sprintf(`
SELECT MAX(rkv.id) AS id
FROM "%s" AS rkv`, tableName)

compactRev = fmt.Sprintf(`
SELECT MAX(crkv.prev_revision) AS prev_revision
FROM "%s" AS crkv
WHERE crkv.name = 'compact_rev_key'`, tableName)

list = fmt.Sprintf(`
SELECT *
FROM (
SELECT (%s), (%s), %s
FROM "%s" AS kv
JOIN (
SELECT MAX(mkv.id) AS id
FROM "%s" AS mkv
WHERE
mkv.name LIKE ?
%%s
GROUP BY mkv.name) AS maxkv
ON maxkv.id = kv.id
WHERE
kv.deleted = 0 OR
?
) AS lkv
ORDER BY lkv.thename ASC
`, rev, compactRev, columns, tableName, tableName)

return rev, compactRev, list
}

func openAndTest(driverName, dataSourceName string) (*sql.DB, error) {
db, err := sql.Open(driverName, dataSourceName)
if err != nil {
Expand All @@ -171,12 +196,19 @@ func openAndTest(driverName, dataSourceName string) (*sql.DB, error) {
return db, nil
}

func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig ConnectionPoolConfig, paramCharacter string, numbered bool, metricsRegisterer prometheus.Registerer) (*Generic, error) {
func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig ConnectionPoolConfig, paramCharacter string, numbered bool, metricsRegisterer prometheus.Registerer, customTableName string) (*Generic, error) {
var (
db *sql.DB
err error
)

if err := validateTableName(customTableName); err != nil {
return nil, err
}

tableName = customTableName
revSQL, compactRevSQL, listSQL = buildSQLStatements()

for i := 0; i < 300; i++ {
db, err = openAndTest(driverName, dataSourceName)
if err == nil {
Expand All @@ -203,8 +235,8 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig
GetRevisionSQL: q(fmt.Sprintf(`
SELECT
0, 0, %s
FROM kine AS kv
WHERE kv.id = ?`, columns), paramCharacter, numbered),
FROM "%s" AS kv
WHERE kv.id = ?`, columns, tableName), paramCharacter, numbered),

GetCurrentSQL: q(fmt.Sprintf(listSQL, "AND mkv.name > ?"), paramCharacter, numbered),
ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered),
Expand All @@ -224,29 +256,29 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig

AfterSQL: q(fmt.Sprintf(`
SELECT (%s), (%s), %s
FROM kine AS kv
FROM "%s" AS kv
WHERE
kv.name LIKE ? AND
kv.id > ?
ORDER BY kv.id ASC`, revSQL, compactRevSQL, columns), paramCharacter, numbered),
ORDER BY kv.id ASC`, revSQL, compactRevSQL, columns, tableName), paramCharacter, numbered),

DeleteSQL: q(`
DELETE FROM kine AS kv
WHERE kv.id = ?`, paramCharacter, numbered),
DeleteSQL: q(fmt.Sprintf(`
DELETE FROM "%s" AS kv
WHERE kv.id = ?`, tableName), paramCharacter, numbered),

UpdateCompactSQL: q(`
UPDATE kine
UpdateCompactSQL: q(fmt.Sprintf(`
UPDATE "%s"
SET prev_revision = ?
WHERE name = 'compact_rev_key'`, paramCharacter, numbered),
WHERE name = 'compact_rev_key'`, tableName), paramCharacter, numbered),

InsertLastInsertIDSQL: q(`INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
values(?, ?, ?, ?, ?, ?, ?, ?)`, paramCharacter, numbered),
InsertLastInsertIDSQL: q(fmt.Sprintf(`INSERT INTO "%s"(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
values(?, ?, ?, ?, ?, ?, ?, ?)`, tableName), paramCharacter, numbered),

InsertSQL: q(`INSERT INTO kine(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
values(?, ?, ?, ?, ?, ?, ?, ?) RETURNING id`, paramCharacter, numbered),
InsertSQL: q(fmt.Sprintf(`INSERT INTO "%s"(name, created, deleted, create_revision, prev_revision, lease, value, old_value)
values(?, ?, ?, ?, ?, ?, ?, ?) RETURNING id`, tableName), paramCharacter, numbered),

FillSQL: q(`INSERT INTO kine(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value)
values(?, ?, ?, ?, ?, ?, ?, ?, ?)`, paramCharacter, numbered),
FillSQL: q(fmt.Sprintf(`INSERT INTO "%s"(id, name, created, deleted, create_revision, prev_revision, lease, value, old_value)
values(?, ?, ?, ?, ?, ?, ?, ?, ?)`, tableName), paramCharacter, numbered),
}, err
}

Expand Down
79 changes: 55 additions & 24 deletions pkg/drivers/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"os"
"strconv"
"strings"

"github.com/go-sql-driver/mysql"
"github.com/sirupsen/logrus"
Expand All @@ -24,9 +25,11 @@ const (
defaultHostDSN = "root@tcp(127.0.0.1)/"
)

var (
schema = []string{
`CREATE TABLE IF NOT EXISTS kine
var createDB = "CREATE DATABASE IF NOT EXISTS `%s`;"

func getSchema(tableName string) []string {
return []string{
`CREATE TABLE IF NOT EXISTS "` + tableName + `"
(
id BIGINT UNSIGNED AUTO_INCREMENT,
name VARCHAR(630) CHARACTER SET ascii,
Expand All @@ -39,20 +42,22 @@ var (
old_value MEDIUMBLOB,
PRIMARY KEY (id)
);`,
`CREATE INDEX kine_name_index ON kine (name)`,
`CREATE INDEX kine_name_id_index ON kine (name,id)`,
`CREATE INDEX kine_id_deleted_index ON kine (id,deleted)`,
`CREATE INDEX kine_prev_revision_index ON kine (prev_revision)`,
`CREATE UNIQUE INDEX kine_name_prev_revision_uindex ON kine (name, prev_revision)`,
}
schemaMigrations = []string{
`ALTER TABLE kine MODIFY COLUMN id BIGINT UNSIGNED AUTO_INCREMENT NOT NULL UNIQUE, MODIFY COLUMN create_revision BIGINT UNSIGNED, MODIFY COLUMN prev_revision BIGINT UNSIGNED`,
`CREATE INDEX "` + tableName + `_name_index" ON "` + tableName + `" (name)`,
`CREATE INDEX "` + tableName + `_name_id_index" ON "` + tableName + `" (name,id)`,
`CREATE INDEX "` + tableName + `_id_deleted_index" ON "` + tableName + `" (id,deleted)`,
`CREATE INDEX "` + tableName + `_prev_revision_index" ON "` + tableName + `" (prev_revision)`,
`CREATE UNIQUE INDEX "` + tableName + `_name_prev_revision_uindex" ON "` + tableName + `" (name, prev_revision)`,
}
}

func getSchemaMigrations(tableName string) []string {
return []string{
`ALTER TABLE "` + tableName + `" MODIFY COLUMN id BIGINT UNSIGNED AUTO_INCREMENT NOT NULL UNIQUE, MODIFY COLUMN create_revision BIGINT UNSIGNED, MODIFY COLUMN prev_revision BIGINT UNSIGNED`,
// Creating an empty migration to ensure that postgresql and mysql migrations match up
// with each other for a give value of KINE_SCHEMA_MIGRATION env var
``,
}
createDB = "CREATE DATABASE IF NOT EXISTS `%s`;"
)
}

func New(ctx context.Context, cfg *drivers.Config) (bool, server.Backend, error) {
tlsConfig, err := cfg.BackendTLSConfig.ClientConfig()
Expand All @@ -73,7 +78,12 @@ func New(ctx context.Context, cfg *drivers.Config) (bool, server.Backend, error)
return false, nil, err
}

dialect, err := generic.Open(ctx, "mysql", parsedDSN, cfg.ConnectionPoolConfig, "?", false, cfg.MetricsRegisterer)
tableName := cfg.TableName
if tableName == "" {
tableName = "kine"
}

dialect, err := generic.Open(ctx, "mysql", parsedDSN, cfg.ConnectionPoolConfig, "?", false, cfg.MetricsRegisterer, tableName)
if err != nil {
return false, nil, err
}
Expand All @@ -82,19 +92,19 @@ func New(ctx context.Context, cfg *drivers.Config) (bool, server.Backend, error)
dialect.GetSizeSQL = `
SELECT SUM(data_length + index_length)
FROM information_schema.TABLES
WHERE table_schema = DATABASE() AND table_name = 'kine'`
WHERE table_schema = DATABASE() AND table_name = '` + tableName + `'`
dialect.CompactSQL = `
DELETE kv FROM kine AS kv
DELETE kv FROM "` + tableName + `" AS kv
INNER JOIN (
SELECT kp.prev_revision AS id
FROM kine AS kp
FROM "` + tableName + `" AS kp
WHERE
kp.name != 'compact_rev_key' AND
kp.prev_revision != 0 AND
kp.id <= ?
UNION
SELECT kd.id AS id
FROM kine AS kd
FROM "` + tableName + `" AS kd
WHERE
kd.deleted != 0 AND
kd.id <= ?
Expand All @@ -115,24 +125,24 @@ func New(ctx context.Context, cfg *drivers.Config) (bool, server.Backend, error)
}
return err.Error()
}
if err := setup(dialect.DB); err != nil {
if err := setup(dialect.DB, tableName); err != nil {
return false, nil, err
}

dialect.Migrate(context.Background())
return true, logstructured.New(sqllog.New(dialect, cfg.CompactInterval, cfg.CompactIntervalJitter, cfg.CompactTimeout, cfg.CompactMinRetain, cfg.CompactBatchSize, cfg.PollBatchSize)), nil
}

func setup(db *sql.DB) error {
func setup(db *sql.DB, tableName string) error {
logrus.Infof("Configuring database table schema and indexes, this may take a moment...")
var exists bool
err := db.QueryRow("SELECT 1 FROM information_schema.TABLES WHERE table_schema = DATABASE() AND table_name = ?", "kine").Scan(&exists)
err := db.QueryRow("SELECT 1 FROM information_schema.TABLES WHERE table_schema = DATABASE() AND table_name = ?", tableName).Scan(&exists)
if err != nil && err != sql.ErrNoRows {
logrus.Warnf("Failed to check existence of database table %s, going to attempt create: %v", "kine", err)
logrus.Warnf("Failed to check existence of database table %s, going to attempt create: %v", tableName, err)
}

if !exists {
for _, stmt := range schema {
for _, stmt := range getSchema(tableName) {
logrus.Tracef("SETUP EXEC : %v", util.Stripped(stmt))
if _, err := db.Exec(stmt); err != nil {
if mysqlError, ok := err.(*mysql.MySQLError); !ok || mysqlError.Number != 1061 {
Expand All @@ -146,7 +156,7 @@ func setup(db *sql.DB) error {
// Note that the schema created by the `schema` var is always the latest revision;
// migrations should handle deltas between prior schema versions.
schemaVersion, _ := strconv.ParseUint(os.Getenv("KINE_SCHEMA_MIGRATION"), 10, 64)
for i, stmt := range schemaMigrations {
for i, stmt := range getSchemaMigrations(tableName) {
if i >= int(schemaVersion) {
break
}
Expand Down Expand Up @@ -216,6 +226,27 @@ func prepareDSN(dataSourceName string, tlsConfig *cryptotls.Config) (string, err
if err != nil {
return "", err
}

// ensure that ASNI_QUOTES is set in sql_mode
// this is required for using "" for quoting identifiers
// https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sqlmode_ansi_quotes
if config.Params == nil {
config.Params = map[string]string{}
}

if mode, exists := config.Params["sql_mode"]; exists {
// check if ANSI_QUOTES is already set
if !strings.Contains(strings.ToUpper(mode), "ANSI_QUOTES") {
if mode == "" {
config.Params["sql_mode"] = "ANSI_QUOTES"
} else {
config.Params["sql_mode"] = mode + ",ANSI_QUOTES"
}
}
} else {
config.Params["sql_mode"] = "ANSI_QUOTES"
}

// setting up tlsConfig
if tlsConfig != nil {
if err := mysql.RegisterTLSConfig("kine", tlsConfig); err != nil {
Expand Down
Loading