Skip to content

Commit efbae4a

Browse files
refactor(rust): move the logic of how to sort and apply migrations into the NextMigration enum
1 parent 605f2c7 commit efbae4a

File tree

6 files changed

+228
-131
lines changed

6 files changed

+228
-131
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use ockam_core::compat::collections::{HashMap, HashSet};
1+
use ockam_core::compat::collections::HashSet;
22
use ockam_core::compat::time::now;
33
use ockam_core::errcode::{Kind, Origin};
4-
use sqlx::migrate::{Migrate, Migration};
4+
use sqlx::migrate::{AppliedMigration, Migrate, Migration as SqlxMigration};
55
use sqlx::sqlite::SqliteRow;
66
use sqlx::{query, Row, SqliteConnection, SqlitePool};
77
use std::cmp::Ordering;
@@ -67,29 +67,10 @@ impl Migrator {
6767
}
6868

6969
impl Migrator {
70-
fn is_inside_version_range(
71-
version: i64,
72-
from_version: i64, // including
73-
to_version: i64,
74-
including_to: bool,
75-
) -> bool {
76-
if from_version <= version && version < to_version {
77-
return true;
78-
}
79-
80-
if version == to_version && including_to {
81-
return true;
82-
}
83-
84-
false
85-
}
86-
8770
async fn run_migrations(
8871
&self,
8972
connection: &mut SqliteConnection,
90-
from_version: i64, // including
91-
to_version: i64, // not including
92-
run_last_sql: bool, // will run sql migration with verion == to_version
73+
up_to: Version,
9374
) -> Result<()> {
9475
connection.ensure_migrations_table().await.into_core()?;
9576

@@ -102,6 +83,26 @@ impl Migrator {
10283
));
10384
}
10485

