make thread safe connection more thread safe

Kay Simmons and Mikayla Maki created

Co-Authored-By: Mikayla Maki <mikayla@zed.dev>

Change summary

Cargo.lock                                 |   2 
crates/db/Cargo.toml                       |   1 
crates/db/src/db.rs                        |  32 ++-
crates/db/src/kvp.rs                       |   6 
crates/sqlez/Cargo.toml                    |   1 
crates/sqlez/src/migrations.rs             |   6 
crates/sqlez/src/thread_safe_connection.rs | 230 +++++++++++++++--------
crates/sqlez/src/util.rs                   |   4 
crates/sqlez_macros/src/sqlez_macros.rs    |   2 
crates/workspace/src/persistence.rs        |  14 
crates/workspace/src/workspace.rs          |  17 -
crates/zed/src/zed.rs                      |   5 
12 files changed, 196 insertions(+), 124 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -1569,6 +1569,7 @@ dependencies = [
  "log",
  "parking_lot 0.11.2",
  "serde",
+ "smol",
  "sqlez",
  "sqlez_macros",
  "tempdir",
@@ -5596,6 +5597,7 @@ dependencies = [
  "lazy_static",
  "libsqlite3-sys",
  "parking_lot 0.11.2",
+ "smol",
  "thread_local",
 ]
 

crates/db/Cargo.toml 🔗

@@ -23,6 +23,7 @@ lazy_static = "1.4.0"
 log = { version = "0.4.16", features = ["kv_unstable_serde"] }
 parking_lot = "0.11.1"
 serde = { version = "1.0", features = ["derive"] }
+smol = "1.2"
 
 [dev-dependencies]
 gpui = { path = "../gpui", features = ["test-support"] }

crates/db/src/db.rs 🔗

@@ -4,31 +4,36 @@ pub mod kvp;
 pub use anyhow;
 pub use indoc::indoc;
 pub use lazy_static;
+pub use smol;
 pub use sqlez;
 pub use sqlez_macros;
 
 use sqlez::domain::Migrator;
 use sqlez::thread_safe_connection::ThreadSafeConnection;
+use sqlez_macros::sql;
 use std::fs::{create_dir_all, remove_dir_all};
 use std::path::Path;
 use std::sync::atomic::{AtomicBool, Ordering};
 use util::channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 use util::paths::DB_DIR;
 
-const INITIALIZE_QUERY: &'static str = indoc! {"
-    PRAGMA journal_mode=WAL;
+const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
     PRAGMA synchronous=NORMAL;
     PRAGMA busy_timeout=1;
     PRAGMA foreign_keys=TRUE;
     PRAGMA case_sensitive_like=TRUE;
-"};
+);
+
+const DB_INITIALIZE_QUERY: &'static str = sql!(
+    PRAGMA journal_mode=WAL;
+);
 
 lazy_static::lazy_static! {
     static ref DB_WIPED: AtomicBool = AtomicBool::new(false);
 }
 
 /// Open or create a database at the given directory path.
