@@ -266,11 +266,12 @@ dependencies = [
[[package]]
name = "async-lock"
-version = "2.5.0"
+version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e97a171d191782fba31bb902b14ad94e24a68145032b7eedf871ab0bc0d077b6"
+checksum = "c8101efe8695a6c17e02911402145357e718ac92d3ff88ae8419e84b1707b685"
dependencies = [
"event-listener",
+ "futures-lite",
]
[[package]]
@@ -1031,7 +1032,6 @@ name = "collab"
version = "0.2.2"
dependencies = [
"anyhow",
- "async-trait",
"async-tungstenite",
"axum",
"axum-extra",
@@ -19,7 +19,6 @@ rpc = { path = "../rpc" }
util = { path = "../util" }
anyhow = "1.0.40"
-async-trait = "0.1.50"
async-tungstenite = "0.16"
axum = { version = "0.5", features = ["json", "headers", "ws"] }
axum-extra = { version = "0.3", features = ["erased-json"] }
@@ -2,7 +2,7 @@ mod store;
use crate::{
auth,
- db::{self, ProjectId, RoomId, User, UserId},
+ db::{self, DefaultDb, ProjectId, RoomId, User, UserId},
AppState, Result,
};
use anyhow::anyhow;
@@ -80,6 +80,17 @@ struct Response<R> {
struct Session {
user_id: UserId,
connection_id: ConnectionId,
+ db: Arc<Mutex<DbHandle>>,
+}
+
+struct DbHandle(Arc<DefaultDb>);
+
+impl Deref for DbHandle {
+ type Target = DefaultDb;
+
+ fn deref(&self) -> &Self::Target {
+ self.0.as_ref()
+ }
}
impl<R: RequestMessage> Response<R> {
@@ -352,6 +363,8 @@ impl Server {
let handle_io = handle_io.fuse();
futures::pin_mut!(handle_io);
+ let db = Arc::new(Mutex::new(DbHandle(this.app_state.db.clone())));
+
// Handlers for foreground messages are pushed into the following `FuturesUnordered`.
// This prevents deadlocks when e.g., client A performs a request to client B and
// client B performs a request to client A. If both clients stop processing further
@@ -382,6 +395,7 @@ impl Server {
let session = Session {
user_id,
connection_id,
+ db: db.clone(),
};
let handle_message = (handler)(this.clone(), message, session);
drop(span_enter);
@@ -1409,9 +1423,10 @@ impl Server {
session: Session,
) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
- let project_connection_ids = self
- .app_state
+ let project_connection_ids = session
.db
+ .lock()
+ .await
.project_connection_ids(project_id, session.connection_id)
.await?;