migrate.rs

  1//! Versioned schema migrations for the task database.
  2//!
  3//! Each migration has up/down SQL and optional post-hook functions that run
  4//! inside the same transaction. Schema version is tracked via SQLite's built-in
  5//! `PRAGMA user_version`.
  6
  7use anyhow::{bail, Context, Result};
  8use rusqlite::Connection;
  9
 10/// A single schema migration step.
 11struct Migration {
 12    up_sql: &'static str,
 13    down_sql: &'static str,
 14    post_hook_up: Option<fn(&rusqlite::Transaction) -> Result<()>>,
 15    post_hook_down: Option<fn(&rusqlite::Transaction) -> Result<()>>,
 16}
 17
 18/// All migrations in order.  The array index is the version the database will
 19/// be at *after* the migration runs (1-indexed: migration 0 brings the DB to
 20/// version 1).
 21static MIGRATIONS: &[Migration] = &[
 22    // 0 → 1: initial schema
 23    Migration {
 24        up_sql: include_str!("migrations/0001_initial_schema.up.sql"),
 25        down_sql: include_str!("migrations/0001_initial_schema.down.sql"),
 26        post_hook_up: None,
 27        post_hook_down: None,
 28    },
 29    // 1 → 2: add effort column (integer-backed, default medium)
 30    Migration {
 31        up_sql: include_str!("migrations/0002_add_effort.up.sql"),
 32        down_sql: include_str!("migrations/0002_add_effort.down.sql"),
 33        post_hook_up: None,
 34        post_hook_down: None,
 35    },
 36    Migration {
 37        up_sql: include_str!("migrations/0003_blocker_fk.up.sql"),
 38        down_sql: include_str!("migrations/0003_blocker_fk.down.sql"),
 39        post_hook_up: None,
 40        post_hook_down: None,
 41    },
 42    Migration {
 43        up_sql: include_str!("migrations/0004_task_logs.up.sql"),
 44        down_sql: include_str!("migrations/0004_task_logs.down.sql"),
 45        post_hook_up: None,
 46        post_hook_down: None,
 47    },
 48    Migration {
 49        up_sql: include_str!("migrations/0005_cascade_fks.up.sql"),
 50        down_sql: include_str!("migrations/0005_cascade_fks.down.sql"),
 51        post_hook_up: None,
 52        post_hook_down: None,
 53    },
 54];
 55
 56/// Read the current schema version from the database.
 57fn get_version(conn: &Connection) -> Result<u32> {
 58    let v: u32 = conn.pragma_query_value(None, "user_version", |row| row.get(0))?;
 59    Ok(v)
 60}
 61
 62/// Set the schema version inside an open transaction.
 63fn set_version(tx: &rusqlite::Transaction, version: u32) -> Result<()> {
 64    // PRAGMA cannot be parameterised, but the value is a u32 we control.
 65    tx.pragma_update(None, "user_version", version)?;
 66    Ok(())
 67}
 68
 69/// Apply all pending up-migrations to bring the database to the latest version.
 70pub fn migrate_up(conn: &mut Connection) -> Result<()> {
 71    let current = get_version(conn)?;
 72    let latest = MIGRATIONS.len() as u32;
 73
 74    if current > latest {
 75        bail!(
 76            "database is at version {current} but this binary only knows up to {latest}; \
 77             upgrade td or use a matching version"
 78        );
 79    }
 80
 81    for (idx, m) in MIGRATIONS.iter().enumerate().skip(current as usize) {
 82        let target_version = (idx + 1) as u32;
 83
 84        let tx = conn
 85            .transaction()
 86            .context("failed to begin migration transaction")?;
 87
 88        if !m.up_sql.is_empty() {
 89            tx.execute_batch(m.up_sql)
 90                .with_context(|| format!("migration {target_version} up SQL failed"))?;
 91        }
 92
 93        if let Some(hook) = m.post_hook_up {
 94            hook(&tx)
 95                .with_context(|| format!("migration {target_version} post-hook (up) failed"))?;
 96        }
 97
 98        set_version(&tx, target_version)?;
 99
100        tx.commit()
101            .with_context(|| format!("failed to commit migration {target_version}"))?;
102    }
103
104    Ok(())
105}
106
107/// Roll back migrations down to `target_version` (inclusive — the database
108/// will be at `target_version` when this returns).
109pub fn migrate_down(conn: &mut Connection, target_version: u32) -> Result<()> {
110    let current = get_version(conn)?;
111
112    if target_version >= current {
113        bail!("target version {target_version} is not below current version {current}");
114    }
115
116    if target_version > MIGRATIONS.len() as u32 {
117        bail!("target version {target_version} exceeds known migrations");
118    }
119
120    // Walk backwards: if we're at version 3 and want version 1, we undo
121    // migration index 2 (v3→v2) then index 1 (v2→v1).
122    for (idx, m) in MIGRATIONS
123        .iter()
124        .enumerate()
125        .rev()
126        .filter(|(i, _)| *i >= target_version as usize && *i < current as usize)
127    {
128        let from_version = (idx + 1) as u32;
129
130        let tx = conn
131            .transaction()
132            .context("failed to begin down-migration transaction")?;
133
134        if let Some(hook) = m.post_hook_down {
135            hook(&tx)
136                .with_context(|| format!("migration {from_version} post-hook (down) failed"))?;
137        }
138
139        if !m.down_sql.is_empty() {
140            tx.execute_batch(m.down_sql)
141                .with_context(|| format!("migration {from_version} down SQL failed"))?;
142        }
143
144        set_version(&tx, idx as u32)?;
145
146        tx.commit()
147            .with_context(|| format!("failed to commit down-migration {from_version}"))?;
148    }
149
150    Ok(())
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn migrate_up_from_empty() {
159        let mut conn = Connection::open_in_memory().unwrap();
160        migrate_up(&mut conn).unwrap();
161
162        let version = get_version(&conn).unwrap();
163        assert_eq!(version, MIGRATIONS.len() as u32);
164
165        // Verify tables exist by querying them.
166        conn.execute_batch("SELECT id FROM tasks LIMIT 0").unwrap();
167        conn.execute_batch("SELECT task_id FROM labels LIMIT 0")
168            .unwrap();
169        conn.execute_batch("SELECT task_id FROM blockers LIMIT 0")
170            .unwrap();
171        conn.execute_batch("SELECT task_id FROM task_logs LIMIT 0")
172            .unwrap();
173    }
174
175    #[test]
176    fn migrate_up_is_idempotent() {
177        let mut conn = Connection::open_in_memory().unwrap();
178        migrate_up(&mut conn).unwrap();
179        // Running again should be a no-op, not an error.
180        migrate_up(&mut conn).unwrap();
181        assert_eq!(get_version(&conn).unwrap(), MIGRATIONS.len() as u32);
182    }
183
184    #[test]
185    fn migrate_down_to_zero() {
186        let mut conn = Connection::open_in_memory().unwrap();
187        migrate_up(&mut conn).unwrap();
188        migrate_down(&mut conn, 0).unwrap();
189        assert_eq!(get_version(&conn).unwrap(), 0);
190
191        // Tables should be gone.
192        let result = conn.execute_batch("SELECT id FROM tasks LIMIT 0");
193        assert!(result.is_err());
194    }
195
196    #[test]
197    fn rejects_future_version() {
198        let mut conn = Connection::open_in_memory().unwrap();
199        conn.pragma_update(None, "user_version", 999).unwrap();
200        let err = migrate_up(&mut conn).unwrap_err();
201        assert!(
202            err.to_string().contains("999"),
203            "error should mention the version: {err}"
204        );
205    }
206}