zeta2: Include context in captured examples (#47516)

Ben Kunkle and Zed Zippy created

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

---------

Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com>

Change summary

crates/edit_prediction/src/capture_example.rs       | 104 ++++++++++++++
crates/edit_prediction/src/edit_prediction.rs       |  20 ++
crates/edit_prediction/src/example_spec.rs          |  74 ++++++++++
crates/edit_prediction_cli/src/retrieve_context.rs  |  30 ++++
crates/edit_prediction_cli/src/split_commit.rs      |   2 
crates/edit_prediction_cli/src/synthesize.rs        |   1 
crates/edit_prediction_ui/src/edit_prediction_ui.rs |  10 +
7 files changed, 232 insertions(+), 9 deletions(-)

Detailed changes

crates/edit_prediction/src/capture_example.rs 🔗

@@ -1,6 +1,10 @@
 use crate::{
     EditPredictionExampleCaptureFeatureFlag, StoredEvent,
-    cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
+    cursor_excerpt::editable_and_context_ranges_for_cursor_position,
+    example_spec::{
+        CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
+        ExampleSpec, MAX_CURSOR_FILE_SIZE,
+    },
 };
 use anyhow::Result;
 use buffer_diff::BufferDiffSnapshot;
@@ -10,7 +14,7 @@ use gpui::{App, Entity, Task};
 use language::{Buffer, ToPoint as _};
 use project::{Project, WorktreeId};
 use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
-use text::{BufferSnapshot as TextBufferSnapshot, Point};
+use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
 
 pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
 pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
@@ -20,6 +24,7 @@ pub fn capture_example(
     buffer: Entity<Buffer>,
     cursor_anchor: language::Anchor,
     mut events: Vec<StoredEvent>,
+    related_files: Vec<zeta_prompt::RelatedFile>,
     populate_expected_patch: bool,
     cx: &mut App,
 ) -> Option<Task<Result<ExampleSpec>>> {
@@ -58,7 +63,16 @@ pub fn capture_example(
             .and_then(|lang| lang.config().line_comments.first())
             .map(|s| s.to_string())
             .unwrap_or_default();
-        let (cursor_excerpt, cursor_offset, cursor_excerpt_range) = cx
+
+        let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
+        let cursor_point = cursor_anchor.to_point(&snapshot);
+        let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
+            Some(snapshot.text())
+        } else {
+            None
+        };
+
+        let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
             .background_executor()
             .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
             .await;
@@ -99,6 +113,53 @@ pub fn capture_example(
             rejected_patch = Some(empty_patch);
         }
 
+        let prompt_input = cursor_file_content.map(|content| {
+            let captured_events: Vec<CapturedEvent> = events
+                .iter()
+                .map(|stored_event| {
+                    let zeta_prompt::Event::BufferChange {
+                        path,
+                        old_path,
+                        diff,
+                        predicted,
+                        in_open_source_repo,
+                    } = stored_event.event.as_ref();
+                    CapturedEvent {
+                        path: strip_root_name(path, &root_name).into(),
+                        old_path: strip_root_name(old_path, &root_name).into(),
+                        diff: diff.clone(),
+                        predicted: *predicted,
+                        in_open_source_repo: *in_open_source_repo,
+                    }
+                })
+                .collect();
+
+            let captured_related_files: Vec<CapturedRelatedFile> = related_files
+                .iter()
+                .map(|rf| CapturedRelatedFile {
+                    path: strip_root_name(&rf.path, &root_name).into(),
+                    max_row: rf.max_row,
+                    excerpts: rf
+                        .excerpts
+                        .iter()
+                        .map(|e| CapturedRelatedExcerpt {
+                            row_range: e.row_range.clone(),
+                            text: e.text.to_string(),
+                        })
+                        .collect(),
+                })
+                .collect();
+
+            CapturedPromptInput {
+                cursor_file_content: content,
+                cursor_offset: full_cursor_offset,
+                cursor_row: cursor_point.row,
+                cursor_column: cursor_point.column,
+                events: captured_events,
+                related_files: captured_related_files,
+            }
+        });
+
         let mut spec = ExampleSpec {
             name: generate_timestamp_name(),
             repository_url,
@@ -111,8 +172,13 @@ pub fn capture_example(
             edit_history,
             expected_patches,
             rejected_patch,
+            captured_prompt_input: prompt_input,
         };
-        spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
+        spec.set_cursor_excerpt(
+            &cursor_excerpt,
+            cursor_offset_in_excerpt,
+            &line_comment_prefix,
+        );
         Ok(spec)
     }))
 }
