@@ -3,6 +3,9 @@ package diff
3
3
import (
4
4
"context"
5
5
"fmt"
6
+ "os"
7
+ "path/filepath"
8
+ "strings"
6
9
7
10
"github.com/stripe/pg-schema-diff/internal/schema"
8
11
"github.com/stripe/pg-schema-diff/pkg/log"
@@ -20,13 +23,79 @@ type SchemaSource interface {
20
23
GetSchema (ctx context.Context , deps schemaSourcePlanDeps ) (schema.Schema , error )
21
24
}
22
25
23
- type ddlSchemaSource struct {
24
- ddl []string
26
+ type (
27
+ ddlStatement struct {
28
+ // stmt is the DDL statement to run.
29
+ stmt string
30
+ // file is an optional field that can be used to store the file name from which the DDL was read.
31
+ file string
32
+ }
33
+
34
+ ddlSchemaSource struct {
35
+ ddl []ddlStatement
36
+ }
37
+ )
38
+
39
+ // DirSchemaSource returns a SchemaSource that returns a schema based on the provided directories. You must provide a tempDBFactory
40
+ // via the WithTempDbFactory option.
41
+ func DirSchemaSource (dirs []string ) (SchemaSource , error ) {
42
+ var ddl []ddlStatement
43
+ for _ , dir := range dirs {
44
+ stmts , err := getDDLFromPath (dir )
45
+ if err != nil {
46
+ return & ddlSchemaSource {}, err
47
+ }
48
+ ddl = append (ddl , stmts ... )
49
+
50
+ }
51
+ return & ddlSchemaSource {
52
+ ddl : ddl ,
53
+ }, nil
54
+ }
55
+
56
+ // getDDLFromPath reads all .sql files under the given path (including sub-directories) and returns the DDL
57
+ // in lexical order.
58
+ func getDDLFromPath (path string ) ([]ddlStatement , error ) {
59
+ var ddl []ddlStatement
60
+ if err := filepath .Walk (path , func (path string , entry os.FileInfo , err error ) error {
61
+ if err != nil {
62
+ return fmt .Errorf ("walking path %q: %w" , path , err )
63
+ }
64
+ if strings .ToLower (filepath .Ext (entry .Name ())) != ".sql" {
65
+ return nil
66
+ }
67
+
68
+ fileContents , err := os .ReadFile (path )
69
+ if err != nil {
70
+ return fmt .Errorf ("reading file %q: %w" , entry .Name (), err )
71
+ }
72
+
73
+ // In the future, it would make sense to split the file contents into individual DDL statements; however,
74
+ // that would require fully parsing the SQL. Naively splitting on `;` would not work because `;` can be
75
+ // used in comments, strings, and escaped identifiers.
76
+ ddl = append (ddl , ddlStatement {
77
+ stmt : string (fileContents ),
78
+ file : path ,
79
+ })
80
+ return nil
81
+ }); err != nil {
82
+ return nil , err
83
+ }
84
+ return ddl , nil
25
85
}
26
86
27
87
// DDLSchemaSource returns a SchemaSource that returns a schema based on the provided DDL. You must provide a tempDBFactory
28
88
// via the WithTempDbFactory option.
29
- func DDLSchemaSource (ddl []string ) SchemaSource {
89
+ func DDLSchemaSource (stmts []string ) SchemaSource {
90
+ var ddl []ddlStatement
91
+ for _ , stmt := range stmts {
92
+ ddl = append (ddl , ddlStatement {
93
+ stmt : stmt ,
94
+ // There is no file name associated with the DDL statement.
95
+ file : "" },
96
+ )
97
+ }
98
+
30
99
return & ddlSchemaSource {ddl : ddl }
31
100
}
32
101
@@ -45,9 +114,13 @@ func (s *ddlSchemaSource) GetSchema(ctx context.Context, deps schemaSourcePlanDe
45
114
}
46
115
}(tempDb .ContextualCloser )
47
116
48
- for _ , stmt := range s .ddl {
49
- if _ , err := tempDb .ConnPool .ExecContext (ctx , stmt ); err != nil {
50
- return schema.Schema {}, fmt .Errorf ("running DDL: %w" , err )
117
+ for _ , ddlStmt := range s .ddl {
118
+ if _ , err := tempDb .ConnPool .ExecContext (ctx , ddlStmt .stmt ); err != nil {
119
+ debugInfo := ""
120
+ if ddlStmt .file != "" {
121
+ debugInfo = fmt .Sprintf (" (from %s)" , ddlStmt .file )
122
+ }
123
+ return schema.Schema {}, fmt .Errorf ("running DDL%s: %w" , debugInfo , err )
51
124
}
52
125
}
53
126
0 commit comments