1//! Versioned schema migrations for the task database.
2//!
3//! Each migration has up/down SQL and optional post-hook functions that run
4//! inside the same transaction. Schema version is tracked via SQLite's built-in
5//! `PRAGMA user_version`.
6
7use anyhow::{bail, Context, Result};
8use rusqlite::Connection;
9
10/// A single schema migration step.
11struct Migration {
12 up_sql: &'static str,
13 down_sql: &'static str,
14 post_hook_up: Option<fn(&rusqlite::Transaction) -> Result<()>>,
15 post_hook_down: Option<fn(&rusqlite::Transaction) -> Result<()>>,
16}
17
18/// All migrations in order. The array index is the version the database will
19/// be at *after* the migration runs (1-indexed: migration 0 brings the DB to
20/// version 1).
21static MIGRATIONS: &[Migration] = &[
22 // 0 → 1: initial schema
23 Migration {
24 up_sql: include_str!("migrations/0001_initial_schema.up.sql"),
25 down_sql: include_str!("migrations/0001_initial_schema.down.sql"),
26 post_hook_up: None,
27 post_hook_down: None,
28 },
29];
30
31/// Read the current schema version from the database.
32fn get_version(conn: &Connection) -> Result<u32> {
33 let v: u32 = conn.pragma_query_value(None, "user_version", |row| row.get(0))?;
34 Ok(v)
35}
36
37/// Set the schema version inside an open transaction.
38fn set_version(tx: &rusqlite::Transaction, version: u32) -> Result<()> {
39 // PRAGMA cannot be parameterised, but the value is a u32 we control.
40 tx.pragma_update(None, "user_version", version)?;
41 Ok(())
42}
43
44/// Apply all pending up-migrations to bring the database to the latest version.
45pub fn migrate_up(conn: &mut Connection) -> Result<()> {
46 let current = get_version(conn)?;
47 let latest = MIGRATIONS.len() as u32;
48
49 if current > latest {
50 bail!(
51 "database is at version {current} but this binary only knows up to {latest}; \
52 upgrade td or use a matching version"
53 );
54 }
55
56 for (idx, m) in MIGRATIONS.iter().enumerate().skip(current as usize) {
57 let target_version = (idx + 1) as u32;
58
59 let tx = conn
60 .transaction()
61 .context("failed to begin migration transaction")?;
62
63 if !m.up_sql.is_empty() {
64 tx.execute_batch(m.up_sql)
65 .with_context(|| format!("migration {target_version} up SQL failed"))?;
66 }
67
68 if let Some(hook) = m.post_hook_up {
69 hook(&tx)
70 .with_context(|| format!("migration {target_version} post-hook (up) failed"))?;
71 }
72
73 set_version(&tx, target_version)?;
74
75 tx.commit()
76 .with_context(|| format!("failed to commit migration {target_version}"))?;
77 }
78
79 Ok(())
80}
81
82/// Roll back migrations down to `target_version` (inclusive — the database
83/// will be at `target_version` when this returns).
84pub fn migrate_down(conn: &mut Connection, target_version: u32) -> Result<()> {
85 let current = get_version(conn)?;
86
87 if target_version >= current {
88 bail!("target version {target_version} is not below current version {current}");
89 }
90
91 if target_version > MIGRATIONS.len() as u32 {
92 bail!("target version {target_version} exceeds known migrations");
93 }
94
95 // Walk backwards: if we're at version 3 and want version 1, we undo
96 // migration index 2 (v3→v2) then index 1 (v2→v1).
97 for (idx, m) in MIGRATIONS
98 .iter()
99 .enumerate()
100 .rev()
101 .filter(|(i, _)| *i >= target_version as usize && *i < current as usize)
102 {
103 let from_version = (idx + 1) as u32;
104
105 let tx = conn
106 .transaction()
107 .context("failed to begin down-migration transaction")?;
108
109 if let Some(hook) = m.post_hook_down {
110 hook(&tx)
111 .with_context(|| format!("migration {from_version} post-hook (down) failed"))?;
112 }
113
114 if !m.down_sql.is_empty() {
115 tx.execute_batch(m.down_sql)
116 .with_context(|| format!("migration {from_version} down SQL failed"))?;
117 }
118
119 set_version(&tx, idx as u32)?;
120
121 tx.commit()
122 .with_context(|| format!("failed to commit down-migration {from_version}"))?;
123 }
124
125 Ok(())
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn migrate_up_from_empty() {
134 let mut conn = Connection::open_in_memory().unwrap();
135 migrate_up(&mut conn).unwrap();
136
137 let version = get_version(&conn).unwrap();
138 assert_eq!(version, MIGRATIONS.len() as u32);
139
140 // Verify tables exist by querying them.
141 conn.execute_batch("SELECT id FROM tasks LIMIT 0").unwrap();
142 conn.execute_batch("SELECT task_id FROM labels LIMIT 0")
143 .unwrap();
144 conn.execute_batch("SELECT task_id FROM blockers LIMIT 0")
145 .unwrap();
146 }
147
148 #[test]
149 fn migrate_up_is_idempotent() {
150 let mut conn = Connection::open_in_memory().unwrap();
151 migrate_up(&mut conn).unwrap();
152 // Running again should be a no-op, not an error.
153 migrate_up(&mut conn).unwrap();
154 assert_eq!(get_version(&conn).unwrap(), MIGRATIONS.len() as u32);
155 }
156
157 #[test]
158 fn migrate_down_to_zero() {
159 let mut conn = Connection::open_in_memory().unwrap();
160 migrate_up(&mut conn).unwrap();
161 migrate_down(&mut conn, 0).unwrap();
162 assert_eq!(get_version(&conn).unwrap(), 0);
163
164 // Tables should be gone.
165 let result = conn.execute_batch("SELECT id FROM tasks LIMIT 0");
166 assert!(result.is_err());
167 }
168
169 #[test]
170 fn rejects_future_version() {
171 let mut conn = Connection::open_in_memory().unwrap();
172 conn.pragma_update(None, "user_version", 999).unwrap();
173 let err = migrate_up(&mut conn).unwrap_err();
174 assert!(
175 err.to_string().contains("999"),
176 "error should mention the version: {err}"
177 );
178 }
179}