86+
let migrations = {
87+
let sql_iterator = self.sql_migrator.migrations.iter().filter_map(|m| {
88+
if m.version <= up_to {
89+
Some(NextMigration::Sql(m))
90+
} else {
91+
None
92+
}
93+
});
94+
let rust_iterator = self.rust_migrations.iter().filter_map(|m| {
95+
if m.version() <= up_to {
96+
Some(NextMigration::Rust(m.as_ref()))
97+
} else {
98+
None
99+
}
100+
});
101+
let mut migrations: Vec<NextMigration> = sql_iterator.chain(rust_iterator).collect();
102+
migrations.sort();
103+
migrations
104+
};
105+
105106
// sqlx Migrator also optionally checks for missing migrations (ones that had been run and
106107
// marked as migrated in the db but now don't exist). Skipping that check for now.
107108
// WARNING: the check if re-enabled can potentially fail because of renaming
@@ -110,92 +111,18 @@ impl Migrator {
110111
// before the _rust_migrations table existed
111112
let applied_migrations = connection.list_applied_migrations().await.into_core()?;
112113

113-
let applied_migrations: HashMap<_, _> = applied_migrations
114-
.into_iter()
115-
.map(|m| (m.version, m))
116-
.collect();
117-
118-
enum NextMigration<'a> {
119-
Sql(&'a Migration),
120-
#[allow(clippy::borrowed_box)]
121-
Rust(&'a Box<dyn RustMigration>),
122-
}
123-
124-
impl NextMigration<'_> {
125-
fn is_sql(&self) -> bool {
126-
match self {
127-
NextMigration::Sql(_) => true,
128-
NextMigration::Rust(_) => false,
129-
}
130-
}
131-
}
132-
133-
let sql_iterator = self.sql_migrator.migrations.iter().filter_map(|m| {
134-
let version = m.version;
135-
136-
if !Self::is_inside_version_range(version, from_version, to_version, run_last_sql) {
137-
return None;
138-
}
139-
140-
Some((version, NextMigration::Sql(m)))
141-
});
142-
let rust_iterator = self.rust_migrations.iter().filter_map(|m| {
143-
let version = m.version();
144-
145-
if !Self::is_inside_version_range(version, from_version, to_version, false) {
146-
return None;
147-
}
148-
149-
Some((version, NextMigration::Rust(m)))
150-
});
151-
152-
let mut all_migrations: Vec<(i64, NextMigration)> =
153-
sql_iterator.chain(rust_iterator).collect();
154-
all_migrations.sort_by(|m1, m2| match m1.0.cmp(&m2.0) {
155-
Ordering::Less => Ordering::Less,
156-
Ordering::Equal => {
157-
// Sql migrations go first
158-
if m1.1.is_sql() {
159-
Ordering::Less
160-
} else {
161-
Ordering::Greater
162-
}
163-
}
164-
Ordering::Greater => Ordering::Greater,
165-
});
166-
167-
for migration in all_migrations.iter().map(|(_version, m)| m) {
114+
for migration in migrations.into_iter() {
168115
match migration {
169116
NextMigration::Sql(sql_migration) => {
170-
if sql_migration.migration_type.is_down_migration() {
171-
return Ok(());
172-
}
173-
174-
match applied_migrations.get(&sql_migration.version) {
175-
Some(applied_migration) => {
176-
if sql_migration.checksum != applied_migration.checksum {
177-
return Err(ockam_core::Error::new(
178-
Origin::Node,
179-
Kind::Conflict,
180-
format!(
181-
"Checksum mismatch for sql migration for version {}",
182-
sql_migration.version
183-
),
184-
));
185-
}
186-
}
187-
None => {
188-
connection.apply(sql_migration).await.into_core()?;
189-
}
190-
}
117+
NextMigration::apply_sql_migration(
118+
sql_migration,
119+
connection,
120+
&applied_migrations,
121+
)
122+
.await?;
191123
}
192124
NextMigration::Rust(rust_migration) => {
193-
if Self::has_migrated(connection, rust_migration.name()).await? {
194-
continue;
195-
}
196-
if rust_migration.migrate(connection).await? {
197-
Self::mark_as_migrated(connection, rust_migration.name()).await?;
198-
}
125+
NextMigration::apply_rust_migration(rust_migration, connection).await?;
199126
}
200127
}
201128
}
@@ -240,22 +167,14 @@ impl Migrator {
240167
}
241168

242169
impl Migrator {
243-
/// Run migrations
244-
pub async fn migrate_partial(
245-
&self,
246-
pool: &SqlitePool,
247-
from_version: i64, // including
248-
to_version: i64, // not including
249-
run_last_sql: bool, // Will run `to_version` version of the sql migration
250-
) -> Result<()> {
170+
/// Run migrations up to the specified version (inclusive)
171+
pub(crate) async fn migrate_up_to(&self, pool: &SqlitePool, up_to: Version) -> Result<()> {
251172
let mut connection = pool.acquire().await.into_core()?;
252173

253174
// Apparently does nothing for sqlite...
254175
connection.lock().await.into_core()?;
255176

256-
let res = self
257-
.run_migrations(&mut connection, from_version, to_version, run_last_sql)
258-
.await;
177+
let res = self.run_migrations(&mut connection, up_to).await;
259178

260179
connection.unlock().await.into_core()?;
261180

@@ -266,19 +185,192 @@ impl Migrator {
266185

267186
/// Run all migrations
268187
pub async fn migrate(&self, pool: &SqlitePool) -> Result<()> {
269-
self.migrate_partial(pool, 0, i64::MAX, false).await
188+
self.migrate_up_to(pool, i64::MAX).await
270189
}
271190
}
272191

273192
#[cfg(test)]
274193
impl Migrator {
275-
/// Migrate the schema of the database right before the specified version
276-
pub(crate) async fn migrate_before(
277-
&self,
194+
/// Run migrations up to the specified version (inclusive) but skip the last rust migration
195+
pub(crate) async fn migrate_up_to_skip_last_rust_migration(
196+
mut self,
278197
pool: &SqlitePool,
279-
version: i64, // not including
280-
run_last_sql: bool,
198+
up_to: Version,
199+
) -> Result<()> {
200+
self.rust_migrations.retain(|m| m.version() < up_to);
201+
self.migrate_up_to(pool, up_to).await
202+
}
203+
}
204+
205+
type Version = i64;
206+
207+
#[derive(Debug)]
208+
enum NextMigration<'a> {
209+
Sql(&'a SqlxMigration),
210+
Rust(&'a dyn RustMigration),
211+
}
212+
213+
impl NextMigration<'_> {
214+
fn is_sql(&self) -> bool {
215+
matches!(self, Self::Sql(_))
216+
}
217+
218+
fn version(&self) -> Version {
219+
match self {
220+
Self::Sql(m) => m.version,
221+
Self::Rust(m) => m.version(),
222+
}
223+
}
224+
225+
async fn apply_sql_migration<'a>(
226+
migration: &'a SqlxMigration,
227+
connection: &mut SqliteConnection,
228+
applied_migrations: &[AppliedMigration],
281229
) -> Result<()> {
282-
self.migrate_partial(pool, 0, version, run_last_sql).await
230+
if migration.migration_type.is_down_migration() {
231+
return Ok(());
232+
}
233+
match applied_migrations
234+
.iter()
235+
.find(|m| m.version == migration.version)
236+
{
237+
Some(applied_migration) => {
238+
if migration.checksum != applied_migration.checksum {
239+
return Err(ockam_core::Error::new(
240+
Origin::Node,
241+
Kind::Conflict,
242+
format!(
243+
"Checksum mismatch for sql migration for version {}",
244+
migration.version
245+
),
246+
));
247+
}
248+
}
249+
None => {
250+
connection.apply(migration).await.into_core()?;
251+
}
252+
}
253+
Ok(())
254+
}
255+
256+
async fn apply_rust_migration(
257+
migration: &dyn RustMigration,
258+
connection: &mut SqliteConnection,
259+
) -> Result<()> {
260+
if Migrator::has_migrated(connection, migration.name()).await? {
261+
return Ok(());
262+
}
263+
if migration.migrate(connection).await? {
264+
Migrator::mark_as_migrated(connection, migration.name()).await?;
265+
}
266+
Ok(())
267+
}
268+
}
269+
270+
impl Eq for NextMigration<'_> {}
271+
272+
impl PartialEq<Self> for NextMigration<'_> {
273+
fn eq(&self, other: &Self) -> bool {
274+
let same_kind = matches!(
275+
(self, other),
276+
(Self::Sql(_), Self::Sql(_)) | (Self::Rust(_), Self::Rust(_))
277+
);
278+
same_kind && self.version() == other.version()
279+
}
280+
}
281+
282+
impl Ord for NextMigration<'_> {
283+
fn cmp(&self, other: &Self) -> Ordering {
284+
match self.version().cmp(&other.version()) {
285+
Ordering::Equal => {
286+
// Sql migrations go first
287+
match (self.is_sql(), other.is_sql()) {
288+
(true, true) => Ordering::Equal,
289+
(true, false) => Ordering::Less,
290+
(false, true) => Ordering::Greater,
291+
_ => unreachable!(),
292+
}
293+
}
294+
ord => ord,
295+
}
296+
}
297+
}
298+
299+
impl PartialOrd for NextMigration<'_> {
300+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
301+
Some(self.cmp(other))
302+
}
303+
}
304+
305+
#[cfg(test)]
306+
mod tests {
307+
use super::*;
308+
use ockam_core::async_trait;
309+
use sqlx::migrate::MigrationType;
310+
311+
#[test]
312+
fn ordering_of_migrations() {
313+
let sql_1 = SqlxMigration::new(1, "sql_1".into(), MigrationType::Simple, "1".into());
314+
let sql_2 = SqlxMigration::new(2, "sql_2".into(), MigrationType::Simple, "2".into());
315+
let rust_1: Box<dyn RustMigration> = Box::new(DummyRustMigration::new(1));
316+
let rust_2: Box<dyn RustMigration> = Box::new(DummyRustMigration::new(2));
317+
let rust_3: Box<dyn RustMigration> = Box::new(DummyRustMigration::new(3));
318+
319+
let mut migrations = vec![
320+
NextMigration::Sql(&sql_2),
321+
NextMigration::Sql(&sql_1),
322+
NextMigration::Rust(rust_1.as_ref()),
323+
NextMigration::Rust(rust_3.as_ref()),
324+
NextMigration::Rust(rust_2.as_ref()),
325+
];
326+
migrations.sort();
327+
328+
for m in &migrations {
329+
match m {
330+
NextMigration::Sql(_) => {
331+
assert!(m.is_sql());
332+
}
333+
NextMigration::Rust(_) => {
334+
assert!(!m.is_sql());
335+
}
336+
}
337+
}
338+
339+
assert_eq!(
340+
migrations,
341+
vec![
342+
NextMigration::Sql(&sql_1),
343+
NextMigration::Rust(rust_1.as_ref()),
344+
NextMigration::Sql(&sql_2),
345+
NextMigration::Rust(rust_2.as_ref()),
346+
NextMigration::Rust(rust_3.as_ref())
347+
]
348+
);
349+
}
350+
351+
#[derive(Debug)]
352+
struct DummyRustMigration {
353+
version: Version,
354+
}
355+
356+
impl DummyRustMigration {
357+
fn new(version: Version) -> Self {
358+
Self { version }
359+
}
360+
}
361+
362+
#[async_trait]
363+
impl RustMigration for DummyRustMigration {
364+
fn name(&self) -> &str {
365+
"DummyRustMigration"
366+
}
367+
368+
fn version(&self) -> Version {
369+
self.version
370+
}
371+
372+
async fn migrate(&self, _connection: &mut SqliteConnection) -> Result<bool> {
373+
Ok(true)
374+
}
283375
}
284376
}

0 commit comments

Comments
 (0)