diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs index 1c199be39b7b004bf47ad3e152e264c53efda73b..39eb1cbf45089e871be50884a682c8647a917f7f 100644 --- a/crates/edit_prediction/src/capture_example.rs +++ b/crates/edit_prediction/src/capture_example.rs @@ -1,6 +1,10 @@ use crate::{ EditPredictionExampleCaptureFeatureFlag, StoredEvent, - cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec, + cursor_excerpt::editable_and_context_ranges_for_cursor_position, + example_spec::{ + CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, + ExampleSpec, MAX_CURSOR_FILE_SIZE, + }, }; use anyhow::Result; use buffer_diff::BufferDiffSnapshot; @@ -10,7 +14,7 @@ use gpui::{App, Entity, Task}; use language::{Buffer, ToPoint as _}; use project::{Project, WorktreeId}; use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc}; -use text::{BufferSnapshot as TextBufferSnapshot, Point}; +use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _}; pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10; pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100; @@ -20,6 +24,7 @@ pub fn capture_example( buffer: Entity, cursor_anchor: language::Anchor, mut events: Vec, + related_files: Vec, populate_expected_patch: bool, cx: &mut App, ) -> Option>> { @@ -58,7 +63,16 @@ pub fn capture_example( .and_then(|lang| lang.config().line_comments.first()) .map(|s| s.to_string()) .unwrap_or_default(); - let (cursor_excerpt, cursor_offset, cursor_excerpt_range) = cx + + let full_cursor_offset = cursor_anchor.to_offset(&snapshot); + let cursor_point = cursor_anchor.to_point(&snapshot); + let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE { + Some(snapshot.text()) + } else { + None + }; + + let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx .background_executor() .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) }) .await; @@ -99,6 +113,53 @@ pub fn capture_example( rejected_patch = Some(empty_patch); } + let prompt_input = cursor_file_content.map(|content| { + let captured_events: Vec = events + .iter() + .map(|stored_event| { + let zeta_prompt::Event::BufferChange { + path, + old_path, + diff, + predicted, + in_open_source_repo, + } = stored_event.event.as_ref(); + CapturedEvent { + path: strip_root_name(path, &root_name).into(), + old_path: strip_root_name(old_path, &root_name).into(), + diff: diff.clone(), + predicted: *predicted, + in_open_source_repo: *in_open_source_repo, + } + }) + .collect(); + + let captured_related_files: Vec = related_files + .iter() + .map(|rf| CapturedRelatedFile { + path: strip_root_name(&rf.path, &root_name).into(), + max_row: rf.max_row, + excerpts: rf + .excerpts + .iter() + .map(|e| CapturedRelatedExcerpt { + row_range: e.row_range.clone(), + text: e.text.to_string(), + }) + .collect(), + }) + .collect(); + + CapturedPromptInput { + cursor_file_content: content, + cursor_offset: full_cursor_offset, + cursor_row: cursor_point.row, + cursor_column: cursor_point.column, + events: captured_events, + related_files: captured_related_files, + } + }); + let mut spec = ExampleSpec { name: generate_timestamp_name(), repository_url, @@ -111,8 +172,13 @@ pub fn capture_example( edit_history, expected_patches, rejected_patch, + captured_prompt_input: prompt_input, }; - spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix); + spec.set_cursor_excerpt( + &cursor_excerpt, + cursor_offset_in_excerpt, + &line_comment_prefix, + ); Ok(spec) })) } @@ -414,6 +480,7 @@ mod tests { buffer.clone(), Anchor::MIN, events, + Vec::new(), true, cx, ) @@ -530,9 +597,36 @@ mod tests { } "} .to_string() - ) + ), + captured_prompt_input: example.captured_prompt_input.clone(), } ); + + let prompt_input = example + .captured_prompt_input + .expect("should have captured prompt input"); + assert!( + prompt_input.cursor_file_content.contains("fn main()"), + "cursor_file_content should contain file content" + ); + assert_eq!( + prompt_input.cursor_offset, 0, + "cursor at Anchor::MIN should be offset 0" + ); + assert_eq!( + prompt_input.cursor_row, 0, + "cursor at Anchor::MIN should be row 0" + ); + assert_eq!( + prompt_input.cursor_column, 0, + "cursor at Anchor::MIN should be column 0" + ); + assert!(prompt_input.events.len() > 0, "should have captured events"); + assert_eq!( + prompt_input.related_files.len(), + 0, + "should have no related files (none passed)" + ); } fn init_test(cx: &mut TestAppContext) { diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 38bccbe65bddc1ce1763a0d362a52c0db09be69a..6a04ab62bc94257cbeb8618a7eadd85970f20a5b 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1738,16 +1738,19 @@ impl EditPredictionStore { let can_collect_example = snapshot .file() .is_some_and(|file| self.can_collect_file(&project, file, cx)) - && self.can_collect_events(&inputs.events, cx); + && self.can_collect_events(&inputs.events, cx) + && self.can_collect_related_files(&project, cx); if can_collect_example && should_sample_edit_prediction_example_capture(cx) { let events_for_capture = self.edit_history_for_project_with_pause_split_last_event(&project, cx); + let related_files_for_capture = inputs.related_files.clone(); if let Some(example_task) = capture_example::capture_example( project.clone(), active_buffer.clone(), position, events_for_capture, + related_files_for_capture, false, cx, ) { @@ -2140,6 +2143,21 @@ impl EditPredictionStore { }) } + fn can_collect_related_files(&self, project: &Entity, cx: &mut App) -> bool { + if !self.data_collection_choice.is_enabled(cx) { + return false; + } + + let related_with_buffers = self.context_for_project_with_buffers(project, cx); + + related_with_buffers.iter().all(|(_, buffer)| { + buffer + .read(cx) + .file() + .is_some_and(|file| self.is_file_open_source(project, &file, cx)) + }) + } + fn load_data_collection_choice() -> DataCollectionChoice { let choice = KEY_VALUE_STORE .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index 80d0ba5732633eefff84bdb4ba02795aa1ca70b0..09ef97dffc60d1eda26c292e943c5619eb0bda39 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -1,10 +1,15 @@ use anyhow::{Context as _, Result}; use serde::{Deserialize, Serialize}; -use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc}; +use std::{borrow::Cow, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc}; pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]"; pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>"; +/// Maximum cursor file size to capture (64KB). +/// Files larger than this will not have their content captured, +/// falling back to git-based loading. +pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024; + #[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] pub struct ExampleSpec { #[serde(default)] @@ -23,6 +28,70 @@ pub struct ExampleSpec { pub expected_patches: Vec, #[serde(default, skip_serializing_if = "Option::is_none")] pub rejected_patch: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub captured_prompt_input: Option, +} + +/// All data needed to run format_prompt without loading the project. +#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] +pub struct CapturedPromptInput { + pub cursor_file_content: String, + pub cursor_offset: usize, + pub cursor_row: u32, + pub cursor_column: u32, + pub events: Vec, + pub related_files: Vec, +} + +#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] +pub struct CapturedEvent { + pub path: Arc, + pub old_path: Arc, + pub diff: String, + pub predicted: bool, + pub in_open_source_repo: bool, +} + +impl CapturedEvent { + pub fn to_event(&self) -> zeta_prompt::Event { + zeta_prompt::Event::BufferChange { + path: self.path.clone(), + old_path: self.old_path.clone(), + diff: self.diff.clone(), + predicted: self.predicted, + in_open_source_repo: self.in_open_source_repo, + } + } +} + +#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] +pub struct CapturedRelatedFile { + pub path: Arc, + pub max_row: u32, + pub excerpts: Vec, +} + +impl CapturedRelatedFile { + pub fn to_related_file(&self) -> zeta_prompt::RelatedFile { + zeta_prompt::RelatedFile { + path: self.path.clone(), + max_row: self.max_row, + excerpts: self + .excerpts + .iter() + .map(|e| zeta_prompt::RelatedExcerpt { + row_range: e.row_range.clone(), + text: e.text.clone().into(), + }) + .collect(), + } + } +} + +#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] +pub struct CapturedRelatedExcerpt { + pub row_range: Range, + pub text: String, } const REASONING_HEADING: &str = "Reasoning"; @@ -169,6 +238,7 @@ impl ExampleSpec { edit_history: String::new(), expected_patches: Vec::new(), rejected_patch: None, + captured_prompt_input: None, }; if let Some(rest) = input.strip_prefix("+++\n") @@ -415,6 +485,7 @@ mod tests { edit_history: String::new(), expected_patches: Vec::new(), rejected_patch: None, + captured_prompt_input: None, }; // Cursor before `42` @@ -548,6 +619,7 @@ mod tests { edit_history: String::new(), expected_patches: Vec::new(), rejected_patch: None, + captured_prompt_input: None, }; // Cursor before `42` using inline marker diff --git a/crates/edit_prediction_cli/src/retrieve_context.rs b/crates/edit_prediction_cli/src/retrieve_context.rs index 6f3fafa91b7c67d11c6a2990e6039f4c7f40c0ff..c2b8e6974411560ec216de4fc341e0c7310e3368 100644 --- a/crates/edit_prediction_cli/src/retrieve_context.rs +++ b/crates/edit_prediction_cli/src/retrieve_context.rs @@ -1,5 +1,5 @@ use crate::{ - example::Example, + example::{Example, ExamplePromptInputs}, headless::EpAppState, load_project::run_load_project, progress::{ExampleProgress, InfoStyle, Step, StepProgress}, @@ -28,6 +28,34 @@ pub async fn run_context_retrieval( return Ok(()); } + if let Some(captured) = &example.spec.captured_prompt_input { + let step_progress = example_progress.start(Step::Context); + step_progress.set_substatus("using captured prompt input"); + + let edit_history: Vec> = captured + .events + .iter() + .map(|e| Arc::new(e.to_event())) + .collect(); + + let related_files: Vec = captured + .related_files + .iter() + .map(|rf| rf.to_related_file()) + .collect(); + + example.prompt_inputs = Some(ExamplePromptInputs { + content: captured.cursor_file_content.clone(), + cursor_row: captured.cursor_row, + cursor_column: captured.cursor_column, + cursor_offset: captured.cursor_offset, + edit_history, + related_files: Some(related_files), + }); + + return Ok(()); + } + run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?; let step_progress: Arc = example_progress.start(Step::Context).into(); diff --git a/crates/edit_prediction_cli/src/split_commit.rs b/crates/edit_prediction_cli/src/split_commit.rs index b411be9b8e3a0d5f63472b13ce340727e48b2b3a..4a034f0c35b56266aa67dffe53c6c178f15bcfb0 100644 --- a/crates/edit_prediction_cli/src/split_commit.rs +++ b/crates/edit_prediction_cli/src/split_commit.rs @@ -371,6 +371,7 @@ pub fn generate_evaluation_example_from_ordered_commit( reasoning: None, uncommitted_diff: String::new(), rejected_patch: None, + captured_prompt_input: None, }) } @@ -1402,6 +1403,7 @@ Date: Mon Jan 1 00:00:00 2024 reasoning: None, uncommitted_diff: String::new(), rejected_patch: None, + captured_prompt_input: None, }; let json = serde_json::to_string(&case).unwrap(); diff --git a/crates/edit_prediction_cli/src/synthesize.rs b/crates/edit_prediction_cli/src/synthesize.rs index 0577108400d0fa09126b49195079eb1c4625d9d1..1d7b4eb874fc099b6a898d60be683e358a96b55b 100644 --- a/crates/edit_prediction_cli/src/synthesize.rs +++ b/crates/edit_prediction_cli/src/synthesize.rs @@ -792,6 +792,7 @@ async fn build_example( edit_history, expected_patches: vec![expected_patch_with_header], rejected_patch: None, + captured_prompt_input: None, }; spec.set_cursor_excerpt(&excerpt, cursor_offset, comment_prefix); diff --git a/crates/edit_prediction_ui/src/edit_prediction_ui.rs b/crates/edit_prediction_ui/src/edit_prediction_ui.rs index 2ca852a0140651b515734dd144c868bfebe04328..2ac16ba81370884b20e5fb869bbb1b7cc2c4545c 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_ui.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_ui.rs @@ -154,7 +154,15 @@ fn capture_example_as_markdown( let events = ep_store.update(cx, |store, cx| { store.edit_history_for_project_with_pause_split_last_event(&project, cx) }); - let example = capture_example(project.clone(), buffer, cursor_anchor, events, true, cx)?; + let example = capture_example( + project.clone(), + buffer, + cursor_anchor, + events, + Vec::new(), + true, + cx, + )?; let examples_dir = AllLanguageSettings::get_global(cx) .edit_predictions