Use a `Shared` future to represent started language servers

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/project/src/project.rs | 248 ++++++++++++++++++------------------
1 file changed, 126 insertions(+), 122 deletions(-)

Detailed changes

crates/project/src/project.rs 🔗

@@ -7,7 +7,7 @@ use anyhow::{anyhow, Context, Result};
 use client::{proto, Client, PeerId, TypedEnvelope, User, UserStore};
 use clock::ReplicaId;
 use collections::{hash_map, HashMap, HashSet};
-use futures::Future;
+use futures::{future::Shared, Future, FutureExt};
 use fuzzy::{PathMatch, PathMatchCandidate, PathMatchCandidateSet};
 use gpui::{
     AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task,
@@ -39,8 +39,8 @@ pub struct Project {
     active_entry: Option<ProjectEntry>,
     languages: Arc<LanguageRegistry>,
     language_servers: HashMap<(WorktreeId, String), Arc<LanguageServer>>,
-    loading_language_servers:
-        HashMap<(WorktreeId, String), watch::Receiver<Option<Arc<LanguageServer>>>>,
+    started_language_servers:
+        HashMap<(WorktreeId, String), Shared<Task<Option<Arc<LanguageServer>>>>>,
     client: Arc<client::Client>,
     user_store: ModelHandle<UserStore>,
     fs: Arc<dyn Fs>,
@@ -260,7 +260,7 @@ impl Project {
                 fs,
                 language_servers_with_diagnostics_running: 0,
                 language_servers: Default::default(),
-                loading_language_servers: Default::default(),
+                started_language_servers: Default::default(),
             }
         })
     }
@@ -312,7 +312,7 @@ impl Project {
                 },
                 language_servers_with_diagnostics_running: 0,
                 language_servers: Default::default(),
