Restructure concurrency in EP CLI to allow running many examples in big rust repos (#44673)

Max Brunsfeld created

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/edit_prediction.rs      |  6 
crates/edit_prediction_cli/src/example.rs          | 41 +++--
crates/edit_prediction_cli/src/headless.rs         | 22 +++
crates/edit_prediction_cli/src/load_project.rs     | 97 ++++++++++-----
crates/edit_prediction_cli/src/main.rs             | 53 +++++---
crates/edit_prediction_cli/src/retrieve_context.rs | 37 +++--
crates/project/src/lsp_store.rs                    | 13 +
7 files changed, 173 insertions(+), 96 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -577,6 +577,12 @@ impl EditPredictionStore {
         }
     }
 
+    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
+        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
+            project_state.events.clear();
+        }
+    }
+
     pub fn edit_history_for_project(
         &self,
         project: &Entity<Project>,

crates/edit_prediction_cli/src/example.rs 🔗

@@ -1,9 +1,6 @@
-use crate::{
-    PredictionProvider, PromptFormat,
-    metrics::ClassificationMetrics,
-    paths::{REPOS_DIR, WORKTREES_DIR},
-};
+use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
 use anyhow::{Context as _, Result};
+use collections::HashMap;
 use edit_prediction::udiff::OpenedBuffers;
 use gpui::Entity;
 use http_client::Url;
@@ -102,7 +99,7 @@ pub struct ExampleScore {
 }
 
 impl Example {
-    fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
+    pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
         // git@github.com:owner/repo.git
         if self.repository_url.contains('@') {
             let (owner, repo) = self
@@ -134,17 +131,6 @@ impl Example {
             Ok((owner.into(), repo.into()))
         }
     }
-
-    pub fn worktree_path(&self) -> PathBuf {
-        WORKTREES_DIR
-            .join(&self.name)
-            .join(self.repo_name().unwrap().1.as_ref())
-    }
-
-    pub fn repo_path(&self) -> PathBuf {
-        let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
-        REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
-    }
 }
 
 pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
@@ -218,6 +204,8 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
             }
         }
     }
+
+    sort_examples_by_repo_and_rev(&mut examples);
     examples
 }
 
@@ -235,6 +223,25 @@ pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
     }
 }
 
+pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
+    examples.sort_by(|a, b| {
+        a.repository_url
+            .cmp(&b.repository_url)
+            .then(b.revision.cmp(&a.revision))
+    });
+}
+
+pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>> {
+    let mut examples_by_repo = HashMap::default();
+    for example in examples.iter_mut() {
+        examples_by_repo
+            .entry(example.repository_url.clone())
+            .or_insert_with(Vec::new)
+            .push(example);
+    }
+    examples_by_repo.into_values().collect()
+}
+
 fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
     use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
 

crates/edit_prediction_cli/src/headless.rs 🔗

@@ -1,4 +1,5 @@
 use client::{Client, ProxySettings, UserStore};
+use collections::HashMap;
 use extension::ExtensionHostProxy;
 use fs::RealFs;
 use gpui::http_client::read_proxy_from_env;
@@ -7,12 +8,13 @@ use gpui_tokio::Tokio;
 use language::LanguageRegistry;
 use language_extension::LspAccess;
 use node_runtime::{NodeBinaryOptions, NodeRuntime};
+use project::Project;
 use project::project_settings::ProjectSettings;
 use release_channel::{AppCommitSha, AppVersion};
 use reqwest_client::ReqwestClient;
 use settings::{Settings, SettingsStore};
 use std::path::PathBuf;
-use std::sync::Arc;
+use std::sync::{Arc, Mutex};
 use util::ResultExt as _;
 
 /// Headless subset of `workspace::AppState`.
@@ -22,9 +24,22 @@ pub struct EpAppState {
     pub user_store: Entity<UserStore>,
     pub fs: Arc<dyn fs::Fs>,
     pub node_runtime: NodeRuntime,
+    pub project_cache: ProjectCache,
+}
+
+#[derive(Default)]
+pub struct ProjectCache(Mutex<HashMap<String, Entity<Project>>>);
+
+impl ProjectCache {
+    pub fn insert(&self, repository_url: String, project: Entity<Project>) {
+        self.0.lock().unwrap().insert(repository_url, project);
+    }
+
+    pub fn get(&self, repository_url: &String) -> Option<Entity<Project>> {
+        self.0.lock().unwrap().get(repository_url).cloned()
+    }
 }
 
-// TODO: dedupe with crates/eval/src/eval.rs
 pub fn init(cx: &mut App) -> EpAppState {
     let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned()));
 
