diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index c9237232e5e0bb6167fbeee8732d46ee584b080b..53ddb99bd3f458a540c6593a2b1d6b1b547e463b 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -12,7 +12,7 @@ workspace = true path = "src/edit_prediction.rs" [features] -eval-support = [] +cli-support = [] [dependencies] ai_onboarding.workspace = true diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 11151c1fc19437655075c00589d02725131445ed..dd7b0090cb88c1564fc72de11ce9ec13e78f6a7c 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -55,7 +55,7 @@ pub mod open_ai_response; mod prediction; pub mod sweep_ai; -#[cfg(any(test, feature = "test-support", feature = "eval-support"))] +#[cfg(any(test, feature = "test-support", feature = "cli-support"))] pub mod udiff; mod zed_edit_prediction_delegate; @@ -158,7 +158,7 @@ pub struct EditPredictionStore { use_context: bool, options: ZetaOptions, update_required: bool, - #[cfg(feature = "eval-support")] + #[cfg(feature = "cli-support")] eval_cache: Option>, edit_prediction_model: EditPredictionModel, pub sweep_ai: SweepAi, @@ -505,7 +505,7 @@ impl EditPredictionStore { }, ), update_required: false, - #[cfg(feature = "eval-support")] + #[cfg(feature = "cli-support")] eval_cache: None, edit_prediction_model: EditPredictionModel::Zeta2, sweep_ai: SweepAi::new(cx), @@ -554,7 +554,7 @@ impl EditPredictionStore { .is_some() } - #[cfg(feature = "eval-support")] + #[cfg(feature = "cli-support")] pub fn with_eval_cache(&mut self, cache: Arc) { self.eval_cache = Some(cache); } @@ -1590,8 +1590,8 @@ impl EditPredictionStore { client: Arc, llm_token: LlmApiToken, app_version: Version, - #[cfg(feature = "eval-support")] eval_cache: Option>, - #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind, + #[cfg(feature = "cli-support")] eval_cache: Option>, + #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind, ) -> Result<(open_ai::Response, Option)> { let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() { http_client::Url::parse(&predict_edits_url)? @@ -1601,7 +1601,7 @@ impl EditPredictionStore { .build_zed_llm_url("/predict_edits/raw", &[])? }; - #[cfg(feature = "eval-support")] + #[cfg(feature = "cli-support")] let cache_key = if let Some(cache) = eval_cache { use collections::FxHasher; use std::hash::{Hash, Hasher}; @@ -1635,7 +1635,7 @@ impl EditPredictionStore { ) .await?; - #[cfg(feature = "eval-support")] + #[cfg(feature = "cli-support")] if let Some((cache, request, key)) = cache_key { cache.write(key, &request, &serde_json::to_string_pretty(&response)?); } @@ -1767,7 +1767,7 @@ impl EditPredictionStore { } } - #[cfg(feature = "eval-support")] + #[cfg(feature = "cli-support")] pub fn set_context_for_buffer( &mut self, project: &Entity, @@ -1892,10 +1892,10 @@ pub struct ZedUpdateRequiredError { minimum_version: Version, } -#[cfg(feature = "eval-support")] +#[cfg(feature = "cli-support")] pub type EvalCacheKey = (EvalCacheEntryKind, u64); -#[cfg(feature = "eval-support")] +#[cfg(feature = "cli-support")] #[derive(Debug, Clone, Copy, PartialEq)] pub enum EvalCacheEntryKind { Context, @@ -1903,7 +1903,7 @@ pub enum EvalCacheEntryKind { Prediction, } -#[cfg(feature = "eval-support")] +#[cfg(feature = "cli-support")] impl std::fmt::Display for EvalCacheEntryKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -1914,7 +1914,7 @@ impl std::fmt::Display for EvalCacheEntryKind { } } -#[cfg(feature = "eval-support")] +#[cfg(feature = "cli-support")] pub trait EvalCache: Send + Sync { fn read(&self, key: EvalCacheKey) -> Option; fn write(&self, key: EvalCacheKey, input: &str, value: &str); diff --git a/crates/edit_prediction/src/udiff.rs b/crates/edit_prediction/src/udiff.rs index cefeee9f1ab011e3cd13a1a24f71b18330d3042a..78fec03dd78301d56ac6e3f914ba60432e41637d 100644 --- a/crates/edit_prediction/src/udiff.rs +++ b/crates/edit_prediction/src/udiff.rs @@ -138,7 +138,7 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result { DiffEvent::Hunk { hunk, .. } => { let hunk_offset = text .find(&hunk.context) - .ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?; + .ok_or_else(|| anyhow!("couldn't resolve hunk {:?}", hunk.context))?; for edit in hunk.edits.iter().rev() { let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end); text.replace_range(range, &edit.text); diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 034954f5760939fc31b3e5e1e8a09737c5b2e568..8586e6caaea1fdc9c865ddba8894f680d766b4a9 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -1,4 +1,4 @@ -#[cfg(feature = "eval-support")] +#[cfg(feature = "cli-support")] use crate::EvalCacheEntryKind; use crate::open_ai_response::text_from_response; use crate::prediction::EditPredictionResult; @@ -44,7 +44,7 @@ pub fn request_prediction_with_zeta2( let llm_token = store.llm_token.clone(); let app_version = AppVersion::global(cx); - #[cfg(feature = "eval-support")] + #[cfg(feature = "cli-support")] let eval_cache = store.eval_cache.clone(); let request_task = cx.background_spawn({ @@ -95,9 +95,9 @@ pub fn request_prediction_with_zeta2( client, llm_token, app_version, - #[cfg(feature = "eval-support")] + #[cfg(feature = "cli-support")] eval_cache, - #[cfg(feature = "eval-support")] + #[cfg(feature = "cli-support")] EvalCacheEntryKind::Prediction, ) .await; @@ -226,3 +226,15 @@ pub fn zeta2_prompt_input( }; (editable_offset_range, prompt_input) } + +#[cfg(feature = "cli-support")] +pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> String { + eprintln!("{}", patch); + eprintln!("---------------------"); + eprintln!("{}", input.cursor_excerpt); + crate::udiff::apply_diff_to_string( + patch, + &input.cursor_excerpt[input.editable_range_in_excerpt.clone()], + ) + .unwrap() +} diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 0e7fff8d70156c58147069f8da64035d6a80adc8..14f146a122b55b5a05529d4a32302a6dd65825d7 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -52,7 +52,7 @@ sqlez_macros.workspace = true terminal_view.workspace = true util.workspace = true watch.workspace = true -edit_prediction = { workspace = true, features = ["eval-support"] } +edit_prediction = { workspace = true, features = ["cli-support"] } wasmtime.workspace = true zeta_prompt.workspace = true zlog.workspace = true diff --git a/crates/edit_prediction_cli/src/distill.rs b/crates/edit_prediction_cli/src/distill.rs new file mode 100644 index 0000000000000000000000000000000000000000..495b3cd88cbd05ad1917517580b913aacf4fb107 --- /dev/null +++ b/crates/edit_prediction_cli/src/distill.rs @@ -0,0 +1,14 @@ +use std::mem; + +use crate::example::Example; + +pub async fn run_distill(example: &mut Example) { + let [prediction]: [_; 1] = mem::take(&mut example.predictions) + .try_into() + .expect("Run predict first with a single repetition"); + + example.expected_patch = prediction.actual_patch; + example.prompt = None; + example.predictions = Vec::new(); + example.score = Vec::new(); +} diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index a13b339ae69b9584f3b47186d8b6c36f458a2b76..1e21526e80104013a320a63e764dab0926bdd6f0 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -25,6 +25,7 @@ pub struct Example { pub name: String, pub repository_url: String, pub revision: String, + #[serde(default)] pub uncommitted_diff: String, pub cursor_path: Arc, pub cursor_position: String, @@ -195,9 +196,9 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec { .enumerate() .map(|(line_ix, line)| { let mut example = - serde_json::from_str::(line).unwrap_or_else(|_| { + serde_json::from_str::(line).unwrap_or_else(|error| { panic!( - "Failed to parse example on {}:{}", + "Failed to parse example on {}:{}\n{error}", path.display(), line_ix + 1 ) @@ -264,12 +265,12 @@ fn parse_markdown_example(id: String, input: &str) -> Result { state: None, }; - let mut name = String::new(); let mut text = String::new(); let mut block_info: CowStr = "".into(); #[derive(PartialEq)] enum Section { + Start, UncommittedDiff, EditHistory, CursorPosition, @@ -278,14 +279,16 @@ fn parse_markdown_example(id: String, input: &str) -> Result { Other, } - let mut current_section = Section::Other; + let mut current_section = Section::Start; for event in parser { match event { Event::Text(line) => { text.push_str(&line); - if let Some((field, value)) = line.split_once('=') { + if let Section::Start = current_section + && let Some((field, value)) = line.split_once('=') + { match field.trim() { REPOSITORY_URL_FIELD => { example.repository_url = value.trim().to_string(); @@ -297,14 +300,6 @@ fn parse_markdown_example(id: String, input: &str) -> Result { } } } - Event::End(TagEnd::Heading(HeadingLevel::H1)) => { - if !name.is_empty() { - anyhow::bail!( - "Found multiple H1 headings. There should only be one with the name of the example." - ); - } - name = mem::take(&mut text); - } Event::End(TagEnd::Heading(HeadingLevel::H2)) => { let title = mem::take(&mut text); current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) { @@ -363,7 +358,7 @@ fn parse_markdown_example(id: String, input: &str) -> Result { Section::ExpectedPatch => { example.expected_patch = mem::take(&mut text); } - Section::Other => {} + Section::Start | Section::Other => {} } } _ => {} diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 53ef6ebfde77dcecba9926062cdfd75c1ee3521c..598d98fdb7646585641dd9fc47668506935644f4 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -2,9 +2,13 @@ use crate::{ PromptFormat, example::{Example, ExamplePrompt}, headless::EpAppState, + load_project::run_load_project, retrieve_context::run_context_retrieval, }; -use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input}; +use edit_prediction::{ + EditPredictionStore, + zeta2::{zeta2_output_for_patch, zeta2_prompt_input}, +}; use gpui::AsyncApp; use std::sync::Arc; use zeta_prompt::format_zeta_prompt; @@ -15,11 +19,20 @@ pub async fn run_format_prompt( app_state: Arc, mut cx: AsyncApp, ) { - run_context_retrieval(example, app_state, cx.clone()).await; - - let prompt = match prompt_format { - PromptFormat::Teacher => TeacherPrompt::format(example), + run_context_retrieval(example, app_state.clone(), cx.clone()).await; + + match prompt_format { + PromptFormat::Teacher => { + let prompt = TeacherPrompt::format_prompt(example); + example.prompt = Some(ExamplePrompt { + input: prompt, + expected_output: example.expected_patch.clone(), // TODO + format: prompt_format, + }); + } PromptFormat::Zeta2 => { + run_load_project(example, app_state, cx.clone()).await; + let ep_store = cx .update(|cx| EditPredictionStore::try_global(cx).unwrap()) .unwrap(); @@ -41,30 +54,28 @@ pub async fn run_format_prompt( ) }) .unwrap(); - format_zeta_prompt(&input) + let prompt = format_zeta_prompt(&input); + let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone()); + example.prompt = Some(ExamplePrompt { + input: prompt, + expected_output, + format: prompt_format, + }); } }; - - example.prompt = Some(ExamplePrompt { - input: prompt, - expected_output: example.expected_patch.clone(), // TODO - format: prompt_format, - }); } -pub trait PromptFormatter { - fn format(example: &Example) -> String; -} +pub struct TeacherPrompt; -pub trait PromptParser { - /// Return unified diff patch of prediction given raw LLM response - fn parse(example: &Example, response: &str) -> String; -} +impl TeacherPrompt { + const PROMPT: &str = include_str!("teacher.prompt.md"); + pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n"; + pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>"; -pub struct TeacherPrompt; + /// Truncate edit history to this number of last lines + const MAX_HISTORY_LINES: usize = 128; -impl PromptFormatter for TeacherPrompt { - fn format(example: &Example) -> String { + pub fn format_prompt(example: &Example) -> String { let edit_history = Self::format_edit_history(&example.edit_history); let context = Self::format_context(example); let editable_region = Self::format_editable_region(example); @@ -76,15 +87,46 @@ impl PromptFormatter for TeacherPrompt { prompt } -} -impl TeacherPrompt { - const PROMPT: &str = include_str!("teacher.prompt.md"); - pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n"; - pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>"; + pub fn parse(example: &Example, response: &str) -> String { + // Ideally, we should always be able to find cursor position in the retrieved context. + // In reality, sometimes we don't find it for these reasons: + // 1. `example.cursor_position` contains _more_ context than included in the retrieved context + // (can be fixed by getting cursor coordinates at the load_example stage) + // 2. Context retriever just didn't include cursor line. + // + // In that case, fallback to using `cursor_position` as excerpt. + let cursor_file = &example + .buffer + .as_ref() + .expect("`buffer` should be filled in in the context collection step") + .content; - /// Truncate edit history to this number of last lines - const MAX_HISTORY_LINES: usize = 128; + // Extract updated (new) editable region from the model response + let new_editable_region = extract_last_codeblock(response); + + // Reconstruct old editable region we sent to the model + let old_editable_region = Self::format_editable_region(example); + let old_editable_region = Self::extract_editable_region(&old_editable_region); + if !cursor_file.contains(&old_editable_region) { + panic!("Something's wrong: editable_region is not found in the cursor file") + } + + // Apply editable region to a larger context and compute diff. + // This is needed to get a better context lines around the editable region + let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region); + let diff = language::unified_diff(&cursor_file, &edited_file); + + let diff = indoc::formatdoc! {" + --- a/{path} + +++ b/{path} + {diff}", + path = example.cursor_path.to_string_lossy(), + diff = diff, + }; + + diff + } fn format_edit_history(edit_history: &str) -> String { // Strip comments ("garbage lines") from edit history @@ -157,49 +199,6 @@ impl TeacherPrompt { } } -impl PromptParser for TeacherPrompt { - fn parse(example: &Example, response: &str) -> String { - // Ideally, we should always be able to find cursor position in the retrieved context. - // In reality, sometimes we don't find it for these reasons: - // 1. `example.cursor_position` contains _more_ context than included in the retrieved context - // (can be fixed by getting cursor coordinates at the load_example stage) - // 2. Context retriever just didn't include cursor line. - // - // In that case, fallback to using `cursor_position` as excerpt. - let cursor_file = &example - .buffer - .as_ref() - .expect("`buffer` should be filled in in the context collection step") - .content; - - // Extract updated (new) editable region from the model response - let new_editable_region = extract_last_codeblock(response); - - // Reconstruct old editable region we sent to the model - let old_editable_region = Self::format_editable_region(example); - let old_editable_region = Self::extract_editable_region(&old_editable_region); - if !cursor_file.contains(&old_editable_region) { - panic!("Something's wrong: editable_region is not found in the cursor file") - } - - // Apply editable region to a larger context and compute diff. - // This is needed to get a better context lines around the editable region - let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region); - let diff = language::unified_diff(&cursor_file, &edited_file); - - let diff = indoc::formatdoc! {" - --- a/{path} - +++ b/{path} - {diff} - ", - path = example.cursor_path.to_string_lossy(), - diff = diff, - }; - - diff - } -} - fn extract_last_codeblock(text: &str) -> String { let mut last_block = None; let mut search_start = 0; @@ -221,7 +220,7 @@ fn extract_last_codeblock(text: &str) -> String { } if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) { - let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1]; + let code_block = &text[backtick_end + 1..backtick_end + end_pos]; last_block = Some(code_block.to_string()); search_start = backtick_end + end_pos + backtick_count; } else { @@ -250,7 +249,7 @@ mod tests { ````` "}; let last_block = extract_last_codeblock(text); - assert_eq!(last_block, "last block"); + assert_eq!(last_block, "last block\n"); } #[test] diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 7b9f9dbac4ce8f286a0710b92a14addb5d17b20d..cd05d909f351728b2a7c1c006662621310a5f89b 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -1,4 +1,5 @@ mod anthropic_client; +mod distill; mod example; mod format_prompt; mod headless; @@ -16,6 +17,7 @@ use reqwest_client::ReqwestClient; use serde::{Deserialize, Serialize}; use std::{path::PathBuf, sync::Arc}; +use crate::distill::run_distill; use crate::example::{read_examples, write_examples}; use crate::format_prompt::run_format_prompt; use crate::load_project::run_load_project; @@ -54,6 +56,9 @@ enum Command { Predict(PredictArgs), /// Computes a score based on actual and expected patches Score(PredictArgs), + /// Prepares a distillation dataset by copying expected outputs to + /// predicted outputs and removing actual outputs and prompts. + Distill, /// Print aggregated scores Eval(PredictArgs), /// Remove git repositories and worktrees @@ -87,6 +92,7 @@ enum PredictionProvider { Zeta1, Zeta2, Teacher, + TeacherNonBatching, } impl EpArgs { @@ -175,6 +181,9 @@ fn main() { ) .await; } + Command::Distill => { + run_distill(example).await; + } Command::Score(args) | Command::Eval(args) => { run_scoring(example, &args, app_state, cx).await; } diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 11ed0e3bab0551d1e9d3e87cc98ef91ee015ac13..4ff3e1d947fd886633108cbba0d32909f72304e4 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -2,7 +2,7 @@ use crate::{ PredictionProvider, PromptFormat, anthropic_client::AnthropicClient, example::{Example, ExamplePrediction}, - format_prompt::{PromptParser, TeacherPrompt, run_format_prompt}, + format_prompt::{TeacherPrompt, run_format_prompt}, headless::EpAppState, load_project::run_load_project, paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR}, @@ -30,20 +30,24 @@ pub async fn run_prediction( return; } - run_load_project(example, app_state.clone(), cx.clone()).await; run_context_retrieval(example, app_state.clone(), cx.clone()).await; let provider = provider.unwrap(); - if matches!(provider, PredictionProvider::Teacher) { + if matches!( + provider, + PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching + ) { if example.prompt.is_none() { run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await; } - let batched = true; + let batched = matches!(provider, PredictionProvider::Teacher); return predict_anthropic(example, repetition_count, batched).await; } + run_load_project(example, app_state.clone(), cx.clone()).await; + if matches!( provider, PredictionProvider::Zeta1 | PredictionProvider::Zeta2 @@ -75,7 +79,9 @@ pub async fn run_prediction( PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2, PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, - PredictionProvider::Teacher => unreachable!(), + PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => { + unreachable!() + } }; store.set_edit_prediction_model(model); }) diff --git a/crates/edit_prediction_cli/src/teacher.prompt.md b/crates/edit_prediction_cli/src/teacher.prompt.md index 238d3b7ac1297583727f562f1755d084ff5a3ceb..d629152da6739ec1d603857f6a9ee556c8986fe8 100644 --- a/crates/edit_prediction_cli/src/teacher.prompt.md +++ b/crates/edit_prediction_cli/src/teacher.prompt.md @@ -18,6 +18,7 @@ Focus on: Rules: - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals. - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code. +- Keep existing formatting unless it's absolutely necessary Input format: - You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant.