Move Arc outside of rpc::Client

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

server/src/tests.rs  |  6 +++++-
zed/src/channel.rs   |  2 +-
zed/src/lib.rs       |  2 +-
zed/src/rpc.rs       | 40 ++++++++++++++++++++++++----------------
zed/src/workspace.rs |  2 +-
zed/src/worktree.rs  | 26 +++++++++++++-------------
6 files changed, 45 insertions(+), 33 deletions(-)

Detailed changes

server/src/tests.rs 🔗

@@ -549,7 +549,11 @@ impl TestServer {
         }
     }
 
-    async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> (UserId, Client) {
+    async fn create_client(
+        &mut self,
+        cx: &mut TestAppContext,
+        name: &str,
+    ) -> (UserId, Arc<Client>) {
         let user_id = self.app_state.db.create_user(name, false).await.unwrap();
         let client = Client::new();
         let (client_conn, server_conn) = Channel::bidirectional();

zed/src/channel.rs 🔗

@@ -64,7 +64,7 @@ impl Channel {
     fn handle_message_sent(
         &mut self,
         message: TypedEnvelope<ChannelMessageSent>,
-        rpc: rpc::Client,
+        rpc: Arc<rpc::Client>,
         cx: &mut ModelContext<Self>,
     ) -> Result<()> {
         Ok(())

zed/src/lib.rs 🔗

@@ -30,7 +30,7 @@ pub struct AppState {
     pub settings: watch::Receiver<Settings>,
     pub languages: Arc<language::LanguageRegistry>,
     pub themes: Arc<settings::ThemeRegistry>,
-    pub rpc: rpc::Client,
+    pub rpc: Arc<rpc::Client>,
     pub fs: Arc<dyn fs::Fs>,
 }
 

zed/src/rpc.rs 🔗

@@ -23,14 +23,13 @@ lazy_static! {
         std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev:443".to_string());
 }
 
-#[derive(Clone)]
 pub struct Client {
     peer: Arc<Peer>,
-    state: Arc<RwLock<ClientState>>,
+    state: RwLock<ClientState>,
 }
 
 #[derive(Default)]
-pub struct ClientState {
+struct ClientState {
     connection_id: Option<ConnectionId>,
     entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
     model_handlers: HashMap<
@@ -40,28 +39,33 @@ pub struct ClientState {
 }
 
 pub struct Subscription {
-    state: Weak<RwLock<ClientState>>,
+    client: Weak<Client>,
     id: (TypeId, u64),
 }
 
 impl Drop for Subscription {
     fn drop(&mut self) {
-        if let Some(state) = self.state.upgrade() {
-            let _ = state.write().model_handlers.remove(&self.id).unwrap();
+        if let Some(client) = self.client.upgrade() {
+            client
+                .state
+                .write()
+                .model_handlers
+                .remove(&self.id)
+                .unwrap();
         }
     }
 }
 
 impl Client {
-    pub fn new() -> Self {
-        Self {
+    pub fn new() -> Arc<Self> {
+        Arc::new(Self {
             peer: Peer::new(),
             state: Default::default(),
-        }
+        })
     }
 
     pub fn subscribe_from_model<T, M, F>(
-        &self,
+        self: &Arc<Self>,
         remote_id: u64,
         cx: &mut ModelContext<M>,
         mut handler: F,
@@ -72,7 +76,7 @@ impl Client {
         F: 'static
             + Send
             + Sync
-            + FnMut(&mut M, TypedEnvelope<T>, Client, &mut ModelContext<M>) -> Result<()>,
+            + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
     {
         let subscription_id = (TypeId::of::<T>(), remote_id);
         let client = self.clone();
@@ -108,12 +112,12 @@ impl Client {
         }
 
         Subscription {
-            state: Arc::downgrade(&self.state),
+            client: Arc::downgrade(self),
             id: subscription_id,
         }
     }
 
-    pub async fn log_in_and_connect(&self, cx: AsyncAppContext) -> surf::Result<()> {
+    pub async fn log_in_and_connect(self: &Arc<Self>, cx: AsyncAppContext) -> surf::Result<()> {
         if self.state.read().connection_id.is_some() {
             return Ok(());
         }
@@ -144,7 +148,11 @@ impl Client {
         Ok(())
     }
 
-    pub async fn add_connection<Conn>(&self, conn: Conn, cx: AsyncAppContext) -> surf::Result<()>
+    pub async fn add_connection<Conn>(
+        self: &Arc<Self>,
+        conn: Conn,
+        cx: AsyncAppContext,
+    ) -> surf::Result<()>
     where
         Conn: 'static
             + futures::Sink<WebSocketMessage, Error = WebSocketError>
@@ -155,11 +163,11 @@ impl Client {
         let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
         {
             let mut cx = cx.clone();
-            let state = self.state.clone();
+            let this = self.clone();
             cx.foreground()
                 .spawn(async move {
                     while let Some(message) = incoming.recv().await {
-                        let mut state = state.write();
+                        let mut state = this.state.write();
                         if let Some(extract_entity_id) =
                             state.entity_id_extractors.get(&message.payload_type_id())
                         {

zed/src/workspace.rs 🔗

@@ -312,7 +312,7 @@ pub struct State {
 pub struct Workspace {
     pub settings: watch::Receiver<Settings>,
     languages: Arc<LanguageRegistry>,
-    rpc: rpc::Client,
+    rpc: Arc<rpc::Client>,
     fs: Arc<dyn Fs>,
     modal: Option<AnyViewHandle>,
     center: PaneGroup,

zed/src/worktree.rs 🔗

@@ -107,7 +107,7 @@ impl Worktree {
     }
 
     pub async fn open_remote(
-        rpc: rpc::Client,
+        rpc: Arc<rpc::Client>,
         id: u64,
         access_token: String,
         languages: Arc<LanguageRegistry>,
@@ -125,7 +125,7 @@ impl Worktree {
 
     async fn remote(
         open_response: proto::OpenWorktreeResponse,
-        rpc: rpc::Client,
+        rpc: Arc<rpc::Client>,
         languages: Arc<LanguageRegistry>,
         cx: &mut AsyncAppContext,
     ) -> Result<ModelHandle<Self>> {
@@ -283,7 +283,7 @@ impl Worktree {
     pub fn handle_add_peer(
         &mut self,
         envelope: TypedEnvelope<proto::AddPeer>,
-        _: rpc::Client,
+        _: Arc<rpc::Client>,
         cx: &mut ModelContext<Self>,
     ) -> Result<()> {
         match self {
@@ -295,7 +295,7 @@ impl Worktree {
     pub fn handle_remove_peer(
         &mut self,
         envelope: TypedEnvelope<proto::RemovePeer>,
-        _: rpc::Client,
+        _: Arc<rpc::Client>,
         cx: &mut ModelContext<Self>,
     ) -> Result<()> {
         match self {
@@ -307,7 +307,7 @@ impl Worktree {
     pub fn handle_update(
         &mut self,
         envelope: TypedEnvelope<proto::UpdateWorktree>,
-        _: rpc::Client,
+        _: Arc<rpc::Client>,
         cx: &mut ModelContext<Self>,
     ) -> anyhow::Result<()> {
         self.as_remote_mut()
@@ -318,7 +318,7 @@ impl Worktree {
     pub fn handle_open_buffer(
         &mut self,
         envelope: TypedEnvelope<proto::OpenBuffer>,
-        rpc: rpc::Client,
+        rpc: Arc<rpc::Client>,
         cx: &mut ModelContext<Self>,
     ) -> anyhow::Result<()> {
         let receipt = envelope.receipt();
@@ -341,7 +341,7 @@ impl Worktree {
     pub fn handle_close_buffer(
         &mut self,
         envelope: TypedEnvelope<proto::CloseBuffer>,
-        _: rpc::Client,
+        _: Arc<rpc::Client>,
         cx: &mut ModelContext<Self>,
     ) -> anyhow::Result<()> {
         self.as_local_mut()
@@ -397,7 +397,7 @@ impl Worktree {
     pub fn handle_update_buffer(
         &mut self,
         envelope: TypedEnvelope<proto::UpdateBuffer>,
-        _: rpc::Client,
+        _: Arc<rpc::Client>,
         cx: &mut ModelContext<Self>,
     ) -> Result<()> {
         let payload = envelope.payload.clone();
@@ -444,7 +444,7 @@ impl Worktree {
     pub fn handle_save_buffer(
         &mut self,
         envelope: TypedEnvelope<proto::SaveBuffer>,
-        rpc: rpc::Client,
+        rpc: Arc<rpc::Client>,
         cx: &mut ModelContext<Self>,
     ) -> Result<()> {
         let sender_id = envelope.original_sender_id()?;
@@ -488,7 +488,7 @@ impl Worktree {
     pub fn handle_buffer_saved(
         &mut self,
         envelope: TypedEnvelope<proto::BufferSaved>,
-        _: rpc::Client,
+        _: Arc<rpc::Client>,
         cx: &mut ModelContext<Self>,
     ) -> Result<()> {
         let payload = envelope.payload.clone();
@@ -966,7 +966,7 @@ impl LocalWorktree {
 
     pub fn share(
         &mut self,
-        rpc: rpc::Client,
+        rpc: Arc<rpc::Client>,
         cx: &mut ModelContext<Worktree>,
     ) -> Task<anyhow::Result<(u64, String)>> {
         let snapshot = self.snapshot();
@@ -1068,7 +1068,7 @@ impl fmt::Debug for LocalWorktree {
 }
 
 struct ShareState {
-    rpc: rpc::Client,
+    rpc: Arc<rpc::Client>,
     remote_id: u64,
     snapshots_tx: Sender<Snapshot>,
     _subscriptions: Vec<rpc::Subscription>,
@@ -1078,7 +1078,7 @@ pub struct RemoteWorktree {
     remote_id: u64,
     snapshot: Snapshot,
     snapshot_rx: watch::Receiver<Snapshot>,
-    rpc: rpc::Client,
+    rpc: Arc<rpc::Client>,
     updates_tx: postage::mpsc::Sender<proto::UpdateWorktree>,
     replica_id: ReplicaId,
     open_buffers: HashMap<usize, RemoteBuffer>,