@@ -414,6 +480,7 @@ mod tests {
                     buffer.clone(),
                     Anchor::MIN,
                     events,
+                    Vec::new(),
                     true,
                     cx,
                 )
@@ -530,9 +597,36 @@ mod tests {
                          }
                     "}
                     .to_string()
-                )
+                ),
+                captured_prompt_input: example.captured_prompt_input.clone(),
             }
         );
+
+        let prompt_input = example
+            .captured_prompt_input
+            .expect("should have captured prompt input");
+        assert!(
+            prompt_input.cursor_file_content.contains("fn main()"),
+            "cursor_file_content should contain file content"
+        );
+        assert_eq!(
+            prompt_input.cursor_offset, 0,
+            "cursor at Anchor::MIN should be offset 0"
+        );
+        assert_eq!(
+            prompt_input.cursor_row, 0,
+            "cursor at Anchor::MIN should be row 0"
+        );
+        assert_eq!(
+            prompt_input.cursor_column, 0,
+            "cursor at Anchor::MIN should be column 0"
+        );
+        assert!(prompt_input.events.len() > 0, "should have captured events");
+        assert_eq!(
+            prompt_input.related_files.len(),
+            0,
+            "should have no related files (none passed)"
+        );
     }
 
     fn init_test(cx: &mut TestAppContext) {

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -1738,16 +1738,19 @@ impl EditPredictionStore {
         let can_collect_example = snapshot
             .file()
             .is_some_and(|file| self.can_collect_file(&project, file, cx))
-            && self.can_collect_events(&inputs.events, cx);
+            && self.can_collect_events(&inputs.events, cx)
+            && self.can_collect_related_files(&project, cx);
 
         if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
             let events_for_capture =
                 self.edit_history_for_project_with_pause_split_last_event(&project, cx);
+            let related_files_for_capture = inputs.related_files.clone();
             if let Some(example_task) = capture_example::capture_example(
                 project.clone(),
                 active_buffer.clone(),
                 position,
                 events_for_capture,
+                related_files_for_capture,
                 false,
                 cx,
             ) {
@@ -2140,6 +2143,21 @@ impl EditPredictionStore {
         })
     }
 
+    fn can_collect_related_files(&self, project: &Entity<Project>, cx: &mut App) -> bool {
+        if !self.data_collection_choice.is_enabled(cx) {
+            return false;
+        }
+
+        let related_with_buffers = self.context_for_project_with_buffers(project, cx);
+
+        related_with_buffers.iter().all(|(_, buffer)| {
+            buffer
+                .read(cx)
+                .file()
+                .is_some_and(|file| self.is_file_open_source(project, &file, cx))
+        })
+    }
+
     fn load_data_collection_choice() -> DataCollectionChoice {
         let choice = KEY_VALUE_STORE
             .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)

crates/edit_prediction/src/example_spec.rs 🔗

@@ -1,10 +1,15 @@
 use anyhow::{Context as _, Result};
 use serde::{Deserialize, Serialize};
-use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc};
+use std::{borrow::Cow, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc};
 
 pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
 pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
 
+/// Maximum cursor file size to capture (64KB).
+/// Files larger than this will not have their content captured,
+/// falling back to git-based loading.
+pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024;
+
 #[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 pub struct ExampleSpec {
     #[serde(default)]
@@ -23,6 +28,70 @@ pub struct ExampleSpec {
     pub expected_patches: Vec<String>,
     #[serde(default, skip_serializing_if = "Option::is_none")]
     pub rejected_patch: Option<String>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub captured_prompt_input: Option<CapturedPromptInput>,
+}
+
+/// All data needed to run format_prompt without loading the project.
+#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
+pub struct CapturedPromptInput {
+    pub cursor_file_content: String,
+    pub cursor_offset: usize,
+    pub cursor_row: u32,
+    pub cursor_column: u32,
+    pub events: Vec<CapturedEvent>,
+    pub related_files: Vec<CapturedRelatedFile>,
+}
+
+#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
+pub struct CapturedEvent {
+    pub path: Arc<Path>,
+    pub old_path: Arc<Path>,
+    pub diff: String,
+    pub predicted: bool,
+    pub in_open_source_repo: bool,
+}
+
+impl CapturedEvent {
+    pub fn to_event(&self) -> zeta_prompt::Event {
+        zeta_prompt::Event::BufferChange {
+            path: self.path.clone(),
+            old_path: self.old_path.clone(),
+            diff: self.diff.clone(),
+            predicted: self.predicted,
+            in_open_source_repo: self.in_open_source_repo,
+        }
+    }
+}
+
+#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
+pub struct CapturedRelatedFile {
+    pub path: Arc<Path>,
+    pub max_row: u32,
+    pub excerpts: Vec<CapturedRelatedExcerpt>,
+}
+
+impl CapturedRelatedFile {
+    pub fn to_related_file(&self) -> zeta_prompt::RelatedFile {
+        zeta_prompt::RelatedFile {
+            path: self.path.clone(),
+            max_row: self.max_row,
+            excerpts: self
+                .excerpts
+                .iter()
+                .map(|e| zeta_prompt::RelatedExcerpt {
+                    row_range: e.row_range.clone(),
+                    text: e.text.clone().into(),
+                })
+                .collect(),
+        }
+    }
+}
+
+#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
+pub struct CapturedRelatedExcerpt {
+    pub row_range: Range<u32>,
+    pub text: String,
 }
 
 const REASONING_HEADING: &str = "Reasoning";
