Skip to content

Commit d20f7c4

Browse files
bplunkett-stripealeclarson
authored andcommitted
Refactor SQL generation - graph-first approach (stripe#176)
* Refactor SQL generation such that the SQL generators take a graph-first approach
1 parent ae17d9f commit d20f7c4

File tree

5 files changed

+509
-402
lines changed

5 files changed

+509
-402
lines changed

pkg/diff/diff.go

-193
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,12 @@ import (
44
"fmt"
55
"sort"
66

7-
"github.com/stripe/pg-schema-diff/internal/graph"
87
"github.com/stripe/pg-schema-diff/internal/schema"
98
)
109

1110
var ErrNotImplemented = fmt.Errorf("not implemented")
1211
var errDuplicateIdentifier = fmt.Errorf("duplicate identifier")
1312

14-
type diffType string
15-
16-
const (
17-
diffTypeDelete diffType = "DELETE"
18-
diffTypeAddAlter diffType = "ADDALTER"
19-
)
20-
2113
type (
2214
diff[S schema.Object] interface {
2315
GetOld() S
@@ -32,79 +24,8 @@ type (
3224
// provided diff. Alter, e.g., with a table, might produce add/delete statements
3325
Alter(Diff) ([]Statement, error)
3426
}
35-
36-
// dependency indicates an edge between the SQL to resolve a diff for a source schema object and the SQL to resolve
37-
// the diff of a target schema object
38-
//
39-
// Most SchemaObjects will have two nodes in the SQL graph: a node for delete SQL and a node for add/alter SQL.
40-
// These nodes will almost always be present in the sqlGraph even if the schema object is not being deleted (or added/altered).
41-
// If a node is present for a schema object where the "diffType" is NOT occurring, it will just be a no-op (no SQl statements)
42-
dependency struct {
43-
sourceObjId string
44-
sourceType diffType
45-
46-
targetObjId string
47-
targetType diffType
48-
}
4927
)
5028

51-
type dependencyBuilder struct {
52-
valObjId string
53-
valType diffType
54-
}
55-
56-
func mustRun(schemaObjId string, schemaDiffType diffType) dependencyBuilder {
57-
return dependencyBuilder{
58-
valObjId: schemaObjId,
59-
valType: schemaDiffType,
60-
}
61-
}
62-
63-
func (d dependencyBuilder) before(valObjId string, valType diffType) dependency {
64-
return dependency{
65-
sourceType: d.valType,
66-
sourceObjId: d.valObjId,
67-
68-
targetType: valType,
69-
targetObjId: valObjId,
70-
}
71-
}
72-
73-
func (d dependencyBuilder) after(valObjId string, valType diffType) dependency {
74-
return dependency{
75-
sourceObjId: valObjId,
76-
sourceType: valType,
77-
78-
targetObjId: d.valObjId,
79-
targetType: d.valType,
80-
}
81-
}
82-
83-
// sqlVertexGenerator is used to generate SQL statements for schema objects that have dependency webs
84-
// with other schema objects. The schema object represents a vertex in the graph.
85-
type sqlVertexGenerator[S schema.Object, Diff diff[S]] interface {
86-
sqlGenerator[S, Diff]
87-
// GetSQLVertexId gets the canonical vertex id to represent the schema object
88-
GetSQLVertexId(S) string
89-
90-
// GetAddAlterDependencies gets the dependencies of the SQL generated to resolve the AddAlter diff for the
91-
// schema objects. Dependencies can be formed on any other nodes in the SQL graph, even if the node has
92-
// no statements. If the diff is just an add, then old will be the zero value
93-
//
94-
// These dependencies can also be built in reverse: the SQL returned by the sqlVertexGenerator to resolve the
95-
// diff for the object must always be run before the SQL required to resolve another SQL vertex diff
96-
GetAddAlterDependencies(new S, old S) ([]dependency, error)
97-
98-
// GetDeleteDependencies is the same as above but for deletes.
99-
// Invariant to maintain:
100-
// - If an object X depends on the delete for an object Y (generated by the sqlVertexGenerator), immediately after the
101-
// the (Y, diffTypeDelete) sqlVertex's SQL is run, Y must no longer be present in the schema; either the
102-
// (Y, diffTypeDelete) statements deleted Y or something that vertex depended on deleted Y. In other words, if a
103-
// delete is cascaded by another delete (e.g., index dropped by table drop) and the index SQL is empty,
104-
// the index delete vertex must still have dependency from itself to the object from which the delete cascades down from
105-
GetDeleteDependencies(S) ([]dependency, error)
106-
}
107-
10829
type (
10930
// listDiff represents the differences between two lists.
11031
listDiff[S schema.Object, Diff diff[S]] struct {
@@ -158,120 +79,6 @@ func (ld listDiff[S, D]) resolveToSQLGroupedByEffect(sqlGenerator sqlGenerator[S
15879
}, nil
15980
}
16081

161-
func (ld listDiff[S, D]) resolveToSQLGraph(generator sqlVertexGenerator[S, D]) (*sqlGraph, error) {
162-
graph := graph.NewGraph[sqlVertex]()
163-
164-
for _, a := range ld.adds {
165-
statements, err := generator.Add(a)
166-
if err != nil {
167-
return nil, fmt.Errorf("generating SQL for add %s: %w", a.GetName(), err)
168-
}
169-
170-
deps, err := generator.GetAddAlterDependencies(a, *new(S))
171-
if err != nil {
172-
return nil, fmt.Errorf("getting dependencies for add %s: %w", a.GetName(), err)
173-
}
174-
if err := addSQLVertexToGraph(graph, sqlVertex{
175-
ObjId: generator.GetSQLVertexId(a),
176-
Statements: statements,
177-
DiffType: diffTypeAddAlter,
178-
}, deps); err != nil {
179-
return nil, fmt.Errorf("adding SQL Vertex for add %s: %w", a.GetName(), err)
180-
}
181-
}
182-
183-
for _, a := range ld.alters {
184-
statements, err := generator.Alter(a)
185-
if err != nil {
186-
return nil, fmt.Errorf("generating SQL for diff %+v: %w", a, err)
187-
}
188-
189-
vertexId := generator.GetSQLVertexId(a.GetOld())
190-
vertexIdAfterAlter := generator.GetSQLVertexId(a.GetNew())
191-
if vertexIdAfterAlter != vertexId {
192-
return nil, fmt.Errorf("an alter lead to a node with a different id: old=%s, new=%s", vertexId, vertexIdAfterAlter)
193-
}
194-
195-
deps, err := generator.GetAddAlterDependencies(a.GetNew(), a.GetOld())
196-
if err != nil {
197-
return nil, fmt.Errorf("getting dependencies for alter %s: %w", a.GetOld().GetName(), err)
198-
}
199-
200-
if err := addSQLVertexToGraph(graph, sqlVertex{
201-
ObjId: vertexId,
202-
Statements: statements,
203-
DiffType: diffTypeAddAlter,
204-
}, deps); err != nil {
205-
return nil, fmt.Errorf("adding SQL Vertex for alter %s: %w", a.GetOld().GetName(), err)
206-
}
207-
}
208-
209-
for _, d := range ld.deletes {
210-
statements, err := generator.Delete(d)
211-
if err != nil {
212-
return nil, fmt.Errorf("generating SQL for delete %s: %w", d.GetName(), err)
213-
}
214-
215-
deps, err := generator.GetDeleteDependencies(d)
216-
if err != nil {
217-
return nil, fmt.Errorf("getting dependencies for delete %s: %w", d.GetName(), err)
218-
}
219-
220-
if err := addSQLVertexToGraph(graph, sqlVertex{
221-
ObjId: generator.GetSQLVertexId(d),
222-
Statements: statements,
223-
DiffType: diffTypeDelete,
224-
}, deps); err != nil {
225-
return nil, fmt.Errorf("adding SQL Vertex for delete %s: %w", d.GetName(), err)
226-
}
227-
}
228-
229-
return (*sqlGraph)(graph), nil
230-
}
231-
232-
func addSQLVertexToGraph(graph *graph.Graph[sqlVertex], vertex sqlVertex, dependencies []dependency) error {
233-
// It's possible the node already exists. merge it if it does
234-
if graph.HasVertexWithId(vertex.GetId()) {
235-
vertex = mergeSQLVertices(graph.GetVertex(vertex.GetId()), vertex)
236-
}
237-
graph.AddVertex(vertex)
238-
for _, dep := range dependencies {
239-
if err := addDependency(graph, dep); err != nil {
240-
return fmt.Errorf("adding dependencies for %s: %w", vertex.GetId(), err)
241-
}
242-
}
243-
return nil
244-
}
245-
246-
func addDependency(graph *graph.Graph[sqlVertex], dep dependency) error {
247-
sourceVertex := sqlVertex{
248-
ObjId: dep.sourceObjId,
249-
DiffType: dep.sourceType,
250-
Statements: nil,
251-
}
252-
targetVertex := sqlVertex{
253-
ObjId: dep.targetObjId,
254-
DiffType: dep.targetType,
255-
Statements: nil,
256-
}
257-
258-
// To maintain the correctness of the graph, we will add a dummy vertex for the missing dependencies
259-
addVertexIfNotExists(graph, sourceVertex)
260-
addVertexIfNotExists(graph, targetVertex)
261-
262-
if err := graph.AddEdge(sourceVertex.GetId(), targetVertex.GetId()); err != nil {
263-
return fmt.Errorf("adding edge from %s to %s: %w", sourceVertex.GetId(), targetVertex.GetId(), err)
264-
}
265-
266-
return nil
267-
}
268-
269-
func addVertexIfNotExists(graph *graph.Graph[sqlVertex], vertex sqlVertex) {
270-
if !graph.HasVertexWithId(vertex.GetId()) {
271-
graph.AddVertex(vertex)
272-
}
273-
}
274-
27582
type schemaObjectEntry[S schema.Object] struct {
27683
index int // index is the index the schema object in the list
27784
obj S

pkg/diff/policy_sql_generator.go

+13-13
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ type policyDiff struct {
104104
oldAndNew[schema.Policy]
105105
}
106106

107-
func buildPolicyDiffs(psg *policySQLVertexGenerator, old, new []schema.Policy) (listDiff[schema.Policy, policyDiff], error) {
107+
func buildPolicyDiffs(psg sqlVertexGenerator[schema.Policy, policyDiff], old, new []schema.Policy) (listDiff[schema.Policy, policyDiff], error) {
108108
return diffLists(old, new, func(old, new schema.Policy, _, _ int) (_ policyDiff, requiresRecreate bool, _ error) {
109109
diff := policyDiff{
110110
oldAndNew: oldAndNew[schema.Policy]{
@@ -131,7 +131,7 @@ type policySQLVertexGenerator struct {
131131
oldSchemaColumnsByName map[string]schema.Column
132132
}
133133

134-
func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (*policySQLVertexGenerator, error) {
134+
func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (sqlVertexGenerator[schema.Policy, policyDiff], error) {
135135
var oldSchemaColumnsByName map[string]schema.Column
136136
if oldTable != nil {
137137
if oldTable.SchemaQualifiedName != table.SchemaQualifiedName {
@@ -140,12 +140,12 @@ func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (*p
140140
oldSchemaColumnsByName = buildSchemaObjByNameMap(oldTable.Columns)
141141
}
142142

143-
return &policySQLVertexGenerator{
143+
return legacyToNewSqlVertexGenerator[schema.Policy, policyDiff](&policySQLVertexGenerator{
144144
table: table,
145145
newSchemaColumnsByName: buildSchemaObjByNameMap(table.Columns),
146146
oldTable: oldTable,
147147
oldSchemaColumnsByName: oldSchemaColumnsByName,
148-
}, nil
148+
}), nil
149149
}
150150

151151
func (psg *policySQLVertexGenerator) Add(p schema.Policy) ([]Statement, error) {
@@ -262,17 +262,17 @@ func (psg *policySQLVertexGenerator) Alter(diff policyDiff) ([]Statement, error)
262262
}}, nil
263263
}
264264

265-
func (psg *policySQLVertexGenerator) GetSQLVertexId(p schema.Policy) string {
266-
return buildPolicyVertexId(psg.table.SchemaQualifiedName, p.EscapedName)
265+
func (psg *policySQLVertexGenerator) GetSQLVertexId(p schema.Policy, diffType diffType) sqlVertexId {
266+
return buildPolicyVertexId(psg.table.SchemaQualifiedName, p.EscapedName, diffType)
267267
}
268268

269-
func buildPolicyVertexId(owningTable schema.SchemaQualifiedName, policyEscapedName string) string {
270-
return buildVertexId("policy", fmt.Sprintf("%s.%s", owningTable.GetFQEscapedName(), policyEscapedName))
269+
func buildPolicyVertexId(owningTable schema.SchemaQualifiedName, policyEscapedName string, diffType diffType) sqlVertexId {
270+
return buildSchemaObjVertexId("policy", fmt.Sprintf("%s.%s", owningTable.GetFQEscapedName(), policyEscapedName), diffType)
271271
}
272272

273273
func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolicy schema.Policy) ([]dependency, error) {
274274
deps := []dependency{
275-
mustRun(psg.GetSQLVertexId(newPolicy), diffTypeDelete).before(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter),
275+
mustRun(psg.GetSQLVertexId(newPolicy, diffTypeDelete)).before(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)),
276276
}
277277

278278
newTargetColumns, err := getTargetColumns(newPolicy.Columns, psg.newSchemaColumnsByName)
@@ -282,7 +282,7 @@ func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolic
282282

283283
// Run after the new columns are added/altered
284284
for _, tc := range newTargetColumns {
285-
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter).after(buildColumnVertexId(tc.Name), diffTypeAddAlter))
285+
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).after(buildColumnVertexId(tc.Name, diffTypeAddAlter)))
286286
}
287287

288288
if !cmp.Equal(oldPolicy, schema.Policy{}) {
@@ -294,7 +294,7 @@ func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolic
294294
for _, tc := range oldTargetColumns {
295295
// It only needs to run before the delete if the column is actually being deleted
296296
if _, stillExists := psg.newSchemaColumnsByName[tc.GetName()]; !stillExists {
297-
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter).before(buildColumnVertexId(tc.Name), diffTypeDelete))
297+
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).before(buildColumnVertexId(tc.Name, diffTypeDelete)))
298298
}
299299
}
300300
}
@@ -311,8 +311,8 @@ func (psg *policySQLVertexGenerator) GetDeleteDependencies(pol schema.Policy) ([
311311
}
312312
// The policy needs to be deleted before all the columns it references are deleted or add/altered
313313
for _, c := range columns {
314-
deps = append(deps, mustRun(psg.GetSQLVertexId(pol), diffTypeDelete).before(buildColumnVertexId(c.Name), diffTypeDelete))
315-
deps = append(deps, mustRun(psg.GetSQLVertexId(pol), diffTypeDelete).before(buildColumnVertexId(c.Name), diffTypeAddAlter))
314+
deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildColumnVertexId(c.Name, diffTypeDelete)))
315+
deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildColumnVertexId(c.Name, diffTypeAddAlter)))
316316
}
317317

318318
return deps, nil

0 commit comments

Comments
 (0)