ep cli: Compute editable region during format-prompt (#46929)

Agus Zubiaga , Max Brunsfeld , and Ben Kunkle created

Release Notes:

- N/A

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Ben Kunkle <ben@zed.dev>

Change summary

crates/edit_prediction_cli/src/example.rs       |   3 
crates/edit_prediction_cli/src/format_prompt.rs | 111 +++++++++++-------
crates/edit_prediction_cli/src/load_project.rs  |  21 +--
crates/edit_prediction_cli/src/predict.rs       |   2 
4 files changed, 74 insertions(+), 63 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/example.rs 🔗

@@ -12,7 +12,6 @@ use serde::{Deserialize, Serialize};
 use std::{
     borrow::Cow,
     io::Read,
-    ops::Range,
     path::{Path, PathBuf},
     sync::Arc,
 };
@@ -60,8 +59,6 @@ pub struct ExamplePromptInputs {
     pub cursor_row: u32,
     pub cursor_column: u32,
     pub cursor_offset: usize,
-    pub context_range: Range<usize>,
-    pub editable_range: Range<usize>,
     pub edit_history: Vec<Arc<zeta_prompt::Event>>,
     pub related_files: Option<Vec<RelatedFile>>,
 }

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -6,10 +6,12 @@ use crate::{
     retrieve_context::run_context_retrieval,
 };
 use anyhow::{Context as _, Result};
-use gpui::AsyncApp;
+use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
+use gpui::{AppContext, AsyncApp};
+use language::{Buffer, OffsetRangeExt, Point};
 use similar::DiffableStr;
-use std::fmt::Write as _;
 use std::sync::Arc;
+use std::{fmt::Write as _, ops::Range};
 use zeta_prompt::format_zeta_prompt;
 
 pub async fn run_format_prompt(
@@ -18,7 +20,7 @@ pub async fn run_format_prompt(
     app_state: Arc<EpAppState>,
     cx: AsyncApp,
 ) -> Result<()> {
-    run_context_retrieval(example, app_state, cx).await?;
+    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 
     let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
 
@@ -27,10 +29,35 @@ pub async fn run_format_prompt(
         .as_ref()
         .context("prompt_inputs must be set after context retrieval")?;
 
+    let language = app_state
+        .languages
+        .load_language_for_file_path(&example.spec.cursor_path)
+        .await
+        .ok();
+    let snapshot_fut = cx.update(|cx| {
+        Buffer::build_snapshot(
+            prompt_inputs.content.as_str().into(),
+            language,
+            Some(app_state.languages.clone()),
+            cx,
+        )
+    });
+    let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
+    let snapshot = cx.background_spawn(snapshot_fut).await;
+
+    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
+        cursor_point,
+        &snapshot,
+        edit_prediction::zeta2::MAX_EDITABLE_TOKENS,
+        edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
+    );
+    let editable_range = editable_range.to_offset(&snapshot);
+    let context_range = context_range.to_offset(&snapshot);
+
     match args.provider {
         PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
             step_progress.set_substatus("formatting teacher prompt");
-            let prompt = TeacherPrompt::format_prompt(example);
+            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
             example.prompt = Some(ExamplePrompt {
                 input: prompt,
                 expected_output: example
@@ -45,15 +72,13 @@ pub async fn run_format_prompt(
         PredictionProvider::Zeta2(version) => {
             step_progress.set_substatus("formatting zeta2 prompt");
 
-            let context_start = prompt_inputs.context_range.start;
+            let context_start = context_range.start;
             let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
-            let editable_range_in_excerpt = (prompt_inputs.editable_range.start - context_start)
-                ..(prompt_inputs.editable_range.end - context_start);
+            let editable_range_in_excerpt =
+                (editable_range.start - context_start)..(editable_range.end - context_start);
             let input = zeta_prompt::ZetaPromptInput {
                 cursor_path: example.spec.cursor_path.clone(),
-                cursor_excerpt: prompt_inputs.content[prompt_inputs.context_range.clone()]
-                    .to_string()
-                    .into(),
+                cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
                 editable_range_in_excerpt,
                 cursor_offset_in_excerpt,
                 events: prompt_inputs.edit_history.clone(),
@@ -109,10 +134,14 @@ impl TeacherPrompt {
     /// Truncate edit history to this number of last lines
     const MAX_HISTORY_LINES: usize = 128;
 
-    pub fn format_prompt(example: &Example) -> String {
+    pub fn format_prompt(
+        example: &Example,
+        editable_range: Range<usize>,
+        context_range: Range<usize>,
+    ) -> String {
         let edit_history = Self::format_edit_history(&example.spec.edit_history);
         let context = Self::format_context(example);
-        let cursor_excerpt = Self::format_cursor_excerpt(example);
+        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
 
         let prompt = Self::PROMPT
             .replace("{{context}}", &context)
@@ -123,25 +152,21 @@ impl TeacherPrompt {
     }
 
     pub fn parse(example: &Example, response: &str) -> Result<String> {
-        // Ideally, we should always be able to find cursor position in the retrieved context.
-        // In reality, sometimes we don't find it for these reasons:
-        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
-        //    (can be fixed by getting cursor coordinates at the load_example stage)
-        // 2. Context retriever just didn't include cursor line.
-        //
-        // In that case, fallback to using `cursor_position` as excerpt.
-        let prompt_inputs = example
-            .prompt_inputs
-            .as_ref()
-            .context("`prompt_inputs` should be filled in in the context collection step")?;
-
         // Extract updated (new) editable region from the model response.
         // The model may include editable region markers in its output, so we need to strip them.
         let new_editable_region = extract_last_codeblock(response);
         let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
-
-        let old_editable_region =
-            prompt_inputs.content[prompt_inputs.editable_range.clone()].to_string();
+        let old_editable_region = Self::extract_editable_region(
+            &example
+                .prompt
+                .as_ref()
+                .context("example prompt missing")?
+                .input,
+        );
+        let prompt_inputs = example
+            .prompt_inputs
+            .as_ref()
+            .context("example is missing prompt inputs")?;
 
         // Normalize leading newlines: if old starts with newline but new doesn't,
         // prepend newline to new to preserve whitespace structure.
@@ -150,8 +175,12 @@ impl TeacherPrompt {
             new_editable_region.insert(0, '\n');
         }
 
-        let editable_region_start_line = prompt_inputs.content
-            [..prompt_inputs.editable_range.start]
+        let (editable_region_offset, _) = prompt_inputs
+            .content
+            .match_indices(&old_editable_region)
+            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
+            .unwrap();
+        let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
             .matches('\n')
             .count();
 
@@ -229,30 +258,24 @@ impl TeacherPrompt {
         prompt
     }
 
-    fn format_cursor_excerpt(example: &Example) -> String {
+    fn format_cursor_excerpt(
+        example: &Example,
+        editable_range: Range<usize>,
+        context_range: Range<usize>,
+    ) -> String {
         let mut result = String::new();
 
         let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
 
         let path_str = example.spec.cursor_path.to_string_lossy();
         result.push_str(&format!("`````{path_str}\n"));
-        result.push_str(
-            &prompt_inputs.content
-                [prompt_inputs.context_range.start..prompt_inputs.editable_range.start],
-        );
+        result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
         result.push_str(Self::EDITABLE_REGION_START);
-        result.push_str(
-            &prompt_inputs.content[prompt_inputs.editable_range.start..prompt_inputs.cursor_offset],
-        );
+        result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
         result.push_str(Self::USER_CURSOR_MARKER);
-        result.push_str(
-            &prompt_inputs.content[prompt_inputs.cursor_offset..prompt_inputs.editable_range.end],
-        );
+        result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
         result.push_str(Self::EDITABLE_REGION_END);
-        result.push_str(
-            &prompt_inputs.content
-                [prompt_inputs.editable_range.end..prompt_inputs.context_range.end],
-        );
+        result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
         result.push_str("\n`````");
 
         result

crates/edit_prediction_cli/src/load_project.rs 🔗

@@ -7,13 +7,11 @@ use crate::{
 use anyhow::{Context as _, Result};
 use edit_prediction::{
     EditPredictionStore,
-    cursor_excerpt::editable_and_context_ranges_for_cursor_position,
     udiff::{OpenedBuffers, refresh_worktree_entries, strip_diff_path_prefix},
-    zeta2,
 };
 use futures::AsyncWriteExt as _;
 use gpui::{AsyncApp, Entity};
-use language::{Anchor, Buffer, LanguageNotFound, OffsetRangeExt as _, ToOffset, ToPoint};
+use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint};
 use project::Project;
 use project::buffer_store::BufferStoreEvent;
 use std::{fs, path::PathBuf, sync::Arc};
@@ -55,15 +53,6 @@ pub async fn run_load_project(
 
     let (prompt_inputs, language_name) = buffer.read_with(&cx, |buffer, _cx| {
         let cursor_point = cursor_position.to_point(&buffer);
-        let snapshot = buffer.snapshot();
-        let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
-            cursor_point,
-            &snapshot,
-            zeta2::MAX_EDITABLE_TOKENS,
-            zeta2::MAX_CONTEXT_TOKENS,
-        );
-        let editable_range = editable_range.to_offset(&snapshot);
-        let context_range = context_range.to_offset(&snapshot);
         let language_name = buffer
             .language()
             .map(|l| l.name().to_string())
@@ -74,10 +63,12 @@ pub async fn run_load_project(
                 cursor_row: cursor_point.row,
                 cursor_column: cursor_point.column,
                 cursor_offset: cursor_position.to_offset(&buffer),
-                context_range,
-                editable_range,
                 edit_history,
-                related_files: None,
+                related_files: example
+                    .prompt_inputs
+                    .take()
+                    .map(|inputs| inputs.related_files)
+                    .unwrap_or_default(),
             },
             language_name,
         )

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -287,7 +287,7 @@ async fn predict_anthropic(
         .collect::<Vec<String>>()
         .join("\n");
 
-    let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
+    let actual_patch = TeacherPrompt::parse(&example, &actual_output)?;
 
     let prediction = ExamplePrediction {
         actual_patch,