@@ -169,6 +238,7 @@ impl ExampleSpec {
             edit_history: String::new(),
             expected_patches: Vec::new(),
             rejected_patch: None,
+            captured_prompt_input: None,
         };
 
         if let Some(rest) = input.strip_prefix("+++\n")
@@ -415,6 +485,7 @@ mod tests {
             edit_history: String::new(),
             expected_patches: Vec::new(),
             rejected_patch: None,
+            captured_prompt_input: None,
         };
 
         // Cursor before `42`
@@ -548,6 +619,7 @@ mod tests {
             edit_history: String::new(),
             expected_patches: Vec::new(),
             rejected_patch: None,
+            captured_prompt_input: None,
         };
 
         // Cursor before `42` using inline marker

crates/edit_prediction_cli/src/retrieve_context.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{
-    example::Example,
+    example::{Example, ExamplePromptInputs},
     headless::EpAppState,
     load_project::run_load_project,
     progress::{ExampleProgress, InfoStyle, Step, StepProgress},
@@ -28,6 +28,34 @@ pub async fn run_context_retrieval(
         return Ok(());
     }
 
+    if let Some(captured) = &example.spec.captured_prompt_input {
+        let step_progress = example_progress.start(Step::Context);
+        step_progress.set_substatus("using captured prompt input");
+
+        let edit_history: Vec<Arc<zeta_prompt::Event>> = captured
+            .events
+            .iter()
+            .map(|e| Arc::new(e.to_event()))
+            .collect();
+
+        let related_files: Vec<zeta_prompt::RelatedFile> = captured
+            .related_files
+            .iter()
+            .map(|rf| rf.to_related_file())
+            .collect();
+
+        example.prompt_inputs = Some(ExamplePromptInputs {
+            content: captured.cursor_file_content.clone(),
+            cursor_row: captured.cursor_row,
+            cursor_column: captured.cursor_column,
+            cursor_offset: captured.cursor_offset,
+            edit_history,
+            related_files: Some(related_files),
+        });
+
+        return Ok(());
+    }
+
     run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 
     let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();

crates/edit_prediction_cli/src/split_commit.rs 🔗

@@ -371,6 +371,7 @@ pub fn generate_evaluation_example_from_ordered_commit(
         reasoning: None,
         uncommitted_diff: String::new(),
         rejected_patch: None,
+        captured_prompt_input: None,
     })
 }
 
@@ -1402,6 +1403,7 @@ Date: Mon Jan 1 00:00:00 2024
             reasoning: None,
             uncommitted_diff: String::new(),
             rejected_patch: None,
+            captured_prompt_input: None,
         };
 
         let json = serde_json::to_string(&case).unwrap();

crates/edit_prediction_cli/src/synthesize.rs 🔗

@@ -792,6 +792,7 @@ async fn build_example(
         edit_history,
         expected_patches: vec![expected_patch_with_header],
         rejected_patch: None,
+        captured_prompt_input: None,
     };
     spec.set_cursor_excerpt(&excerpt, cursor_offset, comment_prefix);
 

crates/edit_prediction_ui/src/edit_prediction_ui.rs 🔗

@@ -154,7 +154,15 @@ fn capture_example_as_markdown(
     let events = ep_store.update(cx, |store, cx| {
         store.edit_history_for_project_with_pause_split_last_event(&project, cx)
     });
-    let example = capture_example(project.clone(), buffer, cursor_anchor, events, true, cx)?;
+    let example = capture_example(
+        project.clone(),
+        buffer,
+        cursor_anchor,
+        events,
+        Vec::new(),
+        true,
+        cx,
+    )?;
 
     let examples_dir = AllLanguageSettings::get_global(cx)
         .edit_predictions