1
- use ockam_core:: compat:: collections:: { HashMap , HashSet } ;
1
+ use ockam_core:: compat:: collections:: HashSet ;
2
2
use ockam_core:: compat:: time:: now;
3
3
use ockam_core:: errcode:: { Kind , Origin } ;
4
- use sqlx:: migrate:: { Migrate , Migration } ;
4
+ use sqlx:: migrate:: { AppliedMigration , Migrate , Migration as SqlxMigration } ;
5
5
use sqlx:: sqlite:: SqliteRow ;
6
6
use sqlx:: { query, Row , SqliteConnection , SqlitePool } ;
7
7
use std:: cmp:: Ordering ;
@@ -67,29 +67,10 @@ impl Migrator {
67
67
}
68
68
69
69
impl Migrator {
70
- fn is_inside_version_range (
71
- version : i64 ,
72
- from_version : i64 , // including
73
- to_version : i64 ,
74
- including_to : bool ,
75
- ) -> bool {
76
- if from_version <= version && version < to_version {
77
- return true ;
78
- }
79
-
80
- if version == to_version && including_to {
81
- return true ;
82
- }
83
-
84
- false
85
- }
86
-
87
70
async fn run_migrations (
88
71
& self ,
89
72
connection : & mut SqliteConnection ,
90
- from_version : i64 , // including
91
- to_version : i64 , // not including
92
- run_last_sql : bool , // will run sql migration with verion == to_version
73
+ up_to : Version ,
93
74
) -> Result < ( ) > {
94
75
connection. ensure_migrations_table ( ) . await . into_core ( ) ?;
95
76
@@ -102,6 +83,26 @@ impl Migrator {
102
83
) ) ;
103
84
}
104
85
86
+ let migrations = {
87
+ let sql_iterator = self . sql_migrator . migrations . iter ( ) . filter_map ( |m| {
88
+ if m. version <= up_to {
89
+ Some ( NextMigration :: Sql ( m) )
90
+ } else {
91
+ None
92
+ }
93
+ } ) ;
94
+ let rust_iterator = self . rust_migrations . iter ( ) . filter_map ( |m| {
95
+ if m. version ( ) <= up_to {
96
+ Some ( NextMigration :: Rust ( m. as_ref ( ) ) )
97
+ } else {
98
+ None
99
+ }
100
+ } ) ;
101
+ let mut migrations: Vec < NextMigration > = sql_iterator. chain ( rust_iterator) . collect ( ) ;
102
+ migrations. sort ( ) ;
103
+ migrations
104
+ } ;
105
+
105
106
// sqlx Migrator also optionally checks for missing migrations (ones that had been run and
106
107
// marked as migrated in the db but now don't exist). Skipping that check for now.
107
108
// WARNING: the check if re-enabled can potentially fail because of renaming
@@ -110,92 +111,18 @@ impl Migrator {
110
111
// before the _rust_migrations table existed
111
112
let applied_migrations = connection. list_applied_migrations ( ) . await . into_core ( ) ?;
112
113
113
- let applied_migrations: HashMap < _ , _ > = applied_migrations
114
- . into_iter ( )
115
- . map ( |m| ( m. version , m) )
116
- . collect ( ) ;
117
-
118
- enum NextMigration < ' a > {
119
- Sql ( & ' a Migration ) ,
120
- #[ allow( clippy:: borrowed_box) ]
121
- Rust ( & ' a Box < dyn RustMigration > ) ,
122
- }
123
-
124
- impl NextMigration < ' _ > {
125
- fn is_sql ( & self ) -> bool {
126
- match self {
127
- NextMigration :: Sql ( _) => true ,
128
- NextMigration :: Rust ( _) => false ,
129
- }
130
- }
131
- }
132
-
133
- let sql_iterator = self . sql_migrator . migrations . iter ( ) . filter_map ( |m| {
134
- let version = m. version ;
135
-
136
- if !Self :: is_inside_version_range ( version, from_version, to_version, run_last_sql) {
137
- return None ;
138
- }
139
-
140
- Some ( ( version, NextMigration :: Sql ( m) ) )
141
- } ) ;
142
- let rust_iterator = self . rust_migrations . iter ( ) . filter_map ( |m| {
143
- let version = m. version ( ) ;
144
-
145
- if !Self :: is_inside_version_range ( version, from_version, to_version, false ) {
146
- return None ;
147
- }
148
-
149
- Some ( ( version, NextMigration :: Rust ( m) ) )
150
- } ) ;
151
-
152
- let mut all_migrations: Vec < ( i64 , NextMigration ) > =
153
- sql_iterator. chain ( rust_iterator) . collect ( ) ;
154
- all_migrations. sort_by ( |m1, m2| match m1. 0 . cmp ( & m2. 0 ) {
155
- Ordering :: Less => Ordering :: Less ,
156
- Ordering :: Equal => {
157
- // Sql migrations go first
158
- if m1. 1 . is_sql ( ) {
159
- Ordering :: Less
160
- } else {
161
- Ordering :: Greater
162
- }
163
- }
164
- Ordering :: Greater => Ordering :: Greater ,
165
- } ) ;
166
-
167
- for migration in all_migrations. iter ( ) . map ( |( _version, m) | m) {
114
+ for migration in migrations. into_iter ( ) {
168
115
match migration {
169
116
NextMigration :: Sql ( sql_migration) => {
170
- if sql_migration. migration_type . is_down_migration ( ) {
171
- return Ok ( ( ) ) ;
172
- }
173
-
174
- match applied_migrations. get ( & sql_migration. version ) {
175
- Some ( applied_migration) => {
176
- if sql_migration. checksum != applied_migration. checksum {
177
- return Err ( ockam_core:: Error :: new (
178
- Origin :: Node ,
179
- Kind :: Conflict ,
180
- format ! (
181
- "Checksum mismatch for sql migration for version {}" ,
182
- sql_migration. version
183
- ) ,
184
- ) ) ;
185
- }
186
- }
187
- None => {
188
- connection. apply ( sql_migration) . await . into_core ( ) ?;
189
- }
190
- }
117
+ NextMigration :: apply_sql_migration (
118
+ sql_migration,
119
+ connection,
120
+ & applied_migrations,
121
+ )
122
+ . await ?;
191
123
}
192
124
NextMigration :: Rust ( rust_migration) => {
193
- if Self :: has_migrated ( connection, rust_migration. name ( ) ) . await ? {
194
- continue ;
195
- }
196
- if rust_migration. migrate ( connection) . await ? {
197
- Self :: mark_as_migrated ( connection, rust_migration. name ( ) ) . await ?;
198
- }
125
+ NextMigration :: apply_rust_migration ( rust_migration, connection) . await ?;
199
126
}
200
127
}
201
128
}
@@ -240,22 +167,14 @@ impl Migrator {
240
167
}
241
168
242
169
impl Migrator {
243
- /// Run migrations
244
- pub async fn migrate_partial (
245
- & self ,
246
- pool : & SqlitePool ,
247
- from_version : i64 , // including
248
- to_version : i64 , // not including
249
- run_last_sql : bool , // Will run `to_version` version of the sql migration
250
- ) -> Result < ( ) > {
170
+ /// Run migrations up to the specified version (inclusive)
171
+ pub ( crate ) async fn migrate_up_to ( & self , pool : & SqlitePool , up_to : Version ) -> Result < ( ) > {
251
172
let mut connection = pool. acquire ( ) . await . into_core ( ) ?;
252
173
253
174
// Apparently does nothing for sqlite...
254
175
connection. lock ( ) . await . into_core ( ) ?;
255
176
256
- let res = self
257
- . run_migrations ( & mut connection, from_version, to_version, run_last_sql)
258
- . await ;
177
+ let res = self . run_migrations ( & mut connection, up_to) . await ;
259
178
260
179
connection. unlock ( ) . await . into_core ( ) ?;
261
180
@@ -266,19 +185,192 @@ impl Migrator {
266
185
267
186
/// Run all migrations
268
187
pub async fn migrate ( & self , pool : & SqlitePool ) -> Result < ( ) > {
269
- self . migrate_partial ( pool, 0 , i64:: MAX , false ) . await
188
+ self . migrate_up_to ( pool, i64:: MAX ) . await
270
189
}
271
190
}
272
191
273
192
#[ cfg( test) ]
274
193
impl Migrator {
275
- /// Migrate the schema of the database right before the specified version
276
- pub ( crate ) async fn migrate_before (
277
- & self ,
194
+ /// Run migrations up to the specified version (inclusive) but skip the last rust migration
195
+ pub ( crate ) async fn migrate_up_to_skip_last_rust_migration (
196
+ mut self ,
278
197
pool : & SqlitePool ,
279
- version : i64 , // not including
280
- run_last_sql : bool ,
198
+ up_to : Version ,
199
+ ) -> Result < ( ) > {
200
+ self . rust_migrations . retain ( |m| m. version ( ) < up_to) ;
201
+ self . migrate_up_to ( pool, up_to) . await
202
+ }
203
+ }
204
+
205
+ type Version = i64 ;
206
+
207
+ #[ derive( Debug ) ]
208
+ enum NextMigration < ' a > {
209
+ Sql ( & ' a SqlxMigration ) ,
210
+ Rust ( & ' a dyn RustMigration ) ,
211
+ }
212
+
213
+ impl NextMigration < ' _ > {
214
+ fn is_sql ( & self ) -> bool {
215
+ matches ! ( self , Self :: Sql ( _) )
216
+ }
217
+
218
+ fn version ( & self ) -> Version {
219
+ match self {
220
+ Self :: Sql ( m) => m. version ,
221
+ Self :: Rust ( m) => m. version ( ) ,
222
+ }
223
+ }
224
+
225
+ async fn apply_sql_migration < ' a > (
226
+ migration : & ' a SqlxMigration ,
227
+ connection : & mut SqliteConnection ,
228
+ applied_migrations : & [ AppliedMigration ] ,
281
229
) -> Result < ( ) > {
282
- self . migrate_partial ( pool, 0 , version, run_last_sql) . await
230
+ if migration. migration_type . is_down_migration ( ) {
231
+ return Ok ( ( ) ) ;
232
+ }
233
+ match applied_migrations
234
+ . iter ( )
235
+ . find ( |m| m. version == migration. version )
236
+ {
237
+ Some ( applied_migration) => {
238
+ if migration. checksum != applied_migration. checksum {
239
+ return Err ( ockam_core:: Error :: new (
240
+ Origin :: Node ,
241
+ Kind :: Conflict ,
242
+ format ! (
243
+ "Checksum mismatch for sql migration for version {}" ,
244
+ migration. version
245
+ ) ,
246
+ ) ) ;
247
+ }
248
+ }
249
+ None => {
250
+ connection. apply ( migration) . await . into_core ( ) ?;
251
+ }
252
+ }
253
+ Ok ( ( ) )
254
+ }
255
+
256
+ async fn apply_rust_migration (
257
+ migration : & dyn RustMigration ,
258
+ connection : & mut SqliteConnection ,
259
+ ) -> Result < ( ) > {
260
+ if Migrator :: has_migrated ( connection, migration. name ( ) ) . await ? {
261
+ return Ok ( ( ) ) ;
262
+ }
263
+ if migration. migrate ( connection) . await ? {
264
+ Migrator :: mark_as_migrated ( connection, migration. name ( ) ) . await ?;
265
+ }
266
+ Ok ( ( ) )
267
+ }
268
+ }
269
+
270
+ impl Eq for NextMigration < ' _ > { }
271
+
272
+ impl PartialEq < Self > for NextMigration < ' _ > {
273
+ fn eq ( & self , other : & Self ) -> bool {
274
+ let same_kind = matches ! (
275
+ ( self , other) ,
276
+ ( Self :: Sql ( _) , Self :: Sql ( _) ) | ( Self :: Rust ( _) , Self :: Rust ( _) )
277
+ ) ;
278
+ same_kind && self . version ( ) == other. version ( )
279
+ }
280
+ }
281
+
282
+ impl Ord for NextMigration < ' _ > {
283
+ fn cmp ( & self , other : & Self ) -> Ordering {
284
+ match self . version ( ) . cmp ( & other. version ( ) ) {
285
+ Ordering :: Equal => {
286
+ // Sql migrations go first
287
+ match ( self . is_sql ( ) , other. is_sql ( ) ) {
288
+ ( true , true ) => Ordering :: Equal ,
289
+ ( true , false ) => Ordering :: Less ,
290
+ ( false , true ) => Ordering :: Greater ,
291
+ _ => unreachable ! ( ) ,
292
+ }
293
+ }
294
+ ord => ord,
295
+ }
296
+ }
297
+ }
298
+
299
+ impl PartialOrd for NextMigration < ' _ > {
300
+ fn partial_cmp ( & self , other : & Self ) -> Option < Ordering > {
301
+ Some ( self . cmp ( other) )
302
+ }
303
+ }
304
+
305
+ #[ cfg( test) ]
306
+ mod tests {
307
+ use super :: * ;
308
+ use ockam_core:: async_trait;
309
+ use sqlx:: migrate:: MigrationType ;
310
+
311
+ #[ test]
312
+ fn ordering_of_migrations ( ) {
313
+ let sql_1 = SqlxMigration :: new ( 1 , "sql_1" . into ( ) , MigrationType :: Simple , "1" . into ( ) ) ;
314
+ let sql_2 = SqlxMigration :: new ( 2 , "sql_2" . into ( ) , MigrationType :: Simple , "2" . into ( ) ) ;
315
+ let rust_1: Box < dyn RustMigration > = Box :: new ( DummyRustMigration :: new ( 1 ) ) ;
316
+ let rust_2: Box < dyn RustMigration > = Box :: new ( DummyRustMigration :: new ( 2 ) ) ;
317
+ let rust_3: Box < dyn RustMigration > = Box :: new ( DummyRustMigration :: new ( 3 ) ) ;
318
+
319
+ let mut migrations = vec ! [
320
+ NextMigration :: Sql ( & sql_2) ,
321
+ NextMigration :: Sql ( & sql_1) ,
322
+ NextMigration :: Rust ( rust_1. as_ref( ) ) ,
323
+ NextMigration :: Rust ( rust_3. as_ref( ) ) ,
324
+ NextMigration :: Rust ( rust_2. as_ref( ) ) ,
325
+ ] ;
326
+ migrations. sort ( ) ;
327
+
328
+ for m in & migrations {
329
+ match m {
330
+ NextMigration :: Sql ( _) => {
331
+ assert ! ( m. is_sql( ) ) ;
332
+ }
333
+ NextMigration :: Rust ( _) => {
334
+ assert ! ( !m. is_sql( ) ) ;
335
+ }
336
+ }
337
+ }
338
+
339
+ assert_eq ! (
340
+ migrations,
341
+ vec![
342
+ NextMigration :: Sql ( & sql_1) ,
343
+ NextMigration :: Rust ( rust_1. as_ref( ) ) ,
344
+ NextMigration :: Sql ( & sql_2) ,
345
+ NextMigration :: Rust ( rust_2. as_ref( ) ) ,
346
+ NextMigration :: Rust ( rust_3. as_ref( ) )
347
+ ]
348
+ ) ;
349
+ }
350
+
351
+ #[ derive( Debug ) ]
352
+ struct DummyRustMigration {
353
+ version : Version ,
354
+ }
355
+
356
+ impl DummyRustMigration {
357
+ fn new ( version : Version ) -> Self {
358
+ Self { version }
359
+ }
360
+ }
361
+
362
+ #[ async_trait]
363
+ impl RustMigration for DummyRustMigration {
364
+ fn name ( & self ) -> & str {
365
+ "DummyRustMigration"
366
+ }
367
+
368
+ fn version ( & self ) -> Version {
369
+ self . version
370
+ }
371
+
372
+ async fn migrate ( & self , _connection : & mut SqliteConnection ) -> Result < bool > {
373
+ Ok ( true )
374
+ }
283
375
}
284
376
}
0 commit comments