sqlez: Open named in-memory databases as SQLite URIs (#50967)

Danny Milosavljevic created

Closes #51011

Before you mark this PR as ready for review, make sure that you have:
- [X] Added a solid test coverage and/or screenshots from doing manual
testing
- [X] Done a self-review taking into account security and performance
aspects
- [X] Aligned any UI changes with the [UI
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/sqlez/src/connection.rs             | 103 ++++++++++++++++++++++-
crates/sqlez/src/thread_safe_connection.rs | 101 +++++++++++++---------
2 files changed, 155 insertions(+), 49 deletions(-)

Detailed changes

crates/sqlez/src/connection.rs 🔗

@@ -18,7 +18,7 @@ pub struct Connection {
 unsafe impl Send for Connection {}
 
 impl Connection {
-    pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
+    fn open_with_flags(uri: &str, persistent: bool, flags: i32) -> Result<Self> {
         let mut connection = Self {
             sqlite3: ptr::null_mut(),
             persistent,
@@ -26,7 +26,6 @@ impl Connection {
             _sqlite: PhantomData,
         };
 
-        let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
         unsafe {
             sqlite3_open_v2(
                 CString::new(uri)?.as_ptr(),
@@ -44,6 +43,14 @@ impl Connection {
         Ok(connection)
     }
 
+    pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
+        Self::open_with_flags(
+            uri,
+            persistent,
+            SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE,
+        )
+    }
+
     /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
     /// instead.
     pub fn open_file(uri: &str) -> Self {
@@ -51,13 +58,17 @@ impl Connection {
     }
 
     pub fn open_memory(uri: Option<&str>) -> Self {
-        let in_memory_path = if let Some(uri) = uri {
-            format!("file:{}?mode=memory&cache=shared", uri)
+        if let Some(uri) = uri {
+            let in_memory_path = format!("file:{}?mode=memory&cache=shared", uri);
+            return Self::open_with_flags(
+                &in_memory_path,
+                false,
+                SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE | SQLITE_OPEN_URI,
+            )
+            .expect("Could not create fallback in memory db");
         } else {
-            ":memory:".to_string()
-        };
-
-        Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
+            Self::open(":memory:", false).expect("Could not create fallback in memory db")
+        }
     }
 
     pub fn persistent(&self) -> bool {
@@ -265,9 +276,50 @@ impl Drop for Connection {
 mod test {
     use anyhow::Result;
     use indoc::indoc;
+    use std::{
+        fs,
+        sync::atomic::{AtomicUsize, Ordering},
+    };
 
     use crate::connection::Connection;
 
+    static NEXT_NAMED_MEMORY_DB_ID: AtomicUsize = AtomicUsize::new(0);
+
+    fn unique_named_memory_db(prefix: &str) -> String {
+        format!(
+            "{prefix}_{}_{}",
+            std::process::id(),
+            NEXT_NAMED_MEMORY_DB_ID.fetch_add(1, Ordering::Relaxed)
+        )
+    }
+
+    fn literal_named_memory_paths(name: &str) -> [String; 3] {
+        let main = format!("file:{name}?mode=memory&cache=shared");
+        [main.clone(), format!("{main}-wal"), format!("{main}-shm")]
+    }
+
+    struct NamedMemoryPathGuard {
+        paths: [String; 3],
+    }
+
+    impl NamedMemoryPathGuard {
+        fn new(name: &str) -> Self {
+            let paths = literal_named_memory_paths(name);
+            for path in &paths {
+                let _ = fs::remove_file(path);
+            }
+            Self { paths }
+        }
+    }
+
+    impl Drop for NamedMemoryPathGuard {
+        fn drop(&mut self) {
+            for path in &self.paths {
+                let _ = fs::remove_file(path);
+            }
+        }
+    }
+
     #[test]
     fn string_round_trips() -> Result<()> {
         let connection = Connection::open_memory(Some("string_round_trips"));
@@ -382,6 +434,41 @@ mod test {
         assert_eq!(read_blobs, vec![blob]);
     }
 
+    #[test]
+    fn named_memory_connections_do_not_create_literal_backing_files() {
+        let name = unique_named_memory_db("named_memory_connections_do_not_create_backing_files");
+        let guard = NamedMemoryPathGuard::new(&name);
+
+        let connection1 = Connection::open_memory(Some(&name));
+        connection1
+            .exec(indoc! {"
+                CREATE TABLE shared (
+                    value INTEGER
+                )"})
+            .unwrap()()
+        .unwrap();
+        connection1
+            .exec("INSERT INTO shared (value) VALUES (7)")
+            .unwrap()()
+        .unwrap();
+
+        let connection2 = Connection::open_memory(Some(&name));
+        assert_eq!(
+            connection2
+                .select_row::<i64>("SELECT value FROM shared")
+                .unwrap()()
+            .unwrap(),
+            Some(7)
+        );
+
+        for path in &guard.paths {
+            assert!(
+                fs::metadata(path).is_err(),
+                "named in-memory database unexpectedly created backing file {path}"
+            );
+        }
+    }
+
     #[test]
     fn multi_step_statement_works() {
         let connection = Connection::open_memory(Some("multi_step_statement_works"));

crates/sqlez/src/thread_safe_connection.rs 🔗

@@ -7,12 +7,15 @@ use std::{
     ops::Deref,
     sync::{Arc, LazyLock},
     thread,
+    time::Duration,
 };
 use thread_local::ThreadLocal;
 
 use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
 
 const MIGRATION_RETRIES: usize = 10;
+const CONNECTION_INITIALIZE_RETRIES: usize = 50;
+const CONNECTION_INITIALIZE_RETRY_DELAY: Duration = Duration::from_millis(1);
 
 type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
 type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
@@ -197,21 +200,54 @@ impl ThreadSafeConnection {
             Self::open_shared_memory(uri)
         };
 
+        if let Some(initialize_query) = connection_initialize_query {
+            let mut last_error = None;
+            let initialized = (0..CONNECTION_INITIALIZE_RETRIES).any(|attempt| {
+                match connection
+                    .exec(initialize_query)
+                    .and_then(|mut statement| statement())
+                {
+                    Ok(()) => true,
+                    Err(err)
+                        if is_schema_lock_error(&err)
+                            && attempt + 1 < CONNECTION_INITIALIZE_RETRIES =>
+                    {
+                        last_error = Some(err);
+                        thread::sleep(CONNECTION_INITIALIZE_RETRY_DELAY);
+                        false
+                    }
+                    Err(err) => {
+                        panic!(
+                            "Initialize query failed to execute: {}\n\nCaused by:\n{err:#}",
+                            initialize_query
+                        )
+                    }
+                }
+            });
+
+            if !initialized {
+                let err = last_error
+                    .expect("connection initialization retries should record the last error");
+                panic!(
+                    "Initialize query failed to execute after retries: {}\n\nCaused by:\n{err:#}",
+                    initialize_query
+                );
+            }
+        }
+
         // Disallow writes on the connection. The only writes allowed for thread safe connections
         // are from the background thread that can serialize them.
         *connection.write.get_mut() = false;
 
-        if let Some(initialize_query) = connection_initialize_query {
-            connection.exec(initialize_query).unwrap_or_else(|_| {
-                panic!("Initialize query failed to execute: {}", initialize_query)
-            })()
-            .unwrap()
-        }
-
         connection
     }
 }
 
+fn is_schema_lock_error(err: &anyhow::Error) -> bool {
+    let message = format!("{err:#}");
+    message.contains("database schema is locked") || message.contains("database is locked")
+}
+
 impl ThreadSafeConnection {
     /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
     /// This allows construction to be infallible and not write to the db.
@@ -282,7 +318,7 @@ mod test {
     use indoc::indoc;
     use std::ops::Deref;
 
-    use std::thread;
+    use std::{thread, time::Duration};
 
     use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
 
@@ -318,38 +354,21 @@ mod test {
     }
 
     #[test]
-    #[should_panic]
-    fn wild_zed_lost_failure() {
-        enum TestWorkspace {}
-        impl Domain for TestWorkspace {
-            const NAME: &str = "workspace";
-
-            const MIGRATIONS: &[&str] = &["
-                    CREATE TABLE workspaces(
-                        workspace_id INTEGER PRIMARY KEY,
-                        dock_visible INTEGER, -- Boolean
-                        dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
-                        dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
-                        timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
-                        FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
-                        FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
-                    ) STRICT;
-
-                    CREATE TABLE panes(
-                        pane_id INTEGER PRIMARY KEY,
-                        workspace_id INTEGER NOT NULL,
-                        active INTEGER NOT NULL, -- Boolean
-                        FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
-                            ON DELETE CASCADE
-                            ON UPDATE CASCADE
-                    ) STRICT;
-                "];
-        }
-
-        let builder =
-            ThreadSafeConnection::builder::<TestWorkspace>("wild_zed_lost_failure", false)
-                .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
-
-        smol::block_on(builder.build()).unwrap();
+    fn connection_initialize_query_retries_transient_schema_lock() {
+        let name = "connection_initialize_query_retries_transient_schema_lock";
+        let locking_connection = crate::connection::Connection::open_memory(Some(name));
+        locking_connection.exec("BEGIN IMMEDIATE").unwrap()().unwrap();
+        locking_connection
+            .exec("CREATE TABLE test(col TEXT)")
+            .unwrap()()
+        .unwrap();
+
+        let releaser = thread::spawn(move || {
+            thread::sleep(Duration::from_millis(10));
+            locking_connection.exec("ROLLBACK").unwrap()().unwrap();
+        });
+
+        ThreadSafeConnection::create_connection(false, name, Some("PRAGMA FOREIGN_KEYS=true"));
+        releaser.join().unwrap();
     }
 }