Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add --pre-plan-file flag #168

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cmd/pg-schema-diff/apply_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ func runPlan(ctx context.Context, connConfig *pgx.ConnConfig, plan diff.Plan) er
}
defer conn.Close()

if plan.PrePlanDDL != "" {
fmt.Println(header("Executing pre-plan DDL"))
fmt.Printf("%s\n\n", plan.PrePlanDDL)
if _, err := conn.ExecContext(ctx, plan.PrePlanDDL); err != nil {
return fmt.Errorf("executing pre-plan DDL: %w", err)
}
}

// Due to the way *sql.Db works, when a statement_timeout is set for the session, it will NOT reset
// by default when it's returned to the pool.
//
Expand Down
8 changes: 7 additions & 1 deletion cmd/pg-schema-diff/plan_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ type (

schemaSourceFlags struct {
schemaDirs []string
prePlanFile string
targetDatabaseDSN string
}

Expand Down Expand Up @@ -149,6 +150,11 @@ func schemaSourceFlagsVar(cmd *cobra.Command, p *schemaSourceFlags) {
if err := cmd.MarkFlagDirname("schema-dir"); err != nil {
panic(err)
}
cmd.Flags().StringVar(&p.prePlanFile, "pre-plan-file", "", "File path to a file containing DDL statements to prepend to the generated plan.")
if err := cmd.MarkFlagFilename("pre-plan-file"); err != nil {
panic(err)
}

cmd.Flags().StringVar(&p.targetDatabaseDSN, "schema-source-dsn", "", "DSN for the database to use as the schema source. Use to generate a diff between the target database and the schema in this database.")

cmd.MarkFlagsMutuallyExclusive("schema-dir", "schema-source-dsn")
Expand Down Expand Up @@ -232,7 +238,7 @@ func parseSchemaSource(p schemaSourceFlags) (schemaSourceFactory, error) {
ddl = append(ddl, stmts...)
}
return func() (diff.SchemaSource, io.Closer, error) {
return diff.DDLSchemaSource(ddl), nil, nil
return diff.DDLSchemaSource(ddl, p.prePlanFile), nil, nil
}, nil
}

Expand Down
3 changes: 3 additions & 0 deletions pkg/diff/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ type Plan struct {
// plan on running them later, you should verify that the current schema hash matches the current schema hash.
// To get the current schema hash, you can use schema.GetPublicSchemaHash(ctx, conn)
CurrentSchemaHash string
// PrePlanDDL is a string containing DDL statements that should be executed before the plan is applied.
// This can be used for setup operations or preliminary changes that need to occur before the main migration.
PrePlanDDL string
}

// ApplyStatementTimeoutModifier applies the given timeout to all statements that match the given regex
Expand Down
29 changes: 23 additions & 6 deletions pkg/diff/plan_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"time"

Expand Down Expand Up @@ -106,7 +107,7 @@ func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt {
// newDDL: DDL encoding the new schema
// opts: Additional options to configure the plan generation
func GeneratePlan(ctx context.Context, queryable sqldb.Queryable, tempdbFactory tempdb.Factory, newDDL []string, opts ...PlanOpt) (Plan, error) {
return Generate(ctx, queryable, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...)
return Generate(ctx, queryable, DDLSchemaSource(newDDL, ""), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...)
}

// Generate generates a migration plan to migrate the database to the target schema
Expand Down Expand Up @@ -155,9 +156,20 @@ func Generate(
return Plan{}, fmt.Errorf("generating current schema hash: %w", err)
}

prePlanDDL := ""
// Prepend pre-plan file statements if available
if ddlSource, ok := targetSchema.(*ddlSchemaSource); ok && ddlSource.prePlanFile != "" {
content, err := os.ReadFile(ddlSource.prePlanFile)
if err != nil {
return Plan{}, fmt.Errorf("reading pre-plan file: %w", err)
}
prePlanDDL = string(content)
}

plan := Plan{
Statements: statements,
CurrentSchemaHash: hash,
PrePlanDDL: prePlanDDL,
}

if planOptions.validatePlan {
Expand Down Expand Up @@ -216,11 +228,11 @@ func assertValidPlan(ctx context.Context,
// on the database.
setMaxConnectionsIfNotSet(tempDb.ConnPool, tempDbMaxConnections)

if err := setSchemaForEmptyDatabase(ctx, tempDb, currentSchema, planOptions); err != nil {
if err := setSchemaForEmptyDatabase(ctx, tempDb, currentSchema, planOptions, plan.PrePlanDDL); err != nil {
return fmt.Errorf("inserting schema in temporary database: %w", err)
}

if err := executeStatementsIgnoreTimeouts(ctx, tempDb.ConnPool, plan.Statements); err != nil {
if err := executeStatementsIgnoreTimeouts(ctx, tempDb.ConnPool, plan.Statements, plan.PrePlanDDL); err != nil {
return fmt.Errorf("running migration plan: %w", err)
}

Expand All @@ -238,7 +250,7 @@ func setMaxConnectionsIfNotSet(db *sql.DB, defaultMax int) {
}
}

func setSchemaForEmptyDatabase(ctx context.Context, emptyDb *tempdb.Database, targetSchema schema.Schema, options *planOptions) error {
func setSchemaForEmptyDatabase(ctx context.Context, emptyDb *tempdb.Database, targetSchema schema.Schema, options *planOptions, prePlanDDL string) error {
// We can't create invalid indexes. We'll mark them valid in the schema, which should be functionally
// equivalent for the sake of DDL and other statements.
//
Expand All @@ -261,7 +273,7 @@ func setSchemaForEmptyDatabase(ctx context.Context, emptyDb *tempdb.Database, ta
if err != nil {
return fmt.Errorf("building schema diff: %w", err)
}
if err := executeStatementsIgnoreTimeouts(ctx, emptyDb.ConnPool, statements); err != nil {
if err := executeStatementsIgnoreTimeouts(ctx, emptyDb.ConnPool, statements, prePlanDDL); err != nil {
return fmt.Errorf("executing statements: %w\n%# v", err, pretty.Formatter(statements))
}
return nil
Expand Down Expand Up @@ -290,7 +302,7 @@ func assertMigratedSchemaMatchesTarget(migratedSchema, targetSchema schema.Schem

// executeStatementsIgnoreTimeouts executes the statements using the sql connection but ignores any provided timeouts.
// This function is currently used to validate migration plans.
func executeStatementsIgnoreTimeouts(ctx context.Context, connPool *sql.DB, statements []Statement) error {
func executeStatementsIgnoreTimeouts(ctx context.Context, connPool *sql.DB, statements []Statement, prePlanDDL string) error {
conn, err := connPool.Conn(ctx)
if err != nil {
return fmt.Errorf("getting connection from pool: %w", err)
Expand All @@ -301,6 +313,11 @@ func executeStatementsIgnoreTimeouts(ctx context.Context, connPool *sql.DB, stat
if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION statement_timeout = %d", (10*time.Second).Milliseconds())); err != nil {
return fmt.Errorf("setting statement timeout: %w", err)
}
if prePlanDDL != "" {
if _, err := conn.ExecContext(ctx, prePlanDDL); err != nil {
return fmt.Errorf("executing pre-plan DDL: %w", err)
}
}
// Due to the way *sql.Db works, when a statement_timeout is set for the session, it will NOT reset
// by default when it's returned to the pool.
//
Expand Down
19 changes: 16 additions & 3 deletions pkg/diff/schema_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package diff
import (
"context"
"fmt"
"os"

"github.com/stripe/pg-schema-diff/internal/schema"
"github.com/stripe/pg-schema-diff/pkg/log"
Expand All @@ -21,13 +22,14 @@ type SchemaSource interface {
}

type ddlSchemaSource struct {
ddl []string
ddl []string
prePlanFile string
}

// DDLSchemaSource returns a SchemaSource that returns a schema based on the provided DDL. You must provide a tempDBFactory
// via the WithTempDbFactory option.
func DDLSchemaSource(ddl []string) SchemaSource {
return &ddlSchemaSource{ddl: ddl}
func DDLSchemaSource(ddl []string, prePlanFile string) SchemaSource {
return &ddlSchemaSource{ddl: ddl, prePlanFile: prePlanFile}
}

func (s *ddlSchemaSource) GetSchema(ctx context.Context, deps schemaSourcePlanDeps) (schema.Schema, error) {
Expand All @@ -45,6 +47,17 @@ func (s *ddlSchemaSource) GetSchema(ctx context.Context, deps schemaSourcePlanDe
}
}(tempDb.ContextualCloser)

if s.prePlanFile != "" {
prePlanDDL, err := os.ReadFile(s.prePlanFile)
if err != nil {
return schema.Schema{}, fmt.Errorf("opening pre-plan file: %w", err)
}

if _, err := tempDb.ConnPool.ExecContext(ctx, string(prePlanDDL)); err != nil {
return schema.Schema{}, fmt.Errorf("running pre-plan DDL: %w", err)
}
}

for _, stmt := range s.ddl {
if _, err := tempDb.ConnPool.ExecContext(ctx, stmt); err != nil {
return schema.Schema{}, fmt.Errorf("running DDL: %w", err)
Expand Down