-                loading_language_servers: Default::default(),
+                started_language_servers: Default::default(),
             };
             for worktree in worktrees {
                 this.add_worktree(&worktree, cx);
@@ -824,7 +824,7 @@ impl Project {
         worktree_path: Arc<Path>,
         language: Arc<Language>,
         cx: &mut ModelContext<Self>,
-    ) -> Task<Option<Arc<LanguageServer>>> {
+    ) -> Shared<Task<Option<Arc<LanguageServer>>>> {
         enum LspEvent {
             DiagnosticsStart,
             DiagnosticsUpdate(lsp::PublishDiagnosticsParams),
@@ -832,131 +832,137 @@ impl Project {
         }
 
         let key = (worktree_id, language.name().to_string());
-        if let Some(language_server) = self.language_servers.get(&key) {
-            return Task::ready(Some(language_server.clone()));
-        } else if let Some(mut language_server) = self.loading_language_servers.get(&key).cloned() {
-            return cx
-                .foreground()
-                .spawn(async move { language_server.recv().await.flatten() });
-        }
-
-        let (mut language_server_tx, language_server_rx) = watch::channel();
-        self.loading_language_servers
-            .insert(key.clone(), language_server_rx);
-        let language_server = language.start_server(worktree_path, cx);
-        let rpc = self.client.clone();
-        cx.spawn_weak(|this, mut cx| async move {
-            let language_server = language_server.await.log_err().flatten();
-            if let Some(this) = this.upgrade(&cx) {
-                this.update(&mut cx, |this, _| {
-                    this.loading_language_servers.remove(&key);
-                    if let Some(language_server) = language_server.clone() {
-                        this.language_servers.insert(key, language_server);
+        self.started_language_servers
+            .entry(key.clone())
+            .or_insert_with(|| {
+                let language_server = language.start_server(worktree_path, cx);
+                let rpc = self.client.clone();
+                cx.spawn_weak(|this, mut cx| async move {
+                    let language_server = language_server.await.log_err().flatten();
+                    if let Some(this) = this.upgrade(&cx) {
+                        this.update(&mut cx, |this, _| {
+                            if let Some(language_server) = language_server.clone() {
+                                this.language_servers.insert(key, language_server);
+                            }
+                        });
                     }
-                });
-            }
 
-            let language_server = language_server?;
-            *language_server_tx.borrow_mut() = Some(language_server.clone());
-
-            let disk_based_sources = language
-                .disk_based_diagnostic_sources()
-                .cloned()
-                .unwrap_or_default();
-            let disk_based_diagnostics_progress_token =
-                language.disk_based_diagnostics_progress_token().cloned();
-            let has_disk_based_diagnostic_progress_token =
-                disk_based_diagnostics_progress_token.is_some();
-            let (diagnostics_tx, diagnostics_rx) = smol::channel::unbounded();
-
-            // Listen for `PublishDiagnostics` notifications.
-            language_server
-                .on_notification::<lsp::notification::PublishDiagnostics, _>({
-                    let diagnostics_tx = diagnostics_tx.clone();
-                    move |params| {
-                        if !has_disk_based_diagnostic_progress_token {
-                            block_on(diagnostics_tx.send(LspEvent::DiagnosticsStart)).ok();
-                        }
-                        block_on(diagnostics_tx.send(LspEvent::DiagnosticsUpdate(params))).ok();
-                        if !has_disk_based_diagnostic_progress_token {
-                            block_on(diagnostics_tx.send(LspEvent::DiagnosticsFinish)).ok();
-                        }
-                    }
-                })
-                .detach();
+                    let language_server = language_server?;
 
-            // Listen for `Progress` notifications. Send an event when the language server
-            // transitions between running jobs and not running any jobs.
-            let mut running_jobs_for_this_server: i32 = 0;
-            language_server
-                .on_notification::<lsp::notification::Progress, _>(move |params| {
-                    let token = match params.token {
-                        lsp::NumberOrString::Number(_) => None,
-                        lsp::NumberOrString::String(token) => Some(token),
-                    };
+                    let disk_based_sources = language
+                        .disk_based_diagnostic_sources()
+                        .cloned()
+                        .unwrap_or_default();
+                    let disk_based_diagnostics_progress_token =
+                        language.disk_based_diagnostics_progress_token().cloned();
+                    let has_disk_based_diagnostic_progress_token =
+                        disk_based_diagnostics_progress_token.is_some();
+                    let (diagnostics_tx, diagnostics_rx) = smol::channel::unbounded();
+
+                    // Listen for `PublishDiagnostics` notifications.
+                    language_server
+                        .on_notification::<lsp::notification::PublishDiagnostics, _>({
+                            let diagnostics_tx = diagnostics_tx.clone();
+                            move |params| {
+                                if !has_disk_based_diagnostic_progress_token {
+                                    block_on(diagnostics_tx.send(LspEvent::DiagnosticsStart)).ok();
+                                }
+                                block_on(diagnostics_tx.send(LspEvent::DiagnosticsUpdate(params)))
+                                    .ok();
+                                if !has_disk_based_diagnostic_progress_token {
+                                    block_on(diagnostics_tx.send(LspEvent::DiagnosticsFinish)).ok();
+                                }
+                            }
+                        })
+                        .detach();
+
+                    // Listen for `Progress` notifications. Send an event when the language server
+                    // transitions between running jobs and not running any jobs.
+                    let mut running_jobs_for_this_server: i32 = 0;
+                    language_server
+                        .on_notification::<lsp::notification::Progress, _>(move |params| {
+                            let token = match params.token {
+                                lsp::NumberOrString::Number(_) => None,
+                                lsp::NumberOrString::String(token) => Some(token),
+                            };
 
-                    if token == disk_based_diagnostics_progress_token {
-                        match params.value {
-                            lsp::ProgressParamsValue::WorkDone(progress) => match progress {
-                                lsp::WorkDoneProgress::Begin(_) => {
-                                    running_jobs_for_this_server += 1;
-                                    if running_jobs_for_this_server == 1 {
-                                        block_on(diagnostics_tx.send(LspEvent::DiagnosticsStart))
-                                            .ok();
+                            if token == disk_based_diagnostics_progress_token {
+                                match params.value {
+                                    lsp::ProgressParamsValue::WorkDone(progress) => {
+                                        match progress {
+                                            lsp::WorkDoneProgress::Begin(_) => {
+                                                running_jobs_for_this_server += 1;
+                                                if running_jobs_for_this_server == 1 {
+                                                    block_on(
+                                                        diagnostics_tx
+                                                            .send(LspEvent::DiagnosticsStart),
+                                                    )
+                                                    .ok();
+                                                }
+                                            }
+                                            lsp::WorkDoneProgress::End(_) => {
+                                                running_jobs_for_this_server -= 1;
+                                                if running_jobs_for_this_server == 0 {
+                                                    block_on(
+                                                        diagnostics_tx
+                                                            .send(LspEvent::DiagnosticsFinish),
+                                                    )
+                                                    .ok();
+                                                }
+                                            }
+                                            _ => {}
+                                        }
                                     }
                                 }
-                                lsp::WorkDoneProgress::End(_) => {
-                                    running_jobs_for_this_server -= 1;
-                                    if running_jobs_for_this_server == 0 {
-                                        block_on(diagnostics_tx.send(LspEvent::DiagnosticsFinish))
-                                            .ok();
-                                    }
+                            }
+                        })
+                        .detach();
+
+                    // Process all the LSP events.
+                    cx.spawn(|mut cx| async move {
+                        while let Ok(message) = diagnostics_rx.recv().await {
+                            let this = this.upgrade(&cx)?;
+                            match message {
+                                LspEvent::DiagnosticsStart => {
+                                    this.update(&mut cx, |this, cx| {
+                                        this.disk_based_diagnostics_started(cx);
+                                        if let Some(project_id) = this.remote_id() {
+                                            rpc.send(proto::DiskBasedDiagnosticsUpdating {
+                                                project_id,
+                                            })
+                                            .log_err();
+                                        }
+                                    });
                                 }
-                                _ => {}
-                            },
-                        }
-                    }
-                })
-                .detach();
-
-            // Process all the LSP events.
-            cx.spawn(|mut cx| async move {
-                while let Ok(message) = diagnostics_rx.recv().await {
-                    let this = this.upgrade(&cx)?;
-                    match message {
-                        LspEvent::DiagnosticsStart => {
-                            this.update(&mut cx, |this, cx| {
-                                this.disk_based_diagnostics_started(cx);
-                                if let Some(project_id) = this.remote_id() {
-                                    rpc.send(proto::DiskBasedDiagnosticsUpdating { project_id })
-                                        .log_err();
+                                LspEvent::DiagnosticsUpdate(mut params) => {
+                                    language.process_diagnostics(&mut params);
+                                    this.update(&mut cx, |this, cx| {
+                                        this.update_diagnostics(params, &disk_based_sources, cx)
+                                            .log_err();
+                                    });
                                 }
-                            });
-                        }
-                        LspEvent::DiagnosticsUpdate(mut params) => {
-                            language.process_diagnostics(&mut params);
-                            this.update(&mut cx, |this, cx| {
-                                this.update_diagnostics(params, &disk_based_sources, cx)
-                                    .log_err();
-                            });
-                        }
-                        LspEvent::DiagnosticsFinish => {
-                            this.update(&mut cx, |this, cx| {
-                                this.disk_based_diagnostics_finished(cx);
-                                if let Some(project_id) = this.remote_id() {
-                                    rpc.send(proto::DiskBasedDiagnosticsUpdated { project_id })
-                                        .log_err();
+                                LspEvent::DiagnosticsFinish => {
+                                    this.update(&mut cx, |this, cx| {
+                                        this.disk_based_diagnostics_finished(cx);
+                                        if let Some(project_id) = this.remote_id() {
+                                            rpc.send(proto::DiskBasedDiagnosticsUpdated {
+                                                project_id,
+                                            })
+                                            .log_err();
+                                        }
+                                    });
                                 }
-                            });
+                            }
                         }
-                    }
-                }
-                Some(())
-            })
-            .detach();
+                        Some(())
+                    })
+                    .detach();
 
-            Some(language_server)
-        })
+                    Some(language_server)
+                })
+                .shared()
+            })
+            .clone()
     }
 
     pub fn update_diagnostics(
@@ -2885,8 +2891,6 @@ impl Entity for Project {
         &mut self,
         _: &mut MutableAppContext,
     ) -> Option<std::pin::Pin<Box<dyn 'static + Future<Output = ()>>>> {
-        use futures::FutureExt;
-
         let shutdown_futures = self
             .language_servers
             .drain()