-pub fn open_file_db<M: Migrator>() -> ThreadSafeConnection<M> {
+pub async fn open_file_db<M: Migrator>() -> ThreadSafeConnection<M> {
     // Use 0 for now. Will implement incrementing and clearing of old db files soon TM
     let current_db_dir = (*DB_DIR).join(Path::new(&format!("0-{}", *RELEASE_CHANNEL_NAME)));
 
@@ -43,12 +48,19 @@ pub fn open_file_db<M: Migrator>() -> ThreadSafeConnection<M> {
     create_dir_all(&current_db_dir).expect("Should be able to create the database directory");
     let db_path = current_db_dir.join(Path::new("db.sqlite"));
 
-    ThreadSafeConnection::new(db_path.to_string_lossy().as_ref(), true)
-        .with_initialize_query(INITIALIZE_QUERY)
+    ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
+        .with_db_initialization_query(DB_INITIALIZE_QUERY)
+        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
+        .build()
+        .await
 }
 
-pub fn open_memory_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
-    ThreadSafeConnection::new(db_name, false).with_initialize_query(INITIALIZE_QUERY)
+pub async fn open_memory_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
+    ThreadSafeConnection::<M>::builder(db_name, false)
+        .with_db_initialization_query(DB_INITIALIZE_QUERY)
+        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
+        .build()
+        .await
 }
 
 /// Implements a basic DB wrapper for a given domain
@@ -67,9 +79,9 @@ macro_rules! connection {
 
         ::db::lazy_static::lazy_static! {
             pub static ref $id: $t = $t(if cfg!(any(test, feature = "test-support")) {
-                ::db::open_memory_db(stringify!($id))
+                $crate::smol::block_on(::db::open_memory_db(stringify!($id)))
             } else {
-                ::db::open_file_db()
+                $crate::smol::block_on(::db::open_file_db())
             });
         }
     };

crates/db/src/kvp.rs 🔗

@@ -15,9 +15,9 @@ impl std::ops::Deref for KeyValueStore {
 
 lazy_static::lazy_static! {
     pub static ref KEY_VALUE_STORE: KeyValueStore = KeyValueStore(if cfg!(any(test, feature = "test-support")) {
-        open_memory_db(stringify!($id))
+        smol::block_on(open_memory_db("KEY_VALUE_STORE"))
     } else {
-        open_file_db()
+        smol::block_on(open_file_db())
     });
 }
 
@@ -62,7 +62,7 @@ mod tests {
 
     #[gpui::test]
     async fn test_kvp() {
-        let db = KeyValueStore(crate::open_memory_db("test_kvp"));
+        let db = KeyValueStore(crate::open_memory_db("test_kvp").await);
 
         assert_eq!(db.read_kvp("key-1").unwrap(), None);
 

crates/sqlez/Cargo.toml 🔗

@@ -9,6 +9,7 @@ edition = "2021"
 anyhow = { version = "1.0.38", features = ["backtrace"] }
 indoc = "1.0.7"
 libsqlite3-sys = { version = "0.25.2", features = ["bundled"] }
+smol = "1.2"
 thread_local = "1.1.4"
 lazy_static = "1.4"
 parking_lot = "0.11.1"

crates/sqlez/src/migrations.rs 🔗

@@ -15,9 +15,9 @@ impl Connection {
             // Setup the migrations table unconditionally
             self.exec(indoc! {"
                 CREATE TABLE IF NOT EXISTS migrations (
-                domain TEXT,
-                step INTEGER,
-                migration TEXT
+                    domain TEXT,
+                    step INTEGER,
+                    migration TEXT
                 )"})?()?;
 
             let completed_migrations =

crates/sqlez/src/thread_safe_connection.rs 🔗

@@ -1,4 +1,4 @@
-use futures::{Future, FutureExt};
+use futures::{channel::oneshot, Future, FutureExt};
 use lazy_static::lazy_static;
 use parking_lot::RwLock;
 use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
@@ -10,17 +10,25 @@ use crate::{
     util::UnboundedSyncSender,
 };
 
-type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
+const MIGRATION_RETRIES: usize = 10;
 
+type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
 lazy_static! {
+    /// List of queues of tasks by database uri. This lets us serialize writes to the database
+    /// and have a single worker thread per db file. This means many thread safe connections
+    /// (possibly with different migrations) could all be communicating with the same background
+    /// thread.
     static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
         Default::default();
 }
 
+/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
+/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
+/// may be accessed by passing a callback to the `write` function which will queue the callback
 pub struct ThreadSafeConnection<M: Migrator = ()> {
     uri: Arc<str>,
     persistent: bool,
-    initialize_query: Option<&'static str>,
+    connection_initialize_query: Option<&'static str>,
     connections: Arc<ThreadLocal<Connection>>,
     _migrator: PhantomData<M>,
 }
@@ -28,87 +36,125 @@ pub struct ThreadSafeConnection<M: Migrator = ()> {
 unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
 unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
 
-impl<M: Migrator> ThreadSafeConnection<M> {
-    pub fn new(uri: &str, persistent: bool) -> Self {
-        Self {
-            uri: Arc::from(uri),
-            persistent,
-            initialize_query: None,
-            connections: Default::default(),
-            _migrator: PhantomData,
-        }
-    }
+pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
+    db_initialize_query: Option<&'static str>,
+    connection: ThreadSafeConnection<M>,
+}
 
+impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
     /// Sets the query to run every time a connection is opened. This must
-    /// be infallible (EG only use pragma statements)
-    pub fn with_initialize_query(mut self, initialize_query: &'static str) -> Self {
-        self.initialize_query = Some(initialize_query);
+    /// be infallible (EG only use pragma statements) and not cause writes.
+    /// to the db or it will panic.
+    pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
+        self.connection.connection_initialize_query = Some(initialize_query);
+        self
+    }
+
+    /// Queues an initialization query for the database file. This must be infallible
+    /// but may cause changes to the database file such as with `PRAGMA journal_mode`
+    pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
+        self.db_initialize_query = Some(initialize_query);
         self
     }
 
+    pub async fn build(self) -> ThreadSafeConnection<M> {
+        let db_initialize_query = self.db_initialize_query;
+
+        self.connection
+            .write(move |connection| {
+                if let Some(db_initialize_query) = db_initialize_query {
+                    connection.exec(db_initialize_query).expect(&format!(
+                        "Db initialize query failed to execute: {}",
+                        db_initialize_query
+                    ))()
+                    .unwrap();
+                }
+
+                let mut failure_result = None;
+                for _ in 0..MIGRATION_RETRIES {
+                    failure_result = Some(M::migrate(connection));
+                    if failure_result.as_ref().unwrap().is_ok() {
+                        break;
+                    }
+                }
+
+                failure_result.unwrap().expect("Migration failed");
+            })
+            .await;
+
+        self.connection
+    }
+}
+
+impl<M: Migrator> ThreadSafeConnection<M> {
+    pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
+        ThreadSafeConnectionBuilder::<M> {
+            db_initialize_query: None,
+            connection: Self {
+                uri: Arc::from(uri),
+                persistent,
+                connection_initialize_query: None,
+                connections: Default::default(),
+                _migrator: PhantomData,
+            },
+        }
+    }
+
     /// Opens a new db connection with the initialized file path. This is internal and only
     /// called from the deref function.
-    /// If opening fails, the connection falls back to a shared memory connection
     fn open_file(&self) -> Connection {
-        // This unwrap is secured by a panic in the constructor. Be careful if you remove it!
         Connection::open_file(self.uri.as_ref())
     }
 
-    /// Opens a shared memory connection using the file path as the identifier. This unwraps
-    /// as we expect it always to succeed
+    /// Opens a shared memory connection using the file path as the identifier. This is internal
+    /// and only called from the deref function.
     fn open_shared_memory(&self) -> Connection {
         Connection::open_memory(Some(self.uri.as_ref()))
     }
 
-    // Open a new connection for the given domain, leaving this
-    // connection intact.
-    pub fn for_domain<D2: Domain>(&self) -> ThreadSafeConnection<D2> {
-        ThreadSafeConnection {
-            uri: self.uri.clone(),
-            persistent: self.persistent,
-            initialize_query: self.initialize_query,
-            connections: Default::default(),
-            _migrator: PhantomData,
-        }
-    }
-
-    pub fn write<T: 'static + Send + Sync>(
-        &self,
-        callback: impl 'static + Send + FnOnce(&Connection) -> T,
-    ) -> impl Future<Output = T> {
+    fn queue_write_task(&self, callback: QueuedWrite) {
         // Startup write thread for this database if one hasn't already
         // been started and insert a channel to queue work for it
         if !QUEUES.read().contains_key(&self.uri) {
-            use std::sync::mpsc::channel;
-
-            let (sender, reciever) = channel::<QueuedWrite>();
-            let mut write_connection = self.create_connection();
-            // Enable writes for this connection
-            write_connection.write = true;
-            thread::spawn(move || {
-                while let Ok(write) = reciever.recv() {
-                    write(&write_connection)
-                }
-            });
-
             let mut queues = QUEUES.write();
-            queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
+            if !queues.contains_key(&self.uri) {
+                use std::sync::mpsc::channel;
+
+                let (sender, reciever) = channel::<QueuedWrite>();
+                let mut write_connection = self.create_connection();
+                // Enable writes for this connection
+                write_connection.write = true;
+                thread::spawn(move || {
+                    while let Ok(write) = reciever.recv() {
+                        write(&write_connection)
+                    }
+                });
+
+                queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
+            }
         }
 
         // Grab the queue for this database
         let queues = QUEUES.read();
         let write_channel = queues.get(&self.uri).unwrap();
 
+        write_channel
+            .send(callback)
+            .expect("Could not send write action to backgorund thread");
+    }
+
+    pub fn write<T: 'static + Send + Sync>(
+        &self,
+        callback: impl 'static + Send + FnOnce(&Connection) -> T,
+    ) -> impl Future<Output = T> {
         // Create a one shot channel for the result of the queued write
         // so we can await on the result
-        let (sender, reciever) = futures::channel::oneshot::channel();
-        write_channel
-            .send(Box::new(move |connection| {
-                sender.send(callback(connection)).ok();
-            }))
-            .expect("Could not send write action to background thread");
+        let (sender, reciever) = oneshot::channel();
+        self.queue_write_task(Box::new(move |connection| {
+            sender.send(callback(connection)).ok();
+        }));
 
-        reciever.map(|response| response.expect("Background thread unexpectedly closed"))
+        reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
     }
 
     pub(crate) fn create_connection(&self) -> Connection {
@@ -118,10 +164,11 @@ impl<M: Migrator> ThreadSafeConnection<M> {
             self.open_shared_memory()
         };
 
-        // Enable writes for the migrations and initialization queries
-        connection.write = true;
+        // Disallow writes on the connection. The only writes allowed for thread safe connections
+        // are from the background thread that can serialize them.
+        connection.write = false;
 
-        if let Some(initialize_query) = self.initialize_query {
+        if let Some(initialize_query) = self.connection_initialize_query {
             connection.exec(initialize_query).expect(&format!(
                 "Initialize query failed to execute: {}",
                 initialize_query
@@ -129,20 +176,34 @@ impl<M: Migrator> ThreadSafeConnection<M> {
             .unwrap()
         }
 
-        M::migrate(&connection).expect("Migrations failed");
-
-        // Disable db writes for normal thread local connection
-        connection.write = false;
         connection
     }
 }
 
+impl ThreadSafeConnection<()> {
+    /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
+    /// This allows construction to be infallible and not write to the db.
+    pub fn new(
+        uri: &str,
+        persistent: bool,
+        connection_initialize_query: Option<&'static str>,
+    ) -> Self {
+        Self {
+            uri: Arc::from(uri),
+            persistent,
+            connection_initialize_query,
+            connections: Default::default(),
+            _migrator: PhantomData,
+        }
+    }
+}
+
 impl<D: Domain> Clone for ThreadSafeConnection<D> {
     fn clone(&self) -> Self {
         Self {
             uri: self.uri.clone(),
             persistent: self.persistent,
-            initialize_query: self.initialize_query.clone(),
+            connection_initialize_query: self.connection_initialize_query.clone(),
             connections: self.connections.clone(),
             _migrator: PhantomData,
         }
@@ -163,11 +224,11 @@ impl<M: Migrator> Deref for ThreadSafeConnection<M> {
 
 #[cfg(test)]
 mod test {
-    use std::{fs, ops::Deref, thread};
+    use indoc::indoc;
+    use lazy_static::__Deref;
+    use std::thread;
 
-    use crate::domain::Domain;
-
-    use super::ThreadSafeConnection;
+    use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
 
     #[test]
     fn many_initialize_and_migrate_queries_at_once() {
@@ -185,27 +246,22 @@ mod test {
 
         for _ in 0..100 {
             handles.push(thread::spawn(|| {
-                let _ = ThreadSafeConnection::<TestDomain>::new("annoying-test.db", false)
-                    .with_initialize_query(
-                        "
-                        PRAGMA journal_mode=WAL;
-                        PRAGMA synchronous=NORMAL;
-                        PRAGMA busy_timeout=1;
-                        PRAGMA foreign_keys=TRUE;
-                        PRAGMA case_sensitive_like=TRUE;
-                    ",
-                    )
-                    .deref();
+                let builder =
+                    ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
+                        .with_db_initialization_query("PRAGMA journal_mode=WAL")
+                        .with_connection_initialize_query(indoc! {"
+                                PRAGMA synchronous=NORMAL;
+                                PRAGMA busy_timeout=1;
+                                PRAGMA foreign_keys=TRUE;
+                                PRAGMA case_sensitive_like=TRUE;
+                            "});
+                let _ = smol::block_on(builder.build()).deref();
             }));
         }
 
         for handle in handles {
             let _ = handle.join();
         }
-
-        // fs::remove_file("annoying-test.db").unwrap();
-        // fs::remove_file("annoying-test.db-shm").unwrap();
-        // fs::remove_file("annoying-test.db-wal").unwrap();
     }
 
     #[test]
@@ -241,8 +297,10 @@ mod test {
             }
         }
 
-        let _ = ThreadSafeConnection::<TestWorkspace>::new("wild_zed_lost_failure", false)
-            .with_initialize_query("PRAGMA FOREIGN_KEYS=true")
-            .deref();
+        let builder =
+            ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
+                .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
+
+        smol::block_on(builder.build());
     }
 }

crates/sqlez/src/util.rs 🔗

@@ -4,6 +4,10 @@ use std::sync::mpsc::Sender;
 use parking_lot::Mutex;
 use thread_local::ThreadLocal;
 
+/// Unbounded standard library sender which is stored per thread to get around
+/// the lack of sync on the standard library version while still being unbounded
+/// Note: this locks on the cloneable sender, but its done once per thread, so it
+/// shouldn't result in too much contention
 pub struct UnboundedSyncSender<T: Send> {
     clonable_sender: Mutex<Sender<T>>,
     local_senders: ThreadLocal<Sender<T>>,

crates/sqlez_macros/src/sqlez_macros.rs 🔗

@@ -3,7 +3,7 @@ use sqlez::thread_safe_connection::ThreadSafeConnection;
 use syn::Error;
 
 lazy_static::lazy_static! {
-    static ref SQLITE: ThreadSafeConnection = ThreadSafeConnection::new(":memory:", false);
+    static ref SQLITE: ThreadSafeConnection = ThreadSafeConnection::new(":memory:", false, None);
 }
 
 #[proc_macro]

crates/workspace/src/persistence.rs 🔗

@@ -395,7 +395,7 @@ mod tests {
     async fn test_next_id_stability() {
         env_logger::try_init().ok();
 
-        let db = WorkspaceDb(open_memory_db("test_next_id_stability"));
+        let db = WorkspaceDb(open_memory_db("test_next_id_stability").await);
 
         db.write(|conn| {
             conn.migrate(
@@ -442,7 +442,7 @@ mod tests {
     async fn test_workspace_id_stability() {
         env_logger::try_init().ok();
 
-        let db = WorkspaceDb(open_memory_db("test_workspace_id_stability"));
+        let db = WorkspaceDb(open_memory_db("test_workspace_id_stability").await);
 
         db.write(|conn| {
             conn.migrate(
@@ -523,7 +523,7 @@ mod tests {
     async fn test_full_workspace_serialization() {
         env_logger::try_init().ok();
 
-        let db = WorkspaceDb(open_memory_db("test_full_workspace_serialization"));
+        let db = WorkspaceDb(open_memory_db("test_full_workspace_serialization").await);
 
         let dock_pane = crate::persistence::model::SerializedPane {
             children: vec![
@@ -597,7 +597,7 @@ mod tests {
     async fn test_workspace_assignment() {
         env_logger::try_init().ok();
 
-        let db = WorkspaceDb(open_memory_db("test_basic_functionality"));
+        let db = WorkspaceDb(open_memory_db("test_basic_functionality").await);
 
         let workspace_1 = SerializedWorkspace {
             id: 1,
@@ -689,7 +689,7 @@ mod tests {
     async fn test_basic_dock_pane() {
         env_logger::try_init().ok();
 
-        let db = WorkspaceDb(open_memory_db("basic_dock_pane"));
+        let db = WorkspaceDb(open_memory_db("basic_dock_pane").await);
 
         let dock_pane = crate::persistence::model::SerializedPane::new(
             vec![
@@ -714,7 +714,7 @@ mod tests {
     async fn test_simple_split() {
         env_logger::try_init().ok();
 
-        let db = WorkspaceDb(open_memory_db("simple_split"));
+        let db = WorkspaceDb(open_memory_db("simple_split").await);
 
         //  -----------------
         //  | 1,2   | 5,6   |
@@ -766,7 +766,7 @@ mod tests {
     async fn test_cleanup_panes() {
         env_logger::try_init().ok();
 
-        let db = WorkspaceDb(open_memory_db("test_cleanup_panes"));
+        let db = WorkspaceDb(open_memory_db("test_cleanup_panes").await);
 
         let center_pane = SerializedPaneGroup::Group {
             axis: gpui::Axis::Horizontal,

crates/workspace/src/workspace.rs 🔗

@@ -162,11 +162,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut MutableAppContext) {
         let app_state = Arc::downgrade(&app_state);
         move |_: &NewFile, cx: &mut MutableAppContext| {
             if let Some(app_state) = app_state.upgrade() {
-                let task = open_new(&app_state, cx);
-                cx.spawn(|_| async {
-                    task.await;
-                })
-                .detach();
+                open_new(&app_state, cx).detach();
             }
         }
     });
@@ -174,11 +170,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut MutableAppContext) {
         let app_state = Arc::downgrade(&app_state);
         move |_: &NewWindow, cx: &mut MutableAppContext| {
             if let Some(app_state) = app_state.upgrade() {
-                let task = open_new(&app_state, cx);
-                cx.spawn(|_| async {
-                    task.await;
-                })
-                .detach();
+                open_new(&app_state, cx).detach();
             }
         }
     });
@@ -2641,13 +2633,16 @@ pub fn open_paths(
     })
 }
 
-fn open_new(app_state: &Arc<AppState>, cx: &mut MutableAppContext) -> Task<()> {
+pub fn open_new(app_state: &Arc<AppState>, cx: &mut MutableAppContext) -> Task<()> {
     let task = Workspace::new_local(Vec::new(), app_state.clone(), cx);
     cx.spawn(|mut cx| async move {
+        eprintln!("Open new task spawned");
         let (workspace, opened_paths) = task.await;
+        eprintln!("workspace and path items created");
 
         workspace.update(&mut cx, |_, cx| {
             if opened_paths.is_empty() {
+                eprintln!("new file redispatched");
                 cx.dispatch_action(NewFile);
             }
         })

crates/zed/src/zed.rs 🔗

@@ -626,7 +626,7 @@ mod tests {
     use theme::ThemeRegistry;
     use workspace::{
         item::{Item, ItemHandle},
-        open_paths, pane, NewFile, Pane, SplitDirection, WorkspaceHandle,
+        open_new, open_paths, pane, NewFile, Pane, SplitDirection, WorkspaceHandle,
     };
 
     #[gpui::test]
@@ -762,8 +762,7 @@ mod tests {
     #[gpui::test]
     async fn test_new_empty_workspace(cx: &mut TestAppContext) {
         let app_state = init(cx);
-        cx.dispatch_global_action(workspace::NewFile);
-        cx.foreground().run_until_parked();
+        cx.update(|cx| open_new(&app_state, cx)).await;
 
         let window_id = *cx.window_ids().first().unwrap();
         let workspace = cx.root_view::<Workspace>(window_id).unwrap();