-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathvalidate.go
208 lines (189 loc) · 6.04 KB
/
validate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package squibble
import (
"cmp"
"context"
"database/sql"
"fmt"
"slices"
"strings"
)
// Validate checks whether the current schema of db appears to match the
// specified schema, and reports an error if there are discrepancies.
// An error reported by Validate has concrete type ValidationError if
// the schemas differ.
func Validate(ctx context.Context, db DBConn, schema string) error {
comp, err := schemaTextToRows(ctx, schema)
if err != nil {
return err
}
main, err := readSchema(ctx, db, "main")
if err != nil {
return err
}
if diff := diffSchema(main, comp); diff != "" {
return ValidationError{Diff: diff}
}
return nil
}
func schemaTextToRows(ctx context.Context, schema string) ([]schemaRow, error) {
vdb, err := sql.Open("sqlite", "file::memory:")
if err != nil {
return nil, fmt.Errorf("create validation db: %w", err)
}
defer vdb.Close()
tx, err := vdb.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, schema); err != nil {
return nil, fmt.Errorf("compile schema: %w", err)
}
return readSchema(ctx, tx, "main")
}
// ValidationError is the concrete type of errors reported by the Validate
// function.
type ValidationError struct {
// Diff is a human readable summary of the difference between what was in
// the database (-lhs) and the expected schema (+rhs).
Diff string
}
func (v ValidationError) Error() string {
return fmt.Sprintf("invalid schema (-got, +want):\n%s", v.Diff)
}
type schemaRow struct {
Type string // e.g., "index", "table", "trigger", "view"
Name string
TableName string // affiliated table name (== Name for tables and views)
Columns []schemaCol // for tables, the columns
SQL string // the text of the definition (maybe)
}
type mapKey struct {
Type, Name string
}
func (s schemaRow) mapKey() mapKey { return mapKey{s.Type, s.Name} }
type schemaCol struct {
Name string // column name
Type string // type description
NotNull bool // whether the column is marked NOT NULL
Default any // the default value
PrimaryKey bool // whether this column is part of the primary key
Hidden int // 0=normal, 1=hidden, 2=generated virtual, 3=generated stored
}
func (c schemaCol) String() string {
var sb strings.Builder
fmt.Fprintf(&sb, "%q %s", c.Name, c.Type)
if c.NotNull {
fmt.Fprint(&sb, " not null")
} else {
fmt.Fprint(&sb, " null")
}
if c.Default != nil {
fmt.Fprintf(&sb, " default=%v", c.Default)
}
if c.PrimaryKey {
fmt.Fprint(&sb, " primary key")
}
return sb.String()
}
func compareSchemaRows(a, b schemaRow) int {
if v := cmp.Compare(a.Type, b.Type); v != 0 {
return v
} else if v := cmp.Compare(a.Name, b.Name); v != 0 {
return v
} else if v := cmp.Compare(a.TableName, b.TableName); v != 0 {
return v
}
return cmp.Compare(a.SQL, b.SQL)
}
func compareSchemaCols(a, b schemaCol) int {
if v := cmp.Compare(a.Type, b.Type); v != 0 {
return v
} else if v := cmp.Compare(a.Name, b.Name); v != 0 {
return v
}
return cmp.Compare(
fmt.Sprintf("%v %v %v %d", a.NotNull, a.PrimaryKey, a.Default != nil, a.Hidden),
fmt.Sprintf("%v %v %v %d", b.NotNull, b.PrimaryKey, b.Default != nil, b.Hidden),
)
}
// DBConn is the subset of the sql.DB interface needed by the functions defined
// in this package.
type DBConn interface {
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
ExecContext(context.Context, string, ...any) (sql.Result, error)
}
// readSchema reads the schema for the specified database and returns the
// resulting rows sorted into a stable order. Rows belonging to the history
// table and any affiliated indices are filtered out.
func readSchema(ctx context.Context, db DBConn, root string) ([]schemaRow, error) {
rows, err := db.QueryContext(ctx,
fmt.Sprintf(`SELECT type, name, tbl_name, sql FROM %s.sqlite_schema`, root),
)
if err != nil {
return nil, err
}
defer rows.Close()
var out []schemaRow
for rows.Next() {
var rtype, name, tblName string
var sql sql.NullString
if err := rows.Scan(&rtype, &name, &tblName, &sql); err != nil {
return nil, fmt.Errorf("scan %s schema: %w", root, err)
} else if tblName == "_schema_history" || tblName == "sqlite_sequence" {
continue // skip the history and sequence tables and their indices
} else if strings.HasPrefix(name, "sqlite_autoindex_") {
continue // skip auto-generates SQLite indices
}
out = append(out, schemaRow{Type: rtype, Name: name, TableName: tblName, SQL: sql.String})
// For tables: Read out the column information.
if rtype == "table" {
cols, err := readColumns(ctx, db, root, name)
if err != nil {
return nil, err
}
out[len(out)-1].Columns = cols
}
}
slices.SortFunc(out, compareSchemaRows)
return out, nil
}
// readColumns reads the schema metadata for the columns of the specified table.
func readColumns(ctx context.Context, db DBConn, root, table string) ([]schemaCol, error) {
rows, err := db.QueryContext(ctx, fmt.Sprintf(`PRAGMA %s.table_xinfo('%s')`, root, table))
if err != nil {
return nil, err
}
defer rows.Close()
var out []schemaCol
for rows.Next() {
var idIgnored, notNull, isPK, hidden int
var name, ctype string
var defValue any
if err := rows.Scan(&idIgnored, &name, &ctype, ¬Null, &defValue, &isPK, &hidden); err != nil {
return nil, fmt.Errorf("scan %s columns: %w", table, err)
}
out = append(out, schemaCol{
Name: name,
Type: strings.ToUpper(ctype), // normalize
NotNull: notNull != 0,
Default: defValue,
PrimaryKey: isPK != 0,
Hidden: hidden,
})
}
slices.SortFunc(out, compareSchemaCols)
return out, nil
}
// schemaIsEmpty reports whether the schema for the specified database is
// essentially empty (meaning, it is either empty or contains only a history
// table).
func schemaIsEmpty(ctx context.Context, db DBConn, root string) bool {
main, err := readSchema(ctx, db, root)
if err != nil {
return false
}
return len(main) == 0 || (len(main) == 1 && main[0].Name == historyTableName)
}