Detailed changes
@@ -1,6 +1,10 @@
use crate::{
EditPredictionExampleCaptureFeatureFlag, StoredEvent,
- cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
+ cursor_excerpt::editable_and_context_ranges_for_cursor_position,
+ example_spec::{
+ CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile,
+ ExampleSpec, MAX_CURSOR_FILE_SIZE,
+ },
};
use anyhow::Result;
use buffer_diff::BufferDiffSnapshot;
@@ -10,7 +14,7 @@ use gpui::{App, Entity, Task};
use language::{Buffer, ToPoint as _};
use project::{Project, WorktreeId};
use std::{collections::hash_map, fmt::Write as _, ops::Range, path::Path, sync::Arc};
-use text::{BufferSnapshot as TextBufferSnapshot, Point};
+use text::{BufferSnapshot as TextBufferSnapshot, Point, ToOffset as _};
pub(crate) const DEFAULT_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 10;
pub(crate) const DEFAULT_STAFF_EXAMPLE_CAPTURE_RATE_PER_10K_PREDICTIONS: u16 = 100;
@@ -20,6 +24,7 @@ pub fn capture_example(
buffer: Entity<Buffer>,
cursor_anchor: language::Anchor,
mut events: Vec<StoredEvent>,
+ related_files: Vec<zeta_prompt::RelatedFile>,
populate_expected_patch: bool,
cx: &mut App,
) -> Option<Task<Result<ExampleSpec>>> {
@@ -58,7 +63,16 @@ pub fn capture_example(
.and_then(|lang| lang.config().line_comments.first())
.map(|s| s.to_string())
.unwrap_or_default();
- let (cursor_excerpt, cursor_offset, cursor_excerpt_range) = cx
+
+ let full_cursor_offset = cursor_anchor.to_offset(&snapshot);
+ let cursor_point = cursor_anchor.to_point(&snapshot);
+ let cursor_file_content = if snapshot.len() <= MAX_CURSOR_FILE_SIZE {
+ Some(snapshot.text())
+ } else {
+ None
+ };
+
+ let (cursor_excerpt, cursor_offset_in_excerpt, cursor_excerpt_range) = cx
.background_executor()
.spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
.await;
@@ -99,6 +113,53 @@ pub fn capture_example(
rejected_patch = Some(empty_patch);
}
+ let prompt_input = cursor_file_content.map(|content| {
+ let captured_events: Vec<CapturedEvent> = events
+ .iter()
+ .map(|stored_event| {
+ let zeta_prompt::Event::BufferChange {
+ path,
+ old_path,
+ diff,
+ predicted,
+ in_open_source_repo,
+ } = stored_event.event.as_ref();
+ CapturedEvent {
+ path: strip_root_name(path, &root_name).into(),
+ old_path: strip_root_name(old_path, &root_name).into(),
+ diff: diff.clone(),
+ predicted: *predicted,
+ in_open_source_repo: *in_open_source_repo,
+ }
+ })
+ .collect();
+
+ let captured_related_files: Vec<CapturedRelatedFile> = related_files
+ .iter()
+ .map(|rf| CapturedRelatedFile {
+ path: strip_root_name(&rf.path, &root_name).into(),
+ max_row: rf.max_row,
+ excerpts: rf
+ .excerpts
+ .iter()
+ .map(|e| CapturedRelatedExcerpt {
+ row_range: e.row_range.clone(),
+ text: e.text.to_string(),
+ })
+ .collect(),
+ })
+ .collect();
+
+ CapturedPromptInput {
+ cursor_file_content: content,
+ cursor_offset: full_cursor_offset,
+ cursor_row: cursor_point.row,
+ cursor_column: cursor_point.column,
+ events: captured_events,
+ related_files: captured_related_files,
+ }
+ });
+
let mut spec = ExampleSpec {
name: generate_timestamp_name(),
repository_url,
@@ -111,8 +172,13 @@ pub fn capture_example(
edit_history,
expected_patches,
rejected_patch,
+ captured_prompt_input: prompt_input,
};
- spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
+ spec.set_cursor_excerpt(
+ &cursor_excerpt,
+ cursor_offset_in_excerpt,
+ &line_comment_prefix,
+ );
Ok(spec)
}))
}
@@ -414,6 +480,7 @@ mod tests {
buffer.clone(),
Anchor::MIN,
events,
+ Vec::new(),
true,
cx,
)
@@ -530,9 +597,36 @@ mod tests {
}
"}
.to_string()
- )
+ ),
+ captured_prompt_input: example.captured_prompt_input.clone(),
}
);
+
+ let prompt_input = example
+ .captured_prompt_input
+ .expect("should have captured prompt input");
+ assert!(
+ prompt_input.cursor_file_content.contains("fn main()"),
+ "cursor_file_content should contain file content"
+ );
+ assert_eq!(
+ prompt_input.cursor_offset, 0,
+ "cursor at Anchor::MIN should be offset 0"
+ );
+ assert_eq!(
+ prompt_input.cursor_row, 0,
+ "cursor at Anchor::MIN should be row 0"
+ );
+ assert_eq!(
+ prompt_input.cursor_column, 0,
+ "cursor at Anchor::MIN should be column 0"
+ );
+ assert!(prompt_input.events.len() > 0, "should have captured events");
+ assert_eq!(
+ prompt_input.related_files.len(),
+ 0,
+ "should have no related files (none passed)"
+ );
}
fn init_test(cx: &mut TestAppContext) {
@@ -1738,16 +1738,19 @@ impl EditPredictionStore {
let can_collect_example = snapshot
.file()
.is_some_and(|file| self.can_collect_file(&project, file, cx))
- && self.can_collect_events(&inputs.events, cx);
+ && self.can_collect_events(&inputs.events, cx)
+ && self.can_collect_related_files(&project, cx);
if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
let events_for_capture =
self.edit_history_for_project_with_pause_split_last_event(&project, cx);
+ let related_files_for_capture = inputs.related_files.clone();
if let Some(example_task) = capture_example::capture_example(
project.clone(),
active_buffer.clone(),
position,
events_for_capture,
+ related_files_for_capture,
false,
cx,
) {
@@ -2140,6 +2143,21 @@ impl EditPredictionStore {
})
}
+ fn can_collect_related_files(&self, project: &Entity<Project>, cx: &mut App) -> bool {
+ if !self.data_collection_choice.is_enabled(cx) {
+ return false;
+ }
+
+ let related_with_buffers = self.context_for_project_with_buffers(project, cx);
+
+ related_with_buffers.iter().all(|(_, buffer)| {
+ buffer
+ .read(cx)
+ .file()
+ .is_some_and(|file| self.is_file_open_source(project, &file, cx))
+ })
+ }
+
fn load_data_collection_choice() -> DataCollectionChoice {
let choice = KEY_VALUE_STORE
.read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
@@ -1,10 +1,15 @@
use anyhow::{Context as _, Result};
use serde::{Deserialize, Serialize};
-use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc};
+use std::{borrow::Cow, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc};
pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
+/// Maximum cursor file size to capture (64KB).
+/// Files larger than this will not have their content captured,
+/// falling back to git-based loading.
+pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024;
+
#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
pub struct ExampleSpec {
#[serde(default)]
@@ -23,6 +28,70 @@ pub struct ExampleSpec {
pub expected_patches: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rejected_patch: Option<String>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub captured_prompt_input: Option<CapturedPromptInput>,
+}
+
+/// All data needed to run format_prompt without loading the project.
+#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
+pub struct CapturedPromptInput {
+ pub cursor_file_content: String,
+ pub cursor_offset: usize,
+ pub cursor_row: u32,
+ pub cursor_column: u32,
+ pub events: Vec<CapturedEvent>,
+ pub related_files: Vec<CapturedRelatedFile>,
+}
+
+#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
+pub struct CapturedEvent {
+ pub path: Arc<Path>,
+ pub old_path: Arc<Path>,
+ pub diff: String,
+ pub predicted: bool,
+ pub in_open_source_repo: bool,
+}
+
+impl CapturedEvent {
+ pub fn to_event(&self) -> zeta_prompt::Event {
+ zeta_prompt::Event::BufferChange {
+ path: self.path.clone(),
+ old_path: self.old_path.clone(),
+ diff: self.diff.clone(),
+ predicted: self.predicted,
+ in_open_source_repo: self.in_open_source_repo,
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
+pub struct CapturedRelatedFile {
+ pub path: Arc<Path>,
+ pub max_row: u32,
+ pub excerpts: Vec<CapturedRelatedExcerpt>,
+}
+
+impl CapturedRelatedFile {
+ pub fn to_related_file(&self) -> zeta_prompt::RelatedFile {
+ zeta_prompt::RelatedFile {
+ path: self.path.clone(),
+ max_row: self.max_row,
+ excerpts: self
+ .excerpts
+ .iter()
+ .map(|e| zeta_prompt::RelatedExcerpt {
+ row_range: e.row_range.clone(),
+ text: e.text.clone().into(),
+ })
+ .collect(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
+pub struct CapturedRelatedExcerpt {
+ pub row_range: Range<u32>,
+ pub text: String,
}
const REASONING_HEADING: &str = "Reasoning";
@@ -169,6 +238,7 @@ impl ExampleSpec {
edit_history: String::new(),
expected_patches: Vec::new(),
rejected_patch: None,
+ captured_prompt_input: None,
};
if let Some(rest) = input.strip_prefix("+++\n")
@@ -415,6 +485,7 @@ mod tests {
edit_history: String::new(),
expected_patches: Vec::new(),
rejected_patch: None,
+ captured_prompt_input: None,
};
// Cursor before `42`
@@ -548,6 +619,7 @@ mod tests {
edit_history: String::new(),
expected_patches: Vec::new(),
rejected_patch: None,
+ captured_prompt_input: None,
};
// Cursor before `42` using inline marker
@@ -1,5 +1,5 @@
use crate::{
- example::Example,
+ example::{Example, ExamplePromptInputs},
headless::EpAppState,
load_project::run_load_project,
progress::{ExampleProgress, InfoStyle, Step, StepProgress},
@@ -28,6 +28,34 @@ pub async fn run_context_retrieval(
return Ok(());
}
+ if let Some(captured) = &example.spec.captured_prompt_input {
+ let step_progress = example_progress.start(Step::Context);
+ step_progress.set_substatus("using captured prompt input");
+
+ let edit_history: Vec<Arc<zeta_prompt::Event>> = captured
+ .events
+ .iter()
+ .map(|e| Arc::new(e.to_event()))
+ .collect();
+
+ let related_files: Vec<zeta_prompt::RelatedFile> = captured
+ .related_files
+ .iter()
+ .map(|rf| rf.to_related_file())
+ .collect();
+
+ example.prompt_inputs = Some(ExamplePromptInputs {
+ content: captured.cursor_file_content.clone(),
+ cursor_row: captured.cursor_row,
+ cursor_column: captured.cursor_column,
+ cursor_offset: captured.cursor_offset,
+ edit_history,
+ related_files: Some(related_files),
+ });
+
+ return Ok(());
+ }
+
run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
@@ -371,6 +371,7 @@ pub fn generate_evaluation_example_from_ordered_commit(
reasoning: None,
uncommitted_diff: String::new(),
rejected_patch: None,
+ captured_prompt_input: None,
})
}
@@ -1402,6 +1403,7 @@ Date: Mon Jan 1 00:00:00 2024
reasoning: None,
uncommitted_diff: String::new(),
rejected_patch: None,
+ captured_prompt_input: None,
};
let json = serde_json::to_string(&case).unwrap();
@@ -792,6 +792,7 @@ async fn build_example(
edit_history,
expected_patches: vec![expected_patch_with_header],
rejected_patch: None,
+ captured_prompt_input: None,
};
spec.set_cursor_excerpt(&excerpt, cursor_offset, comment_prefix);
@@ -154,7 +154,15 @@ fn capture_example_as_markdown(
let events = ep_store.update(cx, |store, cx| {
store.edit_history_for_project_with_pause_split_last_event(&project, cx)
});
- let example = capture_example(project.clone(), buffer, cursor_anchor, events, true, cx)?;
+ let example = capture_example(
+ project.clone(),
+ buffer,
+ cursor_anchor,
+ events,
+ Vec::new(),
+ true,
+ cx,
+ )?;
let examples_dir = AllLanguageSettings::get_global(cx)
.edit_predictions