diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index d3e10834f10b071e3602b7f399fbc8f28509fff1..91bee12f29a2cafa8833e7686da784ac527cfec1 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -12,7 +12,6 @@ use serde::{Deserialize, Serialize}; use std::{ borrow::Cow, io::Read, - ops::Range, path::{Path, PathBuf}, sync::Arc, }; @@ -60,8 +59,6 @@ pub struct ExamplePromptInputs { pub cursor_row: u32, pub cursor_column: u32, pub cursor_offset: usize, - pub context_range: Range, - pub editable_range: Range, pub edit_history: Vec>, pub related_files: Option>, } diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index a6ce738f3071e97c0f83bd6b17d65867449b4de7..d876c24726b783102f166049ae0f07e6e7c78d81 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -6,10 +6,12 @@ use crate::{ retrieve_context::run_context_retrieval, }; use anyhow::{Context as _, Result}; -use gpui::AsyncApp; +use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position; +use gpui::{AppContext, AsyncApp}; +use language::{Buffer, OffsetRangeExt, Point}; use similar::DiffableStr; -use std::fmt::Write as _; use std::sync::Arc; +use std::{fmt::Write as _, ops::Range}; use zeta_prompt::format_zeta_prompt; pub async fn run_format_prompt( @@ -18,7 +20,7 @@ pub async fn run_format_prompt( app_state: Arc, cx: AsyncApp, ) -> Result<()> { - run_context_retrieval(example, app_state, cx).await?; + run_context_retrieval(example, app_state.clone(), cx.clone()).await?; let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name); @@ -27,10 +29,35 @@ pub async fn run_format_prompt( .as_ref() .context("prompt_inputs must be set after context retrieval")?; + let language = app_state + .languages + .load_language_for_file_path(&example.spec.cursor_path) + .await + .ok(); + let snapshot_fut = cx.update(|cx| { + Buffer::build_snapshot( + prompt_inputs.content.as_str().into(), + language, + Some(app_state.languages.clone()), + cx, + ) + }); + let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column); + let snapshot = cx.background_spawn(snapshot_fut).await; + + let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( + cursor_point, + &snapshot, + edit_prediction::zeta2::MAX_EDITABLE_TOKENS, + edit_prediction::zeta2::MAX_CONTEXT_TOKENS, + ); + let editable_range = editable_range.to_offset(&snapshot); + let context_range = context_range.to_offset(&snapshot); + match args.provider { PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => { step_progress.set_substatus("formatting teacher prompt"); - let prompt = TeacherPrompt::format_prompt(example); + let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range); example.prompt = Some(ExamplePrompt { input: prompt, expected_output: example @@ -45,15 +72,13 @@ pub async fn run_format_prompt( PredictionProvider::Zeta2(version) => { step_progress.set_substatus("formatting zeta2 prompt"); - let context_start = prompt_inputs.context_range.start; + let context_start = context_range.start; let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start; - let editable_range_in_excerpt = (prompt_inputs.editable_range.start - context_start) - ..(prompt_inputs.editable_range.end - context_start); + let editable_range_in_excerpt = + (editable_range.start - context_start)..(editable_range.end - context_start); let input = zeta_prompt::ZetaPromptInput { cursor_path: example.spec.cursor_path.clone(), - cursor_excerpt: prompt_inputs.content[prompt_inputs.context_range.clone()] - .to_string() - .into(), + cursor_excerpt: prompt_inputs.content[context_range].to_string().into(), editable_range_in_excerpt, cursor_offset_in_excerpt, events: prompt_inputs.edit_history.clone(), @@ -109,10 +134,14 @@ impl TeacherPrompt { /// Truncate edit history to this number of last lines const MAX_HISTORY_LINES: usize = 128; - pub fn format_prompt(example: &Example) -> String { + pub fn format_prompt( + example: &Example, + editable_range: Range, + context_range: Range, + ) -> String { let edit_history = Self::format_edit_history(&example.spec.edit_history); let context = Self::format_context(example); - let cursor_excerpt = Self::format_cursor_excerpt(example); + let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range); let prompt = Self::PROMPT .replace("{{context}}", &context) @@ -123,25 +152,21 @@ impl TeacherPrompt { } pub fn parse(example: &Example, response: &str) -> Result { - // Ideally, we should always be able to find cursor position in the retrieved context. - // In reality, sometimes we don't find it for these reasons: - // 1. `example.cursor_position` contains _more_ context than included in the retrieved context - // (can be fixed by getting cursor coordinates at the load_example stage) - // 2. Context retriever just didn't include cursor line. - // - // In that case, fallback to using `cursor_position` as excerpt. - let prompt_inputs = example - .prompt_inputs - .as_ref() - .context("`prompt_inputs` should be filled in in the context collection step")?; - // Extract updated (new) editable region from the model response. // The model may include editable region markers in its output, so we need to strip them. let new_editable_region = extract_last_codeblock(response); let mut new_editable_region = Self::extract_editable_region(&new_editable_region); - - let old_editable_region = - prompt_inputs.content[prompt_inputs.editable_range.clone()].to_string(); + let old_editable_region = Self::extract_editable_region( + &example + .prompt + .as_ref() + .context("example prompt missing")? + .input, + ); + let prompt_inputs = example + .prompt_inputs + .as_ref() + .context("example is missing prompt inputs")?; // Normalize leading newlines: if old starts with newline but new doesn't, // prepend newline to new to preserve whitespace structure. @@ -150,8 +175,12 @@ impl TeacherPrompt { new_editable_region.insert(0, '\n'); } - let editable_region_start_line = prompt_inputs.content - [..prompt_inputs.editable_range.start] + let (editable_region_offset, _) = prompt_inputs + .content + .match_indices(&old_editable_region) + .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset)) + .unwrap(); + let editable_region_start_line = prompt_inputs.content[..editable_region_offset] .matches('\n') .count(); @@ -229,30 +258,24 @@ impl TeacherPrompt { prompt } - fn format_cursor_excerpt(example: &Example) -> String { + fn format_cursor_excerpt( + example: &Example, + editable_range: Range, + context_range: Range, + ) -> String { let mut result = String::new(); let prompt_inputs = example.prompt_inputs.as_ref().unwrap(); let path_str = example.spec.cursor_path.to_string_lossy(); result.push_str(&format!("`````{path_str}\n")); - result.push_str( - &prompt_inputs.content - [prompt_inputs.context_range.start..prompt_inputs.editable_range.start], - ); + result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]); result.push_str(Self::EDITABLE_REGION_START); - result.push_str( - &prompt_inputs.content[prompt_inputs.editable_range.start..prompt_inputs.cursor_offset], - ); + result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]); result.push_str(Self::USER_CURSOR_MARKER); - result.push_str( - &prompt_inputs.content[prompt_inputs.cursor_offset..prompt_inputs.editable_range.end], - ); + result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]); result.push_str(Self::EDITABLE_REGION_END); - result.push_str( - &prompt_inputs.content - [prompt_inputs.editable_range.end..prompt_inputs.context_range.end], - ); + result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]); result.push_str("\n`````"); result diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index b6296b75ebca4af237d71f31fd209610dfe63f53..7afbb871ff3957c3c70d64093b30e2dcea661b78 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -7,13 +7,11 @@ use crate::{ use anyhow::{Context as _, Result}; use edit_prediction::{ EditPredictionStore, - cursor_excerpt::editable_and_context_ranges_for_cursor_position, udiff::{OpenedBuffers, refresh_worktree_entries, strip_diff_path_prefix}, - zeta2, }; use futures::AsyncWriteExt as _; use gpui::{AsyncApp, Entity}; -use language::{Anchor, Buffer, LanguageNotFound, OffsetRangeExt as _, ToOffset, ToPoint}; +use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint}; use project::Project; use project::buffer_store::BufferStoreEvent; use std::{fs, path::PathBuf, sync::Arc}; @@ -55,15 +53,6 @@ pub async fn run_load_project( let (prompt_inputs, language_name) = buffer.read_with(&cx, |buffer, _cx| { let cursor_point = cursor_position.to_point(&buffer); - let snapshot = buffer.snapshot(); - let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( - cursor_point, - &snapshot, - zeta2::MAX_EDITABLE_TOKENS, - zeta2::MAX_CONTEXT_TOKENS, - ); - let editable_range = editable_range.to_offset(&snapshot); - let context_range = context_range.to_offset(&snapshot); let language_name = buffer .language() .map(|l| l.name().to_string()) @@ -74,10 +63,12 @@ pub async fn run_load_project( cursor_row: cursor_point.row, cursor_column: cursor_point.column, cursor_offset: cursor_position.to_offset(&buffer), - context_range, - editable_range, edit_history, - related_files: None, + related_files: example + .prompt_inputs + .take() + .map(|inputs| inputs.related_files) + .unwrap_or_default(), }, language_name, ) diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index a5f92ba55fd83d8bb5979fe6b9d831f185dcd338..be22da635b320407befc46f04748f18c874294ac 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -287,7 +287,7 @@ async fn predict_anthropic( .collect::>() .join("\n"); - let actual_patch = TeacherPrompt::parse(example, &actual_output)?; + let actual_patch = TeacherPrompt::parse(&example, &actual_output)?; let prediction = ExamplePrediction { actual_patch,