@@ -112,11 +127,14 @@ pub fn init(cx: &mut App) -> EpAppState {
     prompt_store::init(cx);
     terminal_view::init(cx);
 
+    let project_cache = ProjectCache::default();
+
     EpAppState {
         languages,
         client,
         user_store,
         fs,
         node_runtime,
+        project_cache,
     }
 }

crates/edit_prediction_cli/src/load_project.rs 🔗

@@ -1,6 +1,7 @@
 use crate::{
     example::{Example, ExampleBuffer, ExampleState},
     headless::EpAppState,
+    paths::{REPOS_DIR, WORKTREES_DIR},
 };
 use anyhow::{Result, anyhow};
 use collections::HashMap;
@@ -29,29 +30,11 @@ pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>,
     }
 
     let project = setup_project(example, &app_state, &mut cx).await;
-    let buffer_store = project
-        .read_with(&cx, |project, _| project.buffer_store().clone())
-        .unwrap();
-
-    let ep_store = cx
-        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
-        .unwrap();
-
-    cx.subscribe(&buffer_store, {
-        let project = project.clone();
-        move |_, event, cx| match event {
-            BufferStoreEvent::BufferAdded(buffer) => {
-                ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
-            }
-            _ => {}
-        }
-    })
-    .unwrap()
-    .detach();
 
     let _open_buffers = apply_edit_history(example, &project, &mut cx)
         .await
         .unwrap();
+
     let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
     example.buffer = buffer
         .read_with(&cx, |buffer, _cx| {
@@ -64,6 +47,7 @@ pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>,
             })
         })
         .unwrap();
