From b8e40e6fdb61fc108f2db7372b3a38655b101875 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Sun, 14 Dec 2025 20:50:48 -0800 Subject: [PATCH] Add an action for capturing your last edit as an edit prediction example (#44841) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds a staff-only button to the edit prediction menu for capturing your current editing session as edit prediction example file. When you click that button, it opens a markdown tab with the example. By default, the most recent change that you've made is used as the expected patch, and all of the previous events are used as the editing history. Screenshot 2025-12-14 at 6 58 33 PM Release Notes: - N/A --- Cargo.lock | 5 +- crates/edit_prediction/Cargo.toml | 1 + crates/edit_prediction/src/edit_prediction.rs | 146 +++++++++--- .../src/edit_prediction_tests.rs | 97 +++++++- crates/edit_prediction/src/example_spec.rs | 212 ++++++++++++++++++ crates/edit_prediction_cli/Cargo.toml | 1 - crates/edit_prediction_cli/src/distill.rs | 2 +- crates/edit_prediction_cli/src/example.rs | 169 ++------------ .../edit_prediction_cli/src/format_prompt.rs | 19 +- .../edit_prediction_cli/src/load_project.rs | 42 ++-- crates/edit_prediction_cli/src/main.rs | 8 +- crates/edit_prediction_cli/src/predict.rs | 6 +- .../src/retrieve_context.rs | 2 +- crates/edit_prediction_cli/src/score.rs | 6 +- crates/edit_prediction_ui/Cargo.toml | 3 + .../src/edit_prediction_button.rs | 12 +- .../src/edit_prediction_ui.rs | 208 ++++++++++++++++- 17 files changed, 711 insertions(+), 228 deletions(-) create mode 100644 crates/edit_prediction/src/example_spec.rs diff --git a/Cargo.lock b/Cargo.lock index 436da4aef8c0849a61336a9645639c17da731029..dd57996c7ef6dd711c1e67725d1bdfd86d277729 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5130,6 +5130,7 @@ dependencies = [ "postage", "pretty_assertions", "project", + "pulldown-cmark 0.12.2", "rand 0.9.2", "regex", "release_channel", @@ -5184,7 +5185,6 @@ dependencies = [ "pretty_assertions", "project", "prompt_store", - "pulldown-cmark 0.12.2", "release_channel", "reqwest_client", "serde", @@ -5256,9 +5256,11 @@ dependencies = [ "feature_flags", "fs", "futures 0.3.31", + "git", "gpui", "indoc", "language", + "log", "lsp", "markdown", "menu", @@ -5272,6 +5274,7 @@ dependencies = [ "telemetry", "text", "theme", + "time", "ui", "util", "workspace", diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index 5f1799e2dc4bb5460a900664472ad33e3035d4f1..2d5fb36a581f7bd17bb76f79791c276c86c9c631 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -41,6 +41,7 @@ open_ai.workspace = true postage.workspace = true pretty_assertions.workspace = true project.workspace = true +pulldown-cmark.workspace = true rand.workspace = true regex.workspace = true release_channel.workspace = true diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 8b96466667bbac8fba92549487821f0d450670ac..ff15d04cc1c0f8e7bbeb7f2a29b520a8ec32097a 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -25,7 +25,7 @@ use gpui::{ prelude::*, }; use language::language_settings::all_language_settings; -use language::{Anchor, Buffer, File, Point, ToPoint}; +use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint}; use language::{BufferSnapshot, OffsetRangeExt}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use project::{Project, ProjectPath, WorktreeId}; @@ -47,7 +47,8 @@ use thiserror::Error; use util::{RangeExt as _, ResultExt as _}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; -mod cursor_excerpt; +pub mod cursor_excerpt; +pub mod example_spec; mod license_detection; pub mod mercury; mod onboarding_modal; @@ -89,6 +90,7 @@ actions!( /// Maximum number of events to track. const EVENT_COUNT_MAX: usize = 6; const CHANGE_GROUPING_LINE_SPAN: u32 = 8; +const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1); const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice"; const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15); @@ -265,6 +267,19 @@ impl ProjectState { .collect() } + pub fn events_split_by_pause(&self, cx: &App) -> Vec> { + self.events + .iter() + .cloned() + .chain(self.last_event.as_ref().iter().flat_map(|event| { + let (one, two) = event.split_by_pause(); + let one = one.finalize(&self.license_detection_watchers, cx); + let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx)); + one.into_iter().chain(two) + })) + .collect() + } + fn cancel_pending_prediction( &mut self, pending_prediction: PendingPrediction, @@ -385,15 +400,21 @@ impl std::ops::Deref for BufferEditPrediction<'_> { } struct RegisteredBuffer { - snapshot: BufferSnapshot, + file: Option>, + snapshot: TextBufferSnapshot, last_position: Option, _subscriptions: [gpui::Subscription; 2], } +#[derive(Clone)] struct LastEvent { - old_snapshot: BufferSnapshot, - new_snapshot: BufferSnapshot, + old_snapshot: TextBufferSnapshot, + new_snapshot: TextBufferSnapshot, + old_file: Option>, + new_file: Option>, end_edit_anchor: Option, + snapshot_after_last_editing_pause: Option, + last_edit_time: Option, } impl LastEvent { @@ -402,19 +423,19 @@ impl LastEvent { license_detection_watchers: &HashMap>, cx: &App, ) -> Option> { - let path = buffer_path_with_id_fallback(&self.new_snapshot, cx); - let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx); - - let file = self.new_snapshot.file(); - let old_file = self.old_snapshot.file(); - - let in_open_source_repo = [file, old_file].iter().all(|file| { - file.is_some_and(|file| { - license_detection_watchers - .get(&file.worktree_id(cx)) - .is_some_and(|watcher| watcher.is_project_open_source()) - }) - }); + let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx); + let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx); + + let in_open_source_repo = + [self.new_file.as_ref(), self.old_file.as_ref()] + .iter() + .all(|file| { + file.is_some_and(|file| { + license_detection_watchers + .get(&file.worktree_id(cx)) + .is_some_and(|watcher| watcher.is_project_open_source()) + }) + }); let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text()); @@ -431,10 +452,42 @@ impl LastEvent { })) } } + + pub fn split_by_pause(&self) -> (LastEvent, Option) { + let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else { + return (self.clone(), None); + }; + + let before = LastEvent { + old_snapshot: self.old_snapshot.clone(), + new_snapshot: boundary_snapshot.clone(), + old_file: self.old_file.clone(), + new_file: self.new_file.clone(), + end_edit_anchor: self.end_edit_anchor, + snapshot_after_last_editing_pause: None, + last_edit_time: self.last_edit_time, + }; + + let after = LastEvent { + old_snapshot: boundary_snapshot.clone(), + new_snapshot: self.new_snapshot.clone(), + old_file: self.old_file.clone(), + new_file: self.new_file.clone(), + end_edit_anchor: self.end_edit_anchor, + snapshot_after_last_editing_pause: None, + last_edit_time: self.last_edit_time, + }; + + (before, Some(after)) + } } -fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc { - if let Some(file) = snapshot.file() { +fn buffer_path_with_id_fallback( + file: Option<&Arc>, + snapshot: &TextBufferSnapshot, + cx: &App, +) -> Arc { + if let Some(file) = file { file.full_path(cx).into() } else { Path::new(&format!("untitled-{}", snapshot.remote_id())).into() @@ -585,6 +638,17 @@ impl EditPredictionStore { .unwrap_or_default() } + pub fn edit_history_for_project_with_pause_split_last_event( + &self, + project: &Entity, + cx: &App, + ) -> Vec> { + self.projects + .get(&project.entity_id()) + .map(|project_state| project_state.events_split_by_pause(cx)) + .unwrap_or_default() + } + pub fn context_for_project<'a>( &'a self, project: &Entity, @@ -802,10 +866,13 @@ impl EditPredictionStore { match project_state.registered_buffers.entry(buffer_id) { hash_map::Entry::Occupied(entry) => entry.into_mut(), hash_map::Entry::Vacant(entry) => { - let snapshot = buffer.read(cx).snapshot(); + let buf = buffer.read(cx); + let snapshot = buf.text_snapshot(); + let file = buf.file().cloned(); let project_entity_id = project.entity_id(); entry.insert(RegisteredBuffer { snapshot, + file, last_position: None, _subscriptions: [ cx.subscribe(buffer, { @@ -840,11 +907,14 @@ impl EditPredictionStore { let project_state = self.get_or_init_project(project, cx); let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx); - let new_snapshot = buffer.read(cx).snapshot(); + let buf = buffer.read(cx); + let new_file = buf.file().cloned(); + let new_snapshot = buf.text_snapshot(); if new_snapshot.version == registered_buffer.snapshot.version { return; } + let old_file = mem::replace(&mut registered_buffer.file, new_file.clone()); let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); let end_edit_anchor = new_snapshot .anchored_edits_since::(&old_snapshot.version) @@ -852,20 +922,16 @@ impl EditPredictionStore { .map(|(_, range)| range.end); let events = &mut project_state.events; - if let Some(LastEvent { - new_snapshot: last_new_snapshot, - end_edit_anchor: last_end_edit_anchor, - .. - }) = project_state.last_event.as_mut() - { + let now = cx.background_executor().now(); + if let Some(last_event) = project_state.last_event.as_mut() { let is_next_snapshot_of_same_buffer = old_snapshot.remote_id() - == last_new_snapshot.remote_id() - && old_snapshot.version == last_new_snapshot.version; + == last_event.new_snapshot.remote_id() + && old_snapshot.version == last_event.new_snapshot.version; let should_coalesce = is_next_snapshot_of_same_buffer && end_edit_anchor .as_ref() - .zip(last_end_edit_anchor.as_ref()) + .zip(last_event.end_edit_anchor.as_ref()) .is_some_and(|(a, b)| { let a = a.to_point(&new_snapshot); let b = b.to_point(&new_snapshot); @@ -873,8 +939,18 @@ impl EditPredictionStore { }); if should_coalesce { - *last_end_edit_anchor = end_edit_anchor; - *last_new_snapshot = new_snapshot; + let pause_elapsed = last_event + .last_edit_time + .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME) + .unwrap_or(false); + if pause_elapsed { + last_event.snapshot_after_last_editing_pause = + Some(last_event.new_snapshot.clone()); + } + + last_event.end_edit_anchor = end_edit_anchor; + last_event.new_snapshot = new_snapshot; + last_event.last_edit_time = Some(now); return; } } @@ -888,9 +964,13 @@ impl EditPredictionStore { } project_state.last_event = Some(LastEvent { + old_file, + new_file, old_snapshot, new_snapshot, end_edit_anchor, + snapshot_after_last_editing_pause: None, + last_edit_time: Some(now), }); } diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 9e4baa78ef4564ce4348ef1b51085ba0a6abdffc..5067aa0050d7a0831ca7668d17188fa6d41637b9 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -304,11 +304,102 @@ async fn test_request_events(cx: &mut TestAppContext) { let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap(); assert_eq!(prediction.edits.len(), 1); + assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); +} + +#[gpui::test] +async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContext) { + let (ep_store, _requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\n\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + ep_store.update(cx, |ep_store, cx| { + ep_store.register_buffer(&buffer, &project, cx); + }); + + // First burst: insert "How" + buffer.update(cx, |buffer, cx| { + buffer.edit(vec![(7..7, "How")], None, cx); + }); + + // Simulate a pause longer than the grouping threshold (e.g. 500ms). + cx.executor().advance_clock(LAST_CHANGE_GROUPING_TIME * 2); + cx.run_until_parked(); + + // Second burst: append " are you?" immediately after "How" on the same line. + // + // Keeping both bursts on the same line ensures the existing line-span coalescing logic + // groups them into a single `LastEvent`, allowing the pause-split getter to return two diffs. + buffer.update(cx, |buffer, cx| { + buffer.edit(vec![(10..10, " are you?")], None, cx); + }); + + // A second edit shortly after the first post-pause edit ensures the last edit timestamp is + // advanced after the pause boundary is recorded, making pause-splitting deterministic. + buffer.update(cx, |buffer, cx| { + buffer.edit(vec![(19..19, "!")], None, cx); + }); + + // Without time-based splitting, there is one event. + let events = ep_store.update(cx, |ep_store, cx| { + ep_store.edit_history_for_project(&project, cx) + }); + assert_eq!(events.len(), 1); + let zeta_prompt::Event::BufferChange { diff, .. } = events[0].as_ref(); assert_eq!( - prediction.edits[0].0.to_point(&snapshot).start, - language::Point::new(1, 3) + diff.as_str(), + indoc! {" + @@ -1,3 +1,3 @@ + Hello! + - + +How are you?! + Bye + "} + ); + + // With time-based splitting, there are two distinct events. + let events = ep_store.update(cx, |ep_store, cx| { + ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx) + }); + assert_eq!(events.len(), 2); + let zeta_prompt::Event::BufferChange { diff, .. } = events[0].as_ref(); + assert_eq!( + diff.as_str(), + indoc! {" + @@ -1,3 +1,3 @@ + Hello! + - + +How + Bye + "} + ); + + let zeta_prompt::Event::BufferChange { diff, .. } = events[1].as_ref(); + assert_eq!( + diff.as_str(), + indoc! {" + @@ -1,3 +1,3 @@ + Hello! + -How + +How are you?! + Bye + "} ); - assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); } #[gpui::test] diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs new file mode 100644 index 0000000000000000000000000000000000000000..bf221b576b890f1200c4ee3c095f73edaea71462 --- /dev/null +++ b/crates/edit_prediction/src/example_spec.rs @@ -0,0 +1,212 @@ +use serde::{Deserialize, Serialize}; +use std::{fmt::Write as _, mem, path::Path, sync::Arc}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ExampleSpec { + #[serde(default)] + pub name: String, + pub repository_url: String, + pub revision: String, + #[serde(default)] + pub uncommitted_diff: String, + pub cursor_path: Arc, + pub cursor_position: String, + pub edit_history: String, + pub expected_patch: String, +} + +const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff"; +const EDIT_HISTORY_HEADING: &str = "Edit History"; +const CURSOR_POSITION_HEADING: &str = "Cursor Position"; +const EXPECTED_PATCH_HEADING: &str = "Expected Patch"; +const EXPECTED_CONTEXT_HEADING: &str = "Expected Context"; +const REPOSITORY_URL_FIELD: &str = "repository_url"; +const REVISION_FIELD: &str = "revision"; + +impl ExampleSpec { + /// Format this example spec as markdown. + pub fn to_markdown(&self) -> String { + let mut markdown = String::new(); + + _ = writeln!(markdown, "# {}", self.name); + markdown.push('\n'); + + _ = writeln!(markdown, "repository_url = {}", self.repository_url); + _ = writeln!(markdown, "revision = {}", self.revision); + markdown.push('\n'); + + if !self.uncommitted_diff.is_empty() { + _ = writeln!(markdown, "## {}", UNCOMMITTED_DIFF_HEADING); + _ = writeln!(markdown); + _ = writeln!(markdown, "```diff"); + markdown.push_str(&self.uncommitted_diff); + if !markdown.ends_with('\n') { + markdown.push('\n'); + } + _ = writeln!(markdown, "```"); + markdown.push('\n'); + } + + _ = writeln!(markdown, "## {}", EDIT_HISTORY_HEADING); + _ = writeln!(markdown); + + if self.edit_history.is_empty() { + _ = writeln!(markdown, "(No edit history)"); + _ = writeln!(markdown); + } else { + _ = writeln!(markdown, "```diff"); + markdown.push_str(&self.edit_history); + if !markdown.ends_with('\n') { + markdown.push('\n'); + } + _ = writeln!(markdown, "```"); + markdown.push('\n'); + } + + _ = writeln!(markdown, "## {}", CURSOR_POSITION_HEADING); + _ = writeln!(markdown); + _ = writeln!(markdown, "```{}", self.cursor_path.to_string_lossy()); + markdown.push_str(&self.cursor_position); + if !markdown.ends_with('\n') { + markdown.push('\n'); + } + _ = writeln!(markdown, "```"); + markdown.push('\n'); + + _ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING); + markdown.push('\n'); + _ = writeln!(markdown, "```diff"); + markdown.push_str(&self.expected_patch); + if !markdown.ends_with('\n') { + markdown.push('\n'); + } + _ = writeln!(markdown, "```"); + markdown.push('\n'); + + markdown + } + + /// Parse an example spec from markdown. + pub fn from_markdown(name: String, input: &str) -> anyhow::Result { + use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd}; + + let parser = Parser::new(input); + + let mut spec = ExampleSpec { + name, + repository_url: String::new(), + revision: String::new(), + uncommitted_diff: String::new(), + cursor_path: Path::new("").into(), + cursor_position: String::new(), + edit_history: String::new(), + expected_patch: String::new(), + }; + + let mut text = String::new(); + let mut block_info: CowStr = "".into(); + + #[derive(PartialEq)] + enum Section { + Start, + UncommittedDiff, + EditHistory, + CursorPosition, + ExpectedExcerpts, + ExpectedPatch, + Other, + } + + let mut current_section = Section::Start; + + for event in parser { + match event { + Event::Text(line) => { + text.push_str(&line); + + if let Section::Start = current_section + && let Some((field, value)) = line.split_once('=') + { + match field.trim() { + REPOSITORY_URL_FIELD => { + spec.repository_url = value.trim().to_string(); + } + REVISION_FIELD => { + spec.revision = value.trim().to_string(); + } + _ => {} + } + } + } + Event::End(TagEnd::Heading(HeadingLevel::H2)) => { + let title = mem::take(&mut text); + current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) { + Section::UncommittedDiff + } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) { + Section::EditHistory + } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) { + Section::CursorPosition + } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) { + Section::ExpectedPatch + } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) { + Section::ExpectedExcerpts + } else { + Section::Other + }; + } + Event::End(TagEnd::Heading(HeadingLevel::H3)) => { + mem::take(&mut text); + } + Event::End(TagEnd::Heading(HeadingLevel::H4)) => { + mem::take(&mut text); + } + Event::End(TagEnd::Heading(level)) => { + anyhow::bail!("Unexpected heading level: {level}"); + } + Event::Start(Tag::CodeBlock(kind)) => { + match kind { + CodeBlockKind::Fenced(info) => { + block_info = info; + } + CodeBlockKind::Indented => { + anyhow::bail!("Unexpected indented codeblock"); + } + }; + } + Event::Start(_) => { + text.clear(); + block_info = "".into(); + } + Event::End(TagEnd::CodeBlock) => { + let block_info = block_info.trim(); + match current_section { + Section::UncommittedDiff => { + spec.uncommitted_diff = mem::take(&mut text); + } + Section::EditHistory => { + spec.edit_history.push_str(&mem::take(&mut text)); + } + Section::CursorPosition => { + spec.cursor_path = Path::new(block_info).into(); + spec.cursor_position = mem::take(&mut text); + } + Section::ExpectedExcerpts => { + mem::take(&mut text); + } + Section::ExpectedPatch => { + spec.expected_patch = mem::take(&mut text); + } + Section::Start | Section::Other => {} + } + } + _ => {} + } + } + + if spec.cursor_path.as_ref() == Path::new("") || spec.cursor_position.is_empty() { + anyhow::bail!("Missing cursor position codeblock"); + } + + Ok(spec) + } +} diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 811808c72304f4c11a9858e61395e46024b83f1e..b6bace2a2c080626126af96f9ef51e435d6ab8fa 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -40,7 +40,6 @@ node_runtime.workspace = true paths.workspace = true project.workspace = true prompt_store.workspace = true -pulldown-cmark.workspace = true release_channel.workspace = true reqwest_client.workspace = true serde.workspace = true diff --git a/crates/edit_prediction_cli/src/distill.rs b/crates/edit_prediction_cli/src/distill.rs index 085c5f744a1837cbb97f4c33b6f89b6031088e2b..abfe178ae61b6da522f43c93d40b6000800d0e4d 100644 --- a/crates/edit_prediction_cli/src/distill.rs +++ b/crates/edit_prediction_cli/src/distill.rs @@ -14,7 +14,7 @@ pub async fn run_distill(example: &mut Example) -> Result<()> { ) })?; - example.expected_patch = prediction.actual_patch; + example.spec.expected_patch = prediction.actual_patch; example.prompt = None; example.predictions = Vec::new(); example.score = Vec::new(); diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 9499aae0c1ebce7eeca3ef05fedbcf09c960e131..e37619bf224b3fa506516714856cfbc5024ece14 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -1,6 +1,7 @@ use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics}; use anyhow::{Context as _, Result}; use collections::HashMap; +use edit_prediction::example_spec::ExampleSpec; use edit_prediction::udiff::OpenedBuffers; use gpui::Entity; use http_client::Url; @@ -11,23 +12,14 @@ use std::sync::Arc; use std::{ borrow::Cow, io::{Read, Write}, - mem, path::{Path, PathBuf}, }; use zeta_prompt::RelatedFile; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Example { - #[serde(default)] - pub name: String, - pub repository_url: String, - pub revision: String, - #[serde(default)] - pub uncommitted_diff: String, - pub cursor_path: Arc, - pub cursor_position: String, - pub edit_history: String, - pub expected_patch: String, + #[serde(flatten)] + pub spec: ExampleSpec, /// The full content of the file where an edit is being predicted, and the /// actual cursor offset. @@ -101,8 +93,9 @@ pub struct ExampleScore { impl Example { pub fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> { // git@github.com:owner/repo.git - if self.repository_url.contains('@') { + if self.spec.repository_url.contains('@') { let (owner, repo) = self + .spec .repository_url .split_once(':') .context("expected : in git url")? @@ -115,7 +108,7 @@ impl Example { )) // http://github.com/owner/repo.git } else { - let url = Url::parse(&self.repository_url)?; + let url = Url::parse(&self.spec.repository_url)?; let mut segments = url.path_segments().context("empty http url")?; let owner = segments .next() @@ -171,8 +164,8 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec { serde_json::from_str::(&content).unwrap_or_else(|error| { panic!("Failed to parse example file: {}\n{error}", path.display()) }); - if example.name.is_empty() { - example.name = filename; + if example.spec.name.is_empty() { + example.spec.name = filename; } examples.push(example); } @@ -189,8 +182,8 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec { line_ix + 1 ) }); - if example.name.is_empty() { - example.name = format!("{filename}-{line_ix}") + if example.spec.name.is_empty() { + example.spec.name = format!("{filename}-{line_ix}") } example }) @@ -225,9 +218,10 @@ pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) { pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) { examples.sort_by(|a, b| { - a.repository_url - .cmp(&b.repository_url) - .then(b.revision.cmp(&a.revision)) + a.spec + .repository_url + .cmp(&b.spec.repository_url) + .then(b.spec.revision.cmp(&a.spec.revision)) }); } @@ -235,145 +229,22 @@ pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec let mut examples_by_repo = HashMap::default(); for example in examples.iter_mut() { examples_by_repo - .entry(example.repository_url.clone()) + .entry(example.spec.repository_url.clone()) .or_insert_with(Vec::new) .push(example); } examples_by_repo.into_values().collect() } -fn parse_markdown_example(id: String, input: &str) -> Result { - use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd}; - - const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff"; - const EDIT_HISTORY_HEADING: &str = "Edit History"; - const CURSOR_POSITION_HEADING: &str = "Cursor Position"; - const EXPECTED_PATCH_HEADING: &str = "Expected Patch"; - const EXPECTED_CONTEXT_HEADING: &str = "Expected Context"; - const REPOSITORY_URL_FIELD: &str = "repository_url"; - const REVISION_FIELD: &str = "revision"; - - let parser = Parser::new(input); - - let mut example = Example { - name: id, - repository_url: String::new(), - revision: String::new(), - uncommitted_diff: String::new(), - cursor_path: PathBuf::new().into(), - cursor_position: String::new(), - edit_history: String::new(), - expected_patch: String::new(), +fn parse_markdown_example(name: String, input: &str) -> Result { + let spec = ExampleSpec::from_markdown(name, input)?; + Ok(Example { + spec, buffer: None, context: None, prompt: None, predictions: Vec::new(), score: Vec::new(), state: None, - }; - - let mut text = String::new(); - let mut block_info: CowStr = "".into(); - - #[derive(PartialEq)] - enum Section { - Start, - UncommittedDiff, - EditHistory, - CursorPosition, - ExpectedExcerpts, - ExpectedPatch, - Other, - } - - let mut current_section = Section::Start; - - for event in parser { - match event { - Event::Text(line) => { - text.push_str(&line); - - if let Section::Start = current_section - && let Some((field, value)) = line.split_once('=') - { - match field.trim() { - REPOSITORY_URL_FIELD => { - example.repository_url = value.trim().to_string(); - } - REVISION_FIELD => { - example.revision = value.trim().to_string(); - } - _ => {} - } - } - } - Event::End(TagEnd::Heading(HeadingLevel::H2)) => { - let title = mem::take(&mut text); - current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) { - Section::UncommittedDiff - } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) { - Section::EditHistory - } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) { - Section::CursorPosition - } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) { - Section::ExpectedPatch - } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) { - Section::ExpectedExcerpts - } else { - Section::Other - }; - } - Event::End(TagEnd::Heading(HeadingLevel::H3)) => { - mem::take(&mut text); - } - Event::End(TagEnd::Heading(HeadingLevel::H4)) => { - mem::take(&mut text); - } - Event::End(TagEnd::Heading(level)) => { - anyhow::bail!("Unexpected heading level: {level}"); - } - Event::Start(Tag::CodeBlock(kind)) => { - match kind { - CodeBlockKind::Fenced(info) => { - block_info = info; - } - CodeBlockKind::Indented => { - anyhow::bail!("Unexpected indented codeblock"); - } - }; - } - Event::Start(_) => { - text.clear(); - block_info = "".into(); - } - Event::End(TagEnd::CodeBlock) => { - let block_info = block_info.trim(); - match current_section { - Section::UncommittedDiff => { - example.uncommitted_diff = mem::take(&mut text); - } - Section::EditHistory => { - example.edit_history.push_str(&mem::take(&mut text)); - } - Section::CursorPosition => { - example.cursor_path = Path::new(block_info).into(); - example.cursor_position = mem::take(&mut text); - } - Section::ExpectedExcerpts => { - mem::take(&mut text); - } - Section::ExpectedPatch => { - example.expected_patch = mem::take(&mut text); - } - Section::Start | Section::Other => {} - } - } - _ => {} - } - } - if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() { - anyhow::bail!("Missing cursor position codeblock"); - } - - Ok(example) + }) } diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index c778b708b701492b0cc85a0030a1e9d090ce0724..f543d0799b379403f0caa980df76954649e1aceb 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -23,14 +23,14 @@ pub async fn run_format_prompt( ) -> Result<()> { run_context_retrieval(example, app_state.clone(), cx.clone()).await?; - let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name); + let _step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name); match prompt_format { PromptFormat::Teacher => { let prompt = TeacherPrompt::format_prompt(example); example.prompt = Some(ExamplePrompt { input: prompt, - expected_output: example.expected_patch.clone(), // TODO + expected_output: example.spec.expected_patch.clone(), // TODO format: prompt_format, }); } @@ -54,7 +54,7 @@ pub async fn run_format_prompt( .files .clone(), ep_store.edit_history_for_project(&project, cx), - example.cursor_path.clone(), + example.spec.cursor_path.clone(), example .buffer .as_ref() @@ -63,7 +63,8 @@ pub async fn run_format_prompt( )) })??; let prompt = format_zeta_prompt(&input); - let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone())?; + let expected_output = + zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?; example.prompt = Some(ExamplePrompt { input: prompt, expected_output, @@ -85,7 +86,7 @@ impl TeacherPrompt { const MAX_HISTORY_LINES: usize = 128; pub fn format_prompt(example: &Example) -> String { - let edit_history = Self::format_edit_history(&example.edit_history); + let edit_history = Self::format_edit_history(&example.spec.edit_history); let context = Self::format_context(example); let editable_region = Self::format_editable_region(example); @@ -131,7 +132,7 @@ impl TeacherPrompt { --- a/{path} +++ b/{path} {diff}", - path = example.cursor_path.to_string_lossy(), + path = example.spec.cursor_path.to_string_lossy(), diff = diff, }; @@ -170,13 +171,13 @@ impl TeacherPrompt { fn format_editable_region(example: &Example) -> String { let mut result = String::new(); - let path_str = example.cursor_path.to_string_lossy(); + let path_str = example.spec.cursor_path.to_string_lossy(); result.push_str(&format!("`````path=\"{path_str}\"\n")); result.push_str(Self::EDITABLE_REGION_START); // TODO: control number of lines around cursor - result.push_str(&example.cursor_position); - if !example.cursor_position.ends_with('\n') { + result.push_str(&example.spec.cursor_position); + if !example.spec.cursor_position.ends_with('\n') { result.push('\n'); } diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index 4517e6ccbebca76a7ba8ce73322d6467000fc189..38f114d726d3626fac89982b7f3a98c55e92ac07 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -34,7 +34,7 @@ pub async fn run_load_project( return Ok(()); } - let progress = Progress::global().start(Step::LoadProject, &example.name); + let progress = Progress::global().start(Step::LoadProject, &example.spec.name); let project = setup_project(example, &app_state, &progress, &mut cx).await?; @@ -77,7 +77,7 @@ async fn cursor_position( ) -> Result<(Entity, Anchor)> { let language_registry = project.read_with(cx, |project, _| project.languages().clone())?; let result = language_registry - .load_language_for_file_path(&example.cursor_path) + .load_language_for_file_path(&example.spec.cursor_path) .await; if let Err(error) = result @@ -93,7 +93,7 @@ async fn cursor_position( .context("No visible worktrees") })??; - let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix) + let cursor_path = RelPath::new(&example.spec.cursor_path, PathStyle::Posix) .context("Failed to create RelPath")? .into_arc(); let cursor_buffer = project @@ -108,10 +108,11 @@ async fn cursor_position( })? .await?; let cursor_offset_within_excerpt = example + .spec .cursor_position .find(CURSOR_MARKER) .context("missing cursor marker")?; - let mut cursor_excerpt = example.cursor_position.clone(); + let mut cursor_excerpt = example.spec.cursor_position.clone(); cursor_excerpt.replace_range( cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()), "", @@ -123,10 +124,14 @@ async fn cursor_position( let (excerpt_offset, _) = matches.next().with_context(|| { format!( "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.", - example.name + example.spec.name ) })?; - anyhow::ensure!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name); + anyhow::ensure!( + matches.next().is_none(), + "More than one cursor position match found for {}", + &example.spec.name + ); Ok(excerpt_offset) })??; @@ -149,7 +154,7 @@ async fn setup_project( let worktree_path = setup_worktree(example, step_progress).await?; - if let Some(project) = app_state.project_cache.get(&example.repository_url) { + if let Some(project) = app_state.project_cache.get(&example.spec.repository_url) { ep_store.update(cx, |ep_store, _| { ep_store.clear_history_for_project(&project); })?; @@ -187,7 +192,7 @@ async fn setup_project( app_state .project_cache - .insert(example.repository_url.clone(), project.clone()); + .insert(example.spec.repository_url.clone(), project.clone()); let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?; cx.subscribe(&buffer_store, { @@ -218,7 +223,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu run_git(&repo_dir, &["init"]).await?; run_git( &repo_dir, - &["remote", "add", "origin", &example.repository_url], + &["remote", "add", "origin", &example.spec.repository_url], ) .await?; } @@ -226,7 +231,10 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu // Resolve the example to a revision, fetching it if needed. let revision = run_git( &repo_dir, - &["rev-parse", &format!("{}^{{commit}}", example.revision)], + &[ + "rev-parse", + &format!("{}^{{commit}}", example.spec.revision), + ], ) .await; let revision = if let Ok(revision) = revision { @@ -235,7 +243,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu step_progress.set_substatus("fetching"); if run_git( &repo_dir, - &["fetch", "--depth", "1", "origin", &example.revision], + &["fetch", "--depth", "1", "origin", &example.spec.revision], ) .await .is_err() @@ -256,7 +264,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu let worktree_path_string = worktree_path.to_string_lossy(); run_git( &repo_dir, - &["branch", "-f", &example.name, revision.as_str()], + &["branch", "-f", &example.spec.name, revision.as_str()], ) .await?; run_git( @@ -266,7 +274,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu "add", "-f", &worktree_path_string, - &example.name, + &example.spec.name, ], ) .await?; @@ -274,7 +282,7 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu drop(repo_lock); // Apply the uncommitted diff for this example. - if !example.uncommitted_diff.is_empty() { + if !example.spec.uncommitted_diff.is_empty() { step_progress.set_substatus("applying diff"); let mut apply_process = smol::process::Command::new("git") .current_dir(&worktree_path) @@ -283,7 +291,9 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu .spawn()?; let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?; - stdin.write_all(example.uncommitted_diff.as_bytes()).await?; + stdin + .write_all(example.spec.uncommitted_diff.as_bytes()) + .await?; stdin.close().await?; drop(stdin); @@ -306,7 +316,7 @@ async fn apply_edit_history( project: &Entity, cx: &mut AsyncApp, ) -> Result { - edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await + edit_prediction::udiff::apply_diff(&example.spec.edit_history, project, cx).await } thread_local! { diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 3b185103390016f60fc4f621f280d16a58c363e5..dce0fbbed57dbc4b18faf93787cfb8f2341a126a 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -267,7 +267,7 @@ fn main() { if let Err(e) = result { Progress::global().increment_failed(); let failed_example_path = - FAILED_EXAMPLES_DIR.join(format!("{}.json", example.name)); + FAILED_EXAMPLES_DIR.join(format!("{}.json", example.spec.name)); app_state .fs .write( @@ -276,8 +276,8 @@ fn main() { ) .await .unwrap(); - let err_path = - FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example.name)); + let err_path = FAILED_EXAMPLES_DIR + .join(format!("{}_err.txt", example.spec.name)); app_state .fs .write(&err_path, e.to_string().as_bytes()) @@ -298,7 +298,7 @@ fn main() { Re-run this example with: cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m "}, - example.name, + example.spec.name, e, err_path.display(), failed_example_path.display(), diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 3e6104e3a8afc3adc609df094a70fc34138c1619..aa93c5415dea091164a68b76a34242697aac70e3 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -40,7 +40,7 @@ pub async fn run_prediction( provider, PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching ) { - let _step_progress = Progress::global().start(Step::Predict, &example.name); + let _step_progress = Progress::global().start(Step::Predict, &example.spec.name); if example.prompt.is_none() { run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?; @@ -52,7 +52,7 @@ pub async fn run_prediction( run_load_project(example, app_state.clone(), cx.clone()).await?; - let _step_progress = Progress::global().start(Step::Predict, &example.name); + let _step_progress = Progress::global().start(Step::Predict, &example.spec.name); if matches!( provider, @@ -90,7 +90,7 @@ pub async fn run_prediction( store.set_edit_prediction_model(model); })?; let state = example.state.as_ref().context("state must be set")?; - let run_dir = RUN_DIR.join(&example.name); + let run_dir = RUN_DIR.join(&example.spec.name); let updated_example = Arc::new(Mutex::new(example.clone())); let current_run_ix = Arc::new(AtomicUsize::new(0)); diff --git a/crates/edit_prediction_cli/src/retrieve_context.rs b/crates/edit_prediction_cli/src/retrieve_context.rs index a07c7ec8752ff987b8783c4fa15904078bd5612d..abba4504edc6c0733ffd8c0677e2e3304d8100fa 100644 --- a/crates/edit_prediction_cli/src/retrieve_context.rs +++ b/crates/edit_prediction_cli/src/retrieve_context.rs @@ -26,7 +26,7 @@ pub async fn run_context_retrieval( run_load_project(example, app_state.clone(), cx.clone()).await?; let step_progress: Arc = Progress::global() - .start(Step::Context, &example.name) + .start(Step::Context, &example.spec.name) .into(); let state = example.state.as_ref().unwrap(); diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 314d19b67259e6a4a0fcff932826325f4366ddde..7b507e6d19c943de92eb0b22c7d24d4026789fed 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -25,9 +25,9 @@ pub async fn run_scoring( ) .await?; - let _progress = Progress::global().start(Step::Score, &example.name); + let _progress = Progress::global().start(Step::Score, &example.spec.name); - let expected_patch = parse_patch(&example.expected_patch); + let expected_patch = parse_patch(&example.spec.expected_patch); let mut scores = vec![]; @@ -71,7 +71,7 @@ pub fn print_report(examples: &[Example]) { eprintln!( "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}", - truncate_name(&example.name, 30), + truncate_name(&example.spec.name, 30), line_match.true_positives, line_match.false_positives, line_match.false_negatives, diff --git a/crates/edit_prediction_ui/Cargo.toml b/crates/edit_prediction_ui/Cargo.toml index 63d674250001483bb8963ce62b44af524686399e..b406a450601bef908c27a48be14fe9b1f2204c08 100644 --- a/crates/edit_prediction_ui/Cargo.toml +++ b/crates/edit_prediction_ui/Cargo.toml @@ -15,6 +15,9 @@ doctest = false [dependencies] anyhow.workspace = true buffer_diff.workspace = true +git.workspace = true +log.workspace = true +time.workspace = true client.workspace = true cloud_llm_client.workspace = true codestral.workspace = true diff --git a/crates/edit_prediction_ui/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs index b008f09ec8886086578b571b3655dac566fb6c5d..bbf9f4677df278c014379964e7bdc714e6ce78d8 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -46,7 +46,9 @@ use workspace::{ }; use zed_actions::{OpenBrowser, OpenSettingsAt}; -use crate::{RatePredictions, rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag}; +use crate::{ + CaptureExample, RatePredictions, rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag, +}; actions!( edit_prediction, @@ -899,7 +901,13 @@ impl EditPredictionButton { .context(editor_focus_handle) .when( cx.has_flag::(), - |this| this.action("Rate Predictions", RatePredictions.boxed_clone()), + |this| { + this.action( + "Capture Edit Prediction Example", + CaptureExample.boxed_clone(), + ) + .action("Rate Predictions", RatePredictions.boxed_clone()) + }, ); } diff --git a/crates/edit_prediction_ui/src/edit_prediction_ui.rs b/crates/edit_prediction_ui/src/edit_prediction_ui.rs index 74c81fbfe16eec7846e70aefd59bbfeb282072dc..a762fd22aa7c32779a096fa97b2ea20ef3c9b744 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_ui.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_ui.rs @@ -3,15 +3,24 @@ mod edit_prediction_context_view; mod rate_prediction_modal; use std::any::{Any as _, TypeId}; +use std::path::Path; +use std::sync::Arc; use command_palette_hooks::CommandPaletteFilter; -use edit_prediction::{ResetOnboarding, Zeta2FeatureFlag}; +use edit_prediction::{ + EditPredictionStore, ResetOnboarding, Zeta2FeatureFlag, example_spec::ExampleSpec, +}; use edit_prediction_context_view::EditPredictionContextView; +use editor::Editor; use feature_flags::FeatureFlagAppExt as _; -use gpui::actions; +use git::repository::DiffType; +use gpui::{Window, actions}; +use language::ToPoint as _; +use log; use project::DisableAiSettings; use rate_prediction_modal::RatePredictionsModal; use settings::{Settings as _, SettingsStore}; +use text::ToOffset as _; use ui::{App, prelude::*}; use workspace::{SplitDirection, Workspace}; @@ -32,6 +41,8 @@ actions!( [ /// Opens the rate completions modal. RatePredictions, + /// Captures an ExampleSpec from the current editing session and opens it as Markdown. + CaptureExample, ] ); @@ -45,6 +56,7 @@ pub fn init(cx: &mut App) { } }); + workspace.register_action(capture_edit_prediction_example); workspace.register_action_renderer(|div, _, _, cx| { let has_flag = cx.has_flag::(); div.when(has_flag, |div| { @@ -78,6 +90,7 @@ fn feature_gate_predict_edits_actions(cx: &mut App) { let reset_onboarding_action_types = [TypeId::of::()]; let all_action_types = [ TypeId::of::(), + TypeId::of::(), TypeId::of::(), zed_actions::OpenZedPredictOnboarding.type_id(), TypeId::of::(), @@ -124,3 +137,194 @@ fn feature_gate_predict_edits_actions(cx: &mut App) { }) .detach(); } + +fn capture_edit_prediction_example( + workspace: &mut Workspace, + _: &CaptureExample, + window: &mut Window, + cx: &mut Context, +) { + let Some(ep_store) = EditPredictionStore::try_global(cx) else { + return; + }; + + let project = workspace.project().clone(); + + let (worktree_root, repository) = { + let project_ref = project.read(cx); + let worktree_root = project_ref + .visible_worktrees(cx) + .next() + .map(|worktree| worktree.read(cx).abs_path()); + let repository = project_ref.active_repository(cx); + (worktree_root, repository) + }; + + let (Some(worktree_root), Some(repository)) = (worktree_root, repository) else { + log::error!("CaptureExampleSpec: missing worktree or active repository"); + return; + }; + + let repository_snapshot = repository.read(cx).snapshot(); + if worktree_root.as_ref() != repository_snapshot.work_directory_abs_path.as_ref() { + log::error!( + "repository is not at worktree root (repo={:?}, worktree={:?})", + repository_snapshot.work_directory_abs_path, + worktree_root + ); + return; + } + + let Some(repository_url) = repository_snapshot + .remote_origin_url + .clone() + .or_else(|| repository_snapshot.remote_upstream_url.clone()) + else { + log::error!("active repository has no origin/upstream remote url"); + return; + }; + + let Some(revision) = repository_snapshot + .head_commit + .as_ref() + .map(|commit| commit.sha.to_string()) + else { + log::error!("active repository has no head commit"); + return; + }; + + let mut events = ep_store.update(cx, |store, cx| { + store.edit_history_for_project_with_pause_split_last_event(&project, cx) + }); + + let Some(editor) = workspace.active_item_as::(cx) else { + log::error!("no active editor"); + return; + }; + + let Some(project_path) = editor.read(cx).project_path(cx) else { + log::error!("active editor has no project path"); + return; + }; + + let Some((buffer, cursor_anchor)) = editor + .read(cx) + .buffer() + .read(cx) + .text_anchor_for_position(editor.read(cx).selections.newest_anchor().head(), cx) + else { + log::error!("failed to resolve cursor buffer/anchor"); + return; + }; + + let snapshot = buffer.read(cx).snapshot(); + let cursor_point = cursor_anchor.to_point(&snapshot); + let (_editable_range, context_range) = + edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position( + cursor_point, + &snapshot, + 100, + 50, + ); + + let cursor_path: Arc = repository + .read(cx) + .project_path_to_repo_path(&project_path, cx) + .map(|repo_path| Path::new(repo_path.as_unix_str()).into()) + .unwrap_or_else(|| Path::new(project_path.path.as_unix_str()).into()); + + let cursor_position = { + let context_start_offset = context_range.start.to_offset(&snapshot); + let cursor_offset = cursor_anchor.to_offset(&snapshot); + let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset); + let mut excerpt = snapshot.text_for_range(context_range).collect::(); + if cursor_offset_in_excerpt <= excerpt.len() { + excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER); + } + excerpt + }; + + let markdown_language = workspace + .app_state() + .languages + .language_for_name("Markdown"); + + cx.spawn_in(window, async move |workspace_entity, cx| { + let markdown_language = markdown_language.await?; + + let uncommitted_diff_rx = repository.update(cx, |repository, cx| { + repository.diff(DiffType::HeadToWorktree, cx) + })?; + + let uncommitted_diff = match uncommitted_diff_rx.await { + Ok(Ok(diff)) => diff, + Ok(Err(error)) => { + log::error!("failed to compute uncommitted diff: {error:#}"); + return Ok(()); + } + Err(error) => { + log::error!("uncommitted diff channel dropped: {error:#}"); + return Ok(()); + } + }; + + let mut edit_history = String::new(); + let mut expected_patch = String::new(); + if let Some(last_event) = events.pop() { + for event in &events { + zeta_prompt::write_event(&mut edit_history, event); + if !edit_history.ends_with('\n') { + edit_history.push('\n'); + } + edit_history.push('\n'); + } + + zeta_prompt::write_event(&mut expected_patch, &last_event); + } + + let format = + time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]"); + let name = match format { + Ok(format) => { + let now = time::OffsetDateTime::now_local() + .unwrap_or_else(|_| time::OffsetDateTime::now_utc()); + now.format(&format) + .unwrap_or_else(|_| "unknown-time".to_string()) + } + Err(_) => "unknown-time".to_string(), + }; + + let markdown = ExampleSpec { + name, + repository_url, + revision, + uncommitted_diff, + cursor_path, + cursor_position, + edit_history, + expected_patch, + } + .to_markdown(); + + let buffer = project + .update(cx, |project, cx| project.create_buffer(false, cx))? + .await?; + buffer.update(cx, |buffer, cx| { + buffer.set_text(markdown, cx); + buffer.set_language(Some(markdown_language), cx); + })?; + + workspace_entity.update_in(cx, |workspace, window, cx| { + workspace.add_item_to_active_pane( + Box::new( + cx.new(|cx| Editor::for_buffer(buffer, Some(project.clone()), window, cx)), + ), + None, + true, + window, + cx, + ); + }) + }) + .detach_and_log_err(cx); +}