1//! Versioned schema migrations for the task database.
2//!
3//! Each migration has up/down SQL and optional post-hook functions that run
4//! inside the same transaction. Schema version is tracked via SQLite's built-in
5//! `PRAGMA user_version`.
6
7use anyhow::{bail, Context, Result};
8use rusqlite::Connection;
9
10/// A single schema migration step.
11struct Migration {
12 up_sql: &'static str,
13 down_sql: &'static str,
14 post_hook_up: Option<fn(&rusqlite::Transaction) -> Result<()>>,
15 post_hook_down: Option<fn(&rusqlite::Transaction) -> Result<()>>,
16}
17
18/// All migrations in order. The array index is the version the database will
19/// be at *after* the migration runs (1-indexed: migration 0 brings the DB to
20/// version 1).
21static MIGRATIONS: &[Migration] = &[
22 // 0 → 1: initial schema
23 Migration {
24 up_sql: include_str!("migrations/0001_initial_schema.up.sql"),
25 down_sql: include_str!("migrations/0001_initial_schema.down.sql"),
26 post_hook_up: None,
27 post_hook_down: None,
28 },
29 // 1 → 2: add effort column (integer-backed, default medium)
30 Migration {
31 up_sql: include_str!("migrations/0002_add_effort.up.sql"),
32 down_sql: include_str!("migrations/0002_add_effort.down.sql"),
33 post_hook_up: None,
34 post_hook_down: None,
35 },
36 Migration {
37 up_sql: include_str!("migrations/0003_blocker_fk.up.sql"),
38 down_sql: include_str!("migrations/0003_blocker_fk.down.sql"),
39 post_hook_up: None,
40 post_hook_down: None,
41 },
42 Migration {
43 up_sql: include_str!("migrations/0004_task_logs.up.sql"),
44 down_sql: include_str!("migrations/0004_task_logs.down.sql"),
45 post_hook_up: None,
46 post_hook_down: None,
47 },
48 Migration {
49 up_sql: include_str!("migrations/0005_cascade_fks.up.sql"),
50 down_sql: include_str!("migrations/0005_cascade_fks.down.sql"),
51 post_hook_up: None,
52 post_hook_down: None,
53 },
54];
55
56/// Read the current schema version from the database.
57fn get_version(conn: &Connection) -> Result<u32> {
58 let v: u32 = conn.pragma_query_value(None, "user_version", |row| row.get(0))?;
59 Ok(v)
60}
61
62/// Set the schema version inside an open transaction.
63fn set_version(tx: &rusqlite::Transaction, version: u32) -> Result<()> {
64 // PRAGMA cannot be parameterised, but the value is a u32 we control.
65 tx.pragma_update(None, "user_version", version)?;
66 Ok(())
67}
68
69/// Apply all pending up-migrations to bring the database to the latest version.
70pub fn migrate_up(conn: &mut Connection) -> Result<()> {
71 let current = get_version(conn)?;
72 let latest = MIGRATIONS.len() as u32;
73
74 if current > latest {
75 bail!(
76 "database is at version {current} but this binary only knows up to {latest}; \
77 upgrade td or use a matching version"
78 );
79 }
80
81 for (idx, m) in MIGRATIONS.iter().enumerate().skip(current as usize) {
82 let target_version = (idx + 1) as u32;
83
84 let tx = conn
85 .transaction()
86 .context("failed to begin migration transaction")?;
87
88 if !m.up_sql.is_empty() {
89 tx.execute_batch(m.up_sql)
90 .with_context(|| format!("migration {target_version} up SQL failed"))?;
91 }
92
93 if let Some(hook) = m.post_hook_up {
94 hook(&tx)
95 .with_context(|| format!("migration {target_version} post-hook (up) failed"))?;
96 }
97
98 set_version(&tx, target_version)?;
99
100 tx.commit()
101 .with_context(|| format!("failed to commit migration {target_version}"))?;
102 }
103
104 Ok(())
105}
106
107/// Roll back migrations down to `target_version` (inclusive — the database
108/// will be at `target_version` when this returns).
109pub fn migrate_down(conn: &mut Connection, target_version: u32) -> Result<()> {
110 let current = get_version(conn)?;
111
112 if target_version >= current {
113 bail!("target version {target_version} is not below current version {current}");
114 }
115
116 if target_version > MIGRATIONS.len() as u32 {
117 bail!("target version {target_version} exceeds known migrations");
118 }
119
120 // Walk backwards: if we're at version 3 and want version 1, we undo
121 // migration index 2 (v3→v2) then index 1 (v2→v1).
122 for (idx, m) in MIGRATIONS
123 .iter()
124 .enumerate()
125 .rev()
126 .filter(|(i, _)| *i >= target_version as usize && *i < current as usize)
127 {
128 let from_version = (idx + 1) as u32;
129
130 let tx = conn
131 .transaction()
132 .context("failed to begin down-migration transaction")?;
133
134 if let Some(hook) = m.post_hook_down {
135 hook(&tx)
136 .with_context(|| format!("migration {from_version} post-hook (down) failed"))?;
137 }
138
139 if !m.down_sql.is_empty() {
140 tx.execute_batch(m.down_sql)
141 .with_context(|| format!("migration {from_version} down SQL failed"))?;
142 }
143
144 set_version(&tx, idx as u32)?;
145
146 tx.commit()
147 .with_context(|| format!("failed to commit down-migration {from_version}"))?;
148 }
149
150 Ok(())
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn migrate_up_from_empty() {
159 let mut conn = Connection::open_in_memory().unwrap();
160 migrate_up(&mut conn).unwrap();
161
162 let version = get_version(&conn).unwrap();
163 assert_eq!(version, MIGRATIONS.len() as u32);
164
165 // Verify tables exist by querying them.
166 conn.execute_batch("SELECT id FROM tasks LIMIT 0").unwrap();
167 conn.execute_batch("SELECT task_id FROM labels LIMIT 0")
168 .unwrap();
169 conn.execute_batch("SELECT task_id FROM blockers LIMIT 0")
170 .unwrap();
171 conn.execute_batch("SELECT task_id FROM task_logs LIMIT 0")
172 .unwrap();
173 }
174
175 #[test]
176 fn migrate_up_is_idempotent() {
177 let mut conn = Connection::open_in_memory().unwrap();
178 migrate_up(&mut conn).unwrap();
179 // Running again should be a no-op, not an error.
180 migrate_up(&mut conn).unwrap();
181 assert_eq!(get_version(&conn).unwrap(), MIGRATIONS.len() as u32);
182 }
183
184 #[test]
185 fn migrate_down_to_zero() {
186 let mut conn = Connection::open_in_memory().unwrap();
187 migrate_up(&mut conn).unwrap();
188 migrate_down(&mut conn, 0).unwrap();
189 assert_eq!(get_version(&conn).unwrap(), 0);
190
191 // Tables should be gone.
192 let result = conn.execute_batch("SELECT id FROM tasks LIMIT 0");
193 assert!(result.is_err());
194 }
195
196 #[test]
197 fn rejects_future_version() {
198 let mut conn = Connection::open_in_memory().unwrap();
199 conn.pragma_update(None, "user_version", 999).unwrap();
200 let err = migrate_up(&mut conn).unwrap_err();
201 assert!(
202 err.to_string().contains("999"),
203 "error should mention the version: {err}"
204 );
205 }
206}