Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 82 additions & 4 deletions pkg/driver/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"fmt"
"io"
"net/url"
"os/exec"
"regexp"
"strconv"
"strings"

"github.com/amacneil/dbmate/v2/pkg/dbmate"
Expand All @@ -15,6 +17,16 @@ import (
_ "github.com/go-sql-driver/mysql" // database/sql driver
)

// for mocking out during tests
type execCmd interface {
Output() ([]byte, error)
}

var execCommand = func(command string, args ...string) execCmd {
return exec.Command(command, args...)
}
var execLookPath = exec.LookPath

func init() {
dbmate.RegisterDriver(NewDriver, "mysql")
}
Expand Down Expand Up @@ -127,17 +139,82 @@ func (drv *Driver) DropDatabase() error {
return err
}

func (drv *Driver) mysqldumpArgs() []string {
type mysqldumpVersion struct {
DbType string // "mysql" or "mariadb"
Version float64 // major.minor version (e.g., 5.7, 8.4, 11.8)
Command string // "mysqldump" or "mariadb-dump"
}

var mysqldumpVersionRegexp = regexp.MustCompile(`(?:Ver \d+\.\d+ Distrib |Ver |from )(\d+\.\d+)`)

func getMysqldumpVersion() *mysqldumpVersion {
ver := &mysqldumpVersion{
DbType: "mysql",
Version: 5.0, // MariaDB 10.x is similar enough to MySQL 5.x
Command: "mysqldump",
}

if _, err := execLookPath("mariadb-dump"); err == nil {
// if we have mariadb-dump, we're at least MariaDB 11.x
ver.DbType = "mariadb"
ver.Version = 11.0
ver.Command = "mariadb-dump"
}

cmd := execCommand(ver.Command, "--version")
output, err := cmd.Output()
if err != nil {
return ver
}
outputStr := string(output)

if strings.Contains(outputStr, "MariaDB") {
ver.DbType = "mariadb"
ver.Version = 10.0
}

matches := mysqldumpVersionRegexp.FindStringSubmatch(outputStr)

if len(matches) < 2 {
return ver
}

version, err := strconv.ParseFloat(matches[1], 64)
if err != nil {
return ver
}

ver.Version = version

return ver
}

func (drv *Driver) mysqldumpArgs(ver *mysqldumpVersion) []string {
// generate CLI arguments
args := []string{"--opt", "--routines", "--no-data",
"--skip-dump-date", "--skip-add-drop-table"}

tls := drv.databaseURL.Query().Get("tls")

// Determine SSL flags based on database type and version
useSSLMode := ver.DbType == "mysql" && ver.Version >= 8.0

if tls == "" || strings.EqualFold(tls, "false") {
args = append(args, "--ssl=false")
if useSSLMode {
args = append(args, "--ssl-mode=DISABLED")
} else {
args = append(args, "--ssl=false")
}
}
if strings.EqualFold(tls, "skip-verify") {
args = append(args, "--ssl-verify-server-cert=false")
if useSSLMode {
args = append(args, "--ssl-mode=PREFERRED")
} else {
args = append(args, "--ssl-verify-server-cert=false")
}
}
if strings.EqualFold(tls, "true") && useSSLMode {
args = append(args, "--ssl-mode=REQUIRED")
}

socket := drv.databaseURL.Query().Get("socket")
Expand Down Expand Up @@ -194,7 +271,8 @@ func (drv *Driver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {

// DumpSchema returns the current database schema
func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
schema, err := dbutil.RunCommand("mysqldump", drv.mysqldumpArgs()...)
ver := getMysqldumpVersion()
schema, err := dbutil.RunCommand(ver.Command, drv.mysqldumpArgs(ver)...)
if err != nil {
return nil, err
}
Expand Down
152 changes: 103 additions & 49 deletions pkg/driver/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package mysql

import (
"database/sql"
"fmt"
"net/url"
"os/exec"
"testing"

"github.com/amacneil/dbmate/v2/pkg/dbmate"
Expand All @@ -12,6 +14,14 @@ import (
"github.com/stretchr/testify/require"
)

type mockExecCmd struct {
output string
}

func (m *mockExecCmd) Output() ([]byte, error) {
return []byte(m.output), nil
}

func testMySQLDriver(t *testing.T) *Driver {
u := dbtest.GetenvURLOrSkip(t, "MYSQL_TEST_URL")
drv, err := dbmate.New(u).Driver()
Expand Down Expand Up @@ -145,56 +155,100 @@ func TestMySQLCreateDropDatabase(t *testing.T) {
}()
}

func TestMysqldumpVersion(t *testing.T) {
cases := []struct {
name string
command string
version string
expected *mysqldumpVersion
}{
{"MySQL 5.7.44", "mysqldump", "mysqldump Ver 10.13 Distrib 5.7.44, for Linux (x86_64)", &mysqldumpVersion{DbType: "mysql", Version: 5.7, Command: "mysqldump"}},
{"MySQL 8.4.7", "mysqldump", "mysqldump Ver 8.4.7 for Linux on x86_64 (MySQL Community Server - GPL)", &mysqldumpVersion{DbType: "mysql", Version: 8.4, Command: "mysqldump"}},
{"MariaDB 10.11.15", "mysqldump", "mysqldump Ver 10.19 Distrib 10.11.15-MariaDB, for debian-linux-gnu (x86_64)", &mysqldumpVersion{DbType: "mariadb", Version: 10.11, Command: "mysqldump"}},
{"MariaDB 11.8.5", "mariadb-dump", "mariadb-dump from 11.8.5-MariaDB, client 10.19 for debian-linux-gnu (x86_64)", &mysqldumpVersion{DbType: "mariadb", Version: 11.8, Command: "mariadb-dump"}},
{"MariaDB 12.0.2", "mariadb-dump", "mariadb-dump from 12.0.2-MariaDB, client 10.19 for debian-linux-gnu (x86_64)", &mysqldumpVersion{DbType: "mariadb", Version: 12.0, Command: "mariadb-dump"}},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
origExecCommand := execCommand
execCommand = func(_ string, _ ...string) execCmd {
return &mockExecCmd{
output: c.version,
}
}
origExecLookPath := execLookPath
execLookPath = func(file string) (string, error) {
if file == c.command {
return file, nil
}
return file, exec.ErrNotFound
}
defer func() {
execCommand = origExecCommand
execLookPath = origExecLookPath
}()

actual := getMysqldumpVersion()
require.Equal(t, c.expected, actual)
})
}
}

func TestMySQLDumpArgs(t *testing.T) {
drv := testMySQLDriver(t)
drv.databaseURL = dbtest.MustParseURL(t, "mysql://bob/mydb")

require.Equal(t, []string{"--opt",
"--routines",
"--no-data",
"--skip-dump-date",
"--skip-add-drop-table",
"--ssl=false",
"--host=bob",
"mydb"}, drv.mysqldumpArgs())

drv.databaseURL = dbtest.MustParseURL(t, "mysql://alice:pw@bob:5678/mydb")
require.Equal(t, []string{"--opt",
"--routines",
"--no-data",
"--skip-dump-date",
"--skip-add-drop-table",
"--ssl=false",
"--host=bob",
"--port=5678",
"--user=alice",
"--password=pw",
"mydb"}, drv.mysqldumpArgs())

drv.databaseURL = dbtest.MustParseURL(t, "mysql://alice:pw@bob:5678/mydb?tls=skip-verify")
require.Equal(t, []string{"--opt",
"--routines",
"--no-data",
"--skip-dump-date",
"--skip-add-drop-table",
"--ssl-verify-server-cert=false",
"--host=bob",
"--port=5678",
"--user=alice",
"--password=pw",
"mydb"}, drv.mysqldumpArgs())

drv.databaseURL = dbtest.MustParseURL(t, "mysql://alice:pw@bob:5678/mydb?socket=/var/run/mysqld/mysqld.sock")
require.Equal(t, []string{"--opt",
"--routines",
"--no-data",
"--skip-dump-date",
"--skip-add-drop-table",
"--ssl=false",
"--socket=/var/run/mysqld/mysqld.sock",
"--user=alice",
"--password=pw",
"mydb"}, drv.mysqldumpArgs())
cases := []struct {
name string
command string
version string
url string
expected []string
}{
// mysql://bob/mydb
{"MySQL 5.7.44", "mysqldump", "mysqldump Ver 10.13 Distrib 5.7.44, for Linux (x86_64)", "mysql://bob/mydb", []string{"--opt", "--routines", "--no-data", "--skip-dump-date", "--skip-add-drop-table", "--ssl=false", "--host=bob", "mydb"}},
{"MySQL 8.4.7", "mysqldump", "mysqldump Ver 8.4.7 for Linux on x86_64 (MySQL Community Server - GPL)", "mysql://bob/mydb", []string{"--opt", "--routines", "--no-data", "--skip-dump-date", "--skip-add-drop-table", "--ssl-mode=DISABLED", "--host=bob", "mydb"}},

// mysql://alice:pw@bob:5678/mydb
{"MySQL 5.7.44", "mysqldump", "mysqldump Ver 10.13 Distrib 5.7.44, for Linux (x86_64)", "mysql://alice:pw@bob:5678/mydb", []string{"--opt", "--routines", "--no-data", "--skip-dump-date", "--skip-add-drop-table", "--ssl=false", "--host=bob", "--port=5678", "--user=alice", "--password=pw", "mydb"}},
{"MySQL 8.4.7", "mysqldump", "mysqldump Ver 8.4.7 for Linux on x86_64 (MySQL Community Server - GPL)", "mysql://alice:pw@bob:5678/mydb", []string{"--opt", "--routines", "--no-data", "--skip-dump-date", "--skip-add-drop-table", "--ssl-mode=DISABLED", "--host=bob", "--port=5678", "--user=alice", "--password=pw", "mydb"}},

// mysql://alice:pw@bob:5678/mydb?tls=skip-verify
{"MySQL 5.7.44", "mysqldump", "mysqldump Ver 10.13 Distrib 5.7.44, for Linux (x86_64)", "mysql://alice:pw@bob:5678/mydb?tls=skip-verify", []string{"--opt", "--routines", "--no-data", "--skip-dump-date", "--skip-add-drop-table", "--ssl-verify-server-cert=false", "--host=bob", "--port=5678", "--user=alice", "--password=pw", "mydb"}},
{"MySQL 8.4.7", "mysqldump", "mysqldump Ver 8.4.7 for Linux on x86_64 (MySQL Community Server - GPL)", "mysql://alice:pw@bob:5678/mydb?tls=skip-verify", []string{"--opt", "--routines", "--no-data", "--skip-dump-date", "--skip-add-drop-table", "--ssl-mode=PREFERRED", "--host=bob", "--port=5678", "--user=alice", "--password=pw", "mydb"}},

// mysql://alice:pw@bob:5678/mydb?socket=/var/run/mysqld/mysqld.sock
{"MySQL 5.7.44", "mysqldump", "mysqldump Ver 10.13 Distrib 5.7.44, for Linux (x86_64)", "mysql://alice:pw@bob:5678/mydb?socket=/var/run/mysqld/mysqld.sock", []string{"--opt", "--routines", "--no-data", "--skip-dump-date", "--skip-add-drop-table", "--ssl=false", "--socket=/var/run/mysqld/mysqld.sock", "--user=alice", "--password=pw", "mydb"}},
{"MySQL 8.4.7", "mysqldump", "mysqldump Ver 8.4.7 for Linux on x86_64 (MySQL Community Server - GPL)", "mysql://alice:pw@bob:5678/mydb?socket=/var/run/mysqld/mysqld.sock", []string{"--opt", "--routines", "--no-data", "--skip-dump-date", "--skip-add-drop-table", "--ssl-mode=DISABLED", "--socket=/var/run/mysqld/mysqld.sock", "--user=alice", "--password=pw", "mydb"}},
}

for _, c := range cases {
t.Run(fmt.Sprintf("%s__%s", c.name, c.url), func(t *testing.T) {
origExecCommand := execCommand
execCommand = func(_ string, _ ...string) execCmd {
return &mockExecCmd{
output: c.version,
}
}
origExecLookPath := execLookPath
execLookPath = func(file string) (string, error) {
if file == c.command {
return file, nil
}
return file, exec.ErrNotFound
}
defer func() {
execCommand = origExecCommand
execLookPath = origExecLookPath
}()

ver := getMysqldumpVersion()

drv := testMySQLDriver(t)
drv.databaseURL = dbtest.MustParseURL(t, c.url)

actual := drv.mysqldumpArgs(ver)
require.Equal(t, c.expected, actual)
})
}
}

func TestMySQLDumpSchema(t *testing.T) {
Expand Down