From 24b6beb4a5394c9f9112adcdde9b595546571c90 Mon Sep 17 00:00:00 2001 From: bplunkett-stripe Date: Fri, 27 Sep 2024 00:47:57 -0700 Subject: [PATCH 1/3] Refactor SQL generation such that the SQL generators take a graph-first approach --- pkg/diff/diff.go | 193 -------------------------- pkg/diff/policy_sql_generator.go | 8 +- pkg/diff/sql_generator.go | 171 ++++++++++++----------- pkg/diff/sql_graph.go | 69 ++++++++-- pkg/diff/sql_vertex_generator.go | 225 +++++++++++++++++++++++++++++++ 5 files changed, 374 insertions(+), 292 deletions(-) create mode 100644 pkg/diff/sql_vertex_generator.go diff --git a/pkg/diff/diff.go b/pkg/diff/diff.go index 0f51f65..36b04c0 100644 --- a/pkg/diff/diff.go +++ b/pkg/diff/diff.go @@ -4,20 +4,12 @@ import ( "fmt" "sort" - "github.com/stripe/pg-schema-diff/internal/graph" "github.com/stripe/pg-schema-diff/internal/schema" ) var ErrNotImplemented = fmt.Errorf("not implemented") var errDuplicateIdentifier = fmt.Errorf("duplicate identifier") -type diffType string - -const ( - diffTypeDelete diffType = "DELETE" - diffTypeAddAlter diffType = "ADDALTER" -) - type ( diff[S schema.Object] interface { GetOld() S @@ -32,79 +24,8 @@ type ( // provided diff. Alter, e.g., with a table, might produce add/delete statements Alter(Diff) ([]Statement, error) } - - // dependency indicates an edge between the SQL to resolve a diff for a source schema object and the SQL to resolve - // the diff of a target schema object - // - // Most SchemaObjects will have two nodes in the SQL graph: a node for delete SQL and a node for add/alter SQL. - // These nodes will almost always be present in the sqlGraph even if the schema object is not being deleted (or added/altered). - // If a node is present for a schema object where the "diffType" is NOT occurring, it will just be a no-op (no SQl statements) - dependency struct { - sourceObjId string - sourceType diffType - - targetObjId string - targetType diffType - } ) -type dependencyBuilder struct { - valObjId string - valType diffType -} - -func mustRun(schemaObjId string, schemaDiffType diffType) dependencyBuilder { - return dependencyBuilder{ - valObjId: schemaObjId, - valType: schemaDiffType, - } -} - -func (d dependencyBuilder) before(valObjId string, valType diffType) dependency { - return dependency{ - sourceType: d.valType, - sourceObjId: d.valObjId, - - targetType: valType, - targetObjId: valObjId, - } -} - -func (d dependencyBuilder) after(valObjId string, valType diffType) dependency { - return dependency{ - sourceObjId: valObjId, - sourceType: valType, - - targetObjId: d.valObjId, - targetType: d.valType, - } -} - -// sqlVertexGenerator is used to generate SQL statements for schema objects that have dependency webs -// with other schema objects. The schema object represents a vertex in the graph. -type sqlVertexGenerator[S schema.Object, Diff diff[S]] interface { - sqlGenerator[S, Diff] - // GetSQLVertexId gets the canonical vertex id to represent the schema object - GetSQLVertexId(S) string - - // GetAddAlterDependencies gets the dependencies of the SQL generated to resolve the AddAlter diff for the - // schema objects. Dependencies can be formed on any other nodes in the SQL graph, even if the node has - // no statements. If the diff is just an add, then old will be the zero value - // - // These dependencies can also be built in reverse: the SQL returned by the sqlVertexGenerator to resolve the - // diff for the object must always be run before the SQL required to resolve another SQL vertex diff - GetAddAlterDependencies(new S, old S) ([]dependency, error) - - // GetDeleteDependencies is the same as above but for deletes. - // Invariant to maintain: - // - If an object X depends on the delete for an object Y (generated by the sqlVertexGenerator), immediately after the - // the (Y, diffTypeDelete) sqlVertex's SQL is run, Y must no longer be present in the schema; either the - // (Y, diffTypeDelete) statements deleted Y or something that vertex depended on deleted Y. In other words, if a - // delete is cascaded by another delete (e.g., index dropped by table drop) and the index SQL is empty, - // the index delete vertex must still have dependency from itself to the object from which the delete cascades down from - GetDeleteDependencies(S) ([]dependency, error) -} - type ( // listDiff represents the differences between two lists. listDiff[S schema.Object, Diff diff[S]] struct { @@ -158,120 +79,6 @@ func (ld listDiff[S, D]) resolveToSQLGroupedByEffect(sqlGenerator sqlGenerator[S }, nil } -func (ld listDiff[S, D]) resolveToSQLGraph(generator sqlVertexGenerator[S, D]) (*sqlGraph, error) { - graph := graph.NewGraph[sqlVertex]() - - for _, a := range ld.adds { - statements, err := generator.Add(a) - if err != nil { - return nil, fmt.Errorf("generating SQL for add %s: %w", a.GetName(), err) - } - - deps, err := generator.GetAddAlterDependencies(a, *new(S)) - if err != nil { - return nil, fmt.Errorf("getting dependencies for add %s: %w", a.GetName(), err) - } - if err := addSQLVertexToGraph(graph, sqlVertex{ - ObjId: generator.GetSQLVertexId(a), - Statements: statements, - DiffType: diffTypeAddAlter, - }, deps); err != nil { - return nil, fmt.Errorf("adding SQL Vertex for add %s: %w", a.GetName(), err) - } - } - - for _, a := range ld.alters { - statements, err := generator.Alter(a) - if err != nil { - return nil, fmt.Errorf("generating SQL for diff %+v: %w", a, err) - } - - vertexId := generator.GetSQLVertexId(a.GetOld()) - vertexIdAfterAlter := generator.GetSQLVertexId(a.GetNew()) - if vertexIdAfterAlter != vertexId { - return nil, fmt.Errorf("an alter lead to a node with a different id: old=%s, new=%s", vertexId, vertexIdAfterAlter) - } - - deps, err := generator.GetAddAlterDependencies(a.GetNew(), a.GetOld()) - if err != nil { - return nil, fmt.Errorf("getting dependencies for alter %s: %w", a.GetOld().GetName(), err) - } - - if err := addSQLVertexToGraph(graph, sqlVertex{ - ObjId: vertexId, - Statements: statements, - DiffType: diffTypeAddAlter, - }, deps); err != nil { - return nil, fmt.Errorf("adding SQL Vertex for alter %s: %w", a.GetOld().GetName(), err) - } - } - - for _, d := range ld.deletes { - statements, err := generator.Delete(d) - if err != nil { - return nil, fmt.Errorf("generating SQL for delete %s: %w", d.GetName(), err) - } - - deps, err := generator.GetDeleteDependencies(d) - if err != nil { - return nil, fmt.Errorf("getting dependencies for delete %s: %w", d.GetName(), err) - } - - if err := addSQLVertexToGraph(graph, sqlVertex{ - ObjId: generator.GetSQLVertexId(d), - Statements: statements, - DiffType: diffTypeDelete, - }, deps); err != nil { - return nil, fmt.Errorf("adding SQL Vertex for delete %s: %w", d.GetName(), err) - } - } - - return (*sqlGraph)(graph), nil -} - -func addSQLVertexToGraph(graph *graph.Graph[sqlVertex], vertex sqlVertex, dependencies []dependency) error { - // It's possible the node already exists. merge it if it does - if graph.HasVertexWithId(vertex.GetId()) { - vertex = mergeSQLVertices(graph.GetVertex(vertex.GetId()), vertex) - } - graph.AddVertex(vertex) - for _, dep := range dependencies { - if err := addDependency(graph, dep); err != nil { - return fmt.Errorf("adding dependencies for %s: %w", vertex.GetId(), err) - } - } - return nil -} - -func addDependency(graph *graph.Graph[sqlVertex], dep dependency) error { - sourceVertex := sqlVertex{ - ObjId: dep.sourceObjId, - DiffType: dep.sourceType, - Statements: nil, - } - targetVertex := sqlVertex{ - ObjId: dep.targetObjId, - DiffType: dep.targetType, - Statements: nil, - } - - // To maintain the correctness of the graph, we will add a dummy vertex for the missing dependencies - addVertexIfNotExists(graph, sourceVertex) - addVertexIfNotExists(graph, targetVertex) - - if err := graph.AddEdge(sourceVertex.GetId(), targetVertex.GetId()); err != nil { - return fmt.Errorf("adding edge from %s to %s: %w", sourceVertex.GetId(), targetVertex.GetId(), err) - } - - return nil -} - -func addVertexIfNotExists(graph *graph.Graph[sqlVertex], vertex sqlVertex) { - if !graph.HasVertexWithId(vertex.GetId()) { - graph.AddVertex(vertex) - } -} - type schemaObjectEntry[S schema.Object] struct { index int // index is the index the schema object in the list obj S diff --git a/pkg/diff/policy_sql_generator.go b/pkg/diff/policy_sql_generator.go index b99e2e5..8ef2a91 100644 --- a/pkg/diff/policy_sql_generator.go +++ b/pkg/diff/policy_sql_generator.go @@ -104,7 +104,7 @@ type policyDiff struct { oldAndNew[schema.Policy] } -func buildPolicyDiffs(psg *policySQLVertexGenerator, old, new []schema.Policy) (listDiff[schema.Policy, policyDiff], error) { +func buildPolicyDiffs(psg sqlVertexGenerator[schema.Policy, policyDiff], old, new []schema.Policy) (listDiff[schema.Policy, policyDiff], error) { return diffLists(old, new, func(old, new schema.Policy, _, _ int) (_ policyDiff, requiresRecreate bool, _ error) { diff := policyDiff{ oldAndNew: oldAndNew[schema.Policy]{ @@ -131,7 +131,7 @@ type policySQLVertexGenerator struct { oldSchemaColumnsByName map[string]schema.Column } -func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (*policySQLVertexGenerator, error) { +func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (sqlVertexGenerator[schema.Policy, policyDiff], error) { var oldSchemaColumnsByName map[string]schema.Column if oldTable != nil { if oldTable.SchemaQualifiedName != table.SchemaQualifiedName { @@ -140,12 +140,12 @@ func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (*p oldSchemaColumnsByName = buildSchemaObjByNameMap(oldTable.Columns) } - return &policySQLVertexGenerator{ + return legacyToNewSqlVertexGenerator[schema.Policy, policyDiff](&policySQLVertexGenerator{ table: table, newSchemaColumnsByName: buildSchemaObjByNameMap(table.Columns), oldTable: oldTable, oldSchemaColumnsByName: oldSchemaColumnsByName, - }, nil + }), nil } func (psg *policySQLVertexGenerator) Add(p schema.Policy) ([]Statement, error) { diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go index abd9341..527a58b 100644 --- a/pkg/diff/sql_generator.go +++ b/pkg/diff/sql_generator.go @@ -505,110 +505,106 @@ func (schemaSQLGenerator) Alter(diff schemaDiff) ([]Statement, error) { return nil, fmt.Errorf("resolving named schema sql statements: %w", err) } - tableGraphs, err := diff.tableDiffs.resolveToSQLGraph(&tableSQLVertexGenerator{ + var partialGraph partialSQLGraph + + tablePartialGraph, err := generatePartialGraph(legacyToNewSqlVertexGenerator[schema.Table, tableDiff](&tableSQLVertexGenerator{ deletedTablesByName: deletedTablesByName, tablesInNewSchemaByName: tablesInNewSchemaByName, tableDiffsByName: buildDiffByNameMap[schema.Table, tableDiff](diff.tableDiffs.alters), - }) + }), diff.tableDiffs) if err != nil { - return nil, fmt.Errorf("resolving table sql graphs: %w", err) + return nil, fmt.Errorf("resolving table diff: %w", err) } + partialGraph = concatPartialGraphs(partialGraph, tablePartialGraph) extensionStatements, err := diff.extensionDiffs.resolveToSQLGroupedByEffect(&extensionSQLGenerator{}) if err != nil { - return nil, fmt.Errorf("resolving extension sql graphs: %w", err) + return nil, fmt.Errorf("resolving extension diff: %w", err) } enumStatements, err := diff.enumDiffs.resolveToSQLGroupedByEffect(&enumSQLGenerator{}) if err != nil { - return nil, fmt.Errorf("resolving enum sql graphs: %w", err) + return nil, fmt.Errorf("resolving enum diff: %w", err) } - attachPartitionSQLVertexGenerator := newAttachPartitionSQLVertexGenerator(diff.new.Indexes, diff.tableDiffs.adds) - attachPartitionGraphs, err := diff.tableDiffs.resolveToSQLGraph(attachPartitionSQLVertexGenerator) + attachPartitionGenerator := newAttachPartitionSQLVertexGenerator(diff.new.Indexes, diff.tableDiffs.adds) + attachPartitionsPartialGraph, err := generatePartialGraph(legacyToNewSqlVertexGenerator[schema.Table, tableDiff](attachPartitionGenerator), diff.tableDiffs) if err != nil { - return nil, fmt.Errorf("resolving attach partition sql graphs: %w", err) + return nil, fmt.Errorf("resolving attach partition diff: %w", err) } + partialGraph = concatPartialGraphs(partialGraph, attachPartitionsPartialGraph) - renameConflictingIndexSQLVertexGenerator := newRenameConflictingIndexSQLVertexGenerator(buildSchemaObjByNameMap(diff.old.Indexes)) - renameConflictingIndexGraphs, err := diff.indexDiffs.resolveToSQLGraph(renameConflictingIndexSQLVertexGenerator) + renameConflictingIndexesGenerator := newRenameConflictingIndexSQLVertexGenerator(buildSchemaObjByNameMap(diff.old.Indexes)) + renameConflictingIndexesPartialGraph, err := generatePartialGraph(legacyToNewSqlVertexGenerator[schema.Index, indexDiff](renameConflictingIndexesGenerator), diff.indexDiffs) if err != nil { - return nil, fmt.Errorf("resolving renaming conflicting indexes: %w", err) + return nil, fmt.Errorf("resolving renaming conflicting indexes diff: %w", err) } + partialGraph = concatPartialGraphs(partialGraph, renameConflictingIndexesPartialGraph) - indexGraphs, err := diff.indexDiffs.resolveToSQLGraph(&indexSQLVertexGenerator{ + indexGenerator := legacyToNewSqlVertexGenerator[schema.Index, indexDiff](&indexSQLVertexGenerator{ deletedTablesByName: deletedTablesByName, addedTablesByName: addedTablesByName, tablesInNewSchemaByName: tablesInNewSchemaByName, indexesInNewSchemaByName: buildSchemaObjByNameMap(diff.new.Indexes), - renameSQLVertexGenerator: renameConflictingIndexSQLVertexGenerator, - attachPartitionSQLVertexGenerator: attachPartitionSQLVertexGenerator, + renameSQLVertexGenerator: renameConflictingIndexesGenerator, + attachPartitionSQLVertexGenerator: attachPartitionGenerator, }) + indexesPartialGraph, err := generatePartialGraph(indexGenerator, diff.indexDiffs) if err != nil { - return nil, fmt.Errorf("resolving index sql graphs: %w", err) + return nil, fmt.Errorf("resolving index diff: %w", err) } + partialGraph = concatPartialGraphs(partialGraph, indexesPartialGraph) - fkConsGraphs, err := diff.foreignKeyConstraintDiffs.resolveToSQLGraph(newForeignKeyConstraintSQLVertexGenerator(diff.oldAndNew, diff.tableDiffs)) + foreignKeyGenerator := newForeignKeyConstraintSQLVertexGenerator(diff.oldAndNew, diff.tableDiffs) + fkConsPartialGraph, err := generatePartialGraph(foreignKeyGenerator, diff.foreignKeyConstraintDiffs) if err != nil { - return nil, fmt.Errorf("resolving foreign key constraint sql graphs: %w", err) + return nil, fmt.Errorf("resolving foreign key constraint diff: %w", err) } + partialGraph = concatPartialGraphs(partialGraph, fkConsPartialGraph) - sequenceGraphs, err := diff.sequenceDiffs.resolveToSQLGraph(&sequenceSQLVertexGenerator{ + sequenceGenerator := legacyToNewSqlVertexGenerator[schema.Sequence, sequenceDiff](&sequenceSQLVertexGenerator{ deletedTablesByName: deletedTablesByName, tableDiffsByName: buildDiffByNameMap[schema.Table, tableDiff](diff.tableDiffs.alters), }) + sequencesPartialGraph, err := generatePartialGraph(sequenceGenerator, diff.sequenceDiffs) if err != nil { - return nil, fmt.Errorf("resolving sequence sql graphs: %w", err) + return nil, fmt.Errorf("resolving sequence diff: %w", err) } - sequenceOwnershipGraphs, err := diff.sequenceDiffs.resolveToSQLGraph(&sequenceOwnershipSQLVertexGenerator{}) + partialGraph = concatPartialGraphs(partialGraph, sequencesPartialGraph) + + sequenceOwnershipGenerator := legacyToNewSqlVertexGenerator[schema.Sequence, sequenceDiff](&sequenceOwnershipSQLVertexGenerator{}) + sequenceOwnershipsPartialGraph, err := generatePartialGraph(sequenceOwnershipGenerator, diff.sequenceDiffs) if err != nil { - return nil, fmt.Errorf("resolving sequence ownership sql graphs: %w", err) + return nil, fmt.Errorf("resolving sequence ownership diff: %w", err) } + partialGraph = concatPartialGraphs(partialGraph, sequenceOwnershipsPartialGraph) functionsInNewSchemaByName := buildSchemaObjByNameMap(diff.new.Functions) - functionGraphs, err := diff.functionDiffs.resolveToSQLGraph(&functionSQLVertexGenerator{ + functionGenerator := legacyToNewSqlVertexGenerator[schema.Function, functionDiff](&functionSQLVertexGenerator{ functionsInNewSchemaByName: functionsInNewSchemaByName, }) + functionsPartialGraph, err := generatePartialGraph(functionGenerator, diff.functionDiffs) if err != nil { - return nil, fmt.Errorf("resolving function sql graphs: %w", err) + return nil, fmt.Errorf("resolving function diff: %w", err) } + partialGraph = concatPartialGraphs(partialGraph, functionsPartialGraph) - triggerGraphs, err := diff.triggerDiffs.resolveToSQLGraph(&triggerSQLVertexGenerator{ + triggerGenerator := legacyToNewSqlVertexGenerator[schema.Trigger, triggerDiff](&triggerSQLVertexGenerator{ functionsInNewSchemaByName: functionsInNewSchemaByName, }) + triggersPartialGraph, err := generatePartialGraph(triggerGenerator, diff.triggerDiffs) if err != nil { - return nil, fmt.Errorf("resolving trigger sql graphs: %w", err) - } - - if err := tableGraphs.union(attachPartitionGraphs); err != nil { - return nil, fmt.Errorf("unioning table and attach partition graphs: %w", err) - } - if err := tableGraphs.union(indexGraphs); err != nil { - return nil, fmt.Errorf("unioning table and index graphs: %w", err) + return nil, fmt.Errorf("resolving trigger diff: %w", err) } - if err := tableGraphs.union(renameConflictingIndexGraphs); err != nil { - return nil, fmt.Errorf("unioning table and rename conflicting index graphs: %w", err) - } - if err := tableGraphs.union(fkConsGraphs); err != nil { - return nil, fmt.Errorf("unioning table and foreign key constraint graphs: %w", err) - } - if err := tableGraphs.union(sequenceGraphs); err != nil { - return nil, fmt.Errorf("unioning table and sequence graphs: %w", err) - } - if err := tableGraphs.union(sequenceOwnershipGraphs); err != nil { - return nil, fmt.Errorf("unioning table and sequence ownership graphs: %w", err) - } - if err := tableGraphs.union(functionGraphs); err != nil { - return nil, fmt.Errorf("unioning table and function graphs: %w", err) - } - if err := tableGraphs.union(triggerGraphs); err != nil { - return nil, fmt.Errorf("unioning table and trigger graphs: %w", err) + partialGraph = concatPartialGraphs(partialGraph, triggersPartialGraph) + graph, err := graphFromPartials(partialGraph) + if err != nil { + return nil, fmt.Errorf("converting to graph: %w", err) } - - graphStatements, err := tableGraphs.toOrderedStatements() + graphStatements, err := graph.toOrderedStatements() if err != nil { - return nil, fmt.Errorf("getting ordered statements from tableGraph: %w", err) + return nil, fmt.Errorf("getting ordered statements: %w", err) } // We migrate schemas and extensions first and disable them last since their dependencies may span across @@ -773,17 +769,17 @@ func (t *tableSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) { stmts = append(stmts, alterReplicaIdentityStmt) } - psg, err := newPolicySQLVertexGenerator(nil, table) + policyGenerator, err := newPolicySQLVertexGenerator(nil, table) if err != nil { return nil, fmt.Errorf("creating policy sql vertex generator: %w", err) } for _, policy := range table.Policies { - addPolicyStmts, err := psg.Add(policy) + addPolicyPartialGraph, err := policyGenerator.Add(policy) if err != nil { return nil, fmt.Errorf("generating add policy statements for policy %s: %w", policy.EscapedName, err) } // Remove hazards from statements since the table is brand new - stmts = append(stmts, stripMigrationHazards(addPolicyStmts...)...) + stmts = append(stmts, stripMigrationHazards(addPolicyPartialGraph.statements()...)...) } if table.RLSEnabled { @@ -887,54 +883,57 @@ func (t *tableSQLVertexGenerator) alterBaseTable(diff tableDiff) ([]Statement, e tempCCs = append(tempCCs, tempCC) } - columnSQLVertexGenerator := newColumnSQLVertexGenerator(diff.new.SchemaQualifiedName) - columnGraph, err := diff.columnsDiff.resolveToSQLGraph(columnSQLVertexGenerator) + var partialGraph partialSQLGraph + + columnGenerator := newColumnSQLVertexGenerator(diff.new.SchemaQualifiedName) + columnsPartialGraph, err := generatePartialGraph(columnGenerator, diff.columnsDiff) if err != nil { return nil, fmt.Errorf("resolving index diff: %w", err) } + partialGraph = concatPartialGraphs(partialGraph, columnsPartialGraph) - checkConSqlVertexGenerator := checkConstraintSQLVertexGenerator{ + checkConGenerator := legacyToNewSqlVertexGenerator[schema.CheckConstraint, checkConstraintDiff](&checkConstraintSQLVertexGenerator{ tableName: diff.new.SchemaQualifiedName, newSchemaColumnsByName: buildSchemaObjByNameMap(diff.new.Columns), oldSchemaColumnsByName: buildSchemaObjByNameMap(diff.old.Columns), addedColumnsByName: buildSchemaObjByNameMap(diff.columnsDiff.adds), deletedColumnsByName: buildSchemaObjByNameMap(diff.columnsDiff.deletes), isNewTable: false, - } - checkConGraphs, err := diff.checkConstraintDiff.resolveToSQLGraph(&checkConSqlVertexGenerator) + }) + checkConsPartialGraph, err := generatePartialGraph(checkConGenerator, diff.checkConstraintDiff) if err != nil { - return nil, fmt.Errorf("resolving check constraints diff: %w", err) + return nil, fmt.Errorf("resolving check constraints sql: %w", err) } + partialGraph = concatPartialGraphs(partialGraph, checkConsPartialGraph) + var dropTempCCs []Statement for _, tempCC := range tempCCs { - stmt, err := checkConSqlVertexGenerator.Delete(tempCC) + dropTempCCsPartialGraph, err := checkConGenerator.Delete(tempCC) if err != nil { return nil, fmt.Errorf("deleting temp check constraint: %w", err) } - dropTempCCs = append(dropTempCCs, stmt...) + dropTempCCs = append(dropTempCCs, dropTempCCsPartialGraph.statements()...) } var nilableOldTable *schema.Table if !cmp.Equal(diff.old, schema.Table{}) { nilableOldTable = &diff.old } - psg, err := newPolicySQLVertexGenerator(nilableOldTable, diff.new) + policyGenerator, err := newPolicySQLVertexGenerator(nilableOldTable, diff.new) if err != nil { return nil, fmt.Errorf("creating policy sql vertex generator: %w", err) } - policyGraph, err := diff.policiesDiff.resolveToSQLGraph(psg) + policiesPartialGraph, err := generatePartialGraph(policyGenerator, diff.policiesDiff) if err != nil { - return nil, fmt.Errorf("resolving policy diff: %w", err) + return nil, fmt.Errorf("resolving policy sql: %w", err) } + partialGraph = concatPartialGraphs(partialGraph, policiesPartialGraph) - if err := columnGraph.union(checkConGraphs); err != nil { - return nil, fmt.Errorf("unioning column and check constraint graphs: %w", err) - } - if err := columnGraph.union(policyGraph); err != nil { - return nil, fmt.Errorf("unioning column and policy graphs: %w", err) + graph, err := graphFromPartials(partialGraph) + if err != nil { + return nil, fmt.Errorf("converting to graph") } - - graphStmts, err := columnGraph.toOrderedStatements() + graphStmts, err := graph.toOrderedStatements() if err != nil { return nil, fmt.Errorf("getting ordered statements from columnGraphs: %w", err) } @@ -1129,8 +1128,8 @@ type columnSQLVertexGenerator struct { tableName schema.SchemaQualifiedName } -func newColumnSQLVertexGenerator(tableName schema.SchemaQualifiedName) *columnSQLVertexGenerator { - return &columnSQLVertexGenerator{tableName: tableName} +func newColumnSQLVertexGenerator(tableName schema.SchemaQualifiedName) sqlVertexGenerator[schema.Column, columnDiff] { + return legacyToNewSqlVertexGenerator[schema.Column, columnDiff](&columnSQLVertexGenerator{tableName: tableName}) } func (csg *columnSQLVertexGenerator) Add(column schema.Column) ([]Statement, error) { @@ -1395,13 +1394,18 @@ type renameConflictingIndexSQLVertexGenerator struct { oldSchemaIndexesByName map[string]schema.Index indexRenamesByOldName map[string]schema.SchemaQualifiedName + + sqlVertexGenerator[schema.Index, indexDiff] } func newRenameConflictingIndexSQLVertexGenerator(oldSchemaIndexesByName map[string]schema.Index) *renameConflictingIndexSQLVertexGenerator { - return &renameConflictingIndexSQLVertexGenerator{ + rsg := &renameConflictingIndexSQLVertexGenerator{ oldSchemaIndexesByName: oldSchemaIndexesByName, indexRenamesByOldName: make(map[string]schema.SchemaQualifiedName), } + generator := legacyToNewSqlVertexGenerator[schema.Index, indexDiff](rsg) + rsg.sqlVertexGenerator = generator + return rsg } func (rsg *renameConflictingIndexSQLVertexGenerator) Add(index schema.Index) ([]Statement, error) { @@ -2008,15 +2012,20 @@ type attachPartitionSQLVertexGenerator struct { // isPartitionAttachedAfterIdxBuildsByTableName is a map of table name to whether or not the table partition will be // attached after its indexes are built. This is useful for determining when indexes need to be attached isPartitionAttachedAfterIdxBuildsByTableName map[string]bool + + sqlVertexGenerator[schema.Table, tableDiff] } func newAttachPartitionSQLVertexGenerator(newSchemaIndexes []schema.Index, addedTables []schema.Table) *attachPartitionSQLVertexGenerator { - return &attachPartitionSQLVertexGenerator{ + asg := &attachPartitionSQLVertexGenerator{ indexesInNewSchemaByTableName: buildIndexesByTableNameMap(newSchemaIndexes), addedTablesByName: buildSchemaObjByNameMap(addedTables), isPartitionAttachedAfterIdxBuildsByTableName: make(map[string]bool), } + sqlVertexGenerator := legacyToNewSqlVertexGenerator[schema.Table, tableDiff](asg) + asg.sqlVertexGenerator = sqlVertexGenerator + return asg } func (*attachPartitionSQLVertexGenerator) Add(table schema.Table) ([]Statement, error) { @@ -2079,7 +2088,7 @@ func (a *attachPartitionSQLVertexGenerator) GetDeleteDependencies(_ schema.Table return nil, nil } -func buildForeignKeyConstraintDiff(fsg *foreignKeyConstraintSQLVertexGenerator, addedTablesByName map[string]schema.Table, old, new schema.ForeignKeyConstraint) (foreignKeyConstraintDiff, bool, error) { +func buildForeignKeyConstraintDiff(fsg sqlVertexGenerator[schema.ForeignKeyConstraint, foreignKeyConstraintDiff], addedTablesByName map[string]schema.Table, old, new schema.ForeignKeyConstraint) (foreignKeyConstraintDiff, bool, error) { if _, isOnNewTable := addedTablesByName[new.OwningTable.GetName()]; isOnNewTable { // If the owning table is new, then it must be re-created (this occurs if the base table has been // re-created). In other words, a foreign key constraint must be re-created if the owning table or referenced @@ -2127,15 +2136,15 @@ type foreignKeyConstraintSQLVertexGenerator struct { childrenInNewSchemaByPartitionedIndexName map[string][]schema.Index } -func newForeignKeyConstraintSQLVertexGenerator(oldAndNewSchema oldAndNew[schema.Schema], tableDiffs listDiff[schema.Table, tableDiff]) *foreignKeyConstraintSQLVertexGenerator { - return &foreignKeyConstraintSQLVertexGenerator{ +func newForeignKeyConstraintSQLVertexGenerator(oldAndNewSchema oldAndNew[schema.Schema], tableDiffs listDiff[schema.Table, tableDiff]) sqlVertexGenerator[schema.ForeignKeyConstraint, foreignKeyConstraintDiff] { + return legacyToNewSqlVertexGenerator[schema.ForeignKeyConstraint, foreignKeyConstraintDiff](&foreignKeyConstraintSQLVertexGenerator{ newSchemaTablesByName: buildSchemaObjByNameMap(oldAndNewSchema.new.Tables), addedTablesByName: buildSchemaObjByNameMap(tableDiffs.adds), indexInOldSchemaByTableName: buildIndexesByTableNameMap(oldAndNewSchema.old.Indexes), childrenInOldSchemaByPartitionedIndexName: buildChildrenByPartitionedIndexNameMap(oldAndNewSchema.old.Indexes), indexesInNewSchemaByTableName: buildIndexesByTableNameMap(oldAndNewSchema.new.Indexes), childrenInNewSchemaByPartitionedIndexName: buildChildrenByPartitionedIndexNameMap(oldAndNewSchema.new.Indexes), - } + }) } func (f *foreignKeyConstraintSQLVertexGenerator) Add(con schema.ForeignKeyConstraint) ([]Statement, error) { diff --git a/pkg/diff/sql_graph.go b/pkg/diff/sql_graph.go index a37c946..5c2e99d 100644 --- a/pkg/diff/sql_graph.go +++ b/pkg/diff/sql_graph.go @@ -8,9 +8,12 @@ import ( ) type sqlVertex struct { + // todo(bplunkett) - rename to statement id? + // these ids will need to be globally unique ObjId string Statements []Statement - DiffType diffType + // todo(bplunkett) - this should not affect the id + DiffType diffType } func (s sqlVertex) GetId() string { @@ -25,27 +28,65 @@ func buildIndexVertexId(name schema.SchemaQualifiedName) string { return fmt.Sprintf("index_%s", name) } -// sqlGraph represents two dependency webs of SQL statements -type sqlGraph graph.Graph[sqlVertex] +// dependency indicates an edge between the SQL to resolve a diff for a source schema object and the SQL to resolve +// the diff of a target schema object +// +// Most SchemaObjects will have two nodes in the SQL graph: a node for delete SQL and a node for add/alter SQL. +// These nodes will almost always be present in the sqlGraph even if the schema object is not being deleted (or added/altered). +// If a node is present for a schema object where the "diffType" is NOT occurring, it will just be a no-op (no SQl statements) +type dependency struct { + sourceObjId string + sourceType diffType + + targetObjId string + targetType diffType +} + +type dependencyBuilder struct { + valObjId string + valType diffType +} + +func mustRun(schemaObjId string, schemaDiffType diffType) dependencyBuilder { + return dependencyBuilder{ + valObjId: schemaObjId, + valType: schemaDiffType, + } +} -// union unions the two AddsAndAlters graphs and, separately, unions the two delete graphs -func (s *sqlGraph) union(sqlGraph *sqlGraph) error { - if err := (*graph.Graph[sqlVertex])(s).Union((*graph.Graph[sqlVertex])(sqlGraph), mergeSQLVertices); err != nil { - return fmt.Errorf("unioning the graphs: %w", err) +func (d dependencyBuilder) before(valObjId string, valType diffType) dependency { + return dependency{ + sourceType: d.valType, + sourceObjId: d.valObjId, + + targetType: valType, + targetObjId: valObjId, + } +} + +func (d dependencyBuilder) after(valObjId string, valType diffType) dependency { + return dependency{ + sourceObjId: valObjId, + sourceType: valType, + + targetObjId: d.valObjId, + targetType: d.valType, } - return nil } -func mergeSQLVertices(old, new sqlVertex) sqlVertex { - return sqlVertex{ - ObjId: old.ObjId, - DiffType: old.DiffType, - Statements: append(old.Statements, new.Statements...), +// sqlGraph represents two dependency webs of SQL statements +type sqlGraph struct { + *graph.Graph[sqlVertex] +} + +func newSqlGraph() *sqlGraph { + return &sqlGraph{ + Graph: graph.NewGraph[sqlVertex](), } } func (s *sqlGraph) toOrderedStatements() ([]Statement, error) { - vertices, err := (*graph.Graph[sqlVertex])(s).TopologicallySortWithPriority(graph.IsLowerPriorityFromGetPriority( + vertices, err := s.TopologicallySortWithPriority(graph.IsLowerPriorityFromGetPriority( func(vertex sqlVertex) int { multiplier := 1 if vertex.DiffType == diffTypeDelete { diff --git a/pkg/diff/sql_vertex_generator.go b/pkg/diff/sql_vertex_generator.go new file mode 100644 index 0000000..b453825 --- /dev/null +++ b/pkg/diff/sql_vertex_generator.go @@ -0,0 +1,225 @@ +package diff + +import ( + "fmt" + + "github.com/stripe/pg-schema-diff/internal/schema" +) + +type diffType string + +const ( + diffTypeDelete diffType = "DELETE" + diffTypeAddAlter diffType = "ADDALTER" +) + +// partialSQLGraph represents the SQL statements and their dependencies. Every SQL statement is a +// vertex in the graph. This is different from a graph because a dependency may exist between a node +// within this part and a node in another part. +type partialSQLGraph struct { + vertices []sqlVertex + dependencies []dependency +} + +func (s *partialSQLGraph) statements() []Statement { + var statements []Statement + for _, vertex := range s.vertices { + statements = append(statements, vertex.Statements...) + } + return statements +} + +func concatPartialGraphs(parts ...partialSQLGraph) partialSQLGraph { + var vertices []sqlVertex + var dependencies []dependency + for _, part := range parts { + vertices = append(vertices, part.vertices...) + dependencies = append(dependencies, part.dependencies...) + } + return partialSQLGraph{ + vertices: vertices, + dependencies: dependencies, + } +} + +func graphFromPartials(parts partialSQLGraph) (*sqlGraph, error) { + graph := newSqlGraph() + for _, vertex := range parts.vertices { + // It's possible the node already exists. merge it if it does + if graph.HasVertexWithId(vertex.GetId()) { + vertex = mergeVertices(graph.GetVertex(vertex.GetId()), vertex) + } + graph.AddVertex(vertex) + } + + for _, dep := range parts.dependencies { + sourceVertex := sqlVertex{ + ObjId: dep.sourceObjId, + DiffType: dep.sourceType, + Statements: nil, + } + targetVertex := sqlVertex{ + ObjId: dep.targetObjId, + DiffType: dep.targetType, + Statements: nil, + } + + // To maintain the correctness of the graph, we will add a dummy vertex for the missing dependencies + addVertexIfNotExists(graph, sourceVertex) + addVertexIfNotExists(graph, targetVertex) + + if err := graph.AddEdge(sourceVertex.GetId(), targetVertex.GetId()); err != nil { + return nil, fmt.Errorf("adding edge from %s to %s: %w", sourceVertex.GetId(), targetVertex.GetId(), err) + } + } + + return graph, nil +} + +func mergeVertices(old, new sqlVertex) sqlVertex { + return sqlVertex{ + ObjId: old.ObjId, + DiffType: old.DiffType, + Statements: append(old.Statements, new.Statements...), + } +} + +func addVertexIfNotExists(graph *sqlGraph, vertex sqlVertex) { + if !graph.HasVertexWithId(vertex.GetId()) { + graph.AddVertex(vertex) + } +} + +// sqlVertexGenerator generates SQL statements for a schema object and its diff. This is the canonical interface +// for SQL generation. +type sqlVertexGenerator[S schema.Object, Diff diff[S]] interface { + Add(S) (partialSQLGraph, error) + Delete(S) (partialSQLGraph, error) + // Alter generates the statements required to resolve the schema object to its new state using the + // provided diff. Alter, e.g., with a table, might produce add/delete statements + Alter(Diff) (partialSQLGraph, error) +} + +// generatePartialGraph generates a partial for the given schema object list diff using the inutted generator. +func generatePartialGraph[S schema.Object, Diff diff[S]](generator sqlVertexGenerator[S, Diff], listDiff listDiff[S, Diff]) (partialSQLGraph, error) { + var partialGraphs []partialSQLGraph + for _, a := range listDiff.adds { + v, err := generator.Add(a) + if err != nil { + return partialSQLGraph{}, fmt.Errorf("generating add statements for %s: %w", a.GetName(), err) + } + partialGraphs = append(partialGraphs, v) + } + for _, d := range listDiff.deletes { + v, err := generator.Delete(d) + if err != nil { + return partialSQLGraph{}, fmt.Errorf("generating delete statements for %s: %w", d.GetName(), err) + } + partialGraphs = append(partialGraphs, v) + } + for _, a := range listDiff.alters { + v, err := generator.Alter(a) + if err != nil { + return partialSQLGraph{}, fmt.Errorf("generating alter statements for %s: %w", a.GetNew().GetName(), err) + } + partialGraphs = append(partialGraphs, v) + } + return concatPartialGraphs(partialGraphs...), nil +} + +// deprecated legacySqlVertexGenerator represents the "old" style for generating SQL vertices where the Add/Delete/Alter functions +// return a flat list of statements. +type legacySqlVertexGenerator[S schema.Object, Diff diff[S]] interface { + sqlGenerator[S, Diff] + // GetSQLVertexId gets the canonical vertex id to represent the schema object + GetSQLVertexId(S) string + + // GetAddAlterDependencies gets the dependencies of the SQL generated to resolve the AddAlter diff for the + // schema objects. Dependencies can be formed on any other nodes in the SQL graph, even if the node has + // no statements. If the diff is just an add, then old will be the zero value + // + // These dependencies can also be built in reverse: the SQL returned by the sqlVertexGenerator to resolve the + // diff for the object must always be run before the SQL required to resolve another SQL vertex diff + GetAddAlterDependencies(new S, old S) ([]dependency, error) + + // GetDeleteDependencies is the same as above but for deletes. + // Invariant to maintain: + // - If an object X depends on the delete for an object Y (generated by the sqlVertexGenerator), immediately after the + // the (Y, diffTypeDelete) sqlVertex's SQL is run, Y must no longer be present in the schema; either the + // (Y, diffTypeDelete) statements deleted Y or something that vertex depended on deleted Y. In other words, if a + // delete is cascaded by another delete (e.g., index dropped by table drop) and the index SQL is empty, + // the index delete vertex must still have dependency from itself to the object from which the delete cascades down from + GetDeleteDependencies(S) ([]dependency, error) +} + +type wrappedLegacySqlVertexGenerator[S schema.Object, Diff diff[S]] struct { + generator legacySqlVertexGenerator[S, Diff] +} + +func legacyToNewSqlVertexGenerator[S schema.Object, Diff diff[S]](generator legacySqlVertexGenerator[S, Diff]) sqlVertexGenerator[S, Diff] { + return &wrappedLegacySqlVertexGenerator[S, Diff]{ + generator: generator, + } +} + +func (s *wrappedLegacySqlVertexGenerator[S, Diff]) Add(o S) (partialSQLGraph, error) { + statements, err := s.generator.Add(o) + if err != nil { + return partialSQLGraph{}, fmt.Errorf("generating sql: %w", err) + } + + var zeroVal S + deps, err := s.generator.GetAddAlterDependencies(o, zeroVal) + if err != nil { + return partialSQLGraph{}, fmt.Errorf("getting dependencies: %w", err) + } + + return partialSQLGraph{ + vertices: []sqlVertex{{ + DiffType: diffTypeAddAlter, + ObjId: s.generator.GetSQLVertexId(o), + Statements: statements, + }}, + dependencies: deps, + }, nil +} + +func (s *wrappedLegacySqlVertexGenerator[S, Diff]) Delete(o S) (partialSQLGraph, error) { + statements, err := s.generator.Delete(o) + if err != nil { + return partialSQLGraph{}, fmt.Errorf("generating sql: %w", err) + } + deps, err := s.generator.GetDeleteDependencies(o) + if err != nil { + return partialSQLGraph{}, fmt.Errorf("getting dependencies: %w", err) + } + + return partialSQLGraph{ + vertices: []sqlVertex{{ + DiffType: diffTypeDelete, + ObjId: s.generator.GetSQLVertexId(o), + Statements: statements, + }}, + dependencies: deps, + }, nil +} + +func (s *wrappedLegacySqlVertexGenerator[S, Diff]) Alter(d Diff) (partialSQLGraph, error) { + statements, err := s.generator.Alter(d) + if err != nil { + return partialSQLGraph{}, fmt.Errorf("generating sql: %w", err) + } + deps, err := s.generator.GetAddAlterDependencies(d.GetNew(), d.GetOld()) + if err != nil { + return partialSQLGraph{}, fmt.Errorf("getting dependencies: %w", err) + } + + return partialSQLGraph{ + vertices: []sqlVertex{{ + DiffType: diffTypeAddAlter, + ObjId: s.generator.GetSQLVertexId(d.GetNew()), + Statements: statements, + }}, + dependencies: deps, + }, nil +} From 51df1813cb3f7aef69150caad90e6bf30f025344 Mon Sep 17 00:00:00 2001 From: bplunkett-stripe Date: Sun, 29 Sep 2024 21:57:58 -0700 Subject: [PATCH 2/3] Update dependency system --- pkg/diff/policy_sql_generator.go | 18 ++-- pkg/diff/sql_generator.go | 175 ++++++++++++++++--------------- pkg/diff/sql_graph.go | 115 +++++++++++--------- pkg/diff/sql_vertex_generator.go | 44 ++++---- 4 files changed, 190 insertions(+), 162 deletions(-) diff --git a/pkg/diff/policy_sql_generator.go b/pkg/diff/policy_sql_generator.go index 8ef2a91..79ecd4b 100644 --- a/pkg/diff/policy_sql_generator.go +++ b/pkg/diff/policy_sql_generator.go @@ -262,17 +262,17 @@ func (psg *policySQLVertexGenerator) Alter(diff policyDiff) ([]Statement, error) }}, nil } -func (psg *policySQLVertexGenerator) GetSQLVertexId(p schema.Policy) string { - return buildPolicyVertexId(psg.table.SchemaQualifiedName, p.EscapedName) +func (psg *policySQLVertexGenerator) GetSQLVertexId(p schema.Policy, diffType diffType) sqlVertexId { + return buildPolicyVertexId(psg.table.SchemaQualifiedName, p.EscapedName, diffType) } -func buildPolicyVertexId(owningTable schema.SchemaQualifiedName, policyEscapedName string) string { - return buildVertexId("policy", fmt.Sprintf("%s.%s", owningTable.GetFQEscapedName(), policyEscapedName)) +func buildPolicyVertexId(owningTable schema.SchemaQualifiedName, policyEscapedName string, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("policy", fmt.Sprintf("%s.%s", owningTable.GetFQEscapedName(), policyEscapedName), diffType) } func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolicy schema.Policy) ([]dependency, error) { deps := []dependency{ - mustRun(psg.GetSQLVertexId(newPolicy), diffTypeDelete).before(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter), + mustRun(psg.GetSQLVertexId(newPolicy, diffTypeDelete)).before(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)), } newTargetColumns, err := getTargetColumns(newPolicy.Columns, psg.newSchemaColumnsByName) @@ -282,7 +282,7 @@ func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolic // Run after the new columns are added/altered for _, tc := range newTargetColumns { - deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter).after(buildColumnVertexId(tc.Name), diffTypeAddAlter)) + deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).after(buildColumnVertexId(tc.Name, diffTypeAddAlter))) } if !cmp.Equal(oldPolicy, schema.Policy{}) { @@ -294,7 +294,7 @@ func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolic for _, tc := range oldTargetColumns { // It only needs to run before the delete if the column is actually being deleted if _, stillExists := psg.newSchemaColumnsByName[tc.GetName()]; !stillExists { - deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter).before(buildColumnVertexId(tc.Name), diffTypeDelete)) + deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).before(buildColumnVertexId(tc.Name, diffTypeDelete))) } } } @@ -311,8 +311,8 @@ func (psg *policySQLVertexGenerator) GetDeleteDependencies(pol schema.Policy) ([ } // The policy needs to be deleted before all the columns it references are deleted or add/altered for _, c := range columns { - deps = append(deps, mustRun(psg.GetSQLVertexId(pol), diffTypeDelete).before(buildColumnVertexId(c.Name), diffTypeDelete)) - deps = append(deps, mustRun(psg.GetSQLVertexId(pol), diffTypeDelete).before(buildColumnVertexId(c.Name), diffTypeAddAlter)) + deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildColumnVertexId(c.Name, diffTypeDelete))) + deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildColumnVertexId(c.Name, diffTypeAddAlter))) } return deps, nil diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go index 527a58b..cf786ab 100644 --- a/pkg/diff/sql_generator.go +++ b/pkg/diff/sql_generator.go @@ -598,11 +598,12 @@ func (schemaSQLGenerator) Alter(diff schemaDiff) ([]Statement, error) { return nil, fmt.Errorf("resolving trigger diff: %w", err) } partialGraph = concatPartialGraphs(partialGraph, triggersPartialGraph) - graph, err := graphFromPartials(partialGraph) + sqlGraph, err := graphFromPartials(partialGraph) if err != nil { return nil, fmt.Errorf("converting to graph: %w", err) } - graphStatements, err := graph.toOrderedStatements() + + graphStatements, err := sqlGraph.toOrderedStatements() if err != nil { return nil, fmt.Errorf("getting ordered statements: %w", err) } @@ -1039,18 +1040,22 @@ func replicaIdentityAlterType(identity schema.ReplicaIdentity) (string, error) { return "", fmt.Errorf("unknown/unsupported replica identity %s: %w", identity, ErrNotImplemented) } -func (t *tableSQLVertexGenerator) GetSQLVertexId(table schema.Table) string { - return buildTableVertexId(table.SchemaQualifiedName) +func (t *tableSQLVertexGenerator) GetSQLVertexId(table schema.Table, diffType diffType) sqlVertexId { + return buildTableVertexId(table.SchemaQualifiedName, diffType) +} + +func buildTableVertexId(name schema.SchemaQualifiedName, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("table", name.GetFQEscapedName(), diffType) } func (t *tableSQLVertexGenerator) GetAddAlterDependencies(table, _ schema.Table) ([]dependency, error) { deps := []dependency{ - mustRun(t.GetSQLVertexId(table), diffTypeAddAlter).after(t.GetSQLVertexId(table), diffTypeDelete), + mustRun(t.GetSQLVertexId(table, diffTypeAddAlter)).after(t.GetSQLVertexId(table, diffTypeDelete)), } if table.ParentTable != nil { deps = append(deps, - mustRun(t.GetSQLVertexId(table), diffTypeAddAlter).after(buildTableVertexId(*table.ParentTable), diffTypeAddAlter), + mustRun(t.GetSQLVertexId(table, diffTypeAddAlter)).after(buildTableVertexId(*table.ParentTable, diffTypeAddAlter)), ) } return deps, nil @@ -1118,7 +1123,7 @@ func (t *tableSQLVertexGenerator) GetDeleteDependencies(table schema.Table) ([]d var deps []dependency if table.ParentTable != nil { deps = append(deps, - mustRun(t.GetSQLVertexId(table), diffTypeDelete).after(buildTableVertexId(*table.ParentTable), diffTypeDelete), + mustRun(t.GetSQLVertexId(table, diffTypeDelete)).after(buildTableVertexId(*table.ParentTable, diffTypeDelete)), ) } return deps, nil @@ -1370,17 +1375,17 @@ func (csg *columnSQLVertexGenerator) alterColumnPrefix(col schema.Column) string return fmt.Sprintf("%s ALTER COLUMN %s", alterTablePrefix(csg.tableName), schema.EscapeIdentifier(col.Name)) } -func (csg *columnSQLVertexGenerator) GetSQLVertexId(column schema.Column) string { - return buildColumnVertexId(column.Name) +func (csg *columnSQLVertexGenerator) GetSQLVertexId(column schema.Column, diffType diffType) sqlVertexId { + return buildColumnVertexId(column.Name, diffType) } -func buildColumnVertexId(columnName string) string { - return buildVertexId("column", columnName) +func buildColumnVertexId(columnName string, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("column", columnName, diffType) } func (csg *columnSQLVertexGenerator) GetAddAlterDependencies(col, _ schema.Column) ([]dependency, error) { return []dependency{ - mustRun(csg.GetSQLVertexId(col), diffTypeDelete).before(csg.GetSQLVertexId(col), diffTypeAddAlter), + mustRun(csg.GetSQLVertexId(col, diffTypeDelete)).before(csg.GetSQLVertexId(col, diffTypeAddAlter)), }, nil } @@ -1474,8 +1479,8 @@ func (rsg *renameConflictingIndexSQLVertexGenerator) Alter(_ indexDiff) ([]State return nil, nil } -func (*renameConflictingIndexSQLVertexGenerator) GetSQLVertexId(index schema.Index) string { - return buildRenameConflictingIndexVertexId(index.GetSchemaQualifiedName()) +func (*renameConflictingIndexSQLVertexGenerator) GetSQLVertexId(index schema.Index, diffType diffType) sqlVertexId { + return buildRenameConflictingIndexVertexId(index.GetSchemaQualifiedName(), diffType) } func (rsg *renameConflictingIndexSQLVertexGenerator) GetAddAlterDependencies(_, _ schema.Index) ([]dependency, error) { @@ -1486,8 +1491,8 @@ func (rsg *renameConflictingIndexSQLVertexGenerator) GetDeleteDependencies(_ sch return nil, nil } -func buildRenameConflictingIndexVertexId(indexName schema.SchemaQualifiedName) string { - return buildVertexId("indexrename", indexName.GetName()) +func buildRenameConflictingIndexVertexId(indexName schema.SchemaQualifiedName, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("indexrename", indexName.GetName(), diffType) } type indexSQLVertexGenerator struct { @@ -1738,21 +1743,25 @@ func buildAttachIndex(index schema.Index) Statement { } } -func (*indexSQLVertexGenerator) GetSQLVertexId(index schema.Index) string { - return buildIndexVertexId(index.GetSchemaQualifiedName()) +func (*indexSQLVertexGenerator) GetSQLVertexId(index schema.Index, diffType diffType) sqlVertexId { + return buildIndexVertexId(index.GetSchemaQualifiedName(), diffType) +} + +func buildIndexVertexId(name schema.SchemaQualifiedName, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("index", name.GetFQEscapedName(), diffType) } func (isg *indexSQLVertexGenerator) GetAddAlterDependencies(index, _ schema.Index) ([]dependency, error) { dependencies := []dependency{ - mustRun(isg.GetSQLVertexId(index), diffTypeAddAlter).after(buildTableVertexId(index.OwningTable), diffTypeAddAlter), + mustRun(isg.GetSQLVertexId(index, diffTypeAddAlter)).after(buildTableVertexId(index.OwningTable, diffTypeAddAlter)), // To allow for online changes to indexes, rename the older version of the index (if it exists) before the new version is added - mustRun(isg.GetSQLVertexId(index), diffTypeAddAlter).after(buildRenameConflictingIndexVertexId(index.GetSchemaQualifiedName()), diffTypeAddAlter), + mustRun(isg.GetSQLVertexId(index, diffTypeAddAlter)).after(buildRenameConflictingIndexVertexId(index.GetSchemaQualifiedName(), diffTypeAddAlter)), } if index.ParentIdx != nil { // Partitions of indexes must be created after the parent index is created dependencies = append(dependencies, - mustRun(isg.GetSQLVertexId(index), diffTypeAddAlter).after(buildIndexVertexId(*index.ParentIdx), diffTypeAddAlter)) + mustRun(isg.GetSQLVertexId(index, diffTypeAddAlter)).after(buildIndexVertexId(*index.ParentIdx, diffTypeAddAlter))) } return dependencies, nil @@ -1760,16 +1769,16 @@ func (isg *indexSQLVertexGenerator) GetAddAlterDependencies(index, _ schema.Inde func (isg *indexSQLVertexGenerator) GetDeleteDependencies(index schema.Index) ([]dependency, error) { dependencies := []dependency{ - mustRun(isg.GetSQLVertexId(index), diffTypeDelete).after(buildTableVertexId(index.OwningTable), diffTypeDelete), + mustRun(isg.GetSQLVertexId(index, diffTypeDelete)).after(buildTableVertexId(index.OwningTable, diffTypeDelete)), // Drop the index after it has been potentially renamed - mustRun(isg.GetSQLVertexId(index), diffTypeDelete).after(buildRenameConflictingIndexVertexId(index.GetSchemaQualifiedName()), diffTypeAddAlter), + mustRun(isg.GetSQLVertexId(index, diffTypeDelete)).after(buildRenameConflictingIndexVertexId(index.GetSchemaQualifiedName(), diffTypeAddAlter)), } if index.ParentIdx != nil { // Since dropping the parent index will cause the partition of the index to drop, the parent drop should come // before dependencies = append(dependencies, - mustRun(isg.GetSQLVertexId(index), diffTypeDelete).after(buildIndexVertexId(*index.ParentIdx), diffTypeDelete)) + mustRun(isg.GetSQLVertexId(index, diffTypeDelete)).after(buildIndexVertexId(*index.ParentIdx, diffTypeDelete))) } dependencies = append(dependencies, isg.addDepsOnTableAddAlterIfNecessary(index)...) @@ -1787,14 +1796,14 @@ func (isg *indexSQLVertexGenerator) addDepsOnTableAddAlterIfNecessary(index sche // These dependencies will force the index deletion statement to come before the table AddAlter addAlterColumnDeps := []dependency{ - mustRun(isg.GetSQLVertexId(index), diffTypeDelete).before(buildTableVertexId(index.OwningTable), diffTypeAddAlter), + mustRun(isg.GetSQLVertexId(index, diffTypeDelete)).before(buildTableVertexId(index.OwningTable, diffTypeAddAlter)), } if parentTable.ParentTable != nil { // If the table is partitioned, columns modifications occur on the base table not the children. Thus, we // need the dependency to also be on the parent table add/alter statements addAlterColumnDeps = append( addAlterColumnDeps, - mustRun(isg.GetSQLVertexId(index), diffTypeDelete).before(buildTableVertexId(*parentTable.ParentTable), diffTypeAddAlter), + mustRun(isg.GetSQLVertexId(index, diffTypeDelete)).before(buildTableVertexId(*parentTable.ParentTable, diffTypeAddAlter)), ) } @@ -1914,13 +1923,13 @@ func (csg *checkConstraintSQLVertexGenerator) Alter(diff checkConstraintDiff) ([ return stmts, nil } -func (*checkConstraintSQLVertexGenerator) GetSQLVertexId(con schema.CheckConstraint) string { - return buildVertexId("checkconstraint", con.Name) +func (*checkConstraintSQLVertexGenerator) GetSQLVertexId(con schema.CheckConstraint, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("checkconstraint", con.Name, diffType) } func (csg *checkConstraintSQLVertexGenerator) GetAddAlterDependencies(con, _ schema.CheckConstraint) ([]dependency, error) { deps := []dependency{ - mustRun(csg.GetSQLVertexId(con), diffTypeDelete).before(csg.GetSQLVertexId(con), diffTypeAddAlter), + mustRun(csg.GetSQLVertexId(con, diffTypeDelete)).before(csg.GetSQLVertexId(con, diffTypeAddAlter)), } targetColumns, err := getTargetColumns(con.KeyColumns, csg.newSchemaColumnsByName) @@ -1939,10 +1948,10 @@ func (csg *checkConstraintSQLVertexGenerator) GetAddAlterDependencies(con, _ sch if isOnValidNotNullPreExistingColumn { // If the NOT NULL check constraint is on a pre-existing column, then we should ensure it is added before // the column alter. - deps = append(deps, mustRun(csg.GetSQLVertexId(con), diffTypeAddAlter).before(buildColumnVertexId(targetColumns[0].Name), diffTypeAddAlter)) + deps = append(deps, mustRun(csg.GetSQLVertexId(con, diffTypeAddAlter)).before(buildColumnVertexId(targetColumns[0].Name, diffTypeAddAlter))) } else { for _, tc := range targetColumns { - deps = append(deps, mustRun(csg.GetSQLVertexId(con), diffTypeAddAlter).after(buildColumnVertexId(tc.Name), diffTypeAddAlter)) + deps = append(deps, mustRun(csg.GetSQLVertexId(con, diffTypeAddAlter)).after(buildColumnVertexId(tc.Name, diffTypeAddAlter))) } } return deps, nil @@ -1968,16 +1977,16 @@ func (csg *checkConstraintSQLVertexGenerator) GetDeleteDependencies(con schema.C tc := targetColumns[0] if _, ok := csg.deletedColumnsByName[tc.Name]; ok { // If the column is being deleted, we should drop the not null check constraint before the column is deleted. - deps = append(deps, mustRun(csg.GetSQLVertexId(con), diffTypeDelete).before(buildColumnVertexId(tc.Name), diffTypeDelete)) + deps = append(deps, mustRun(csg.GetSQLVertexId(con, diffTypeDelete)).before(buildColumnVertexId(tc.Name, diffTypeDelete))) } else { // Otherwise, we should drop the not null check constraint after the column is altered. This dependency // doesn't need to be explicitly, since our topological sort prioritizes adds/alters over deletes. Nevertheless, // we'll add it for clarity and to ensure that an error is returned if the delete is not placed after the alter. - deps = append(deps, mustRun(csg.GetSQLVertexId(con), diffTypeDelete).after(buildColumnVertexId(tc.Name), diffTypeAddAlter)) + deps = append(deps, mustRun(csg.GetSQLVertexId(con, diffTypeDelete)).after(buildColumnVertexId(tc.Name, diffTypeAddAlter))) } } else { for _, tc := range targetColumns { - deps = append(deps, mustRun(csg.GetSQLVertexId(con), diffTypeDelete).before(buildColumnVertexId(tc.Name), diffTypeAddAlter)) + deps = append(deps, mustRun(csg.GetSQLVertexId(con, diffTypeDelete)).before(buildColumnVertexId(tc.Name, diffTypeAddAlter))) // This is a weird quirk of our graph system, where if a -> b and b -> c and b does-not-exist, b will be // implicitly created s.t. a -> b -> c (https://github.com/stripe/pg-schema-diff/issues/84) // @@ -1985,7 +1994,7 @@ func (csg *checkConstraintSQLVertexGenerator) GetDeleteDependencies(con schema.C // the column, and "c" is the alter/addition of the column. We do not want this behavior. We only want // a -> b -> c iff the column is being deleted. if _, ok := csg.deletedColumnsByName[tc.Name]; ok { - deps = append(deps, mustRun(csg.GetSQLVertexId(con), diffTypeDelete).before(buildColumnVertexId(tc.Name), diffTypeDelete)) + deps = append(deps, mustRun(csg.GetSQLVertexId(con, diffTypeDelete)).before(buildColumnVertexId(tc.Name, diffTypeDelete))) } } } @@ -2048,8 +2057,8 @@ func (*attachPartitionSQLVertexGenerator) Delete(_ schema.Table) ([]Statement, e return nil, nil } -func (*attachPartitionSQLVertexGenerator) GetSQLVertexId(table schema.Table) string { - return fmt.Sprintf("attachpartition_%s", table.GetName()) +func (*attachPartitionSQLVertexGenerator) GetSQLVertexId(table schema.Table, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("attachpartition", table.GetName(), diffType) } func (a *attachPartitionSQLVertexGenerator) GetAddAlterDependencies(table, old schema.Table) ([]dependency, error) { @@ -2059,7 +2068,7 @@ func (a *attachPartitionSQLVertexGenerator) GetAddAlterDependencies(table, old s } deps := []dependency{ - mustRun(a.GetSQLVertexId(table), diffTypeAddAlter).after(buildTableVertexId(table.SchemaQualifiedName), diffTypeAddAlter), + mustRun(a.GetSQLVertexId(table, diffTypeAddAlter)).after(buildTableVertexId(table.SchemaQualifiedName, diffTypeAddAlter)), } if _, baseTableIsNew := a.addedTablesByName[table.ParentTable.GetName()]; baseTableIsNew { @@ -2068,14 +2077,14 @@ func (a *attachPartitionSQLVertexGenerator) GetAddAlterDependencies(table, old s // have the PK (this is useful when creating the fresh database schema for migration validation) // If we attach the partition after the index is built, the index will be automatically built by Postgres for _, idx := range a.indexesInNewSchemaByTableName[table.ParentTable.GetName()] { - deps = append(deps, mustRun(a.GetSQLVertexId(table), diffTypeAddAlter).before(buildIndexVertexId(idx.GetSchemaQualifiedName()), diffTypeAddAlter)) + deps = append(deps, mustRun(a.GetSQLVertexId(table, diffTypeAddAlter)).before(buildIndexVertexId(idx.GetSchemaQualifiedName(), diffTypeAddAlter))) } return deps, nil } a.isPartitionAttachedAfterIdxBuildsByTableName[table.GetName()] = true for _, idx := range a.indexesInNewSchemaByTableName[table.GetName()] { - deps = append(deps, mustRun(a.GetSQLVertexId(table), diffTypeAddAlter).after(buildIndexVertexId(idx.GetSchemaQualifiedName()), diffTypeAddAlter)) + deps = append(deps, mustRun(a.GetSQLVertexId(table, diffTypeAddAlter)).after(buildIndexVertexId(idx.GetSchemaQualifiedName(), diffTypeAddAlter))) } return deps, nil } @@ -2221,25 +2230,25 @@ func (f *foreignKeyConstraintSQLVertexGenerator) Alter(diff foreignKeyConstraint return stmts, nil } -func (*foreignKeyConstraintSQLVertexGenerator) GetSQLVertexId(con schema.ForeignKeyConstraint) string { - return buildVertexId("fkconstraint", con.GetName()) +func (*foreignKeyConstraintSQLVertexGenerator) GetSQLVertexId(con schema.ForeignKeyConstraint, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("fkconstraint", con.GetName(), diffType) } func (f *foreignKeyConstraintSQLVertexGenerator) GetAddAlterDependencies(con, _ schema.ForeignKeyConstraint) ([]dependency, error) { deps := []dependency{ - mustRun(f.GetSQLVertexId(con), diffTypeAddAlter).after(f.GetSQLVertexId(con), diffTypeDelete), - mustRun(f.GetSQLVertexId(con), diffTypeAddAlter).after(buildTableVertexId(con.OwningTable), diffTypeAddAlter), - mustRun(f.GetSQLVertexId(con), diffTypeAddAlter).after(buildTableVertexId(con.ForeignTable), diffTypeAddAlter), + mustRun(f.GetSQLVertexId(con, diffTypeAddAlter)).after(f.GetSQLVertexId(con, diffTypeDelete)), + mustRun(f.GetSQLVertexId(con, diffTypeAddAlter)).after(buildTableVertexId(con.OwningTable, diffTypeAddAlter)), + mustRun(f.GetSQLVertexId(con, diffTypeAddAlter)).after(buildTableVertexId(con.ForeignTable, diffTypeAddAlter)), } // This is the slightly lazy way of ensuring the foreign key constraint is added after the requisite index is // built and marked as valid. // We __could__ do this just for the index the fk depends on, but that's slightly more wiring than we need right now // because of partitioned indexes, which are only valid when all child indexes have been built for _, i := range f.indexesInNewSchemaByTableName[con.ForeignTable.GetName()] { - deps = append(deps, mustRun(f.GetSQLVertexId(con), diffTypeAddAlter).after(buildIndexVertexId(i.GetSchemaQualifiedName()), diffTypeAddAlter)) + deps = append(deps, mustRun(f.GetSQLVertexId(con, diffTypeAddAlter)).after(buildIndexVertexId(i.GetSchemaQualifiedName(), diffTypeAddAlter))) // Build a dependency on any child index if the index is partitioned for _, c := range f.childrenInNewSchemaByPartitionedIndexName[i.GetName()] { - deps = append(deps, mustRun(f.GetSQLVertexId(con), diffTypeAddAlter).after(buildIndexVertexId(c.GetSchemaQualifiedName()), diffTypeAddAlter)) + deps = append(deps, mustRun(f.GetSQLVertexId(con, diffTypeAddAlter)).after(buildIndexVertexId(c.GetSchemaQualifiedName(), diffTypeAddAlter))) } } @@ -2248,17 +2257,17 @@ func (f *foreignKeyConstraintSQLVertexGenerator) GetAddAlterDependencies(con, _ func (f *foreignKeyConstraintSQLVertexGenerator) GetDeleteDependencies(con schema.ForeignKeyConstraint) ([]dependency, error) { deps := []dependency{ - mustRun(f.GetSQLVertexId(con), diffTypeDelete).before(buildTableVertexId(con.OwningTable), diffTypeDelete), - mustRun(f.GetSQLVertexId(con), diffTypeDelete).before(buildTableVertexId(con.ForeignTable), diffTypeDelete), + mustRun(f.GetSQLVertexId(con, diffTypeDelete)).before(buildTableVertexId(con.OwningTable, diffTypeDelete)), + mustRun(f.GetSQLVertexId(con, diffTypeDelete)).before(buildTableVertexId(con.ForeignTable, diffTypeDelete)), } // This is the slightly lazy way of ensuring the foreign key constraint is deleted before the index it depends on is deleted // We __could__ do this just for the index the fk depends on, but that's slightly more wiring than we need right now // because of partitioned indexes, which are only valid when all child indexes have been built for _, i := range f.indexInOldSchemaByTableName[con.ForeignTable.GetName()] { - deps = append(deps, mustRun(f.GetSQLVertexId(con), diffTypeDelete).before(buildIndexVertexId(i.GetSchemaQualifiedName()), diffTypeDelete)) + deps = append(deps, mustRun(f.GetSQLVertexId(con, diffTypeDelete)).before(buildIndexVertexId(i.GetSchemaQualifiedName(), diffTypeDelete))) // Build a dependency on any child index if the index is partitioned for _, c := range f.childrenInOldSchemaByPartitionedIndexName[i.GetName()] { - deps = append(deps, mustRun(f.GetSQLVertexId(con), diffTypeDelete).before(buildIndexVertexId(c.GetSchemaQualifiedName()), diffTypeDelete)) + deps = append(deps, mustRun(f.GetSQLVertexId(con, diffTypeDelete)).before(buildIndexVertexId(c.GetSchemaQualifiedName(), diffTypeDelete))) } } return deps, nil @@ -2355,17 +2364,21 @@ func (s *sequenceSQLVertexGenerator) buildAddAlterSequenceStatement(seq schema.S } } -func (s *sequenceSQLVertexGenerator) GetSQLVertexId(seq schema.Sequence) string { - return buildSequenceVertexId(seq.SchemaQualifiedName) +func (s *sequenceSQLVertexGenerator) GetSQLVertexId(seq schema.Sequence, diffType diffType) sqlVertexId { + return buildSequenceVertexId(seq.SchemaQualifiedName, diffType) +} + +func buildSequenceVertexId(name schema.SchemaQualifiedName, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("sequence", name.GetFQEscapedName(), diffType) } func (s *sequenceSQLVertexGenerator) GetAddAlterDependencies(new schema.Sequence, _ schema.Sequence) ([]dependency, error) { deps := []dependency{ - mustRun(s.GetSQLVertexId(new), diffTypeAddAlter).after(s.GetSQLVertexId(new), diffTypeDelete), + mustRun(s.GetSQLVertexId(new, diffTypeAddAlter)).after(s.GetSQLVertexId(new, diffTypeDelete)), } if new.Owner != nil { // Sequences should be added/altered before the table they are owned by - deps = append(deps, mustRun(s.GetSQLVertexId(new), diffTypeAddAlter).before(buildTableVertexId(new.Owner.TableName), diffTypeAddAlter)) + deps = append(deps, mustRun(s.GetSQLVertexId(new, diffTypeAddAlter)).before(buildTableVertexId(new.Owner.TableName, diffTypeAddAlter))) } return deps, nil } @@ -2379,7 +2392,7 @@ func (s *sequenceSQLVertexGenerator) GetDeleteDependencies(seq schema.Sequence) // old owner column delete (equivalent to add/alter) and the sequence add/alter. We can get away with this because // we, so far, no columns are ever "re-created". If we ever do support that, we'll need to revisit this. if seq.Owner != nil { - deps = append(deps, mustRun(s.GetSQLVertexId(seq), diffTypeDelete).after(buildTableVertexId(seq.Owner.TableName), diffTypeDelete)) + deps = append(deps, mustRun(s.GetSQLVertexId(seq, diffTypeDelete)).after(buildTableVertexId(seq.Owner.TableName, diffTypeDelete))) } return deps, nil } @@ -2402,10 +2415,6 @@ func (s *sequenceSQLVertexGenerator) isDeletedWithColumns(seq schema.Sequence) b return false } -func buildSequenceVertexId(name schema.SchemaQualifiedName) string { - return buildVertexId("sequence", name.GetFQEscapedName()) -} - type sequenceOwnershipSQLVertexGenerator struct{} func (s sequenceOwnershipSQLVertexGenerator) Add(seq schema.Sequence) ([]Statement, error) { @@ -2440,8 +2449,8 @@ func (s sequenceOwnershipSQLVertexGenerator) buildAlterOwnershipStmt(new schema. } } -func (s sequenceOwnershipSQLVertexGenerator) GetSQLVertexId(seq schema.Sequence) string { - return fmt.Sprintf("%s-ownership", buildSequenceVertexId(seq.SchemaQualifiedName)) +func (s sequenceOwnershipSQLVertexGenerator) GetSQLVertexId(seq schema.Sequence, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("sequence_ownership", seq.SchemaQualifiedName.GetFQEscapedName(), diffType) } func (s sequenceOwnershipSQLVertexGenerator) GetAddAlterDependencies(new schema.Sequence, old schema.Sequence) ([]dependency, error) { @@ -2451,17 +2460,17 @@ func (s sequenceOwnershipSQLVertexGenerator) GetAddAlterDependencies(new schema. deps := []dependency{ // Always change ownership after the sequence has been added/altered - mustRun(s.GetSQLVertexId(new), diffTypeAddAlter).after(buildSequenceVertexId(new.SchemaQualifiedName), diffTypeAddAlter), + mustRun(s.GetSQLVertexId(new, diffTypeAddAlter)).after(buildSequenceVertexId(new.SchemaQualifiedName, diffTypeAddAlter)), } if old.Owner != nil { // Always update ownership before the old owner has been deleted - deps = append(deps, mustRun(s.GetSQLVertexId(new), diffTypeAddAlter).before(buildTableVertexId(old.Owner.TableName), diffTypeDelete)) + deps = append(deps, mustRun(s.GetSQLVertexId(new, diffTypeAddAlter)).before(buildTableVertexId(old.Owner.TableName, diffTypeDelete))) } if new.Owner != nil { // Always update ownership after the new owner has been created - deps = append(deps, mustRun(s.GetSQLVertexId(new), diffTypeAddAlter).after(buildTableVertexId(new.Owner.TableName), diffTypeAddAlter)) + deps = append(deps, mustRun(s.GetSQLVertexId(new, diffTypeAddAlter)).after(buildTableVertexId(new.Owner.TableName, diffTypeAddAlter))) } return deps, nil @@ -2591,8 +2600,12 @@ func canFunctionDependenciesBeTracked(function schema.Function) bool { return function.Language == "sql" } -func (f *functionSQLVertexGenerator) GetSQLVertexId(function schema.Function) string { - return buildFunctionVertexId(function.SchemaQualifiedName) +func (f *functionSQLVertexGenerator) GetSQLVertexId(function schema.Function, diffType diffType) sqlVertexId { + return buildFunctionVertexId(function.SchemaQualifiedName, diffType) +} + +func buildFunctionVertexId(name schema.SchemaQualifiedName, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("function", name.GetFQEscapedName(), diffType) } func (f *functionSQLVertexGenerator) GetAddAlterDependencies(newFunction, oldFunction schema.Function) ([]dependency, error) { @@ -2601,7 +2614,7 @@ func (f *functionSQLVertexGenerator) GetAddAlterDependencies(newFunction, oldFun // because there won't be one if it is being added/altered var deps []dependency for _, depFunction := range newFunction.DependsOnFunctions { - deps = append(deps, mustRun(f.GetSQLVertexId(newFunction), diffTypeAddAlter).after(buildFunctionVertexId(depFunction), diffTypeAddAlter)) + deps = append(deps, mustRun(f.GetSQLVertexId(newFunction, diffTypeAddAlter)).after(buildFunctionVertexId(depFunction, diffTypeAddAlter))) } if !cmp.Equal(oldFunction, schema.Function{}) { @@ -2609,7 +2622,7 @@ func (f *functionSQLVertexGenerator) GetAddAlterDependencies(newFunction, oldFun // If the old version of the function calls other functions that are being deleted come, those deletions // must come after the function is altered, so it is no longer dependent on those dropped functions for _, depFunction := range oldFunction.DependsOnFunctions { - deps = append(deps, mustRun(f.GetSQLVertexId(newFunction), diffTypeAddAlter).before(buildFunctionVertexId(depFunction), diffTypeDelete)) + deps = append(deps, mustRun(f.GetSQLVertexId(newFunction, diffTypeAddAlter)).before(buildFunctionVertexId(depFunction, diffTypeDelete))) } } @@ -2619,15 +2632,11 @@ func (f *functionSQLVertexGenerator) GetAddAlterDependencies(newFunction, oldFun func (f *functionSQLVertexGenerator) GetDeleteDependencies(function schema.Function) ([]dependency, error) { var deps []dependency for _, depFunction := range function.DependsOnFunctions { - deps = append(deps, mustRun(f.GetSQLVertexId(function), diffTypeDelete).before(buildFunctionVertexId(depFunction), diffTypeDelete)) + deps = append(deps, mustRun(f.GetSQLVertexId(function, diffTypeDelete)).before(buildFunctionVertexId(depFunction, diffTypeDelete))) } return deps, nil } -func buildFunctionVertexId(name schema.SchemaQualifiedName) string { - return buildVertexId("function", name.GetFQEscapedName()) -} - type triggerSQLVertexGenerator struct { // functionsInNewSchemaByName is a map of function new to functions in the new schema. // These functions are not necessarily new @@ -2666,8 +2675,8 @@ func (t *triggerSQLVertexGenerator) Alter(diff triggerDiff) ([]Statement, error) }}, nil } -func (t *triggerSQLVertexGenerator) GetSQLVertexId(trigger schema.Trigger) string { - return buildVertexId("trigger", trigger.GetName()) +func (t *triggerSQLVertexGenerator) GetSQLVertexId(trigger schema.Trigger, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("trigger", trigger.GetName(), diffType) } func (t *triggerSQLVertexGenerator) GetAddAlterDependencies(newTrigger, oldTrigger schema.Trigger) ([]dependency, error) { @@ -2675,8 +2684,8 @@ func (t *triggerSQLVertexGenerator) GetAddAlterDependencies(newTrigger, oldTrigg // added and dropped in the same migration. Thus, we don't need a dependency on the delete node of a function // because there won't be one if it is being added/altered deps := []dependency{ - mustRun(t.GetSQLVertexId(newTrigger), diffTypeAddAlter).after(buildFunctionVertexId(newTrigger.Function), diffTypeAddAlter), - mustRun(t.GetSQLVertexId(newTrigger), diffTypeAddAlter).after(buildTableVertexId(newTrigger.OwningTable), diffTypeAddAlter), + mustRun(t.GetSQLVertexId(newTrigger, diffTypeAddAlter)).after(buildFunctionVertexId(newTrigger.Function, diffTypeAddAlter)), + mustRun(t.GetSQLVertexId(newTrigger, diffTypeAddAlter)).after(buildTableVertexId(newTrigger.OwningTable, diffTypeAddAlter)), } if !cmp.Equal(oldTrigger, schema.Trigger{}) { @@ -2684,7 +2693,7 @@ func (t *triggerSQLVertexGenerator) GetAddAlterDependencies(newTrigger, oldTrigg // If the old version of the trigger called a function being deleted, the function deletion must come after the // trigger is altered, so the trigger no longer has a dependency on the function deps = append(deps, - mustRun(t.GetSQLVertexId(newTrigger), diffTypeAddAlter).before(buildFunctionVertexId(oldTrigger.Function), diffTypeDelete), + mustRun(t.GetSQLVertexId(newTrigger, diffTypeAddAlter)).before(buildFunctionVertexId(oldTrigger.Function, diffTypeDelete)), ) } @@ -2693,15 +2702,11 @@ func (t *triggerSQLVertexGenerator) GetAddAlterDependencies(newTrigger, oldTrigg func (t *triggerSQLVertexGenerator) GetDeleteDependencies(trigger schema.Trigger) ([]dependency, error) { return []dependency{ - mustRun(t.GetSQLVertexId(trigger), diffTypeDelete).before(buildFunctionVertexId(trigger.Function), diffTypeDelete), - mustRun(t.GetSQLVertexId(trigger), diffTypeDelete).before(buildTableVertexId(trigger.OwningTable), diffTypeDelete), + mustRun(t.GetSQLVertexId(trigger, diffTypeDelete)).before(buildFunctionVertexId(trigger.Function, diffTypeDelete)), + mustRun(t.GetSQLVertexId(trigger, diffTypeDelete)).before(buildTableVertexId(trigger.OwningTable, diffTypeDelete)), }, nil } -func buildVertexId(objType string, id string) string { - return fmt.Sprintf("%s_%s", objType, id) -} - func stripMigrationHazards(stmts ...Statement) []Statement { var noHazardsStmts []Statement for _, stmt := range stmts { diff --git a/pkg/diff/sql_graph.go b/pkg/diff/sql_graph.go index 5c2e99d..2951cbb 100644 --- a/pkg/diff/sql_graph.go +++ b/pkg/diff/sql_graph.go @@ -4,28 +4,66 @@ import ( "fmt" "github.com/stripe/pg-schema-diff/internal/graph" - "github.com/stripe/pg-schema-diff/internal/schema" ) -type sqlVertex struct { - // todo(bplunkett) - rename to statement id? - // these ids will need to be globally unique - ObjId string - Statements []Statement - // todo(bplunkett) - this should not affect the id - DiffType diffType +// sqlVertexId is an interface for a vertex id in the SQL graph +type sqlVertexId interface { + fmt.Stringer } -func (s sqlVertex) GetId() string { - return fmt.Sprintf("%s_%s", s.DiffType, s.ObjId) +// sqlPriority is an enum for the priority of a statement in the SQL graph, i.e., whether it should be run sooner +// or later in the topological sort of the graph +type sqlPriority int + +const ( + // Indicates a statement should run as soon as possible + sqlPrioritySooner sqlPriority = 1 + // sqlPriorityUnset is the default priority for a statement + sqlPriorityUnset sqlPriority = 0 + // Indicates a statement should run as late as possible + sqlPriorityLater sqlPriority = -1 +) + +// schemaObjSqlVertexId is a vertex id for a standard schema object node, i.e., indicating the creation/deletiion +// of a schema object. It slots into the legacySqlVertexGenerator system. +type schemaObjSqlVertexId struct { + objType string + objId string + diffType diffType +} + +func buildSchemaObjVertexId(objType string, id string, diffType diffType) sqlVertexId { + return schemaObjSqlVertexId{ + objType: objType, + objId: id, + diffType: diffType, + } +} + +func (s schemaObjSqlVertexId) String() string { + return fmt.Sprintf("%s:%s:%s", s.objType, s.objId, s.diffType) } -func buildTableVertexId(name schema.SchemaQualifiedName) string { - return fmt.Sprintf("table_%s", name) +type sqlVertex struct { + // id is used to identify the sql vertex + id sqlVertexId + + // priority is used to determine if the sql vertex should be included sooner or later in the topological + // sort of the graph + priority sqlPriority + + // statements is the set of statements to run for this vertex + statements []Statement } -func buildIndexVertexId(name schema.SchemaQualifiedName) string { - return fmt.Sprintf("index_%s", name) +func (s sqlVertex) GetId() string { + return s.id.String() +} + +func (s sqlVertex) GetPriority() int { + // Prioritize adds/alters over deletes. Weight by number of statements. A 0 statement delete should be + // prioritized over a 1 statement delete + return len(s.statements) * int(s.priority) } // dependency indicates an edge between the SQL to resolve a diff for a source schema object and the SQL to resolve @@ -35,42 +73,31 @@ func buildIndexVertexId(name schema.SchemaQualifiedName) string { // These nodes will almost always be present in the sqlGraph even if the schema object is not being deleted (or added/altered). // If a node is present for a schema object where the "diffType" is NOT occurring, it will just be a no-op (no SQl statements) type dependency struct { - sourceObjId string - sourceType diffType - - targetObjId string - targetType diffType + source sqlVertexId + target sqlVertexId } type dependencyBuilder struct { - valObjId string - valType diffType + base sqlVertexId } -func mustRun(schemaObjId string, schemaDiffType diffType) dependencyBuilder { +func mustRun(id sqlVertexId) dependencyBuilder { return dependencyBuilder{ - valObjId: schemaObjId, - valType: schemaDiffType, + base: id, } } -func (d dependencyBuilder) before(valObjId string, valType diffType) dependency { +func (d dependencyBuilder) before(id sqlVertexId) dependency { return dependency{ - sourceType: d.valType, - sourceObjId: d.valObjId, - - targetType: valType, - targetObjId: valObjId, + source: d.base, + target: id, } } -func (d dependencyBuilder) after(valObjId string, valType diffType) dependency { +func (d dependencyBuilder) after(id sqlVertexId) dependency { return dependency{ - sourceObjId: valObjId, - sourceType: valType, - - targetObjId: d.valObjId, - targetType: d.valType, + source: id, + target: d.base, } } @@ -86,23 +113,15 @@ func newSqlGraph() *sqlGraph { } func (s *sqlGraph) toOrderedStatements() ([]Statement, error) { - vertices, err := s.TopologicallySortWithPriority(graph.IsLowerPriorityFromGetPriority( - func(vertex sqlVertex) int { - multiplier := 1 - if vertex.DiffType == diffTypeDelete { - multiplier = -1 - } - // Prioritize adds/alters over deletes. Weight by number of statements. A 0 statement delete should be - // prioritized over a 1 statement delete - return len(vertex.Statements) * multiplier - }), - ) + vertices, err := s.TopologicallySortWithPriority(graph.IsLowerPriorityFromGetPriority(func(v sqlVertex) int { + return v.GetPriority() + })) if err != nil { return nil, fmt.Errorf("topologically sorting graph: %w", err) } var stmts []Statement for _, v := range vertices { - stmts = append(stmts, v.Statements...) + stmts = append(stmts, v.statements...) } return stmts, nil } diff --git a/pkg/diff/sql_vertex_generator.go b/pkg/diff/sql_vertex_generator.go index b453825..7f39f66 100644 --- a/pkg/diff/sql_vertex_generator.go +++ b/pkg/diff/sql_vertex_generator.go @@ -24,7 +24,7 @@ type partialSQLGraph struct { func (s *partialSQLGraph) statements() []Statement { var statements []Statement for _, vertex := range s.vertices { - statements = append(statements, vertex.Statements...) + statements = append(statements, vertex.statements...) } return statements } @@ -54,14 +54,14 @@ func graphFromPartials(parts partialSQLGraph) (*sqlGraph, error) { for _, dep := range parts.dependencies { sourceVertex := sqlVertex{ - ObjId: dep.sourceObjId, - DiffType: dep.sourceType, - Statements: nil, + id: dep.source, + priority: sqlPriorityUnset, + statements: nil, } targetVertex := sqlVertex{ - ObjId: dep.targetObjId, - DiffType: dep.targetType, - Statements: nil, + id: dep.target, + priority: sqlPriorityUnset, + statements: nil, } // To maintain the correctness of the graph, we will add a dummy vertex for the missing dependencies @@ -77,10 +77,14 @@ func graphFromPartials(parts partialSQLGraph) (*sqlGraph, error) { } func mergeVertices(old, new sqlVertex) sqlVertex { + priority := old.priority + if old.priority == sqlPriorityUnset { + priority = new.priority + } return sqlVertex{ - ObjId: old.ObjId, - DiffType: old.DiffType, - Statements: append(old.Statements, new.Statements...), + id: old.id, + priority: priority, + statements: append(old.statements, new.statements...), } } @@ -132,7 +136,7 @@ func generatePartialGraph[S schema.Object, Diff diff[S]](generator sqlVertexGene type legacySqlVertexGenerator[S schema.Object, Diff diff[S]] interface { sqlGenerator[S, Diff] // GetSQLVertexId gets the canonical vertex id to represent the schema object - GetSQLVertexId(S) string + GetSQLVertexId(S, diffType) sqlVertexId // GetAddAlterDependencies gets the dependencies of the SQL generated to resolve the AddAlter diff for the // schema objects. Dependencies can be formed on any other nodes in the SQL graph, even if the node has @@ -176,9 +180,9 @@ func (s *wrappedLegacySqlVertexGenerator[S, Diff]) Add(o S) (partialSQLGraph, er return partialSQLGraph{ vertices: []sqlVertex{{ - DiffType: diffTypeAddAlter, - ObjId: s.generator.GetSQLVertexId(o), - Statements: statements, + id: s.generator.GetSQLVertexId(o, diffTypeAddAlter), + priority: sqlPrioritySooner, + statements: statements, }}, dependencies: deps, }, nil @@ -196,9 +200,9 @@ func (s *wrappedLegacySqlVertexGenerator[S, Diff]) Delete(o S) (partialSQLGraph, return partialSQLGraph{ vertices: []sqlVertex{{ - DiffType: diffTypeDelete, - ObjId: s.generator.GetSQLVertexId(o), - Statements: statements, + id: s.generator.GetSQLVertexId(o, diffTypeDelete), + priority: sqlPriorityLater, + statements: statements, }}, dependencies: deps, }, nil @@ -216,9 +220,9 @@ func (s *wrappedLegacySqlVertexGenerator[S, Diff]) Alter(d Diff) (partialSQLGrap return partialSQLGraph{ vertices: []sqlVertex{{ - DiffType: diffTypeAddAlter, - ObjId: s.generator.GetSQLVertexId(d.GetNew()), - Statements: statements, + id: s.generator.GetSQLVertexId(d.GetNew(), diffTypeAddAlter), + priority: sqlPrioritySooner, + statements: statements, }}, dependencies: deps, }, nil From 238c26749227e920a4cd48743789b37d0a7d7499 Mon Sep 17 00:00:00 2001 From: bplunkett-stripe Date: Tue, 1 Oct 2024 16:08:53 -0700 Subject: [PATCH 3/3] Suggested changes --- pkg/diff/sql_graph.go | 15 +++++++------- pkg/diff/sql_vertex_generator.go | 34 ++++++++++++++------------------ 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/pkg/diff/sql_graph.go b/pkg/diff/sql_graph.go index 2951cbb..e5bf681 100644 --- a/pkg/diff/sql_graph.go +++ b/pkg/diff/sql_graph.go @@ -12,15 +12,15 @@ type sqlVertexId interface { } // sqlPriority is an enum for the priority of a statement in the SQL graph, i.e., whether it should be run sooner -// or later in the topological sort of the graph +// or later in the topological sort of the graph. It can be thought of as a unary unit vector. type sqlPriority int const ( - // Indicates a statement should run as soon as possible + // Indicates a statement should run as soon as possible. Usually, most adds will have this priority. sqlPrioritySooner sqlPriority = 1 // sqlPriorityUnset is the default priority for a statement sqlPriorityUnset sqlPriority = 0 - // Indicates a statement should run as late as possible + // Indicates a statement should run as late as possible. Usually, most deletes will have this priority. sqlPriorityLater sqlPriority = -1 ) @@ -61,19 +61,20 @@ func (s sqlVertex) GetId() string { } func (s sqlVertex) GetPriority() int { - // Prioritize adds/alters over deletes. Weight by number of statements. A 0 statement delete should be - // prioritized over a 1 statement delete + // Weight the priority (which is just a "direction") by the number of statements return len(s.statements) * int(s.priority) } // dependency indicates an edge between the SQL to resolve a diff for a source schema object and the SQL to resolve -// the diff of a target schema object +// the diff of a target schema object. // // Most SchemaObjects will have two nodes in the SQL graph: a node for delete SQL and a node for add/alter SQL. // These nodes will almost always be present in the sqlGraph even if the schema object is not being deleted (or added/altered). // If a node is present for a schema object where the "diffType" is NOT occurring, it will just be a no-op (no SQl statements) type dependency struct { + // source must run before target source sqlVertexId + // target must run after source target sqlVertexId } @@ -101,7 +102,7 @@ func (d dependencyBuilder) after(id sqlVertexId) dependency { } } -// sqlGraph represents two dependency webs of SQL statements +// sqlGraph represents a dependency web of SQL statements type sqlGraph struct { *graph.Graph[sqlVertex] } diff --git a/pkg/diff/sql_vertex_generator.go b/pkg/diff/sql_vertex_generator.go index 7f39f66..3b54459 100644 --- a/pkg/diff/sql_vertex_generator.go +++ b/pkg/diff/sql_vertex_generator.go @@ -53,23 +53,12 @@ func graphFromPartials(parts partialSQLGraph) (*sqlGraph, error) { } for _, dep := range parts.dependencies { - sourceVertex := sqlVertex{ - id: dep.source, - priority: sqlPriorityUnset, - statements: nil, - } - targetVertex := sqlVertex{ - id: dep.target, - priority: sqlPriorityUnset, - statements: nil, - } - // To maintain the correctness of the graph, we will add a dummy vertex for the missing dependencies - addVertexIfNotExists(graph, sourceVertex) - addVertexIfNotExists(graph, targetVertex) + addVertexIfNotExists(graph, dep.source) + addVertexIfNotExists(graph, dep.target) - if err := graph.AddEdge(sourceVertex.GetId(), targetVertex.GetId()); err != nil { - return nil, fmt.Errorf("adding edge from %s to %s: %w", sourceVertex.GetId(), targetVertex.GetId(), err) + if err := graph.AddEdge(dep.source.String(), dep.target.String()); err != nil { + return nil, fmt.Errorf("adding edge from %s to %s: %w", dep.source, dep.target, err) } } @@ -78,9 +67,11 @@ func graphFromPartials(parts partialSQLGraph) (*sqlGraph, error) { func mergeVertices(old, new sqlVertex) sqlVertex { priority := old.priority - if old.priority == sqlPriorityUnset { + if new.priority != sqlPriorityUnset && (priority == sqlPriorityUnset || new.priority > priority) { + // If one is unset, use the other. If both are set, use the higher priority. priority = new.priority } + return sqlVertex{ id: old.id, priority: priority, @@ -88,9 +79,14 @@ func mergeVertices(old, new sqlVertex) sqlVertex { } } -func addVertexIfNotExists(graph *sqlGraph, vertex sqlVertex) { - if !graph.HasVertexWithId(vertex.GetId()) { - graph.AddVertex(vertex) +func addVertexIfNotExists(graph *sqlGraph, id sqlVertexId) { + if !graph.HasVertexWithId(id.String()) { + // Create a filler node + graph.AddVertex(sqlVertex{ + id: id, + priority: sqlPriorityUnset, + statements: nil, + }) } }