1use anyhow::{bail, Context, Result};
2use util::{iife, ResultExt};
3
4use std::{
5 fmt::Debug,
6 os::unix::prelude::OsStrExt,
7 path::{Path, PathBuf},
8};
9
10use indoc::indoc;
11use sqlez::{
12 bindable::{Bind, Column},
13 connection::Connection,
14 migrations::Migration,
15 statement::Statement,
16};
17
18use crate::pane::{SerializedDockPane, SerializedPaneGroup};
19
20use super::Db;
21
22// If you need to debug the worktree root code, change 'BLOB' here to 'TEXT' for easier debugging
23// you might want to update some of the parsing code as well, I've left the variations in but commented
24// out. This will panic if run on an existing db that has already been migrated
25pub(crate) const WORKSPACES_MIGRATION: Migration = Migration::new(
26 "workspace",
27 &[indoc! {"
28 CREATE TABLE workspaces(
29 workspace_id INTEGER PRIMARY KEY,
30 dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
31 dock_visible INTEGER, -- Boolean
32 timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL
33 ) STRICT;
34
35 CREATE TABLE worktree_roots(
36 worktree_root BLOB NOT NULL,
37 workspace_id INTEGER NOT NULL,
38 FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) ON DELETE CASCADE
39 PRIMARY KEY(worktree_root, workspace_id)
40 ) STRICT;"}],
41);
42
43#[derive(Debug, PartialEq, Eq, Copy, Clone, Default)]
44pub(crate) struct WorkspaceId(i64);
45
46impl Bind for WorkspaceId {
47 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
48 self.0.bind(statement, start_index)
49 }
50}
51
52impl Column for WorkspaceId {
53 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
54 i64::column(statement, start_index).map(|(id, next_index)| (Self(id), next_index))
55 }
56}
57
58#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
59pub enum DockAnchor {
60 #[default]
61 Bottom,
62 Right,
63 Expanded,
64}
65
66impl Bind for DockAnchor {
67 fn bind(&self, statement: &Statement, start_index: i32) -> anyhow::Result<i32> {
68 match self {
69 DockAnchor::Bottom => "Bottom",
70 DockAnchor::Right => "Right",
71 DockAnchor::Expanded => "Expanded",
72 }
73 .bind(statement, start_index)
74 }
75}
76
77impl Column for DockAnchor {
78 fn column(statement: &mut Statement, start_index: i32) -> anyhow::Result<(Self, i32)> {
79 String::column(statement, start_index).and_then(|(anchor_text, next_index)| {
80 Ok((
81 match anchor_text.as_ref() {
82 "Bottom" => DockAnchor::Bottom,
83 "Right" => DockAnchor::Right,
84 "Expanded" => DockAnchor::Expanded,
85 _ => bail!("Stored dock anchor is incorrect"),
86 },
87 next_index,
88 ))
89 })
90 }
91}
92
93type WorkspaceRow = (WorkspaceId, DockAnchor, bool);
94
95#[derive(Default, Debug)]
96pub struct SerializedWorkspace {
97 pub center_group: SerializedPaneGroup,
98 pub dock_anchor: DockAnchor,
99 pub dock_visible: bool,
100 pub dock_pane: SerializedDockPane,
101}
102
103impl Db {
104 /// Finds or creates a workspace id for the given set of worktree roots. If the passed worktree roots is empty,
105 /// returns the last workspace which was updated
106 pub fn workspace_for_roots<P>(&self, worktree_roots: &[P]) -> Option<SerializedWorkspace>
107 where
108 P: AsRef<Path> + Debug,
109 {
110 // Find the workspace id which is uniquely identified by this set of paths
111 // return it if found
112 let mut workspace_row = get_workspace(worktree_roots, &self)
113 .log_err()
114 .unwrap_or_default();
115 if workspace_row.is_none() && worktree_roots.len() == 0 {
116 workspace_row = self.prepare(
117 "SELECT workspace_id, dock_anchor, dock_visible FROM workspaces ORDER BY timestamp DESC LIMIT 1"
118 ).and_then(|mut stmt| stmt.maybe_row::<WorkspaceRow>())
119 .log_err()
120 .flatten()
121 }
122
123 workspace_row.and_then(|(workspace_id, dock_anchor, dock_visible)| {
124 Some(SerializedWorkspace {
125 dock_pane: self.get_dock_pane(workspace_id)?,
126 center_group: self.get_center_group(workspace_id),
127 dock_anchor,
128 dock_visible,
129 })
130 })
131 }
132
133 /// TODO: Change to be 'update workspace' and to serialize the whole workspace in one go.
134 ///
135 /// Updates the open paths for the given workspace id. Will garbage collect items from
136 /// any workspace ids which are no replaced by the new workspace id. Updates the timestamps
137 /// in the workspace id table
138 pub fn update_worktrees<P>(&self, workspace_id: &WorkspaceId, worktree_roots: &[P])
139 where
140 P: AsRef<Path> + Debug,
141 {
142 self.with_savepoint("update_worktrees", |conn| {
143 // Lookup any old WorkspaceIds which have the same set of roots, and delete them.
144 let preexisting_workspace = get_workspace(worktree_roots, &conn)?;
145 if let Some((preexisting_workspace_id, _, _)) = preexisting_workspace {
146 if preexisting_workspace_id != *workspace_id {
147 // Should also delete fields in other tables with cascading updates
148 conn.prepare("DELETE FROM workspaces WHERE workspace_id = ?")?
149 .with_bindings(preexisting_workspace_id)?
150 .exec()?;
151 }
152 }
153
154 conn.prepare("DELETE FROM worktree_roots WHERE workspace_id = ?")?
155 .with_bindings(workspace_id.0)?
156 .exec()?;
157
158 for root in worktree_roots {
159 let path = root.as_ref().as_os_str().as_bytes();
160 // If you need to debug this, here's the string parsing:
161 // let path = root.as_ref().to_string_lossy().to_string();
162
163 conn.prepare(
164 "INSERT INTO worktree_roots(workspace_id, worktree_root) VALUES (?, ?)",
165 )?
166 .with_bindings((workspace_id.0, path))?
167 .exec()?;
168 }
169
170 conn.prepare(
171 "UPDATE workspaces SET timestamp = CURRENT_TIMESTAMP WHERE workspace_id = ?",
172 )?
173 .with_bindings(workspace_id.0)?
174 .exec()?;
175
176 Ok(())
177 })
178 .context("Update workspace {workspace_id:?} with roots {worktree_roots:?}")
179 .log_err();
180 }
181
182 /// Returns the previous workspace ids sorted by last modified along with their opened worktree roots
183 pub fn recent_workspaces(&self, limit: usize) -> Vec<Vec<PathBuf>> {
184 self.with_savepoint("recent_workspaces", |conn| {
185 let mut stmt =
186 conn.prepare("SELECT worktree_root FROM worktree_roots WHERE workspace_id = ?")?;
187
188 conn.prepare("SELECT workspace_id FROM workspaces ORDER BY timestamp DESC LIMIT ?")?
189 .with_bindings(limit)?
190 .rows::<WorkspaceId>()?
191 .iter()
192 .map(|workspace_id| stmt.with_bindings(workspace_id.0)?.rows::<PathBuf>())
193 .collect::<Result<_>>()
194 })
195 .log_err()
196 .unwrap_or_default()
197 }
198}
199
200fn get_workspace<P>(worktree_roots: &[P], connection: &Connection) -> Result<Option<WorkspaceRow>>
201where
202 P: AsRef<Path> + Debug,
203{
204 // Short circuit if we can
205 if worktree_roots.len() == 0 {
206 return Ok(None);
207 }
208
209 // Prepare the array binding string. SQL doesn't have syntax for this, so
210 // we have to do it ourselves.
211 let array_binding_stmt = format!(
212 "({})",
213 (0..worktree_roots.len())
214 .map(|index| format!("?{}", index + 1))
215 .collect::<Vec<_>>()
216 .join(", ")
217 );
218
219 // Any workspace can have multiple independent paths, and these paths
220 // can overlap in the database. Take this test data for example:
221 //
222 // [/tmp, /tmp2] -> 1
223 // [/tmp] -> 2
224 // [/tmp2, /tmp3] -> 3
225 //
226 // This would be stred in the database like so:
227 //
228 // ID PATH
229 // 1 /tmp
230 // 1 /tmp2
231 // 2 /tmp
232 // 3 /tmp2
233 // 3 /tmp3
234 //
235 // Note how both /tmp and /tmp2 are associated with multiple workspace IDs.
236 // So, given an array of worktree roots, how can we find the exactly matching ID?
237 // Let's analyze what happens when querying for [/tmp, /tmp2], from the inside out:
238 // - We start with a join of this table on itself, generating every possible
239 // pair of ((path, ID), (path, ID)), and filtering the join down to just the
240 // *overlapping but non-matching* workspace IDs. For this small data set,
241 // this would look like:
242 //
243 // wt1.ID wt1.PATH | wt2.ID wt2.PATH
244 // 3 /tmp3 3 /tmp2
245 //
246 // - Moving one SELECT out, we use the first pair's ID column to invert the selection,
247 // meaning we now have a list of all the entries for our array, minus overlapping sets,
248 // but including *subsets* of our worktree roots:
249 //
250 // ID PATH
251 // 1 /tmp
252 // 1 /tmp2
253 // 2 /tmp
254 //
255 // - To trim out the subsets, we can to exploit the PRIMARY KEY constraint that there are no
256 // duplicate entries in this table. Using a GROUP BY and a COUNT we can find the subsets of
257 // our keys:
258 //
259 // ID num_matching
260 // 1 2
261 // 2 1
262 //
263 // - And with one final WHERE num_matching = $num_of_worktree_roots, we're done! We've found the
264 // matching ID correctly :D
265 //
266 // Note: due to limitations in SQLite's query binding, we have to generate the prepared
267 // statement with string substitution (the {array_bind}) below, and then bind the
268 // parameters by number.
269 let query = format!(
270 r#"
271 SELECT workspaces.workspace_id, workspaces.dock_anchor, workspaces.dock_visible
272 FROM (SELECT workspace_id
273 FROM (SELECT count(workspace_id) as num_matching, workspace_id FROM worktree_roots
274 WHERE worktree_root in {array_bind} AND workspace_id NOT IN
275 (SELECT wt1.workspace_id FROM worktree_roots as wt1
276 JOIN worktree_roots as wt2
277 ON wt1.workspace_id = wt2.workspace_id
278 WHERE wt1.worktree_root NOT in {array_bind} AND wt2.worktree_root in {array_bind})
279 GROUP BY workspace_id)
280 WHERE num_matching = ?) as matching_workspace
281 JOIN workspaces ON workspaces.workspace_id = matching_workspace.workspace_id
282 "#,
283 array_bind = array_binding_stmt
284 );
285
286 // This will only be called on start up and when root workspaces change, no need to waste memory
287 // caching it.
288 let mut stmt = connection.prepare(&query)?;
289
290 // Make sure we bound the parameters correctly
291 debug_assert!(worktree_roots.len() as i32 + 1 == stmt.parameter_count());
292
293 let root_bytes: Vec<&[u8]> = worktree_roots
294 .iter()
295 .map(|root| root.as_ref().as_os_str().as_bytes())
296 .collect();
297
298 let num_of_roots = root_bytes.len();
299
300 stmt.with_bindings((root_bytes, num_of_roots))?
301 .maybe_row::<WorkspaceRow>()
302}
303
304#[cfg(test)]
305mod tests {
306
307 use std::{path::PathBuf, thread::sleep, time::Duration};
308
309 use crate::Db;
310
311 use super::WorkspaceId;
312
313 #[test]
314 fn test_new_worktrees_for_roots() {
315 env_logger::init();
316 let db = Db::open_in_memory("test_new_worktrees_for_roots");
317
318 // Test creation in 0 case
319 let workspace_1 = db.workspace_for_roots::<String>(&[]);
320 assert_eq!(workspace_1.workspace_id, WorkspaceId(1));
321
322 // Test pulling from recent workspaces
323 let workspace_1 = db.workspace_for_roots::<String>(&[]);
324 assert_eq!(workspace_1.workspace_id, WorkspaceId(1));
325
326 // Ensure the timestamps are different
327 sleep(Duration::from_secs(1));
328 db.make_new_workspace::<String>(&[]);
329
330 // Test pulling another value from recent workspaces
331 let workspace_2 = db.workspace_for_roots::<String>(&[]);
332 assert_eq!(workspace_2.workspace_id, WorkspaceId(2));
333
334 // Ensure the timestamps are different
335 sleep(Duration::from_secs(1));
336
337 // Test creating a new workspace that doesn't exist already
338 let workspace_3 = db.workspace_for_roots(&["/tmp", "/tmp2"]);
339 assert_eq!(workspace_3.workspace_id, WorkspaceId(3));
340
341 // Make sure it's in the recent workspaces....
342 let workspace_3 = db.workspace_for_roots::<String>(&[]);
343 assert_eq!(workspace_3.workspace_id, WorkspaceId(3));
344
345 // And that it can be pulled out again
346 let workspace_3 = db.workspace_for_roots(&["/tmp", "/tmp2"]);
347 assert_eq!(workspace_3.workspace_id, WorkspaceId(3));
348 }
349
350 #[test]
351 fn test_empty_worktrees() {
352 let db = Db::open_in_memory("test_empty_worktrees");
353
354 assert_eq!(None, db.workspace::<String>(&[]));
355
356 db.make_new_workspace::<String>(&[]); //ID 1
357 db.make_new_workspace::<String>(&[]); //ID 2
358 db.update_worktrees(&WorkspaceId(1), &["/tmp", "/tmp2"]);
359
360 // Sanity check
361 assert_eq!(db.workspace(&["/tmp", "/tmp2"]).unwrap().0, WorkspaceId(1));
362
363 db.update_worktrees::<String>(&WorkspaceId(1), &[]);
364
365 // Make sure 'no worktrees' fails correctly. returning [1, 2] from this
366 // call would be semantically correct (as those are the workspaces that
367 // don't have roots) but I'd prefer that this API to either return exactly one
368 // workspace, and None otherwise
369 assert_eq!(db.workspace::<String>(&[]), None,);
370
371 assert_eq!(db.last_workspace().unwrap().0, WorkspaceId(1));
372
373 assert_eq!(
374 db.recent_workspaces(2),
375 vec![Vec::<PathBuf>::new(), Vec::<PathBuf>::new()],
376 )
377 }
378
379 #[test]
380 fn test_more_workspace_ids() {
381 let data = &[
382 (WorkspaceId(1), vec!["/tmp1"]),
383 (WorkspaceId(2), vec!["/tmp1", "/tmp2"]),
384 (WorkspaceId(3), vec!["/tmp1", "/tmp2", "/tmp3"]),
385 (WorkspaceId(4), vec!["/tmp2", "/tmp3"]),
386 (WorkspaceId(5), vec!["/tmp2", "/tmp3", "/tmp4"]),
387 (WorkspaceId(6), vec!["/tmp2", "/tmp4"]),
388 (WorkspaceId(7), vec!["/tmp2"]),
389 ];
390
391 let db = Db::open_in_memory("test_more_workspace_ids");
392
393 for (workspace_id, entries) in data {
394 db.make_new_workspace::<String>(&[]);
395 db.update_worktrees(workspace_id, entries);
396 }
397
398 assert_eq!(WorkspaceId(1), db.workspace(&["/tmp1"]).unwrap().0);
399 assert_eq!(db.workspace(&["/tmp1", "/tmp2"]).unwrap().0, WorkspaceId(2));
400 assert_eq!(
401 db.workspace(&["/tmp1", "/tmp2", "/tmp3"]).unwrap().0,
402 WorkspaceId(3)
403 );
404 assert_eq!(db.workspace(&["/tmp2", "/tmp3"]).unwrap().0, WorkspaceId(4));
405 assert_eq!(
406 db.workspace(&["/tmp2", "/tmp3", "/tmp4"]).unwrap().0,
407 WorkspaceId(5)
408 );
409 assert_eq!(db.workspace(&["/tmp2", "/tmp4"]).unwrap().0, WorkspaceId(6));
410 assert_eq!(db.workspace(&["/tmp2"]).unwrap().0, WorkspaceId(7));
411
412 assert_eq!(db.workspace(&["/tmp1", "/tmp5"]), None);
413 assert_eq!(db.workspace(&["/tmp5"]), None);
414 assert_eq!(db.workspace(&["/tmp2", "/tmp3", "/tmp4", "/tmp5"]), None);
415 }
416
417 #[test]
418 fn test_detect_workspace_id() {
419 let data = &[
420 (WorkspaceId(1), vec!["/tmp"]),
421 (WorkspaceId(2), vec!["/tmp", "/tmp2"]),
422 (WorkspaceId(3), vec!["/tmp", "/tmp2", "/tmp3"]),
423 ];
424
425 let db = Db::open_in_memory("test_detect_workspace_id");
426
427 for (workspace_id, entries) in data {
428 db.make_new_workspace::<String>(&[]);
429 db.update_worktrees(workspace_id, entries);
430 }
431
432 assert_eq!(db.workspace(&["/tmp2"]), None);
433 assert_eq!(db.workspace(&["/tmp2", "/tmp3"]), None);
434 assert_eq!(db.workspace(&["/tmp"]).unwrap().0, WorkspaceId(1));
435 assert_eq!(db.workspace(&["/tmp", "/tmp2"]).unwrap().0, WorkspaceId(2));
436 assert_eq!(
437 db.workspace(&["/tmp", "/tmp2", "/tmp3"]).unwrap().0,
438 WorkspaceId(3)
439 );
440 }
441
442 #[test]
443 fn test_tricky_overlapping_updates() {
444 // DB state:
445 // (/tree) -> ID: 1
446 // (/tree, /tree2) -> ID: 2
447 // (/tree2, /tree3) -> ID: 3
448
449 // -> User updates 2 to: (/tree2, /tree3)
450
451 // DB state:
452 // (/tree) -> ID: 1
453 // (/tree2, /tree3) -> ID: 2
454 // Get rid of 3 for garbage collection
455
456 let data = &[
457 (WorkspaceId(1), vec!["/tmp"]),
458 (WorkspaceId(2), vec!["/tmp", "/tmp2"]),
459 (WorkspaceId(3), vec!["/tmp2", "/tmp3"]),
460 ];
461
462 let db = Db::open_in_memory("test_tricky_overlapping_update");
463
464 // Load in the test data
465 for (workspace_id, entries) in data {
466 db.make_new_workspace::<String>(&[]);
467 db.update_worktrees(workspace_id, entries);
468 }
469
470 sleep(Duration::from_secs(1));
471 // Execute the update
472 db.update_worktrees(&WorkspaceId(2), &["/tmp2", "/tmp3"]);
473
474 // Make sure that workspace 3 doesn't exist
475 assert_eq!(db.workspace(&["/tmp2", "/tmp3"]).unwrap().0, WorkspaceId(2));
476
477 // And that workspace 1 was untouched
478 assert_eq!(db.workspace(&["/tmp"]).unwrap().0, WorkspaceId(1));
479
480 // And that workspace 2 is no longer registered under these roots
481 assert_eq!(db.workspace(&["/tmp", "/tmp2"]), None);
482
483 assert_eq!(db.last_workspace().unwrap().0, WorkspaceId(2));
484
485 let recent_workspaces = db.recent_workspaces(10);
486 assert_eq!(
487 recent_workspaces.get(0).unwrap(),
488 &vec![PathBuf::from("/tmp2"), PathBuf::from("/tmp3")]
489 );
490 assert_eq!(
491 recent_workspaces.get(1).unwrap(),
492 &vec![PathBuf::from("/tmp")]
493 );
494 }
495}