diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 3be486bae5cdd64510bbb6eb73f9f06632c88bda..aa8727bc47046594e08dc75a2715fea9c0a4c824 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -280,31 +280,7 @@ pub(crate) fn edit_prediction_accepted( #[cfg(feature = "cli-support")] pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result { - let text = &input.cursor_excerpt; - let editable_region = input.editable_range_in_excerpt.clone(); - let old_prefix = &text[..editable_region.start]; - let old_suffix = &text[editable_region.end..]; - - // Try applying the patch directly first - let new = match crate::udiff::apply_diff_to_string(patch, text) { - Ok(new) => new, - Err(_) if !text.ends_with('\n') => { - // If the text doesn't end with a newline, the patch context may expect one - // (due to missing "no newline at EOF" markers). Try again with a trailing newline. - let text_with_newline = format!("{}\n", text); - let mut new = crate::udiff::apply_diff_to_string(patch, &text_with_newline)?; - // Remove the trailing newline we added if the result still has it - if new.ends_with('\n') && !text.ends_with('\n') { - new.pop(); - } - new - } - Err(e) => return Err(e), - }; - - if !new.starts_with(old_prefix) || !new.ends_with(old_suffix) { - anyhow::bail!("Patch shouldn't affect text outside of editable region"); - } - - Ok(new[editable_region.start..new.len() - old_suffix.len()].to_string()) + let old_editable_region = &input.cursor_excerpt[input.editable_range_in_excerpt.clone()]; + let new_editable_region = crate::udiff::apply_diff_to_string(patch, old_editable_region)?; + Ok(new_editable_region) } diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 34504d5299f3bc7740f8c4d35d5e9de3d66e7791..b531da33aa1a4a68cdd67d3dc17f4dc70b364d04 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -6,12 +6,13 @@ use crate::{ progress::{Progress, Step}, retrieve_context::run_context_retrieval, }; -use anyhow::{Context as _, Result, ensure}; +use anyhow::{Context as _, Result}; use edit_prediction::{ EditPredictionStore, zeta2::{zeta2_output_for_patch, zeta2_prompt_input}, }; use gpui::{AsyncApp, Entity}; +use std::fmt::Write as _; use std::sync::Arc; use zeta_prompt::format_zeta_prompt; @@ -102,7 +103,7 @@ pub struct TeacherPrompt; impl TeacherPrompt { const PROMPT: &str = include_str!("teacher.prompt.md"); pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n"; - pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>"; + pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>"; pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>"; /// Truncate edit history to this number of last lines @@ -111,12 +112,12 @@ impl TeacherPrompt { pub fn format_prompt(example: &Example) -> String { let edit_history = Self::format_edit_history(&example.spec.edit_history); let context = Self::format_context(example); - let editable_region = Self::format_editable_region(example); + let cursor_excerpt = Self::format_cursor_excerpt(example); let prompt = Self::PROMPT .replace("{{context}}", &context) .replace("{{edit_history}}", &edit_history) - .replace("{{editable_region}}", &editable_region); + .replace("{{cursor_excerpt}}", &cursor_excerpt); prompt } @@ -133,7 +134,6 @@ impl TeacherPrompt { .buffer .as_ref() .context("`buffer` should be filled in in the context collection step")?; - let cursor_file = &example_buffer.content; // Extract updated (new) editable region from the model response. // The model may include editable region markers in its output, so we need to strip them. @@ -150,15 +150,17 @@ impl TeacherPrompt { new_editable_region.insert(0, '\n'); } - ensure!( - cursor_file.contains(&old_editable_region), - "Something's wrong: editable_region is not found in the cursor file" - ); + let editable_region_start_line = example_buffer.content + [..example_buffer.editable_range.start] + .matches('\n') + .count(); - // Apply editable region to a larger context and compute diff. - // This is needed to get a better context lines around the editable region - let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region); - let diff = language::unified_diff(&cursor_file, &edited_file); + let diff = language::unified_diff_with_offsets( + &old_editable_region, + &new_editable_region, + editable_region_start_line as u32, + editable_region_start_line as u32, + ); let diff = indoc::formatdoc! {" --- a/{path} @@ -192,21 +194,44 @@ impl TeacherPrompt { } fn format_context(example: &Example) -> String { - assert!(example.context.is_some(), "Missing context retriever step"); + let context = example + .context + .as_ref() + .expect("Missing context retriever step"); + + if context.files.is_empty() { + return "(No context)".to_string(); + } let mut prompt = String::new(); - zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files); + for file in context.files.as_ref() { + let path_str = file.path.to_string_lossy(); + writeln!(&mut prompt, "`````{path_str}").ok(); + let mut prev_row = 0; + for excerpt in &file.excerpts { + if excerpt.row_range.start > prev_row { + prompt.push_str("…\n"); + } + prompt.push_str(&excerpt.text); + prompt.push('\n'); + prev_row = excerpt.row_range.end; + } + if prev_row < file.max_row { + prompt.push_str("…\n"); + } + prompt.push_str("\n`````"); + } prompt } - fn format_editable_region(example: &Example) -> String { + fn format_cursor_excerpt(example: &Example) -> String { let mut result = String::new(); let example_buffer = example.buffer.as_ref().unwrap(); let path_str = example.spec.cursor_path.to_string_lossy(); - result.push_str(&format!("`````path=\"{path_str}\"\n")); + result.push_str(&format!("`````{path_str}\n")); result.push_str( &example_buffer.content [example_buffer.context_range.start..example_buffer.editable_range.start], @@ -240,7 +265,7 @@ impl TeacherPrompt { let region = &text[start..end]; let region = region.strip_suffix('\n').unwrap_or(region); - region.replace("<|user_cursor|>", "") + region.replace(Self::USER_CURSOR_MARKER, "") } fn is_udiff_content_line(s: &str) -> bool { @@ -356,8 +381,7 @@ mod tests { parsed, indoc::indoc! {" one - two three - "} + two three"} ); } @@ -403,8 +427,7 @@ mod tests { fn test_extract_editable_region_no_markers() { let text = indoc::indoc! {" one - two three - "}; + two three"}; let parsed = TeacherPrompt::extract_editable_region(text); assert_eq!( parsed, @@ -428,8 +451,7 @@ mod tests { parsed, indoc::indoc! {" one - two three - "} + two three"} ); } } diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index e64246a2af7fb278a6ff5f0d6bfa1db6943d64fe..6e5b28e18f47aaa85a0d11c1754cc9c107cc18ea 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -443,8 +443,11 @@ fn main() { let mut examples = load_examples(app_state.client.http_client(), &args, output.as_ref()).await?; - if let Command::Predict(args) = &command { - predict::sync_batches(&args.provider).await?; + match &command { + Command::Predict(args) | Command::Score(args) | Command::Eval(args) => { + predict::sync_batches(&args.provider).await?; + } + _ => (), } let failfast_on_single_example = examples.len() == 1; @@ -561,7 +564,13 @@ fn main() { Progress::global().finalize(); match &command { - Command::Predict(args) => predict::sync_batches(&args.provider).await?, + Command::Predict(args) | Command::Score(args) | Command::Eval(args) => { + predict::sync_batches(&args.provider).await?; + } + _ => (), + } + + match &command { Command::Eval(_) => score::print_report(&examples), _ => (), }; @@ -606,7 +615,7 @@ async fn handle_error( .await .unwrap(); - let file_path = example + let cursor_path = example .repo_name() .unwrap() .worktree_path() @@ -625,9 +634,9 @@ async fn handle_error( "}, example.spec.name, error, - err_path.display(), - file_path.display(), failed_example_path.display(), + err_path.display(), + cursor_path.display(), command, failed_example_path.display(), ); diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 8a9500f96967415171f4627ffc7c6ce40f355c66..04a58ee2e7b66f2ce40db626088baec898262405 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -46,9 +46,7 @@ pub async fn run_prediction( ) { let _step_progress = Progress::global().start(Step::Predict, &example.spec.name); - if example.prompt.is_none() { - run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?; - } + run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?; let batched = matches!(provider, PredictionProvider::Teacher); return predict_anthropic(example, repetition_count, batched).await; diff --git a/crates/edit_prediction_cli/src/teacher.prompt.md b/crates/edit_prediction_cli/src/teacher.prompt.md index d629152da6739ec1d603857f6a9ee556c8986fe8..d161988dc74528802c66413638ad6a7c3d6a5123 100644 --- a/crates/edit_prediction_cli/src/teacher.prompt.md +++ b/crates/edit_prediction_cli/src/teacher.prompt.md @@ -1,53 +1,80 @@ # Instructions -You are a code completion assistant helping a programmer finish their work. Your task is to: +You are an edit prediction assistant in a code editor. Your task is to predict the next edit to a given region of code surrounding the user's cursor. 1. Analyze the edit history to understand what the programmer is trying to achieve 2. Identify any incomplete refactoring or changes that need to be finished -3. Make the remaining edits that a human programmer would logically make next (by rewriting the corresponding code sections) -4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere. +3. Make the remaining edits that a human programmer would logically make next (by rewriting the code around their cursor) -Focus on: -- Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs) -- Completing any partially-applied changes across the codebase +## Focus on + +- Completing any partially-applied changes made - Ensuring consistency with the programming style and patterns already established - Making edits that maintain or improve code quality -- If the programmer started refactoring one instance of a pattern, find and update ALL similar instances -- Don't write a lot of code if you're not sure what to do -Rules: +## Rules + - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals. - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code. -- Keep existing formatting unless it's absolutely necessary +- Keep existing formatting unless it's absolutely necessary +- Don't write a lot of code if you're not sure what to do + +# Input Format + +You will be provided with: +1. The user's *edit history*, in chronological order. Use this to infer the user's trajectory and predict the next most logical edit. +2. A set of *related excerpts* from the user's codebase. Some of these may be needed for correctly predicting the next edit. + - `…` may appear within a related file to indicate that some code has been skipped. +3. An excerpt from the user's *current file*. + - Within the user's current file, there is an *editable region* delimited by the `<|editable_region_start|>` and `<|editable_region_end|>` tags. You can only predict edits in this region. + - The `<|user_cursor|>` tag marks the user's current cursor position, as it stands after the last edit in the history. + +# Output Format + +- Briefly explain the user's current intent based on the edit history and their current cursor location. +- Output the entire editable region, applying the edits that you predict the user will make next. +- If you're unsure some portion of the next edit, you may still predict the surrounding code (such as a function definition, `for` loop, etc) and place the `<|user_cursor|>` within it for the user to fill in. +- Wrap the edited code in a codeblock with exactly five backticks. -Input format: -- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant. -- Never modify the context code. -- You also receive a code snippet between <|editable_region_start|> and <|editable_region_end|>. This is the editable region. -- The cursor position is marked with <|user_cursor|>. +## Example -Output format: -- Return the entire editable region, applying any edits you make. -- Remove the <|user_cursor|> marker. -- Wrap the edited code in a block of exactly five backticks. +### Input -Output example: ````` - // `zed --askpass` Makes zed operate in nc/netcat mode for use with askpass - if let Some(socket) = &args.askpass {{ - askpass::main(socket); - return Ok(()); - }} +struct Product { + name: String, + price: u32, +} + +fn calculate_total(products: &[Product]) -> u32 { +<|editable_region_start|> + let mut total = 0; + for product in products { + total += <|user_cursor|>; + } + total +<|editable_region_end|> +} ````` -## User Edits History +### Output -{{edit_history}} +The user is computing a sum based on a list of products. The only numeric field on `Product` is `price`, so they must intend to sum the prices. -## Code Context +````` + let mut total = 0; + for product in products { + total += product.price; + } + total +````` + +# 1. User Edits History + +{{edit_history}} {{context}} -## Editable region +# 3. Current File -{{editable_region}} +{{cursor_excerpt}} diff --git a/crates/language/src/text_diff.rs b/crates/language/src/text_diff.rs index 4bca5b60febd86972e39dfbab3ae53621eef507e..774fae2cb832397b07aaa2fbcedef22c119f8bf3 100644 --- a/crates/language/src/text_diff.rs +++ b/crates/language/src/text_diff.rs @@ -16,7 +16,7 @@ pub fn unified_diff(old_text: &str, new_text: &str) -> String { } /// Computes a diff between two strings, returning a unified diff string with -/// hunk headers adjusted to reflect the given starting line numbers (1-indexed). +/// hunk headers adjusted to reflect the given starting line numbers (zero-indexed). pub fn unified_diff_with_offsets( old_text: &str, new_text: &str,