+
     example.state = Some(ExampleState {
         buffer,
         project,
@@ -149,7 +133,35 @@ async fn setup_project(
     app_state: &Arc<EpAppState>,
     cx: &mut AsyncApp,
 ) -> Entity<Project> {
-    setup_worktree(example).await;
+    let ep_store = cx
+        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+        .unwrap();
+
+    let worktree_path = setup_worktree(example).await;
+
+    if let Some(project) = app_state.project_cache.get(&example.repository_url) {
+        ep_store
+            .update(cx, |ep_store, _| {
+                ep_store.clear_history_for_project(&project);
+            })
+            .unwrap();
+        let buffer_store = project
+            .read_with(cx, |project, _| project.buffer_store().clone())
+            .unwrap();
+        let buffers = buffer_store
+            .read_with(cx, |buffer_store, _| {
+                buffer_store.buffers().collect::<Vec<_>>()
+            })
+            .unwrap();
+        for buffer in buffers {
+            buffer
+                .update(cx, |buffer, cx| buffer.reload(cx))
+                .unwrap()
+                .await
+                .unwrap();
+        }
+        return project;
+    }
 
     let project = cx
         .update(|cx| {
@@ -168,30 +180,44 @@ async fn setup_project(
     project
         .update(cx, |project, cx| {
             project.disable_worktree_scanner(cx);
-        })
-        .unwrap();
-
-    let worktree = project
-        .update(cx, |project, cx| {
-            project.create_worktree(&example.worktree_path(), true, cx)
+            project.create_worktree(&worktree_path, true, cx)
         })
         .unwrap()
         .await
         .unwrap();
-    worktree
-        .read_with(cx, |worktree, _cx| {
-            worktree.as_local().unwrap().scan_complete()
-        })
-        .unwrap()
-        .await;
+
+    app_state
+        .project_cache
+        .insert(example.repository_url.clone(), project.clone());
+
+    let buffer_store = project
+        .read_with(cx, |project, _| project.buffer_store().clone())
+        .unwrap();
+    cx.subscribe(&buffer_store, {
+        let project = project.clone();
+        move |_, event, cx| match event {
+            BufferStoreEvent::BufferAdded(buffer) => {
+                ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
+            }
+            _ => {}
+        }
+    })
+    .unwrap()
+    .detach();
+
     project
 }
 
-pub async fn setup_worktree(example: &Example) {
-    let repo_dir = example.repo_path();
+pub async fn setup_worktree(example: &Example) -> PathBuf {
+    let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
+    let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
+    let worktree_path = WORKTREES_DIR
+        .join(repo_owner.as_ref())
+        .join(repo_name.as_ref());
     let repo_lock = lock_repo(&repo_dir).await;
 
     if !repo_dir.is_dir() {
+        eprintln!("Cloning repository {}", example.repository_url);
         fs::create_dir_all(&repo_dir).unwrap();
         run_git(&repo_dir, &["init"]).await.unwrap();
         run_git(
@@ -227,7 +253,6 @@ pub async fn setup_worktree(example: &Example) {
     };
 
     // Create the worktree for this example if needed.
-    let worktree_path = example.worktree_path();
     if worktree_path.is_dir() {
         run_git(&worktree_path, &["clean", "--force", "-d"])
             .await
@@ -288,6 +313,8 @@ pub async fn setup_worktree(example: &Example) {
             );
         }
     }
+
+    worktree_path
 }
 
 async fn apply_edit_history(

crates/edit_prediction_cli/src/main.rs 🔗

@@ -15,10 +15,12 @@ use edit_prediction::EditPredictionStore;
 use gpui::Application;
 use reqwest_client::ReqwestClient;
 use serde::{Deserialize, Serialize};
+use std::sync::atomic::AtomicUsize;
+use std::sync::atomic::Ordering::SeqCst;
 use std::{path::PathBuf, sync::Arc};
 
 use crate::distill::run_distill;
-use crate::example::{read_examples, write_examples};
+use crate::example::{group_examples_by_repo, read_examples, write_examples};
 use crate::format_prompt::run_format_prompt;
 use crate::load_project::run_load_project;
 use crate::predict::run_prediction;
@@ -145,31 +147,40 @@ fn main() {
         EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
 
         cx.spawn(async move |cx| {
-            match &command {
-                Command::Predict(args) => predict::sync_batches(&args.provider).await,
-                _ => (),
+            if let Command::Predict(args) = &command {
+                predict::sync_batches(&args.provider).await
             };
 
-            let chunks = examples.chunks_mut(args.max_parallelism);
-            let total_chunks = chunks.len();
-            for (batch_ix, data) in chunks.enumerate() {
-                let mut futures = Vec::new();
-                eprintln!("Processing batch: {}/{}", batch_ix + 1, total_chunks);
-
-                for example in data.iter_mut() {
-                    let cx = cx.clone();
-                    let app_state = app_state.clone();
-                    futures.push(async {
+            let example_count = examples.len();
+            let example_ix = AtomicUsize::new(0);
+            let mut grouped_examples = group_examples_by_repo(&mut examples);
+
+            let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
+            for example_batch in example_batches {
+                let futures = example_batch.into_iter().map(|repo_examples| async {
+                    for example in repo_examples.iter_mut() {
+                        eprintln!(
+                            "Processing example: {}/{}",
+                            example_ix.load(SeqCst) + 1,
+                            example_count
+                        );
+                        example_ix.fetch_add(1, SeqCst);
                         match &command {
                             Command::ParseExample => {}
                             Command::LoadProject => {
-                                run_load_project(example, app_state.clone(), cx).await;
+                                run_load_project(example, app_state.clone(), cx.clone()).await;
                             }
                             Command::Context => {
-                                run_context_retrieval(example, app_state, cx).await;
+                                run_context_retrieval(example, app_state.clone(), cx.clone()).await;
                             }
                             Command::FormatPrompt(args) => {
-                                run_format_prompt(example, args.prompt_format, app_state, cx).await;
+                                run_format_prompt(
+                                    example,
+                                    args.prompt_format,
+                                    app_state.clone(),
+                                    cx.clone(),
+                                )
+                                .await;
                             }
                             Command::Predict(args) => {
                                 run_prediction(
@@ -177,7 +188,7 @@ fn main() {
                                     Some(args.provider),
                                     args.repetitions,
                                     app_state.clone(),
-                                    cx,
+                                    cx.clone(),
                                 )
                                 .await;
                             }
@@ -185,14 +196,14 @@ fn main() {
                                 run_distill(example).await;
                             }
                             Command::Score(args) | Command::Eval(args) => {
-                                run_scoring(example, &args, app_state, cx).await;
+                                run_scoring(example, &args, app_state.clone(), cx.clone()).await;
                             }
                             Command::Clean => {
                                 unreachable!()
                             }
                         }
-                    });
-                }
+                    }
+                });
                 futures::future::join_all(futures).await;
             }
 

crates/edit_prediction_cli/src/retrieve_context.rs 🔗

@@ -75,25 +75,24 @@ async fn wait_for_language_servers_to_start(
         .read_with(cx, |project, _| project.lsp_store())
         .unwrap();
 
-    let lang_server_ids = buffer
+    let (language_server_ids, mut starting_language_server_ids) = buffer
         .update(cx, |buffer, cx| {
             lsp_store.update(cx, |lsp_store, cx| {
-                lsp_store.language_servers_for_local_buffer(buffer, cx)
+                let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
+                let starting_ids = ids
+                    .iter()
+                    .copied()
+                    .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
+                    .collect::<HashSet<_>>();
+                (ids, starting_ids)
             })
         })
         .unwrap_or_default();
 
-    if !lang_server_ids.is_empty() {
-        project
-            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
-            .unwrap()
-            .detach();
-    }
-
     eprintln!(
         "{}⏵ Waiting for {} language servers",
         log_prefix,
-        lang_server_ids.len()
+        language_server_ids.len()
     );
 
     let timeout = cx
@@ -101,7 +100,7 @@ async fn wait_for_language_servers_to_start(
         .timer(Duration::from_secs(60 * 5))
         .shared();
 
-    let (mut tx, mut rx) = mpsc::channel(lang_server_ids.len());
+    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
     let added_subscription = cx.subscribe(project, {
         let log_prefix = log_prefix.clone();
         move |_, event, _| match event {
@@ -113,12 +112,11 @@ async fn wait_for_language_servers_to_start(
         }
     });
 
-    let mut pending_language_server_ids = HashSet::from_iter(lang_server_ids.iter());
-    while !pending_language_server_ids.is_empty() {
+    while !starting_language_server_ids.is_empty() {
         futures::select! {
             language_server_id = rx.next() => {
                 if let Some(id) = language_server_id {
-                    pending_language_server_ids.remove(&id);
+                    starting_language_server_ids.remove(&id);
                 }
             },
             _ = timeout.clone().fuse() => {
@@ -129,7 +127,14 @@ async fn wait_for_language_servers_to_start(
 
     drop(added_subscription);
 
-    let (mut tx, mut rx) = mpsc::channel(lang_server_ids.len());
+    if !language_server_ids.is_empty() {
+        project
+            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+            .unwrap()
+            .detach();
+    }
+
+    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
     let subscriptions = [
         cx.subscribe(&lsp_store, {
             let log_prefix = log_prefix.clone();
@@ -172,7 +177,7 @@ async fn wait_for_language_servers_to_start(
         .await
         .unwrap();
 
-    let mut pending_language_server_ids = HashSet::from_iter(lang_server_ids.into_iter());
+    let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
     while !pending_language_server_ids.is_empty() {
         futures::select! {
             language_server_id = rx.next() => {

crates/project/src/lsp_store.rs 🔗

@@ -201,7 +201,10 @@ pub enum LspFormatTarget {
     Ranges(BTreeMap<BufferId, Vec<Range<Anchor>>>),
 }
 
-pub type OpenLspBufferHandle = Entity<Entity<Buffer>>;
+#[derive(Clone, PartialEq, Eq, Hash)]
+pub struct OpenLspBufferHandle(Entity<OpenLspBuffer>);
+
+struct OpenLspBuffer(Entity<Buffer>);
 
 impl FormatTrigger {
     fn from_proto(value: i32) -> FormatTrigger {
@@ -4208,7 +4211,7 @@ impl LspStore {
         cx: &mut Context<Self>,
     ) -> OpenLspBufferHandle {
         let buffer_id = buffer.read(cx).remote_id();
-        let handle = cx.new(|_| buffer.clone());
+        let handle = OpenLspBufferHandle(cx.new(|_| OpenLspBuffer(buffer.clone())));
         if let Some(local) = self.as_local_mut() {
             let refcount = local.registered_buffers.entry(buffer_id).or_insert(0);
             if !ignore_refcounts {
@@ -4230,7 +4233,7 @@ impl LspStore {
                 local.register_buffer_with_language_servers(buffer, only_register_servers, cx);
             }
             if !ignore_refcounts {
-                cx.observe_release(&handle, move |lsp_store, buffer, cx| {
+                cx.observe_release(&handle.0, move |lsp_store, buffer, cx| {
                     let refcount = {
                         let local = lsp_store.as_local_mut().unwrap();
                         let Some(refcount) = local.registered_buffers.get_mut(&buffer_id) else {
@@ -4247,8 +4250,8 @@ impl LspStore {
                         local.registered_buffers.remove(&buffer_id);
 
                         local.buffers_opened_in_servers.remove(&buffer_id);
-                        if let Some(file) = File::from_dyn(buffer.read(cx).file()).cloned() {
-                            local.unregister_old_buffer_from_language_servers(buffer, &file, cx);
+                        if let Some(file) = File::from_dyn(buffer.0.read(cx).file()).cloned() {
+                            local.unregister_old_buffer_from_language_servers(&buffer.0, &file, cx);
 
                             let buffer_abs_path = file.abs_path(cx);
                             for (_, buffer_pull_diagnostics_result_ids) in