Detailed changes
@@ -1569,6 +1569,7 @@ dependencies = [
"log",
"parking_lot 0.11.2",
"serde",
+ "smol",
"sqlez",
"sqlez_macros",
"tempdir",
@@ -5596,6 +5597,7 @@ dependencies = [
"lazy_static",
"libsqlite3-sys",
"parking_lot 0.11.2",
+ "smol",
"thread_local",
]
@@ -23,6 +23,7 @@ lazy_static = "1.4.0"
log = { version = "0.4.16", features = ["kv_unstable_serde"] }
parking_lot = "0.11.1"
serde = { version = "1.0", features = ["derive"] }
+smol = "1.2"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }
@@ -4,31 +4,36 @@ pub mod kvp;
pub use anyhow;
pub use indoc::indoc;
pub use lazy_static;
+pub use smol;
pub use sqlez;
pub use sqlez_macros;
use sqlez::domain::Migrator;
use sqlez::thread_safe_connection::ThreadSafeConnection;
+use sqlez_macros::sql;
use std::fs::{create_dir_all, remove_dir_all};
use std::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};
use util::channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
use util::paths::DB_DIR;
-const INITIALIZE_QUERY: &'static str = indoc! {"
- PRAGMA journal_mode=WAL;
+const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
PRAGMA synchronous=NORMAL;
PRAGMA busy_timeout=1;
PRAGMA foreign_keys=TRUE;
PRAGMA case_sensitive_like=TRUE;
-"};
+);
+
+const DB_INITIALIZE_QUERY: &'static str = sql!(
+ PRAGMA journal_mode=WAL;
+);
lazy_static::lazy_static! {
static ref DB_WIPED: AtomicBool = AtomicBool::new(false);
}
/// Open or create a database at the given directory path.
-pub fn open_file_db<M: Migrator>() -> ThreadSafeConnection<M> {
+pub async fn open_file_db<M: Migrator>() -> ThreadSafeConnection<M> {
// Use 0 for now. Will implement incrementing and clearing of old db files soon TM
let current_db_dir = (*DB_DIR).join(Path::new(&format!("0-{}", *RELEASE_CHANNEL_NAME)));
@@ -43,12 +48,19 @@ pub fn open_file_db<M: Migrator>() -> ThreadSafeConnection<M> {
create_dir_all(¤t_db_dir).expect("Should be able to create the database directory");
let db_path = current_db_dir.join(Path::new("db.sqlite"));
- ThreadSafeConnection::new(db_path.to_string_lossy().as_ref(), true)
- .with_initialize_query(INITIALIZE_QUERY)
+ ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
+ .with_db_initialization_query(DB_INITIALIZE_QUERY)
+ .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
+ .build()
+ .await
}
-pub fn open_memory_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
- ThreadSafeConnection::new(db_name, false).with_initialize_query(INITIALIZE_QUERY)
+pub async fn open_memory_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
+ ThreadSafeConnection::<M>::builder(db_name, false)
+ .with_db_initialization_query(DB_INITIALIZE_QUERY)
+ .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
+ .build()
+ .await
}
/// Implements a basic DB wrapper for a given domain
@@ -67,9 +79,9 @@ macro_rules! connection {
::db::lazy_static::lazy_static! {
pub static ref $id: $t = $t(if cfg!(any(test, feature = "test-support")) {
- ::db::open_memory_db(stringify!($id))
+ $crate::smol::block_on(::db::open_memory_db(stringify!($id)))
} else {
- ::db::open_file_db()
+ $crate::smol::block_on(::db::open_file_db())
});
}
};
@@ -15,9 +15,9 @@ impl std::ops::Deref for KeyValueStore {
lazy_static::lazy_static! {
pub static ref KEY_VALUE_STORE: KeyValueStore = KeyValueStore(if cfg!(any(test, feature = "test-support")) {
- open_memory_db(stringify!($id))
+ smol::block_on(open_memory_db("KEY_VALUE_STORE"))
} else {
- open_file_db()
+ smol::block_on(open_file_db())
});
}
@@ -62,7 +62,7 @@ mod tests {
#[gpui::test]
async fn test_kvp() {
- let db = KeyValueStore(crate::open_memory_db("test_kvp"));
+ let db = KeyValueStore(crate::open_memory_db("test_kvp").await);
assert_eq!(db.read_kvp("key-1").unwrap(), None);
@@ -9,6 +9,7 @@ edition = "2021"
anyhow = { version = "1.0.38", features = ["backtrace"] }
indoc = "1.0.7"
libsqlite3-sys = { version = "0.25.2", features = ["bundled"] }
+smol = "1.2"
thread_local = "1.1.4"
lazy_static = "1.4"
parking_lot = "0.11.1"
@@ -15,9 +15,9 @@ impl Connection {
// Setup the migrations table unconditionally
self.exec(indoc! {"
CREATE TABLE IF NOT EXISTS migrations (
- domain TEXT,
- step INTEGER,
- migration TEXT
+ domain TEXT,
+ step INTEGER,
+ migration TEXT
)"})?()?;
let completed_migrations =
@@ -1,4 +1,4 @@
-use futures::{Future, FutureExt};
+use futures::{channel::oneshot, Future, FutureExt};
use lazy_static::lazy_static;
use parking_lot::RwLock;
use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
@@ -10,17 +10,25 @@ use crate::{
util::UnboundedSyncSender,
};
-type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
+const MIGRATION_RETRIES: usize = 10;
+type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
lazy_static! {
+ /// List of queues of tasks by database uri. This lets us serialize writes to the database
+ /// and have a single worker thread per db file. This means many thread safe connections
+ /// (possibly with different migrations) could all be communicating with the same background
+ /// thread.
static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
Default::default();
}
+/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
+/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
+/// may be accessed by passing a callback to the `write` function which will queue the callback
pub struct ThreadSafeConnection<M: Migrator = ()> {
uri: Arc<str>,
persistent: bool,
- initialize_query: Option<&'static str>,
+ connection_initialize_query: Option<&'static str>,
connections: Arc<ThreadLocal<Connection>>,
_migrator: PhantomData<M>,
}
@@ -28,87 +36,125 @@ pub struct ThreadSafeConnection<M: Migrator = ()> {
unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
-impl<M: Migrator> ThreadSafeConnection<M> {
- pub fn new(uri: &str, persistent: bool) -> Self {
- Self {
- uri: Arc::from(uri),
- persistent,
- initialize_query: None,
- connections: Default::default(),
- _migrator: PhantomData,
- }
- }
+pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
+ db_initialize_query: Option<&'static str>,
+ connection: ThreadSafeConnection<M>,
+}
+impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
/// Sets the query to run every time a connection is opened. This must
- /// be infallible (EG only use pragma statements)
- pub fn with_initialize_query(mut self, initialize_query: &'static str) -> Self {
- self.initialize_query = Some(initialize_query);
+ /// be infallible (EG only use pragma statements) and not cause writes.
+ /// to the db or it will panic.
+ pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
+ self.connection.connection_initialize_query = Some(initialize_query);
+ self
+ }
+
+ /// Queues an initialization query for the database file. This must be infallible
+ /// but may cause changes to the database file such as with `PRAGMA journal_mode`
+ pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
+ self.db_initialize_query = Some(initialize_query);
self
}
+ pub async fn build(self) -> ThreadSafeConnection<M> {
+ let db_initialize_query = self.db_initialize_query;
+
+ self.connection
+ .write(move |connection| {
+ if let Some(db_initialize_query) = db_initialize_query {
+ connection.exec(db_initialize_query).expect(&format!(
+ "Db initialize query failed to execute: {}",
+ db_initialize_query
+ ))()
+ .unwrap();
+ }
+
+ let mut failure_result = None;
+ for _ in 0..MIGRATION_RETRIES {
+ failure_result = Some(M::migrate(connection));
+ if failure_result.as_ref().unwrap().is_ok() {
+ break;
+ }
+ }
+
+ failure_result.unwrap().expect("Migration failed");
+ })
+ .await;
+
+ self.connection
+ }
+}
+
+impl<M: Migrator> ThreadSafeConnection<M> {
+ pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
+ ThreadSafeConnectionBuilder::<M> {
+ db_initialize_query: None,
+ connection: Self {
+ uri: Arc::from(uri),
+ persistent,
+ connection_initialize_query: None,
+ connections: Default::default(),
+ _migrator: PhantomData,
+ },
+ }
+ }
+
/// Opens a new db connection with the initialized file path. This is internal and only
/// called from the deref function.
- /// If opening fails, the connection falls back to a shared memory connection
fn open_file(&self) -> Connection {
- // This unwrap is secured by a panic in the constructor. Be careful if you remove it!
Connection::open_file(self.uri.as_ref())
}
- /// Opens a shared memory connection using the file path as the identifier. This unwraps
- /// as we expect it always to succeed
+ /// Opens a shared memory connection using the file path as the identifier. This is internal
+ /// and only called from the deref function.
fn open_shared_memory(&self) -> Connection {
Connection::open_memory(Some(self.uri.as_ref()))
}
- // Open a new connection for the given domain, leaving this
- // connection intact.
- pub fn for_domain<D2: Domain>(&self) -> ThreadSafeConnection<D2> {
- ThreadSafeConnection {
- uri: self.uri.clone(),
- persistent: self.persistent,
- initialize_query: self.initialize_query,
- connections: Default::default(),
- _migrator: PhantomData,
- }
- }
-
- pub fn write<T: 'static + Send + Sync>(
- &self,
- callback: impl 'static + Send + FnOnce(&Connection) -> T,
- ) -> impl Future<Output = T> {
+ fn queue_write_task(&self, callback: QueuedWrite) {
// Startup write thread for this database if one hasn't already
// been started and insert a channel to queue work for it
if !QUEUES.read().contains_key(&self.uri) {
- use std::sync::mpsc::channel;
-
- let (sender, reciever) = channel::<QueuedWrite>();
- let mut write_connection = self.create_connection();
- // Enable writes for this connection
- write_connection.write = true;
- thread::spawn(move || {
- while let Ok(write) = reciever.recv() {
- write(&write_connection)
- }
- });
-
let mut queues = QUEUES.write();
- queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
+ if !queues.contains_key(&self.uri) {
+ use std::sync::mpsc::channel;
+
+ let (sender, reciever) = channel::<QueuedWrite>();
+ let mut write_connection = self.create_connection();
+ // Enable writes for this connection
+ write_connection.write = true;
+ thread::spawn(move || {
+ while let Ok(write) = reciever.recv() {
+ write(&write_connection)
+ }
+ });
+
+ queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
+ }
}
// Grab the queue for this database
let queues = QUEUES.read();
let write_channel = queues.get(&self.uri).unwrap();
+ write_channel
+ .send(callback)
+ .expect("Could not send write action to backgorund thread");
+ }
+
+ pub fn write<T: 'static + Send + Sync>(
+ &self,
+ callback: impl 'static + Send + FnOnce(&Connection) -> T,
+ ) -> impl Future<Output = T> {
// Create a one shot channel for the result of the queued write
// so we can await on the result
- let (sender, reciever) = futures::channel::oneshot::channel();
- write_channel
- .send(Box::new(move |connection| {
- sender.send(callback(connection)).ok();
- }))
- .expect("Could not send write action to background thread");
+ let (sender, reciever) = oneshot::channel();
+ self.queue_write_task(Box::new(move |connection| {
+ sender.send(callback(connection)).ok();
+ }));
- reciever.map(|response| response.expect("Background thread unexpectedly closed"))
+ reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
}
pub(crate) fn create_connection(&self) -> Connection {
@@ -118,10 +164,11 @@ impl<M: Migrator> ThreadSafeConnection<M> {
self.open_shared_memory()
};
- // Enable writes for the migrations and initialization queries
- connection.write = true;
+ // Disallow writes on the connection. The only writes allowed for thread safe connections
+ // are from the background thread that can serialize them.
+ connection.write = false;
- if let Some(initialize_query) = self.initialize_query {
+ if let Some(initialize_query) = self.connection_initialize_query {
connection.exec(initialize_query).expect(&format!(
"Initialize query failed to execute: {}",
initialize_query
@@ -129,20 +176,34 @@ impl<M: Migrator> ThreadSafeConnection<M> {
.unwrap()
}
- M::migrate(&connection).expect("Migrations failed");
-
- // Disable db writes for normal thread local connection
- connection.write = false;
connection
}
}
+impl ThreadSafeConnection<()> {
+ /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
+ /// This allows construction to be infallible and not write to the db.
+ pub fn new(
+ uri: &str,
+ persistent: bool,
+ connection_initialize_query: Option<&'static str>,
+ ) -> Self {
+ Self {
+ uri: Arc::from(uri),
+ persistent,
+ connection_initialize_query,
+ connections: Default::default(),
+ _migrator: PhantomData,
+ }
+ }
+}
+
impl<D: Domain> Clone for ThreadSafeConnection<D> {
fn clone(&self) -> Self {
Self {
uri: self.uri.clone(),
persistent: self.persistent,
- initialize_query: self.initialize_query.clone(),
+ connection_initialize_query: self.connection_initialize_query.clone(),
connections: self.connections.clone(),
_migrator: PhantomData,
}
@@ -163,11 +224,11 @@ impl<M: Migrator> Deref for ThreadSafeConnection<M> {
#[cfg(test)]
mod test {
- use std::{fs, ops::Deref, thread};
+ use indoc::indoc;
+ use lazy_static::__Deref;
+ use std::thread;
- use crate::domain::Domain;
-
- use super::ThreadSafeConnection;
+ use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
#[test]
fn many_initialize_and_migrate_queries_at_once() {
@@ -185,27 +246,22 @@ mod test {
for _ in 0..100 {
handles.push(thread::spawn(|| {
- let _ = ThreadSafeConnection::<TestDomain>::new("annoying-test.db", false)
- .with_initialize_query(
- "
- PRAGMA journal_mode=WAL;
- PRAGMA synchronous=NORMAL;
- PRAGMA busy_timeout=1;
- PRAGMA foreign_keys=TRUE;
- PRAGMA case_sensitive_like=TRUE;
- ",
- )
- .deref();
+ let builder =
+ ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
+ .with_db_initialization_query("PRAGMA journal_mode=WAL")
+ .with_connection_initialize_query(indoc! {"
+ PRAGMA synchronous=NORMAL;
+ PRAGMA busy_timeout=1;
+ PRAGMA foreign_keys=TRUE;
+ PRAGMA case_sensitive_like=TRUE;
+ "});
+ let _ = smol::block_on(builder.build()).deref();
}));
}
for handle in handles {
let _ = handle.join();
}
-
- // fs::remove_file("annoying-test.db").unwrap();
- // fs::remove_file("annoying-test.db-shm").unwrap();
- // fs::remove_file("annoying-test.db-wal").unwrap();
}
#[test]
@@ -241,8 +297,10 @@ mod test {
}
}
- let _ = ThreadSafeConnection::<TestWorkspace>::new("wild_zed_lost_failure", false)
- .with_initialize_query("PRAGMA FOREIGN_KEYS=true")
- .deref();
+ let builder =
+ ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
+ .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
+
+ smol::block_on(builder.build());
}
}
@@ -4,6 +4,10 @@ use std::sync::mpsc::Sender;
use parking_lot::Mutex;
use thread_local::ThreadLocal;
+/// Unbounded standard library sender which is stored per thread to get around
+/// the lack of sync on the standard library version while still being unbounded
+/// Note: this locks on the cloneable sender, but its done once per thread, so it
+/// shouldn't result in too much contention
pub struct UnboundedSyncSender<T: Send> {
clonable_sender: Mutex<Sender<T>>,
local_senders: ThreadLocal<Sender<T>>,
@@ -3,7 +3,7 @@ use sqlez::thread_safe_connection::ThreadSafeConnection;
use syn::Error;
lazy_static::lazy_static! {
- static ref SQLITE: ThreadSafeConnection = ThreadSafeConnection::new(":memory:", false);
+ static ref SQLITE: ThreadSafeConnection = ThreadSafeConnection::new(":memory:", false, None);
}
#[proc_macro]
@@ -395,7 +395,7 @@ mod tests {
async fn test_next_id_stability() {
env_logger::try_init().ok();
- let db = WorkspaceDb(open_memory_db("test_next_id_stability"));
+ let db = WorkspaceDb(open_memory_db("test_next_id_stability").await);
db.write(|conn| {
conn.migrate(
@@ -442,7 +442,7 @@ mod tests {
async fn test_workspace_id_stability() {
env_logger::try_init().ok();
- let db = WorkspaceDb(open_memory_db("test_workspace_id_stability"));
+ let db = WorkspaceDb(open_memory_db("test_workspace_id_stability").await);
db.write(|conn| {
conn.migrate(
@@ -523,7 +523,7 @@ mod tests {
async fn test_full_workspace_serialization() {
env_logger::try_init().ok();
- let db = WorkspaceDb(open_memory_db("test_full_workspace_serialization"));
+ let db = WorkspaceDb(open_memory_db("test_full_workspace_serialization").await);
let dock_pane = crate::persistence::model::SerializedPane {
children: vec![
@@ -597,7 +597,7 @@ mod tests {
async fn test_workspace_assignment() {
env_logger::try_init().ok();
- let db = WorkspaceDb(open_memory_db("test_basic_functionality"));
+ let db = WorkspaceDb(open_memory_db("test_basic_functionality").await);
let workspace_1 = SerializedWorkspace {
id: 1,
@@ -689,7 +689,7 @@ mod tests {
async fn test_basic_dock_pane() {
env_logger::try_init().ok();
- let db = WorkspaceDb(open_memory_db("basic_dock_pane"));
+ let db = WorkspaceDb(open_memory_db("basic_dock_pane").await);
let dock_pane = crate::persistence::model::SerializedPane::new(
vec![
@@ -714,7 +714,7 @@ mod tests {
async fn test_simple_split() {
env_logger::try_init().ok();
- let db = WorkspaceDb(open_memory_db("simple_split"));
+ let db = WorkspaceDb(open_memory_db("simple_split").await);
// -----------------
// | 1,2 | 5,6 |
@@ -766,7 +766,7 @@ mod tests {
async fn test_cleanup_panes() {
env_logger::try_init().ok();
- let db = WorkspaceDb(open_memory_db("test_cleanup_panes"));
+ let db = WorkspaceDb(open_memory_db("test_cleanup_panes").await);
let center_pane = SerializedPaneGroup::Group {
axis: gpui::Axis::Horizontal,
@@ -162,11 +162,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut MutableAppContext) {
let app_state = Arc::downgrade(&app_state);
move |_: &NewFile, cx: &mut MutableAppContext| {
if let Some(app_state) = app_state.upgrade() {
- let task = open_new(&app_state, cx);
- cx.spawn(|_| async {
- task.await;
- })
- .detach();
+ open_new(&app_state, cx).detach();
}
}
});
@@ -174,11 +170,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut MutableAppContext) {
let app_state = Arc::downgrade(&app_state);
move |_: &NewWindow, cx: &mut MutableAppContext| {
if let Some(app_state) = app_state.upgrade() {
- let task = open_new(&app_state, cx);
- cx.spawn(|_| async {
- task.await;
- })
- .detach();
+ open_new(&app_state, cx).detach();
}
}
});
@@ -2641,13 +2633,16 @@ pub fn open_paths(
})
}
-fn open_new(app_state: &Arc<AppState>, cx: &mut MutableAppContext) -> Task<()> {
+pub fn open_new(app_state: &Arc<AppState>, cx: &mut MutableAppContext) -> Task<()> {
let task = Workspace::new_local(Vec::new(), app_state.clone(), cx);
cx.spawn(|mut cx| async move {
+ eprintln!("Open new task spawned");
let (workspace, opened_paths) = task.await;
+ eprintln!("workspace and path items created");
workspace.update(&mut cx, |_, cx| {
if opened_paths.is_empty() {
+ eprintln!("new file redispatched");
cx.dispatch_action(NewFile);
}
})
@@ -626,7 +626,7 @@ mod tests {
use theme::ThemeRegistry;
use workspace::{
item::{Item, ItemHandle},
- open_paths, pane, NewFile, Pane, SplitDirection, WorkspaceHandle,
+ open_new, open_paths, pane, NewFile, Pane, SplitDirection, WorkspaceHandle,
};
#[gpui::test]
@@ -762,8 +762,7 @@ mod tests {
#[gpui::test]
async fn test_new_empty_workspace(cx: &mut TestAppContext) {
let app_state = init(cx);
- cx.dispatch_global_action(workspace::NewFile);
- cx.foreground().run_until_parked();
+ cx.update(|cx| open_new(&app_state, cx)).await;
let window_id = *cx.window_ids().first().unwrap();
let workspace = cx.root_view::<Workspace>(window_id).unwrap();