Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SQL generation - graph-first approach #176

Merged
merged 3 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 0 additions & 193 deletions pkg/diff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,12 @@ import (
"fmt"
"sort"

"github.com/stripe/pg-schema-diff/internal/graph"
"github.com/stripe/pg-schema-diff/internal/schema"
)

var ErrNotImplemented = fmt.Errorf("not implemented")
var errDuplicateIdentifier = fmt.Errorf("duplicate identifier")

type diffType string

const (
diffTypeDelete diffType = "DELETE"
diffTypeAddAlter diffType = "ADDALTER"
)

type (
diff[S schema.Object] interface {
GetOld() S
Expand All @@ -32,79 +24,8 @@ type (
// provided diff. Alter, e.g., with a table, might produce add/delete statements
Alter(Diff) ([]Statement, error)
}

// dependency indicates an edge between the SQL to resolve a diff for a source schema object and the SQL to resolve
// the diff of a target schema object
//
// Most SchemaObjects will have two nodes in the SQL graph: a node for delete SQL and a node for add/alter SQL.
// These nodes will almost always be present in the sqlGraph even if the schema object is not being deleted (or added/altered).
// If a node is present for a schema object where the "diffType" is NOT occurring, it will just be a no-op (no SQl statements)
dependency struct {
sourceObjId string
sourceType diffType

targetObjId string
targetType diffType
}
)

type dependencyBuilder struct {
valObjId string
valType diffType
}

func mustRun(schemaObjId string, schemaDiffType diffType) dependencyBuilder {
return dependencyBuilder{
valObjId: schemaObjId,
valType: schemaDiffType,
}
}

func (d dependencyBuilder) before(valObjId string, valType diffType) dependency {
return dependency{
sourceType: d.valType,
sourceObjId: d.valObjId,

targetType: valType,
targetObjId: valObjId,
}
}

func (d dependencyBuilder) after(valObjId string, valType diffType) dependency {
return dependency{
sourceObjId: valObjId,
sourceType: valType,

targetObjId: d.valObjId,
targetType: d.valType,
}
}

// sqlVertexGenerator is used to generate SQL statements for schema objects that have dependency webs
// with other schema objects. The schema object represents a vertex in the graph.
type sqlVertexGenerator[S schema.Object, Diff diff[S]] interface {
sqlGenerator[S, Diff]
// GetSQLVertexId gets the canonical vertex id to represent the schema object
GetSQLVertexId(S) string

// GetAddAlterDependencies gets the dependencies of the SQL generated to resolve the AddAlter diff for the
// schema objects. Dependencies can be formed on any other nodes in the SQL graph, even if the node has
// no statements. If the diff is just an add, then old will be the zero value
//
// These dependencies can also be built in reverse: the SQL returned by the sqlVertexGenerator to resolve the
// diff for the object must always be run before the SQL required to resolve another SQL vertex diff
GetAddAlterDependencies(new S, old S) ([]dependency, error)

// GetDeleteDependencies is the same as above but for deletes.
// Invariant to maintain:
// - If an object X depends on the delete for an object Y (generated by the sqlVertexGenerator), immediately after the
// the (Y, diffTypeDelete) sqlVertex's SQL is run, Y must no longer be present in the schema; either the
// (Y, diffTypeDelete) statements deleted Y or something that vertex depended on deleted Y. In other words, if a
// delete is cascaded by another delete (e.g., index dropped by table drop) and the index SQL is empty,
// the index delete vertex must still have dependency from itself to the object from which the delete cascades down from
GetDeleteDependencies(S) ([]dependency, error)
}

type (
// listDiff represents the differences between two lists.
listDiff[S schema.Object, Diff diff[S]] struct {
Expand Down Expand Up @@ -158,120 +79,6 @@ func (ld listDiff[S, D]) resolveToSQLGroupedByEffect(sqlGenerator sqlGenerator[S
}, nil
}

func (ld listDiff[S, D]) resolveToSQLGraph(generator sqlVertexGenerator[S, D]) (*sqlGraph, error) {
graph := graph.NewGraph[sqlVertex]()

for _, a := range ld.adds {
statements, err := generator.Add(a)
if err != nil {
return nil, fmt.Errorf("generating SQL for add %s: %w", a.GetName(), err)
}

deps, err := generator.GetAddAlterDependencies(a, *new(S))
if err != nil {
return nil, fmt.Errorf("getting dependencies for add %s: %w", a.GetName(), err)
}
if err := addSQLVertexToGraph(graph, sqlVertex{
ObjId: generator.GetSQLVertexId(a),
Statements: statements,
DiffType: diffTypeAddAlter,
}, deps); err != nil {
return nil, fmt.Errorf("adding SQL Vertex for add %s: %w", a.GetName(), err)
}
}

for _, a := range ld.alters {
statements, err := generator.Alter(a)
if err != nil {
return nil, fmt.Errorf("generating SQL for diff %+v: %w", a, err)
}

vertexId := generator.GetSQLVertexId(a.GetOld())
vertexIdAfterAlter := generator.GetSQLVertexId(a.GetNew())
if vertexIdAfterAlter != vertexId {
return nil, fmt.Errorf("an alter lead to a node with a different id: old=%s, new=%s", vertexId, vertexIdAfterAlter)
}

deps, err := generator.GetAddAlterDependencies(a.GetNew(), a.GetOld())
if err != nil {
return nil, fmt.Errorf("getting dependencies for alter %s: %w", a.GetOld().GetName(), err)
}

if err := addSQLVertexToGraph(graph, sqlVertex{
ObjId: vertexId,
Statements: statements,
DiffType: diffTypeAddAlter,
}, deps); err != nil {
return nil, fmt.Errorf("adding SQL Vertex for alter %s: %w", a.GetOld().GetName(), err)
}
}

for _, d := range ld.deletes {
statements, err := generator.Delete(d)
if err != nil {
return nil, fmt.Errorf("generating SQL for delete %s: %w", d.GetName(), err)
}

deps, err := generator.GetDeleteDependencies(d)
if err != nil {
return nil, fmt.Errorf("getting dependencies for delete %s: %w", d.GetName(), err)
}

if err := addSQLVertexToGraph(graph, sqlVertex{
ObjId: generator.GetSQLVertexId(d),
Statements: statements,
DiffType: diffTypeDelete,
}, deps); err != nil {
return nil, fmt.Errorf("adding SQL Vertex for delete %s: %w", d.GetName(), err)
}
}

return (*sqlGraph)(graph), nil
}

func addSQLVertexToGraph(graph *graph.Graph[sqlVertex], vertex sqlVertex, dependencies []dependency) error {
// It's possible the node already exists. merge it if it does
if graph.HasVertexWithId(vertex.GetId()) {
vertex = mergeSQLVertices(graph.GetVertex(vertex.GetId()), vertex)
}
graph.AddVertex(vertex)
for _, dep := range dependencies {
if err := addDependency(graph, dep); err != nil {
return fmt.Errorf("adding dependencies for %s: %w", vertex.GetId(), err)
}
}
return nil
}

func addDependency(graph *graph.Graph[sqlVertex], dep dependency) error {
sourceVertex := sqlVertex{
ObjId: dep.sourceObjId,
DiffType: dep.sourceType,
Statements: nil,
}
targetVertex := sqlVertex{
ObjId: dep.targetObjId,
DiffType: dep.targetType,
Statements: nil,
}

// To maintain the correctness of the graph, we will add a dummy vertex for the missing dependencies
addVertexIfNotExists(graph, sourceVertex)
addVertexIfNotExists(graph, targetVertex)

if err := graph.AddEdge(sourceVertex.GetId(), targetVertex.GetId()); err != nil {
return fmt.Errorf("adding edge from %s to %s: %w", sourceVertex.GetId(), targetVertex.GetId(), err)
}

return nil
}

func addVertexIfNotExists(graph *graph.Graph[sqlVertex], vertex sqlVertex) {
if !graph.HasVertexWithId(vertex.GetId()) {
graph.AddVertex(vertex)
}
}

type schemaObjectEntry[S schema.Object] struct {
index int // index is the index the schema object in the list
obj S
Expand Down
26 changes: 13 additions & 13 deletions pkg/diff/policy_sql_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ type policyDiff struct {
oldAndNew[schema.Policy]
}

func buildPolicyDiffs(psg *policySQLVertexGenerator, old, new []schema.Policy) (listDiff[schema.Policy, policyDiff], error) {
func buildPolicyDiffs(psg sqlVertexGenerator[schema.Policy, policyDiff], old, new []schema.Policy) (listDiff[schema.Policy, policyDiff], error) {
return diffLists(old, new, func(old, new schema.Policy, _, _ int) (_ policyDiff, requiresRecreate bool, _ error) {
diff := policyDiff{
oldAndNew: oldAndNew[schema.Policy]{
Expand All @@ -131,7 +131,7 @@ type policySQLVertexGenerator struct {
oldSchemaColumnsByName map[string]schema.Column
}

func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (*policySQLVertexGenerator, error) {
func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (sqlVertexGenerator[schema.Policy, policyDiff], error) {
var oldSchemaColumnsByName map[string]schema.Column
if oldTable != nil {
if oldTable.SchemaQualifiedName != table.SchemaQualifiedName {
Expand All @@ -140,12 +140,12 @@ func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (*p
oldSchemaColumnsByName = buildSchemaObjByNameMap(oldTable.Columns)
}

return &policySQLVertexGenerator{
return legacyToNewSqlVertexGenerator[schema.Policy, policyDiff](&policySQLVertexGenerator{
table: table,
newSchemaColumnsByName: buildSchemaObjByNameMap(table.Columns),
oldTable: oldTable,
oldSchemaColumnsByName: oldSchemaColumnsByName,
}, nil
}), nil
}

func (psg *policySQLVertexGenerator) Add(p schema.Policy) ([]Statement, error) {
Expand Down Expand Up @@ -262,17 +262,17 @@ func (psg *policySQLVertexGenerator) Alter(diff policyDiff) ([]Statement, error)
}}, nil
}

func (psg *policySQLVertexGenerator) GetSQLVertexId(p schema.Policy) string {
return buildPolicyVertexId(psg.table.SchemaQualifiedName, p.EscapedName)
func (psg *policySQLVertexGenerator) GetSQLVertexId(p schema.Policy, diffType diffType) sqlVertexId {
return buildPolicyVertexId(psg.table.SchemaQualifiedName, p.EscapedName, diffType)
}

func buildPolicyVertexId(owningTable schema.SchemaQualifiedName, policyEscapedName string) string {
return buildVertexId("policy", fmt.Sprintf("%s.%s", owningTable.GetFQEscapedName(), policyEscapedName))
func buildPolicyVertexId(owningTable schema.SchemaQualifiedName, policyEscapedName string, diffType diffType) sqlVertexId {
return buildSchemaObjVertexId("policy", fmt.Sprintf("%s.%s", owningTable.GetFQEscapedName(), policyEscapedName), diffType)
}

func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolicy schema.Policy) ([]dependency, error) {
deps := []dependency{
mustRun(psg.GetSQLVertexId(newPolicy), diffTypeDelete).before(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter),
mustRun(psg.GetSQLVertexId(newPolicy, diffTypeDelete)).before(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)),
}

newTargetColumns, err := getTargetColumns(newPolicy.Columns, psg.newSchemaColumnsByName)
Expand All @@ -282,7 +282,7 @@ func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolic

// Run after the new columns are added/altered
for _, tc := range newTargetColumns {
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter).after(buildColumnVertexId(tc.Name), diffTypeAddAlter))
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).after(buildColumnVertexId(tc.Name, diffTypeAddAlter)))
}

if !cmp.Equal(oldPolicy, schema.Policy{}) {
Expand All @@ -294,7 +294,7 @@ func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolic
for _, tc := range oldTargetColumns {
// It only needs to run before the delete if the column is actually being deleted
if _, stillExists := psg.newSchemaColumnsByName[tc.GetName()]; !stillExists {
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter).before(buildColumnVertexId(tc.Name), diffTypeDelete))
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).before(buildColumnVertexId(tc.Name, diffTypeDelete)))
}
}
}
Expand All @@ -311,8 +311,8 @@ func (psg *policySQLVertexGenerator) GetDeleteDependencies(pol schema.Policy) ([
}
// The policy needs to be deleted before all the columns it references are deleted or add/altered
for _, c := range columns {
deps = append(deps, mustRun(psg.GetSQLVertexId(pol), diffTypeDelete).before(buildColumnVertexId(c.Name), diffTypeDelete))
deps = append(deps, mustRun(psg.GetSQLVertexId(pol), diffTypeDelete).before(buildColumnVertexId(c.Name), diffTypeAddAlter))
deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildColumnVertexId(c.Name, diffTypeDelete)))
deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildColumnVertexId(c.Name, diffTypeAddAlter)))
}

return deps, nil
Expand Down
Loading
Loading