diff --git a/crates/edit_prediction/src/udiff.rs b/crates/edit_prediction/src/udiff.rs index a9cb4fea7713f4347011455ac9c4bf2fe77b06fa..f107db223137ef62d8f1b0e0327c8dd75616a2ae 100644 --- a/crates/edit_prediction/src/udiff.rs +++ b/crates/edit_prediction/src/udiff.rs @@ -112,10 +112,12 @@ pub async fn apply_diff( }; buffer.read_with(cx, |buffer, _| { - edits.extend( - resolve_hunk_edits_in_buffer(hunk, buffer, ranges.as_slice(), status) - .with_context(|| format!("Diff:\n{diff_str}"))?, - ); + edits.extend(resolve_hunk_edits_in_buffer( + hunk, + buffer, + ranges.as_slice(), + status, + )?); anyhow::Ok(()) })?; } @@ -648,11 +650,7 @@ fn resolve_hunk_edits_in_buffer( }) .ok_or_else(|| { if candidates.is_empty() { - anyhow!( - "Failed to match context:\n\n```\n{}```\n\nBuffer contents:\n\n```\n{}```", - hunk.context, - buffer.text() - ) + anyhow!("Failed to match context:\n\n```\n{}```\n", hunk.context,) } else { anyhow!("Context is not unique enough:\n{}", hunk.context) } diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 91bee12f29a2cafa8833e7686da784ac527cfec1..b7c98c4035e1aaf14f3de484ab3233849a65a2b5 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -11,6 +11,7 @@ use project::Project; use serde::{Deserialize, Serialize}; use std::{ borrow::Cow, + collections::VecDeque, io::Read, path::{Path, PathBuf}, sync::Arc, @@ -214,9 +215,9 @@ pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) { }); } -pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec> { +pub fn group_examples_by_repo(examples: Vec) -> VecDeque> { let mut examples_by_repo = HashMap::default(); - for example in examples.iter_mut() { + for example in examples { examples_by_repo .entry(example.spec.repository_url.clone()) .or_insert_with(Vec::new) diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 3103735ef1a4bbe2328a6ce420750ca54e775787..288db951854e7d6833fbe86f0949f17877abbb37 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -2,7 +2,7 @@ use crate::{ FormatPromptArgs, PredictionProvider, example::{Example, ExamplePrompt}, headless::EpAppState, - progress::{Progress, Step}, + progress::{ExampleProgress, Step}, retrieve_context::run_context_retrieval, }; use anyhow::{Context as _, Result}; @@ -18,11 +18,12 @@ pub async fn run_format_prompt( example: &mut Example, args: &FormatPromptArgs, app_state: Arc, + example_progress: &ExampleProgress, cx: AsyncApp, ) -> Result<()> { - run_context_retrieval(example, app_state.clone(), cx.clone()).await?; + run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?; - let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name); + let step_progress = example_progress.start(Step::FormatPrompt); let prompt_inputs = example .prompt_inputs diff --git a/crates/edit_prediction_cli/src/headless.rs b/crates/edit_prediction_cli/src/headless.rs index da96e7ef6520e952e2b7696eee6b82c243e90e4e..9d0291a2bf013e711027070fe7a2a4fe39d36b87 100644 --- a/crates/edit_prediction_cli/src/headless.rs +++ b/crates/edit_prediction_cli/src/headless.rs @@ -37,6 +37,10 @@ impl ProjectCache { pub fn get(&self, repository_url: &String) -> Option> { self.0.lock().unwrap().get(repository_url).cloned() } + + pub fn remove(&self, repository_url: &String) { + self.0.lock().unwrap().remove(repository_url); + } } pub fn init(cx: &mut App) -> EpAppState { diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index 7dfaad47e1e6564912a1ed81c258c737cb06f4f1..0a5cf9546db79d580567069d57f40fef3dd6dedd 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -2,7 +2,7 @@ use crate::{ example::{Example, ExamplePromptInputs, ExampleState}, git, headless::EpAppState, - progress::{InfoStyle, Progress, Step, StepProgress}, + progress::{ExampleProgress, InfoStyle, Step, StepProgress}, }; use anyhow::{Context as _, Result}; use edit_prediction::{ @@ -18,13 +18,14 @@ use std::{fs, path::PathBuf, sync::Arc}; pub async fn run_load_project( example: &mut Example, app_state: Arc, + example_progress: &ExampleProgress, mut cx: AsyncApp, ) -> Result<()> { if example.state.is_some() { return Ok(()); } - let progress = Progress::global().start(Step::LoadProject, &example.spec.name); + let progress = example_progress.start(Step::LoadProject); let project = setup_project(example, &app_state, &progress, &mut cx).await?; @@ -160,15 +161,11 @@ async fn cursor_position( let mut matches = text.match_indices(&cursor_excerpt); let (excerpt_offset, _) = matches.next().with_context(|| { - format!( - "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.", - example.spec.name - ) + format!("Cursor excerpt did not exist in buffer:\n\n{cursor_excerpt}\n",) })?; anyhow::ensure!( matches.next().is_none(), - "More than one cursor position match found for {}", - &example.spec.name + "More than one cursor position match found", ); Ok(excerpt_offset) })?; @@ -193,9 +190,6 @@ async fn setup_project( let worktree_path = setup_worktree(example, step_progress).await?; if let Some(project) = app_state.project_cache.get(&example.spec.repository_url) { - ep_store.update(cx, |ep_store, _| { - ep_store.clear_history_for_project(&project); - }); let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone()); let buffers = buffer_store.read_with(cx, |buffer_store, _| { buffer_store.buffers().collect::>() @@ -203,6 +197,9 @@ async fn setup_project( for buffer in buffers { buffer.update(cx, |buffer, cx| buffer.reload(cx)).await.ok(); } + ep_store.update(cx, |ep_store, _| { + ep_store.clear_history_for_project(&project); + }); return Ok(project); } diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 13a8399b3fe9853aaaa305355f3329170cf399fd..ce338a535ab7fb4ab50632c3d3eb528d06a074d2 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -30,6 +30,7 @@ use std::fmt::Display; use std::fs::{File, OpenOptions}; use std::hash::{Hash, Hasher}; use std::io::{BufRead, BufReader, BufWriter, Write}; +use std::sync::Mutex; use std::{path::PathBuf, sync::Arc}; use crate::distill::run_distill; @@ -496,7 +497,7 @@ fn main() { cx.spawn(async move |cx| { let result = async { - let mut examples = load_examples( + let examples = load_examples( app_state.client.http_client(), &args, output.as_ref(), @@ -536,97 +537,133 @@ fn main() { None }; - let mut grouped_examples = group_examples_by_repo(&mut examples); - let example_batches = grouped_examples.chunks_mut(args.max_parallelism); - - for example_batch in example_batches { - let futures = example_batch.into_iter().map(|repo_examples| async { - for example in repo_examples.iter_mut() { - let result = async { - match &command { - Command::ParseExample => {} - Command::LoadProject => { - run_load_project(example, app_state.clone(), cx.clone()) + let grouped_examples = Mutex::new(group_examples_by_repo(examples)); + let finished_examples = Mutex::new(Vec::new()); + + let mut tasks = Vec::new(); + for _ in 0..args.max_parallelism { + tasks.push(async { + loop { + let Some(mut repo_examples) = + grouped_examples.lock().unwrap().pop_front() + else { + break; + }; + for example in &mut repo_examples { + let example_progress = + Progress::global().start_group(&example.spec.name); + + let result = async { + match &command { + Command::ParseExample => {} + Command::LoadProject => { + run_load_project( + example, + app_state.clone(), + &example_progress, + cx.clone(), + ) .await?; - } - Command::Context => { - run_context_retrieval( - example, - app_state.clone(), - cx.clone(), - ) - .await?; - } - Command::FormatPrompt(args) => { - run_format_prompt( - example, - args, - app_state.clone(), - cx.clone(), - ) - .await?; - } - Command::Predict(args) => { - run_prediction( - example, - args, - app_state.clone(), - cx.clone(), - ) - .await?; - } - Command::Distill => { - run_distill(example).await?; - } - Command::Score(args) | Command::Eval(args) => { - run_scoring(example, &args, app_state.clone(), cx.clone()) + } + Command::Context => { + run_context_retrieval( + example, + app_state.clone(), + &example_progress, + cx.clone(), + ) .await?; + } + Command::FormatPrompt(args) => { + run_format_prompt( + example, + args, + app_state.clone(), + &example_progress, + cx.clone(), + ) + .await?; + } + Command::Predict(args) => { + run_prediction( + example, + args, + app_state.clone(), + &example_progress, + cx.clone(), + ) + .await?; + } + Command::Distill => { + run_distill(example).await?; + } + Command::Score(args) | Command::Eval(args) => { + run_scoring( + example, + &args, + app_state.clone(), + &example_progress, + cx.clone(), + ) + .await?; + } + Command::Clean + | Command::Synthesize(_) + | Command::SplitCommit(_) + | Command::Split(_) => { + unreachable!() + } } - Command::Clean - | Command::Synthesize(_) - | Command::SplitCommit(_) - | Command::Split(_) => { - unreachable!() - } + anyhow::Ok(()) } - anyhow::Ok(()) - } - .await; - - let failed = if let Err(error) = result { - handle_error( - error, - &args, - &command, - &app_state, - failfast_on_single_example, - example, - ) .await; - true - } else { - false - }; - let should_write = !failed || args.failed == FailedHandling::Keep; - if should_write { - if let Some(ref mut sender) = output_sender.clone() { - let line = serde_json::to_string(example).unwrap(); - sender - .send(line) - .await - .expect("Failed to send to output writer"); - } else if args.output.is_none() - && !matches!(command, Command::Eval(_)) - { - let line = serde_json::to_string(example).unwrap(); - println!("{}", line); + let failed = if let Err(error) = result { + handle_error( + error, + &args, + &command, + &app_state, + failfast_on_single_example, + &example, + ) + .await; + true + } else { + false + }; + + let should_write = !failed || args.failed == FailedHandling::Keep; + if should_write { + if let Some(ref mut sender) = output_sender.clone() { + let line = serde_json::to_string(&example).unwrap(); + sender + .send(line) + .await + .expect("Failed to send to output writer"); + } else if args.output.is_none() + && !matches!(command, Command::Eval(_)) + { + let line = serde_json::to_string(&example).unwrap(); + println!("{}", line); + } } } + + app_state + .project_cache + .remove(&repo_examples.first().unwrap().spec.repository_url); + for example in &mut repo_examples { + example.state.take(); + } + finished_examples + .lock() + .unwrap() + .extend_from_slice(&repo_examples); } }); - futures::future::join_all(futures).await; } + futures::future::join_all(tasks).await; Progress::global().finalize(); @@ -638,7 +675,7 @@ fn main() { } match &command { - Command::Eval(_) => score::print_report(&examples), + Command::Eval(_) => score::print_report(&finished_examples.lock().unwrap()), _ => (), }; diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 17ff5347561c429a3d66987ce27a9f62c2506cae..9749675a08ddf59e221204a89603a67e5ea329ec 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -6,7 +6,7 @@ use crate::{ headless::EpAppState, load_project::run_load_project, paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR}, - progress::{InfoStyle, Progress, Step}, + progress::{ExampleProgress, InfoStyle, Step}, retrieve_context::run_context_retrieval, }; use anyhow::Context as _; @@ -28,6 +28,7 @@ pub async fn run_prediction( example: &mut Example, args: &PredictArgs, app_state: Arc, + example_progress: &ExampleProgress, mut cx: AsyncApp, ) -> anyhow::Result<()> { let provider = args.provider; @@ -41,17 +42,18 @@ pub async fn run_prediction( } } - run_context_retrieval(example, app_state.clone(), cx.clone()).await?; + run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?; if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) = args.provider { - let _step_progress = Progress::global().start(Step::Predict, &example.spec.name); + let _step_progress = example_progress.start(Step::Predict); run_format_prompt( example, &FormatPromptArgs { provider }, app_state.clone(), + example_progress, cx, ) .await?; @@ -60,9 +62,9 @@ pub async fn run_prediction( return predict_anthropic(example, repetition_count, version, batched).await; } - run_load_project(example, app_state.clone(), cx.clone()).await?; + run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?; - let step_progress = Progress::global().start(Step::Predict, &example.spec.name); + let step_progress = example_progress.start(Step::Predict); if matches!( provider, diff --git a/crates/edit_prediction_cli/src/progress.rs b/crates/edit_prediction_cli/src/progress.rs index 90080c79eb93892d2a6d32f290f975d91b09579f..db6394d292dfd33e25564f0e3371c96875422885 100644 --- a/crates/edit_prediction_cli/src/progress.rs +++ b/crates/edit_prediction_cli/src/progress.rs @@ -22,6 +22,7 @@ struct ProgressInner { max_example_name_len: usize, status_lines_displayed: usize, total_examples: usize, + completed_examples: usize, failed_examples: usize, last_line_is_logging: bool, ticker: Option>, @@ -101,12 +102,14 @@ impl Progress { inner: Mutex::new(ProgressInner { completed: Vec::new(), in_progress: HashMap::new(), - is_tty: std::env::var("NO_COLOR").is_err() - && std::io::stderr().is_terminal(), + is_tty: std::env::var("COLOR").is_ok() + || (std::env::var("NO_COLOR").is_err() + && std::io::stderr().is_terminal()), terminal_width: get_terminal_width(), max_example_name_len: 0, status_lines_displayed: 0, total_examples: 0, + completed_examples: 0, failed_examples: 0, last_line_is_logging: false, ticker: None, @@ -119,6 +122,18 @@ impl Progress { .clone() } + pub fn start_group(self: &Arc, example_name: &str) -> ExampleProgress { + ExampleProgress { + progress: self.clone(), + example_name: example_name.to_string(), + } + } + + fn increment_completed(&self) { + let mut inner = self.inner.lock().unwrap(); + inner.completed_examples += 1; + } + pub fn set_total_examples(&self, total: usize) { let mut inner = self.inner.lock().unwrap(); inner.total_examples = total; @@ -247,7 +262,6 @@ impl Progress { for _ in 0..inner.status_lines_displayed { eprint!("\x1b[A\x1b[K"); } - let _ = std::io::stderr().flush(); inner.status_lines_displayed = 0; } } @@ -317,7 +331,7 @@ impl Progress { let dim = "\x1b[2m"; // Build the done/in-progress/total label - let done_count = inner.completed.len(); + let done_count = inner.completed_examples; let in_progress_count = inner.in_progress.len(); let failed_count = inner.failed_examples; @@ -427,6 +441,23 @@ impl Progress { } } +pub struct ExampleProgress { + progress: Arc, + example_name: String, +} + +impl ExampleProgress { + pub fn start(&self, step: Step) -> StepProgress { + self.progress.start(step, &self.example_name) + } +} + +impl Drop for ExampleProgress { + fn drop(&mut self) { + self.progress.increment_completed(); + } +} + pub struct StepProgress { progress: Arc, step: Step, diff --git a/crates/edit_prediction_cli/src/retrieve_context.rs b/crates/edit_prediction_cli/src/retrieve_context.rs index 8d9a5b072920527884d3b83e727551efa2ffb985..6f3fafa91b7c67d11c6a2990e6039f4c7f40c0ff 100644 --- a/crates/edit_prediction_cli/src/retrieve_context.rs +++ b/crates/edit_prediction_cli/src/retrieve_context.rs @@ -2,7 +2,7 @@ use crate::{ example::Example, headless::EpAppState, load_project::run_load_project, - progress::{InfoStyle, Progress, Step, StepProgress}, + progress::{ExampleProgress, InfoStyle, Step, StepProgress}, }; use anyhow::Context as _; use collections::HashSet; @@ -17,6 +17,7 @@ use std::time::Duration; pub async fn run_context_retrieval( example: &mut Example, app_state: Arc, + example_progress: &ExampleProgress, mut cx: AsyncApp, ) -> anyhow::Result<()> { if example @@ -27,11 +28,9 @@ pub async fn run_context_retrieval( return Ok(()); } - run_load_project(example, app_state.clone(), cx.clone()).await?; + run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?; - let step_progress: Arc = Progress::global() - .start(Step::Context, &example.spec.name) - .into(); + let step_progress: Arc = example_progress.start(Step::Context).into(); let state = example.state.as_ref().unwrap(); let project = state.project.clone(); @@ -96,10 +95,13 @@ async fn wait_for_language_servers_to_start( step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len())); - let timeout = cx - .background_executor() - .timer(Duration::from_secs(60 * 5)) - .shared(); + let timeout_duration = if starting_language_server_ids.is_empty() { + Duration::from_secs(30) + } else { + Duration::from_secs(60 * 5) + }; + + let timeout = cx.background_executor().timer(timeout_duration).shared(); let (mut tx, mut rx) = mpsc::channel(language_server_ids.len()); let added_subscription = cx.subscribe(project, { @@ -121,7 +123,7 @@ async fn wait_for_language_servers_to_start( } }, _ = timeout.clone().fuse() => { - return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes")); + return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60)); } } } @@ -190,7 +192,7 @@ async fn wait_for_language_servers_to_start( } }, _ = timeout.clone().fuse() => { - return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes")); + return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60)); } } } diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index d713137f3decae3a2e25e0bbe520724c8756018d..f467f140b4975163e80ff28c8de4b12807edc034 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -4,7 +4,7 @@ use crate::{ headless::EpAppState, metrics, predict::run_prediction, - progress::{Progress, Step}, + progress::{ExampleProgress, Step}, }; use anyhow::Context as _; use edit_prediction::udiff::apply_diff_to_string; @@ -15,11 +15,12 @@ pub async fn run_scoring( example: &mut Example, args: &PredictArgs, app_state: Arc, + example_progress: &ExampleProgress, cx: AsyncApp, ) -> anyhow::Result<()> { - run_prediction(example, args, app_state, cx).await?; + run_prediction(example, args, app_state, example_progress, cx).await?; - let progress = Progress::global().start(Step::Score, &example.spec.name); + let progress = example_progress.start(Step::Score); progress.set_substatus("applying patches"); let original_text = &example.prompt_inputs.as_ref().unwrap().content;