From b871130220054f942c74d8d570614c4ab97ea9b1 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 11 Dec 2025 17:58:53 -0800 Subject: [PATCH] Restructure concurrency in EP CLI to allow running many examples in big rust repos (#44673) Release Notes: - N/A --- crates/edit_prediction/src/edit_prediction.rs | 6 ++ crates/edit_prediction_cli/src/example.rs | 41 ++++---- crates/edit_prediction_cli/src/headless.rs | 22 ++++- .../edit_prediction_cli/src/load_project.rs | 97 ++++++++++++------- crates/edit_prediction_cli/src/main.rs | 53 ++++++---- .../src/retrieve_context.rs | 37 ++++--- crates/project/src/lsp_store.rs | 13 ++- 7 files changed, 173 insertions(+), 96 deletions(-) diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index dd7b0090cb88c1564fc72de11ce9ec13e78f6a7c..6a7c6232d08b15fccacdd80a446432e453a80e20 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -577,6 +577,12 @@ impl EditPredictionStore { } } + pub fn clear_history_for_project(&mut self, project: &Entity) { + if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { + project_state.events.clear(); + } + } + pub fn edit_history_for_project( &self, project: &Entity, diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 1e21526e80104013a320a63e764dab0926bdd6f0..9499aae0c1ebce7eeca3ef05fedbcf09c960e131 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -1,9 +1,6 @@ -use crate::{ - PredictionProvider, PromptFormat, - metrics::ClassificationMetrics, - paths::{REPOS_DIR, WORKTREES_DIR}, -}; +use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics}; use anyhow::{Context as _, Result}; +use collections::HashMap; use edit_prediction::udiff::OpenedBuffers; use gpui::Entity; use http_client::Url; @@ -102,7 +99,7 @@ pub struct ExampleScore { } impl Example { - fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> { + pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> { // git@github.com:owner/repo.git if self.repository_url.contains('@') { let (owner, repo) = self @@ -134,17 +131,6 @@ impl Example { Ok((owner.into(), repo.into())) } } - - pub fn worktree_path(&self) -> PathBuf { - WORKTREES_DIR - .join(&self.name) - .join(self.repo_name().unwrap().1.as_ref()) - } - - pub fn repo_path(&self) -> PathBuf { - let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name"); - REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref()) - } } pub fn read_examples(inputs: &[PathBuf]) -> Vec { @@ -218,6 +204,8 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec { } } } + + sort_examples_by_repo_and_rev(&mut examples); examples } @@ -235,6 +223,25 @@ pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) { } } +pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) { + examples.sort_by(|a, b| { + a.repository_url + .cmp(&b.repository_url) + .then(b.revision.cmp(&a.revision)) + }); +} + +pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec> { + let mut examples_by_repo = HashMap::default(); + for example in examples.iter_mut() { + examples_by_repo + .entry(example.repository_url.clone()) + .or_insert_with(Vec::new) + .push(example); + } + examples_by_repo.into_values().collect() +} + fn parse_markdown_example(id: String, input: &str) -> Result { use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd}; diff --git a/crates/edit_prediction_cli/src/headless.rs b/crates/edit_prediction_cli/src/headless.rs index fd20774168ea3c07f4efffdefe23f1b4ff5f5ef4..2deb96fdbf19a94c5649d87a7bf2f5fea0b601c2 100644 --- a/crates/edit_prediction_cli/src/headless.rs +++ b/crates/edit_prediction_cli/src/headless.rs @@ -1,4 +1,5 @@ use client::{Client, ProxySettings, UserStore}; +use collections::HashMap; use extension::ExtensionHostProxy; use fs::RealFs; use gpui::http_client::read_proxy_from_env; @@ -7,12 +8,13 @@ use gpui_tokio::Tokio; use language::LanguageRegistry; use language_extension::LspAccess; use node_runtime::{NodeBinaryOptions, NodeRuntime}; +use project::Project; use project::project_settings::ProjectSettings; use release_channel::{AppCommitSha, AppVersion}; use reqwest_client::ReqwestClient; use settings::{Settings, SettingsStore}; use std::path::PathBuf; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use util::ResultExt as _; /// Headless subset of `workspace::AppState`. @@ -22,9 +24,22 @@ pub struct EpAppState { pub user_store: Entity, pub fs: Arc, pub node_runtime: NodeRuntime, + pub project_cache: ProjectCache, +} + +#[derive(Default)] +pub struct ProjectCache(Mutex>>); + +impl ProjectCache { + pub fn insert(&self, repository_url: String, project: Entity) { + self.0.lock().unwrap().insert(repository_url, project); + } + + pub fn get(&self, repository_url: &String) -> Option> { + self.0.lock().unwrap().get(repository_url).cloned() + } } -// TODO: dedupe with crates/eval/src/eval.rs pub fn init(cx: &mut App) -> EpAppState { let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned())); @@ -112,11 +127,14 @@ pub fn init(cx: &mut App) -> EpAppState { prompt_store::init(cx); terminal_view::init(cx); + let project_cache = ProjectCache::default(); + EpAppState { languages, client, user_store, fs, node_runtime, + project_cache, } } diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index 4d703ffc57c282681c5fbcf35368e5a01839cd24..3e0b34241164801a30f959f759e1c0419ba324ff 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -1,6 +1,7 @@ use crate::{ example::{Example, ExampleBuffer, ExampleState}, headless::EpAppState, + paths::{REPOS_DIR, WORKTREES_DIR}, }; use anyhow::{Result, anyhow}; use collections::HashMap; @@ -29,29 +30,11 @@ pub async fn run_load_project(example: &mut Example, app_state: Arc, } let project = setup_project(example, &app_state, &mut cx).await; - let buffer_store = project - .read_with(&cx, |project, _| project.buffer_store().clone()) - .unwrap(); - - let ep_store = cx - .update(|cx| EditPredictionStore::try_global(cx).unwrap()) - .unwrap(); - - cx.subscribe(&buffer_store, { - let project = project.clone(); - move |_, event, cx| match event { - BufferStoreEvent::BufferAdded(buffer) => { - ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx)); - } - _ => {} - } - }) - .unwrap() - .detach(); let _open_buffers = apply_edit_history(example, &project, &mut cx) .await .unwrap(); + let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await; example.buffer = buffer .read_with(&cx, |buffer, _cx| { @@ -64,6 +47,7 @@ pub async fn run_load_project(example: &mut Example, app_state: Arc, }) }) .unwrap(); + example.state = Some(ExampleState { buffer, project, @@ -149,7 +133,35 @@ async fn setup_project( app_state: &Arc, cx: &mut AsyncApp, ) -> Entity { - setup_worktree(example).await; + let ep_store = cx + .update(|cx| EditPredictionStore::try_global(cx).unwrap()) + .unwrap(); + + let worktree_path = setup_worktree(example).await; + + if let Some(project) = app_state.project_cache.get(&example.repository_url) { + ep_store + .update(cx, |ep_store, _| { + ep_store.clear_history_for_project(&project); + }) + .unwrap(); + let buffer_store = project + .read_with(cx, |project, _| project.buffer_store().clone()) + .unwrap(); + let buffers = buffer_store + .read_with(cx, |buffer_store, _| { + buffer_store.buffers().collect::>() + }) + .unwrap(); + for buffer in buffers { + buffer + .update(cx, |buffer, cx| buffer.reload(cx)) + .unwrap() + .await + .unwrap(); + } + return project; + } let project = cx .update(|cx| { @@ -168,30 +180,44 @@ async fn setup_project( project .update(cx, |project, cx| { project.disable_worktree_scanner(cx); - }) - .unwrap(); - - let worktree = project - .update(cx, |project, cx| { - project.create_worktree(&example.worktree_path(), true, cx) + project.create_worktree(&worktree_path, true, cx) }) .unwrap() .await .unwrap(); - worktree - .read_with(cx, |worktree, _cx| { - worktree.as_local().unwrap().scan_complete() - }) - .unwrap() - .await; + + app_state + .project_cache + .insert(example.repository_url.clone(), project.clone()); + + let buffer_store = project + .read_with(cx, |project, _| project.buffer_store().clone()) + .unwrap(); + cx.subscribe(&buffer_store, { + let project = project.clone(); + move |_, event, cx| match event { + BufferStoreEvent::BufferAdded(buffer) => { + ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx)); + } + _ => {} + } + }) + .unwrap() + .detach(); + project } -pub async fn setup_worktree(example: &Example) { - let repo_dir = example.repo_path(); +pub async fn setup_worktree(example: &Example) -> PathBuf { + let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name"); + let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref()); + let worktree_path = WORKTREES_DIR + .join(repo_owner.as_ref()) + .join(repo_name.as_ref()); let repo_lock = lock_repo(&repo_dir).await; if !repo_dir.is_dir() { + eprintln!("Cloning repository {}", example.repository_url); fs::create_dir_all(&repo_dir).unwrap(); run_git(&repo_dir, &["init"]).await.unwrap(); run_git( @@ -227,7 +253,6 @@ pub async fn setup_worktree(example: &Example) { }; // Create the worktree for this example if needed. - let worktree_path = example.worktree_path(); if worktree_path.is_dir() { run_git(&worktree_path, &["clean", "--force", "-d"]) .await @@ -288,6 +313,8 @@ pub async fn setup_worktree(example: &Example) { ); } } + + worktree_path } async fn apply_edit_history( diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index cd05d909f351728b2a7c1c006662621310a5f89b..1091f0acfa182b95ed18bc6d560aaf7bca6225c7 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -15,10 +15,12 @@ use edit_prediction::EditPredictionStore; use gpui::Application; use reqwest_client::ReqwestClient; use serde::{Deserialize, Serialize}; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::SeqCst; use std::{path::PathBuf, sync::Arc}; use crate::distill::run_distill; -use crate::example::{read_examples, write_examples}; +use crate::example::{group_examples_by_repo, read_examples, write_examples}; use crate::format_prompt::run_format_prompt; use crate::load_project::run_load_project; use crate::predict::run_prediction; @@ -145,31 +147,40 @@ fn main() { EditPredictionStore::global(&app_state.client, &app_state.user_store, cx); cx.spawn(async move |cx| { - match &command { - Command::Predict(args) => predict::sync_batches(&args.provider).await, - _ => (), + if let Command::Predict(args) = &command { + predict::sync_batches(&args.provider).await }; - let chunks = examples.chunks_mut(args.max_parallelism); - let total_chunks = chunks.len(); - for (batch_ix, data) in chunks.enumerate() { - let mut futures = Vec::new(); - eprintln!("Processing batch: {}/{}", batch_ix + 1, total_chunks); - - for example in data.iter_mut() { - let cx = cx.clone(); - let app_state = app_state.clone(); - futures.push(async { + let example_count = examples.len(); + let example_ix = AtomicUsize::new(0); + let mut grouped_examples = group_examples_by_repo(&mut examples); + + let example_batches = grouped_examples.chunks_mut(args.max_parallelism); + for example_batch in example_batches { + let futures = example_batch.into_iter().map(|repo_examples| async { + for example in repo_examples.iter_mut() { + eprintln!( + "Processing example: {}/{}", + example_ix.load(SeqCst) + 1, + example_count + ); + example_ix.fetch_add(1, SeqCst); match &command { Command::ParseExample => {} Command::LoadProject => { - run_load_project(example, app_state.clone(), cx).await; + run_load_project(example, app_state.clone(), cx.clone()).await; } Command::Context => { - run_context_retrieval(example, app_state, cx).await; + run_context_retrieval(example, app_state.clone(), cx.clone()).await; } Command::FormatPrompt(args) => { - run_format_prompt(example, args.prompt_format, app_state, cx).await; + run_format_prompt( + example, + args.prompt_format, + app_state.clone(), + cx.clone(), + ) + .await; } Command::Predict(args) => { run_prediction( @@ -177,7 +188,7 @@ fn main() { Some(args.provider), args.repetitions, app_state.clone(), - cx, + cx.clone(), ) .await; } @@ -185,14 +196,14 @@ fn main() { run_distill(example).await; } Command::Score(args) | Command::Eval(args) => { - run_scoring(example, &args, app_state, cx).await; + run_scoring(example, &args, app_state.clone(), cx.clone()).await; } Command::Clean => { unreachable!() } } - }); - } + } + }); futures::future::join_all(futures).await; } diff --git a/crates/edit_prediction_cli/src/retrieve_context.rs b/crates/edit_prediction_cli/src/retrieve_context.rs index 042c17b8b147c611e826afa26f6c167656babc3b..0ef7a4676e30189f1417c0a8c339e8ac7f76e0ef 100644 --- a/crates/edit_prediction_cli/src/retrieve_context.rs +++ b/crates/edit_prediction_cli/src/retrieve_context.rs @@ -75,25 +75,24 @@ async fn wait_for_language_servers_to_start( .read_with(cx, |project, _| project.lsp_store()) .unwrap(); - let lang_server_ids = buffer + let (language_server_ids, mut starting_language_server_ids) = buffer .update(cx, |buffer, cx| { lsp_store.update(cx, |lsp_store, cx| { - lsp_store.language_servers_for_local_buffer(buffer, cx) + let ids = lsp_store.language_servers_for_local_buffer(buffer, cx); + let starting_ids = ids + .iter() + .copied() + .filter(|id| !lsp_store.language_server_statuses.contains_key(&id)) + .collect::>(); + (ids, starting_ids) }) }) .unwrap_or_default(); - if !lang_server_ids.is_empty() { - project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) - .unwrap() - .detach(); - } - eprintln!( "{}⏵ Waiting for {} language servers", log_prefix, - lang_server_ids.len() + language_server_ids.len() ); let timeout = cx @@ -101,7 +100,7 @@ async fn wait_for_language_servers_to_start( .timer(Duration::from_secs(60 * 5)) .shared(); - let (mut tx, mut rx) = mpsc::channel(lang_server_ids.len()); + let (mut tx, mut rx) = mpsc::channel(language_server_ids.len()); let added_subscription = cx.subscribe(project, { let log_prefix = log_prefix.clone(); move |_, event, _| match event { @@ -113,12 +112,11 @@ async fn wait_for_language_servers_to_start( } }); - let mut pending_language_server_ids = HashSet::from_iter(lang_server_ids.iter()); - while !pending_language_server_ids.is_empty() { + while !starting_language_server_ids.is_empty() { futures::select! { language_server_id = rx.next() => { if let Some(id) = language_server_id { - pending_language_server_ids.remove(&id); + starting_language_server_ids.remove(&id); } }, _ = timeout.clone().fuse() => { @@ -129,7 +127,14 @@ async fn wait_for_language_servers_to_start( drop(added_subscription); - let (mut tx, mut rx) = mpsc::channel(lang_server_ids.len()); + if !language_server_ids.is_empty() { + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .unwrap() + .detach(); + } + + let (mut tx, mut rx) = mpsc::channel(language_server_ids.len()); let subscriptions = [ cx.subscribe(&lsp_store, { let log_prefix = log_prefix.clone(); @@ -172,7 +177,7 @@ async fn wait_for_language_servers_to_start( .await .unwrap(); - let mut pending_language_server_ids = HashSet::from_iter(lang_server_ids.into_iter()); + let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter()); while !pending_language_server_ids.is_empty() { futures::select! { language_server_id = rx.next() => { diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 25798a6c09c3d8a851e93bd459d92ec8c3c62c77..9514ea03eff5e5ee2135ebd5e406c473f3fdea8d 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -201,7 +201,10 @@ pub enum LspFormatTarget { Ranges(BTreeMap>>), } -pub type OpenLspBufferHandle = Entity>; +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct OpenLspBufferHandle(Entity); + +struct OpenLspBuffer(Entity); impl FormatTrigger { fn from_proto(value: i32) -> FormatTrigger { @@ -4208,7 +4211,7 @@ impl LspStore { cx: &mut Context, ) -> OpenLspBufferHandle { let buffer_id = buffer.read(cx).remote_id(); - let handle = cx.new(|_| buffer.clone()); + let handle = OpenLspBufferHandle(cx.new(|_| OpenLspBuffer(buffer.clone()))); if let Some(local) = self.as_local_mut() { let refcount = local.registered_buffers.entry(buffer_id).or_insert(0); if !ignore_refcounts { @@ -4230,7 +4233,7 @@ impl LspStore { local.register_buffer_with_language_servers(buffer, only_register_servers, cx); } if !ignore_refcounts { - cx.observe_release(&handle, move |lsp_store, buffer, cx| { + cx.observe_release(&handle.0, move |lsp_store, buffer, cx| { let refcount = { let local = lsp_store.as_local_mut().unwrap(); let Some(refcount) = local.registered_buffers.get_mut(&buffer_id) else { @@ -4247,8 +4250,8 @@ impl LspStore { local.registered_buffers.remove(&buffer_id); local.buffers_opened_in_servers.remove(&buffer_id); - if let Some(file) = File::from_dyn(buffer.read(cx).file()).cloned() { - local.unregister_old_buffer_from_language_servers(buffer, &file, cx); + if let Some(file) = File::from_dyn(buffer.0.read(cx).file()).cloned() { + local.unregister_old_buffer_from_language_servers(&buffer.0, &file, cx); let buffer_abs_path = file.abs_path(cx); for (_, buffer_pull_diagnostics_result_ids) in