diff --git a/Cargo.lock b/Cargo.lock index 6c12bd886d7c880f68bca078815dd17ae4ca392f..de94d83793cfdf397099fb24ded02d8cff599e35 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5253,6 +5253,7 @@ dependencies = [ "text", "thiserror 2.0.17", "time", + "toml 0.8.23", "ui", "util", "uuid", diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index 5145777ff0733d674f33a597db6682e611e4d0fc..f2534545abc04ce769cf46be7692f7954bb0bfcf 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -56,6 +56,7 @@ telemetry_events.workspace = true text.workspace = true thiserror.workspace = true time.workspace = true +toml.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs index 0171b711dd24384e1588c076d4b4f678723d7459..232081c579f3e0c01d33d04d1bfebdeb50621cfd 100644 --- a/crates/edit_prediction/src/capture_example.rs +++ b/crates/edit_prediction/src/capture_example.rs @@ -7,15 +7,14 @@ use buffer_diff::BufferDiffSnapshot; use collections::HashMap; use gpui::{App, Entity, Task}; use language::{Buffer, ToPoint as _}; -use project::Project; +use project::{Project, WorktreeId}; use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc}; -use text::{BufferSnapshot as TextBufferSnapshot, ToOffset as _}; +use text::BufferSnapshot as TextBufferSnapshot; pub fn capture_example( project: Entity, buffer: Entity, cursor_anchor: language::Anchor, - last_event_is_expected_patch: bool, cx: &mut App, ) -> Option>> { let ep_store = EditPredictionStore::try_global(cx)?; @@ -43,8 +42,26 @@ pub fn capture_example( let git_store = project.read(cx).git_store().clone(); Some(cx.spawn(async move |mut cx| { - let snapshots_by_path = collect_snapshots(&project, &git_store, &events, &mut cx).await?; - let cursor_excerpt = cx + let snapshots_by_path = + collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?; + + events.retain(|stored_event| { + match stored_event.event.as_ref() { + zeta_prompt::Event::BufferChange { path, .. } => { + if !snapshots_by_path.contains_key(path) { + return false; + } + } + } + true + }); + + let line_comment_prefix = snapshot + .language() + .and_then(|lang| lang.config().line_comments.first()) + .map(|s| s.to_string()) + .unwrap_or_default(); + let (cursor_excerpt, cursor_offset) = cx .background_executor() .spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) }) .await; @@ -54,13 +71,6 @@ pub fn capture_example( .await; let mut edit_history = String::new(); - let mut expected_patch = String::new(); - if last_event_is_expected_patch { - if let Some(stored_event) = events.pop() { - zeta_prompt::write_event(&mut expected_patch, &stored_event.event); - } - } - for stored_event in &events { zeta_prompt::write_event(&mut edit_history, &stored_event.event); if !edit_history.ends_with('\n') { @@ -68,57 +78,62 @@ pub fn capture_example( } } - let name = generate_timestamp_name(); - - Ok(ExampleSpec { - name, + let mut spec = ExampleSpec { + name: generate_timestamp_name(), repository_url, revision, + tags: Vec::new(), + reasoning: None, uncommitted_diff, cursor_path: cursor_path.as_std_path().into(), - cursor_position: cursor_excerpt, + cursor_position: String::new(), edit_history, - expected_patch, - }) + expected_patches: Vec::new(), + }; + spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix); + Ok(spec) })) } fn compute_cursor_excerpt( snapshot: &language::BufferSnapshot, cursor_anchor: language::Anchor, -) -> String { +) -> (String, usize) { + use text::ToOffset as _; + let cursor_point = cursor_anchor.to_point(snapshot); let (_editable_range, context_range) = editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50); - let context_start_offset = context_range.start.to_offset(snapshot); let cursor_offset = cursor_anchor.to_offset(snapshot); let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset); - let mut excerpt = snapshot.text_for_range(context_range).collect::(); - if cursor_offset_in_excerpt <= excerpt.len() { - excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER); - } - excerpt + let excerpt = snapshot.text_for_range(context_range).collect::(); + (excerpt, cursor_offset_in_excerpt) } async fn collect_snapshots( project: &Entity, git_store: &Entity, + worktree_id: WorktreeId, events: &[StoredEvent], cx: &mut gpui::AsyncApp, ) -> Result, (TextBufferSnapshot, BufferDiffSnapshot)>> { let mut snapshots_by_path = HashMap::default(); + let root_name = project.read_with(cx, |project, cx| { + project + .worktree_for_id(worktree_id, cx) + .unwrap() + .read(cx) + .root_name() + .to_owned() + })?; for stored_event in events { let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref(); if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| { - let project_path = project.find_project_path(path, cx)?; - let full_path = project - .worktree_for_id(project_path.worktree_id, cx)? - .read(cx) - .root_name() - .join(&project_path.path) - .as_std_path() - .into(); + let project_path = project + .find_project_path(path, cx) + .filter(|path| path.worktree_id == worktree_id)?; + let full_path = root_name.join(&project_path.path).as_std_path().into(); Some((project_path, full_path)) })? { if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) { @@ -289,9 +304,7 @@ mod tests { cx.run_until_parked(); let mut example = cx - .update(|cx| { - capture_example(project.clone(), buffer.clone(), Anchor::MIN, false, cx).unwrap() - }) + .update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap()) .await .unwrap(); example.name = "test".to_string(); @@ -302,6 +315,8 @@ mod tests { name: "test".to_string(), repository_url: "https://github.com/test/repo.git".to_string(), revision: "abc123def456".to_string(), + tags: Vec::new(), + reasoning: None, uncommitted_diff: indoc! {" --- a/project/src/main.rs +++ b/project/src/main.rs @@ -322,7 +337,8 @@ mod tests { .to_string(), cursor_path: Path::new("project/src/main.rs").into(), cursor_position: indoc! {" - <|user_cursor|>fn main() { + fn main() { + ^[CURSOR_POSITION] // comment 1 one(); two(); @@ -355,7 +371,7 @@ mod tests { seven(); "} .to_string(), - expected_patch: "".to_string(), + expected_patches: Vec::new() } ); } diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index c47c1a70e843a17cc98083dec664c7fc05933426..1c3035d178f86f172f0457ed08dfd4246a626783 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -688,12 +688,14 @@ impl EditPredictionStore { pub fn clear_history(&mut self) { for project_state in self.projects.values_mut() { project_state.events.clear(); + project_state.last_event.take(); } } pub fn clear_history_for_project(&mut self, project: &Entity) { if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { project_state.events.clear(); + project_state.last_event.take(); } } @@ -2044,7 +2046,9 @@ impl EditPredictionStore { "Edit Prediction Rated", rating, inputs = prediction.inputs, - output = prediction.edit_preview.as_unified_diff(&prediction.edits), + output = prediction + .edit_preview + .as_unified_diff(prediction.snapshot.file(), &prediction.edits), feedback ); self.client.telemetry().flush_events().detach(); diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index 8a30c7b85c494fcc7a0b32ece665f814e8452bc4..d4c36d1f9fa43abc5d3f67d771b9fdacb1f425e9 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -1,5 +1,9 @@ +use anyhow::{Context as _, Result}; use serde::{Deserialize, Serialize}; -use std::{fmt::Write as _, mem, path::Path, sync::Arc}; +use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc}; + +pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]"; +pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>"; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct ExampleSpec { @@ -7,34 +11,81 @@ pub struct ExampleSpec { pub name: String, pub repository_url: String, pub revision: String, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tags: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning: Option, #[serde(default)] pub uncommitted_diff: String, pub cursor_path: Arc, pub cursor_position: String, pub edit_history: String, - pub expected_patch: String, + pub expected_patches: Vec, } +const REASONING_HEADING: &str = "Reasoning"; const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff"; const EDIT_HISTORY_HEADING: &str = "Edit History"; const CURSOR_POSITION_HEADING: &str = "Cursor Position"; const EXPECTED_PATCH_HEADING: &str = "Expected Patch"; const EXPECTED_CONTEXT_HEADING: &str = "Expected Context"; -const REPOSITORY_URL_FIELD: &str = "repository_url"; -const REVISION_FIELD: &str = "revision"; + +#[derive(Serialize, Deserialize)] +struct FrontMatter<'a> { + repository_url: Cow<'a, str>, + revision: Cow<'a, str>, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + tags: Vec, +} impl ExampleSpec { + /// Generate a sanitized filename for this example. + pub fn filename(&self) -> String { + self.name + .chars() + .map(|c| match c { + ' ' | ':' | '~' | '^' | '?' | '*' | '[' | '\\' | '@' | '{' | '/' | '<' | '>' + | '|' | '"' => '-', + c => c, + }) + .collect() + } + /// Format this example spec as markdown. pub fn to_markdown(&self) -> String { + use std::fmt::Write as _; + + let front_matter = FrontMatter { + repository_url: Cow::Borrowed(&self.repository_url), + revision: Cow::Borrowed(&self.revision), + tags: self.tags.clone(), + }; + let front_matter_toml = + toml::to_string_pretty(&front_matter).unwrap_or_else(|_| String::new()); + let mut markdown = String::new(); - _ = writeln!(markdown, "# {}", self.name); + _ = writeln!(markdown, "+++"); + markdown.push_str(&front_matter_toml); + if !markdown.ends_with('\n') { + markdown.push('\n'); + } + _ = writeln!(markdown, "+++"); markdown.push('\n'); - _ = writeln!(markdown, "repository_url = {}", self.repository_url); - _ = writeln!(markdown, "revision = {}", self.revision); + _ = writeln!(markdown, "# {}", self.name); markdown.push('\n'); + if let Some(reasoning) = &self.reasoning { + _ = writeln!(markdown, "## {}", REASONING_HEADING); + markdown.push('\n'); + markdown.push_str(reasoning); + if !markdown.ends_with('\n') { + markdown.push('\n'); + } + markdown.push('\n'); + } + if !self.uncommitted_diff.is_empty() { _ = writeln!(markdown, "## {}", UNCOMMITTED_DIFF_HEADING); _ = writeln!(markdown); @@ -75,34 +126,48 @@ impl ExampleSpec { _ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING); markdown.push('\n'); - _ = writeln!(markdown, "```diff"); - markdown.push_str(&self.expected_patch); - if !markdown.ends_with('\n') { + for patch in &self.expected_patches { + _ = writeln!(markdown, "```diff"); + markdown.push_str(patch); + if !markdown.ends_with('\n') { + markdown.push('\n'); + } + _ = writeln!(markdown, "```"); markdown.push('\n'); } - _ = writeln!(markdown, "```"); - markdown.push('\n'); markdown } /// Parse an example spec from markdown. - pub fn from_markdown(name: String, input: &str) -> anyhow::Result { + pub fn from_markdown(mut input: &str) -> anyhow::Result { use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd}; - let parser = Parser::new(input); - let mut spec = ExampleSpec { - name, + name: String::new(), repository_url: String::new(), revision: String::new(), + tags: Vec::new(), + reasoning: None, uncommitted_diff: String::new(), cursor_path: Path::new("").into(), cursor_position: String::new(), edit_history: String::new(), - expected_patch: String::new(), + expected_patches: Vec::new(), }; + if let Some(rest) = input.strip_prefix("+++\n") + && let Some((front_matter, rest)) = rest.split_once("+++\n") + { + if let Ok(data) = toml::from_str::>(front_matter) { + spec.repository_url = data.repository_url.into_owned(); + spec.revision = data.revision.into_owned(); + spec.tags = data.tags; + } + input = rest.trim_start(); + } + + let parser = Parser::new(input); let mut text = String::new(); let mut block_info: CowStr = "".into(); @@ -123,20 +188,9 @@ impl ExampleSpec { match event { Event::Text(line) => { text.push_str(&line); - - if let Section::Start = current_section - && let Some((field, value)) = line.split_once('=') - { - match field.trim() { - REPOSITORY_URL_FIELD => { - spec.repository_url = value.trim().to_string(); - } - REVISION_FIELD => { - spec.revision = value.trim().to_string(); - } - _ => {} - } - } + } + Event::End(TagEnd::Heading(HeadingLevel::H1)) => { + spec.name = mem::take(&mut text); } Event::End(TagEnd::Heading(HeadingLevel::H2)) => { let title = mem::take(&mut text); @@ -194,7 +248,7 @@ impl ExampleSpec { mem::take(&mut text); } Section::ExpectedPatch => { - spec.expected_patch = mem::take(&mut text); + spec.expected_patches.push(mem::take(&mut text)); } Section::Start | Section::Other => {} } @@ -209,4 +263,326 @@ impl ExampleSpec { Ok(spec) } + + /// Returns the excerpt of text around the cursor, and the offset of the cursor within that + /// excerpt. + /// + /// The cursor's position is marked with a special comment that appears + /// below the cursor line, which contains the string `[CURSOR_POSITION]`, + /// preceded by an arrow marking the cursor's column. The arrow can be + /// either: + /// - `^` - The cursor column is at the position of the `^` character (pointing up to the cursor) + /// - `<` - The cursor column is at the first non-whitespace character on that line. + pub fn cursor_excerpt(&self) -> Result<(String, usize)> { + let input = &self.cursor_position; + + // Check for inline cursor marker first + if let Some(inline_offset) = input.find(INLINE_CURSOR_MARKER) { + let excerpt = input[..inline_offset].to_string() + + &input[inline_offset + INLINE_CURSOR_MARKER.len()..]; + return Ok((excerpt, inline_offset)); + } + + let marker_offset = input + .find(CURSOR_POSITION_MARKER) + .context("missing [CURSOR_POSITION] marker")?; + let marker_line_start = input[..marker_offset] + .rfind('\n') + .map(|pos| pos + 1) + .unwrap_or(0); + let marker_line_end = input[marker_line_start..] + .find('\n') + .map(|pos| marker_line_start + pos + 1) + .unwrap_or(input.len()); + let marker_line = &input[marker_line_start..marker_line_end].trim_end_matches('\n'); + + let cursor_column = if let Some(cursor_offset) = marker_line.find('^') { + cursor_offset + } else if let Some(less_than_pos) = marker_line.find('<') { + marker_line + .find(|c: char| !c.is_whitespace()) + .unwrap_or(less_than_pos) + } else { + anyhow::bail!( + "cursor position marker line must contain '^' or '<' before [CURSOR_POSITION]" + ); + }; + + let mut excerpt = input[..marker_line_start].to_string() + &input[marker_line_end..]; + excerpt.truncate(excerpt.trim_end_matches('\n').len()); + + // The cursor is on the line above the marker line. + let cursor_line_end = marker_line_start.saturating_sub(1); + let cursor_line_start = excerpt[..cursor_line_end] + .rfind('\n') + .map(|pos| pos + 1) + .unwrap_or(0); + let cursor_offset = cursor_line_start + cursor_column; + + Ok((excerpt, cursor_offset)) + } + + /// Sets the cursor position excerpt from a plain excerpt and cursor byte offset. + /// + /// The `line_comment_prefix` is used to format the marker line as a comment. + /// If the cursor column is less than the comment prefix length, the `<` format is used. + /// Otherwise, the `^` format is used. + pub fn set_cursor_excerpt( + &mut self, + excerpt: &str, + cursor_offset: usize, + line_comment_prefix: &str, + ) { + // Find which line the cursor is on and its column + let cursor_line_start = excerpt[..cursor_offset] + .rfind('\n') + .map(|pos| pos + 1) + .unwrap_or(0); + let cursor_line_end = excerpt[cursor_line_start..] + .find('\n') + .map(|pos| cursor_line_start + pos + 1) + .unwrap_or(excerpt.len()); + let cursor_line = &excerpt[cursor_line_start..cursor_line_end]; + let cursor_line_indent = &cursor_line[..cursor_line.len() - cursor_line.trim_start().len()]; + let cursor_column = cursor_offset - cursor_line_start; + + // Build the marker line + let mut marker_line = String::new(); + if cursor_column < line_comment_prefix.len() { + for _ in 0..cursor_column { + marker_line.push(' '); + } + marker_line.push_str(line_comment_prefix); + write!(marker_line, " <{}", CURSOR_POSITION_MARKER).unwrap(); + } else { + if cursor_column >= cursor_line_indent.len() + line_comment_prefix.len() { + marker_line.push_str(cursor_line_indent); + } + marker_line.push_str(line_comment_prefix); + while marker_line.len() < cursor_column { + marker_line.push(' '); + } + write!(marker_line, "^{}", CURSOR_POSITION_MARKER).unwrap(); + } + + // Build the final cursor_position string + let mut result = String::with_capacity(excerpt.len() + marker_line.len() + 2); + result.push_str(&excerpt[..cursor_line_end]); + if !result.ends_with('\n') { + result.push('\n'); + } + result.push_str(&marker_line); + if cursor_line_end < excerpt.len() { + result.push('\n'); + result.push_str(&excerpt[cursor_line_end..]); + } + + self.cursor_position = result; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use indoc::indoc; + + #[test] + fn test_cursor_excerpt_with_caret() { + let mut spec = ExampleSpec { + name: String::new(), + repository_url: String::new(), + revision: String::new(), + tags: Vec::new(), + reasoning: None, + uncommitted_diff: String::new(), + cursor_path: Path::new("test.rs").into(), + cursor_position: String::new(), + edit_history: String::new(), + expected_patches: Vec::new(), + }; + + // Cursor before `42` + let excerpt = indoc! {" + fn main() { + let x = 42; + println!(\"{}\", x); + }" + }; + let offset = excerpt.find("42").unwrap(); + let position_string = indoc! {" + fn main() { + let x = 42; + // ^[CURSOR_POSITION] + println!(\"{}\", x); + }" + } + .to_string(); + + spec.set_cursor_excerpt(excerpt, offset, "//"); + assert_eq!(spec.cursor_position, position_string); + assert_eq!( + spec.cursor_excerpt().unwrap(), + (excerpt.to_string(), offset) + ); + + // Cursor after `l` in `let` + let offset = excerpt.find("et x").unwrap(); + let position_string = indoc! {" + fn main() { + let x = 42; + // ^[CURSOR_POSITION] + println!(\"{}\", x); + }" + } + .to_string(); + + spec.set_cursor_excerpt(excerpt, offset, "//"); + assert_eq!(spec.cursor_position, position_string); + assert_eq!( + spec.cursor_excerpt().unwrap(), + (excerpt.to_string(), offset) + ); + + // Cursor before `let` + let offset = excerpt.find("let").unwrap(); + let position_string = indoc! {" + fn main() { + let x = 42; + // ^[CURSOR_POSITION] + println!(\"{}\", x); + }" + } + .to_string(); + + spec.set_cursor_excerpt(excerpt, offset, "//"); + assert_eq!(spec.cursor_position, position_string); + assert_eq!( + spec.cursor_excerpt().unwrap(), + (excerpt.to_string(), offset) + ); + + // Cursor at beginning of the line with `let` + let offset = excerpt.find(" let").unwrap(); + let position_string = indoc! {" + fn main() { + let x = 42; + // <[CURSOR_POSITION] + println!(\"{}\", x); + }" + } + .to_string(); + + spec.set_cursor_excerpt(excerpt, offset, "//"); + assert_eq!(spec.cursor_position, position_string); + assert_eq!( + spec.cursor_excerpt().unwrap(), + (excerpt.to_string(), offset) + ); + + // Cursor at end of line, after the semicolon + let offset = excerpt.find(';').unwrap() + 1; + let position_string = indoc! {" + fn main() { + let x = 42; + // ^[CURSOR_POSITION] + println!(\"{}\", x); + }" + } + .to_string(); + + spec.set_cursor_excerpt(excerpt, offset, "//"); + assert_eq!(spec.cursor_position, position_string); + assert_eq!( + spec.cursor_excerpt().unwrap(), + (excerpt.to_string(), offset) + ); + + // Caret at end of file (no trailing newline) + let excerpt = indoc! {" + fn main() { + let x = 42;" + }; + let offset = excerpt.find(';').unwrap() + 1; + let position_string = indoc! {" + fn main() { + let x = 42; + // ^[CURSOR_POSITION]" + } + .to_string(); + + spec.set_cursor_excerpt(excerpt, offset, "//"); + assert_eq!(spec.cursor_position, position_string); + assert_eq!( + spec.cursor_excerpt().unwrap(), + (excerpt.to_string(), offset) + ); + } + + #[test] + fn test_cursor_excerpt_with_inline_marker() { + let mut spec = ExampleSpec { + name: String::new(), + repository_url: String::new(), + revision: String::new(), + tags: Vec::new(), + reasoning: None, + uncommitted_diff: String::new(), + cursor_path: Path::new("test.rs").into(), + cursor_position: String::new(), + edit_history: String::new(), + expected_patches: Vec::new(), + }; + + // Cursor before `42` using inline marker + spec.cursor_position = indoc! {" + fn main() { + let x = <|user_cursor|>42; + println!(\"{}\", x); + }" + } + .to_string(); + + let expected_excerpt = indoc! {" + fn main() { + let x = 42; + println!(\"{}\", x); + }" + }; + let expected_offset = expected_excerpt.find("42").unwrap(); + + assert_eq!( + spec.cursor_excerpt().unwrap(), + (expected_excerpt.to_string(), expected_offset) + ); + + // Cursor at beginning of line + spec.cursor_position = indoc! {" + fn main() { + <|user_cursor|> let x = 42; + }" + } + .to_string(); + + let expected_excerpt = indoc! {" + fn main() { + let x = 42; + }" + }; + let expected_offset = expected_excerpt.find(" let").unwrap(); + + assert_eq!( + spec.cursor_excerpt().unwrap(), + (expected_excerpt.to_string(), expected_offset) + ); + + // Cursor at end of file + spec.cursor_position = "fn main() {}<|user_cursor|>".to_string(); + let expected_excerpt = "fn main() {}"; + let expected_offset = expected_excerpt.len(); + + assert_eq!( + spec.cursor_excerpt().unwrap(), + (expected_excerpt.to_string(), expected_offset) + ); + } } diff --git a/crates/edit_prediction/src/udiff.rs b/crates/edit_prediction/src/udiff.rs index 78fec03dd78301d56ac6e3f914ba60432e41637d..cf854c0d1ab511dca4c55ec401553f6767b26d1e 100644 --- a/crates/edit_prediction/src/udiff.rs +++ b/crates/edit_prediction/src/udiff.rs @@ -1,23 +1,19 @@ -use std::borrow::Cow; -use std::fmt::Display; -use std::sync::Arc; use std::{ - fmt::{Debug, Write}, + borrow::Cow, + fmt::{Debug, Display, Write}, mem, ops::Range, path::Path, + sync::Arc, }; -use anyhow::Context as _; -use anyhow::Result; -use anyhow::anyhow; +use anyhow::{Context as _, Result, anyhow}; use collections::HashMap; -use gpui::AsyncApp; -use gpui::Entity; -use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot}; -use project::{Project, ProjectPath}; -use util::paths::PathStyle; -use util::rel_path::RelPath; +use gpui::{AsyncApp, Entity}; +use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot, text_diff}; +use postage::stream::Stream as _; +use project::Project; +use util::{paths::PathStyle, rel_path::RelPath}; #[derive(Clone, Debug)] pub struct OpenedBuffers(#[allow(unused)] HashMap>); @@ -28,56 +24,50 @@ pub async fn apply_diff( project: &Entity, cx: &mut AsyncApp, ) -> Result { - let mut included_files = HashMap::default(); - - let worktree_id = project.read_with(cx, |project, cx| { - anyhow::Ok( - project - .visible_worktrees(cx) - .next() - .context("no worktrees")? - .read(cx) - .id(), - ) - })??; + let worktree = project + .read_with(cx, |project, cx| project.visible_worktrees(cx).next())? + .context("project has no worktree")?; + // Ensure the files in the diff are loaded, since worktree scanning is disabled in + // edit prediction CLI. + let mut paths = Vec::new(); for line in diff_str.lines() { - let diff_line = DiffLine::parse(line); - - if let DiffLine::OldPath { path } = diff_line { - let buffer = project - .update(cx, |project, cx| { - let project_path = ProjectPath { - worktree_id, - path: RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?.into_arc(), - }; - anyhow::Ok(project.open_buffer(project_path, cx)) - })?? - .await?; - - included_files.insert(path.to_string(), buffer); + if let DiffLine::OldPath { path } = DiffLine::parse(line) { + paths.push(RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?.into_arc()); } } + worktree + .update(cx, |worktree, _| { + worktree + .as_local() + .unwrap() + .refresh_entries_for_paths(paths) + })? + .recv() + .await; - let ranges = [Anchor::MIN..Anchor::MAX]; + let mut included_files = HashMap::default(); + let ranges = [Anchor::MIN..Anchor::MAX]; let mut diff = DiffParser::new(diff_str); let mut current_file = None; let mut edits = vec![]; while let Some(event) = diff.next()? { match event { - DiffEvent::Hunk { - path: file_path, - hunk, - } => { - let (buffer, ranges) = match current_file { + DiffEvent::Hunk { path, hunk } => { + let buffer = match current_file { None => { - let buffer = included_files - .get_mut(file_path.as_ref()) - .expect("Opened all files in diff"); - - current_file = Some((buffer, ranges.as_slice())); + let buffer = project + .update(cx, |project, cx| { + let project_path = project + .find_project_path(path.as_ref(), cx) + .context("no such path")?; + anyhow::Ok(project.open_buffer(project_path, cx)) + })?? + .await?; + included_files.insert(path.to_string(), buffer.clone()); + current_file = Some(buffer); current_file.as_ref().unwrap() } Some(ref current) => current, @@ -85,14 +75,14 @@ pub async fn apply_diff( buffer.read_with(cx, |buffer, _| { edits.extend( - resolve_hunk_edits_in_buffer(hunk, buffer, ranges) + resolve_hunk_edits_in_buffer(hunk, buffer, ranges.as_slice()) .with_context(|| format!("Diff:\n{diff_str}"))?, ); anyhow::Ok(()) })??; } DiffEvent::FileEnd { renamed_to } => { - let (buffer, _) = current_file + let buffer = current_file .take() .context("Got a FileEnd event before an Hunk event")?; @@ -128,6 +118,65 @@ pub async fn apply_diff( Ok(OpenedBuffers(included_files)) } +/// Extract the diff for a specific file from a multi-file diff. +/// Returns an error if the file is not found in the diff. +pub fn extract_file_diff(full_diff: &str, file_path: &str) -> Result { + let mut result = String::new(); + let mut in_target_file = false; + let mut found_file = false; + + for line in full_diff.lines() { + if line.starts_with("diff --git") { + if in_target_file { + break; + } + in_target_file = line.contains(&format!("a/{}", file_path)) + || line.contains(&format!("b/{}", file_path)); + if in_target_file { + found_file = true; + } + } + + if in_target_file { + result.push_str(line); + result.push('\n'); + } + } + + if !found_file { + anyhow::bail!("File '{}' not found in diff", file_path); + } + + Ok(result) +} + +/// Strip unnecessary git metadata lines from a diff, keeping only the lines +/// needed for patch application: path headers (--- and +++), hunk headers (@@), +/// and content lines (+, -, space). +pub fn strip_diff_metadata(diff: &str) -> String { + let mut result = String::new(); + + for line in diff.lines() { + let dominated = DiffLine::parse(line); + match dominated { + // Keep path headers, hunk headers, and content lines + DiffLine::OldPath { .. } + | DiffLine::NewPath { .. } + | DiffLine::HunkHeader(_) + | DiffLine::Context(_) + | DiffLine::Deletion(_) + | DiffLine::Addition(_) => { + result.push_str(line); + result.push('\n'); + } + // Skip garbage lines (diff --git, index, etc.) + DiffLine::Garbage(_) => {} + } + } + + result +} + pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result { let mut diff = DiffParser::new(diff_str); @@ -151,6 +200,51 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result { Ok(text) } +/// Returns the individual edits that would be applied by a diff to the given content. +/// Each edit is a tuple of (byte_range_in_content, replacement_text). +/// Uses sub-line diffing to find the precise character positions of changes. +/// Returns an empty vec if the hunk context is not found or is ambiguous. +pub fn edits_for_diff(content: &str, diff_str: &str) -> Result, String)>> { + let mut diff = DiffParser::new(diff_str); + let mut result = Vec::new(); + + while let Some(event) = diff.next()? { + match event { + DiffEvent::Hunk { hunk, .. } => { + if hunk.context.is_empty() { + return Ok(Vec::new()); + } + + // Find the context in the content + let first_match = content.find(&hunk.context); + let Some(context_offset) = first_match else { + return Ok(Vec::new()); + }; + + // Check for ambiguity - if context appears more than once, reject + if content[context_offset + 1..].contains(&hunk.context) { + return Ok(Vec::new()); + } + + // Use sub-line diffing to find precise edit positions + for edit in &hunk.edits { + let old_text = &content + [context_offset + edit.range.start..context_offset + edit.range.end]; + let edits_within_hunk = text_diff(old_text, &edit.text); + for (inner_range, inner_text) in edits_within_hunk { + let absolute_start = context_offset + edit.range.start + inner_range.start; + let absolute_end = context_offset + edit.range.start + inner_range.end; + result.push((absolute_start..absolute_end, inner_text.to_string())); + } + } + } + DiffEvent::FileEnd { .. } => {} + } + } + + Ok(result) +} + struct PatchFile<'a> { old_path: Cow<'a, str>, new_path: Cow<'a, str>, @@ -873,4 +967,135 @@ mod tests { FakeFs::new(cx.background_executor.clone()) } + + #[test] + fn test_extract_file_diff() { + let multi_file_diff = indoc! {r#" + diff --git a/file1.txt b/file1.txt + index 1234567..abcdefg 100644 + --- a/file1.txt + +++ b/file1.txt + @@ -1,3 +1,4 @@ + line1 + +added line + line2 + line3 + diff --git a/file2.txt b/file2.txt + index 2345678..bcdefgh 100644 + --- a/file2.txt + +++ b/file2.txt + @@ -1,2 +1,2 @@ + -old line + +new line + unchanged + "#}; + + let file1_diff = extract_file_diff(multi_file_diff, "file1.txt").unwrap(); + assert_eq!( + file1_diff, + indoc! {r#" + diff --git a/file1.txt b/file1.txt + index 1234567..abcdefg 100644 + --- a/file1.txt + +++ b/file1.txt + @@ -1,3 +1,4 @@ + line1 + +added line + line2 + line3 + "#} + ); + + let file2_diff = extract_file_diff(multi_file_diff, "file2.txt").unwrap(); + assert_eq!( + file2_diff, + indoc! {r#" + diff --git a/file2.txt b/file2.txt + index 2345678..bcdefgh 100644 + --- a/file2.txt + +++ b/file2.txt + @@ -1,2 +1,2 @@ + -old line + +new line + unchanged + "#} + ); + + let result = extract_file_diff(multi_file_diff, "nonexistent.txt"); + assert!(result.is_err()); + } + + #[test] + fn test_edits_for_diff() { + let content = indoc! {" + fn main() { + let x = 1; + let y = 2; + println!(\"{} {}\", x, y); + } + "}; + + let diff = indoc! {" + --- a/file.rs + +++ b/file.rs + @@ -1,5 +1,5 @@ + fn main() { + - let x = 1; + + let x = 42; + let y = 2; + println!(\"{} {}\", x, y); + } + "}; + + let edits = edits_for_diff(content, diff).unwrap(); + assert_eq!(edits.len(), 1); + + let (range, replacement) = &edits[0]; + // With sub-line diffing, the edit should start at "1" (the actual changed character) + let expected_start = content.find("let x = 1;").unwrap() + "let x = ".len(); + assert_eq!(range.start, expected_start); + // The deleted text is just "1" + assert_eq!(range.end, expected_start + "1".len()); + // The replacement text + assert_eq!(replacement, "42"); + + // Verify the cursor would be positioned at the column of "1" + let line_start = content[..range.start] + .rfind('\n') + .map(|p| p + 1) + .unwrap_or(0); + let cursor_column = range.start - line_start; + // " let x = " is 12 characters, so column 12 + assert_eq!(cursor_column, " let x = ".len()); + } + + #[test] + fn test_strip_diff_metadata() { + let diff_with_metadata = indoc! {r#" + diff --git a/file.txt b/file.txt + index 1234567..abcdefg 100644 + --- a/file.txt + +++ b/file.txt + @@ -1,3 +1,4 @@ + context line + -removed line + +added line + more context + "#}; + + let stripped = strip_diff_metadata(diff_with_metadata); + + assert_eq!( + stripped, + indoc! {r#" + --- a/file.txt + +++ b/file.txt + @@ -1,3 +1,4 @@ + context line + -removed line + +added line + more context + "#} + ); + } } diff --git a/crates/edit_prediction_cli/src/anthropic_client.rs b/crates/edit_prediction_cli/src/anthropic_client.rs index 8afc4d1c03f8a37ae258cc2926daf85caebe3d8a..e70f834223b5670782d45040b390e49b524bb82f 100644 --- a/crates/edit_prediction_cli/src/anthropic_client.rs +++ b/crates/edit_prediction_cli/src/anthropic_client.rs @@ -1,8 +1,10 @@ use anthropic::{ - ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent, - Response as AnthropicResponse, Role, non_streaming_completion, + ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent, + Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion, + stream_completion, }; use anyhow::Result; +use futures::StreamExt as _; use http_client::HttpClient; use indoc::indoc; use reqwest_client::ReqwestClient; @@ -15,12 +17,12 @@ use std::path::Path; use std::sync::Arc; pub struct PlainLlmClient { - http_client: Arc, - api_key: String, + pub http_client: Arc, + pub api_key: String, } impl PlainLlmClient { - fn new() -> Result { + pub fn new() -> Result { let http_client: Arc = Arc::new(ReqwestClient::new()); let api_key = std::env::var("ANTHROPIC_API_KEY") .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?; @@ -30,7 +32,7 @@ impl PlainLlmClient { }) } - async fn generate( + pub async fn generate( &self, model: &str, max_tokens: u64, @@ -63,6 +65,72 @@ impl PlainLlmClient { Ok(response) } + + pub async fn generate_streaming( + &self, + model: &str, + max_tokens: u64, + messages: Vec, + mut on_progress: F, + ) -> Result + where + F: FnMut(usize, &str), + { + let request = AnthropicRequest { + model: model.to_string(), + max_tokens, + messages, + tools: Vec::new(), + thinking: None, + tool_choice: None, + system: None, + metadata: None, + stop_sequences: Vec::new(), + temperature: None, + top_k: None, + top_p: None, + }; + + let mut stream = stream_completion( + self.http_client.as_ref(), + ANTHROPIC_API_URL, + &self.api_key, + request, + None, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + let mut response: Option = None; + let mut text_content = String::new(); + + while let Some(event_result) = stream.next().await { + let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?; + + match event { + Event::MessageStart { message } => { + response = Some(message); + } + Event::ContentBlockDelta { delta, .. } => { + if let anthropic::ContentDelta::TextDelta { text } = delta { + text_content.push_str(&text); + on_progress(text_content.len(), &text_content); + } + } + _ => {} + } + } + + let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?; + + if response.content.is_empty() && !text_content.is_empty() { + response + .content + .push(ResponseContent::Text { text: text_content }); + } + + Ok(response) + } } pub struct BatchingLlmClient { @@ -408,6 +476,29 @@ impl AnthropicClient { } } + #[allow(dead_code)] + pub async fn generate_streaming( + &self, + model: &str, + max_tokens: u64, + messages: Vec, + on_progress: F, + ) -> Result> + where + F: FnMut(usize, &str), + { + match self { + AnthropicClient::Plain(plain_llm_client) => plain_llm_client + .generate_streaming(model, max_tokens, messages, on_progress) + .await + .map(Some), + AnthropicClient::Batch(_) => { + anyhow::bail!("Streaming not supported with batching client") + } + AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"), + } + } + pub async fn sync_batches(&self) -> Result<()> { match self { AnthropicClient::Plain(_) => Ok(()), diff --git a/crates/edit_prediction_cli/src/distill.rs b/crates/edit_prediction_cli/src/distill.rs index abfe178ae61b6da522f43c93d40b6000800d0e4d..d6343871e8054fc54062f3d3f7f5210374b36812 100644 --- a/crates/edit_prediction_cli/src/distill.rs +++ b/crates/edit_prediction_cli/src/distill.rs @@ -1,20 +1,15 @@ -use anyhow::{Result, anyhow}; +use anyhow::Result; use std::mem; use crate::example::Example; pub async fn run_distill(example: &mut Example) -> Result<()> { - let [prediction]: [_; 1] = - mem::take(&mut example.predictions) - .try_into() - .map_err(|preds: Vec<_>| { - anyhow!( - "Example has {} predictions, but it should have exactly one", - preds.len() - ) - })?; + let predictions = mem::take(&mut example.predictions) + .into_iter() + .map(|p| p.actual_patch) + .collect(); - example.spec.expected_patch = prediction.actual_patch; + example.spec.expected_patches = predictions; example.prompt = None; example.predictions = Vec::new(); example.score = Vec::new(); diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index e37619bf224b3fa506516714856cfbc5024ece14..63a53b0d7dc667b05171d486e078617187f24fe6 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -1,4 +1,4 @@ -use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics}; +use crate::{PredictionProvider, PromptFormat}; use anyhow::{Context as _, Result}; use collections::HashMap; use edit_prediction::example_spec::ExampleSpec; @@ -87,7 +87,6 @@ pub struct ExamplePrediction { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ExampleScore { pub delta_chr_f: f32, - pub line_match: ClassificationMetrics, } impl Example { @@ -190,7 +189,11 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec { .collect::>(), ), "md" => { - examples.push(parse_markdown_example(filename, &content).unwrap()); + let mut example = parse_markdown_example(&content).unwrap(); + if example.spec.name.is_empty() { + example.spec.name = filename; + } + examples.push(example); } ext => { panic!("{} has invalid example extension `{ext}`", path.display()) @@ -236,8 +239,8 @@ pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec examples_by_repo.into_values().collect() } -fn parse_markdown_example(name: String, input: &str) -> Result { - let spec = ExampleSpec::from_markdown(name, input)?; +fn parse_markdown_example(input: &str) -> Result { + let spec = ExampleSpec::from_markdown(input)?; Ok(Example { spec, buffer: None, diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 34e8a92d4140cdbdedb6bd2583e0994eb55b802d..9463b2349e035c195bde3d11b3e25f5ecc4bd018 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -30,7 +30,12 @@ pub async fn run_format_prompt( let prompt = TeacherPrompt::format_prompt(example); example.prompt = Some(ExamplePrompt { input: prompt, - expected_output: example.spec.expected_patch.clone(), // TODO + expected_output: example + .spec + .expected_patches + .first() + .cloned() + .unwrap_or_default(), format: prompt_format, }); } @@ -68,8 +73,15 @@ pub async fn run_format_prompt( )) })??; let prompt = format_zeta_prompt(&input); - let expected_output = - zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?; + let expected_output = zeta2_output_for_patch( + &input, + &example + .spec + .expected_patches + .first() + .context("expected patches is empty")? + .clone(), + )?; example.prompt = Some(ExamplePrompt { input: prompt, expected_output, @@ -86,6 +98,7 @@ impl TeacherPrompt { const PROMPT: &str = include_str!("teacher.prompt.md"); pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n"; pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>"; + pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>"; /// Truncate edit history to this number of last lines const MAX_HISTORY_LINES: usize = 128; @@ -181,13 +194,15 @@ impl TeacherPrompt { result.push_str(Self::EDITABLE_REGION_START); // TODO: control number of lines around cursor - result.push_str(&example.spec.cursor_position); - if !example.spec.cursor_position.ends_with('\n') { + let (mut excerpt, offset) = example.spec.cursor_excerpt().unwrap(); + excerpt.insert_str(offset, Self::USER_CURSOR_MARKER); + result.push_str(&excerpt); + if !result.ends_with('\n') { result.push('\n'); } - result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END)); - result.push_str("`````"); + result.push_str(Self::EDITABLE_REGION_END); + result.push_str("\n`````"); result } diff --git a/crates/edit_prediction_cli/src/git.rs b/crates/edit_prediction_cli/src/git.rs new file mode 100644 index 0000000000000000000000000000000000000000..f2fe183d76e9eeacba04e988cfd86602ab8d597e --- /dev/null +++ b/crates/edit_prediction_cli/src/git.rs @@ -0,0 +1,110 @@ +use anyhow::{Context as _, Result}; +use collections::HashMap; +use futures::lock::{Mutex, OwnedMutexGuard}; +use std::{ + cell::RefCell, + path::{Path, PathBuf}, + sync::Arc, +}; + +use crate::paths::REPOS_DIR; + +thread_local! { + static REPO_LOCKS: RefCell>>> = RefCell::new(HashMap::default()); +} + +#[must_use] +pub async fn lock_repo(path: impl AsRef) -> OwnedMutexGuard<()> { + REPO_LOCKS + .with(|cell| { + cell.borrow_mut() + .entry(path.as_ref().to_path_buf()) + .or_default() + .clone() + }) + .lock_owned() + .await +} + +pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result { + let output = smol::process::Command::new("git") + .current_dir(repo_path) + .args(args) + .output() + .await?; + + anyhow::ensure!( + output.status.success(), + "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}", + args.join(" "), + repo_path.display(), + output.status, + String::from_utf8_lossy(&output.stderr), + String::from_utf8_lossy(&output.stdout), + ); + Ok(String::from_utf8(output.stdout)?.trim().to_string()) +} + +pub fn parse_repo_url(url: &str) -> Result<(String, String)> { + if url.contains('@') { + let (_, path) = url.split_once(':').context("expected : in git url")?; + let (owner, repo) = path.split_once('/').context("expected / in git url")?; + Ok((owner.to_string(), repo.trim_end_matches(".git").to_string())) + } else { + let parsed = http_client::Url::parse(url)?; + let mut segments = parsed.path_segments().context("empty http url")?; + let owner = segments.next().context("expected owner")?; + let repo = segments.next().context("expected repo")?; + Ok((owner.to_string(), repo.trim_end_matches(".git").to_string())) + } +} + +pub fn repo_path_for_url(url: &str) -> Result { + let (owner, name) = parse_repo_url(url)?; + Ok(REPOS_DIR.join(&owner).join(&name)) +} + +pub async fn ensure_repo_cloned(repo_url: &str) -> Result { + let repo_path = repo_path_for_url(repo_url)?; + let _lock = lock_repo(&repo_path).await; + + if !repo_path.is_dir() { + log::info!("Cloning {} into {:?}", repo_url, repo_path); + std::fs::create_dir_all(&repo_path)?; + run_git(&repo_path, &["init"]).await?; + run_git(&repo_path, &["remote", "add", "origin", repo_url]).await?; + } + + // Always fetch to get latest commits + run_git(&repo_path, &["fetch", "origin"]).await?; + + // Check if we have a valid HEAD, if not checkout FETCH_HEAD + let has_head = run_git(&repo_path, &["rev-parse", "HEAD"]).await.is_ok(); + if !has_head { + // Use reset to set HEAD without needing a branch + run_git(&repo_path, &["reset", "--hard", "FETCH_HEAD"]).await?; + } + + Ok(repo_path) +} + +pub async fn fetch_if_needed(repo_path: &Path, revision: &str) -> Result { + let resolved = run_git( + repo_path, + &["rev-parse", &format!("{}^{{commit}}", revision)], + ) + .await; + + if let Ok(sha) = resolved { + return Ok(sha); + } + + if run_git(repo_path, &["fetch", "--depth", "1", "origin", revision]) + .await + .is_err() + { + run_git(repo_path, &["fetch", "origin"]).await?; + } + + run_git(repo_path, &["rev-parse", "FETCH_HEAD"]).await +} diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index 70daf00b79486fd917556cffaa26b1fd01ed4d28..3a501c0bb24a6cf170b92118a6e82e1af1ed4f15 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -1,29 +1,19 @@ use crate::{ example::{Example, ExampleBuffer, ExampleState}, + git, headless::EpAppState, - paths::{REPOS_DIR, WORKTREES_DIR}, + paths::WORKTREES_DIR, progress::{InfoStyle, Progress, Step, StepProgress}, }; use anyhow::{Context as _, Result}; -use collections::HashMap; use edit_prediction::EditPredictionStore; use edit_prediction::udiff::OpenedBuffers; -use futures::{ - AsyncWriteExt as _, - lock::{Mutex, OwnedMutexGuard}, -}; +use futures::AsyncWriteExt as _; use gpui::{AsyncApp, Entity}; use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint}; +use project::Project; use project::buffer_store::BufferStoreEvent; -use project::{Project, ProjectPath}; -use std::{ - cell::RefCell, - fs, - path::{Path, PathBuf}, - sync::Arc, -}; -use util::{paths::PathStyle, rel_path::RelPath}; -use zeta_prompt::CURSOR_MARKER; +use std::{fs, path::PathBuf, sync::Arc}; pub async fn run_load_project( example: &mut Example, @@ -86,37 +76,22 @@ async fn cursor_position( return Err(error); } - let worktree = project.read_with(cx, |project, cx| { - project - .visible_worktrees(cx) - .next() - .context("No visible worktrees") - })??; - - let cursor_path = RelPath::new(&example.spec.cursor_path, PathStyle::Posix) - .context("Failed to create RelPath")? - .into_arc(); - let cursor_buffer = project - .update(cx, |project, cx| { - project.open_buffer( - ProjectPath { - worktree_id: worktree.read(cx).id(), - path: cursor_path, - }, - cx, - ) + let cursor_path = project + .read_with(cx, |project, cx| { + project.find_project_path(&example.spec.cursor_path, cx) })? + .with_context(|| { + format!( + "failed to find cursor path {}", + example.spec.cursor_path.display() + ) + })?; + let cursor_buffer = project + .update(cx, |project, cx| project.open_buffer(cursor_path, cx))? .await?; - let cursor_offset_within_excerpt = example - .spec - .cursor_position - .find(CURSOR_MARKER) - .context("missing cursor marker")?; - let mut cursor_excerpt = example.spec.cursor_position.clone(); - cursor_excerpt.replace_range( - cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()), - "", - ); + + let (cursor_excerpt, cursor_offset_within_excerpt) = example.spec.cursor_excerpt()?; + let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| { let text = buffer.text(); @@ -212,17 +187,17 @@ async fn setup_project( async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result { let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?; - let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref()); + let repo_dir = git::repo_path_for_url(&example.spec.repository_url)?; let worktree_path = WORKTREES_DIR .join(repo_owner.as_ref()) .join(repo_name.as_ref()); - let repo_lock = lock_repo(&repo_dir).await; + let repo_lock = git::lock_repo(&repo_dir).await; if !repo_dir.is_dir() { step_progress.set_substatus(format!("cloning {}", repo_name)); fs::create_dir_all(&repo_dir)?; - run_git(&repo_dir, &["init"]).await?; - run_git( + git::run_git(&repo_dir, &["init"]).await?; + git::run_git( &repo_dir, &["remote", "add", "origin", &example.spec.repository_url], ) @@ -230,53 +205,26 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu } // Resolve the example to a revision, fetching it if needed. - let revision = run_git( - &repo_dir, - &[ - "rev-parse", - &format!("{}^{{commit}}", example.spec.revision), - ], - ) - .await; - let revision = if let Ok(revision) = revision { - revision - } else { - step_progress.set_substatus("fetching"); - if run_git( - &repo_dir, - &["fetch", "--depth", "1", "origin", &example.spec.revision], - ) - .await - .is_err() - { - run_git(&repo_dir, &["fetch", "origin"]).await?; - } - let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?; - revision - }; + step_progress.set_substatus("fetching"); + let revision = git::fetch_if_needed(&repo_dir, &example.spec.revision).await?; // Create the worktree for this example if needed. step_progress.set_substatus("preparing worktree"); if worktree_path.is_dir() { - run_git(&worktree_path, &["clean", "--force", "-d"]).await?; - run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; - run_git(&worktree_path, &["checkout", revision.as_str()]).await?; + git::run_git(&worktree_path, &["clean", "--force", "-d"]).await?; + git::run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; + git::run_git(&worktree_path, &["checkout", revision.as_str()]).await?; } else { let worktree_path_string = worktree_path.to_string_lossy(); - run_git( + let branch_name = example.spec.filename(); + git::run_git( &repo_dir, - &["branch", "-f", &example.spec.name, revision.as_str()], + &["branch", "-f", &branch_name, revision.as_str()], ) .await?; - run_git( + git::run_git( &repo_dir, - &[ - "worktree", - "add", - "-f", - &worktree_path_string, - &example.spec.name, - ], + &["worktree", "add", "-f", &worktree_path_string, &branch_name], ) .await?; } @@ -319,39 +267,3 @@ async fn apply_edit_history( ) -> Result { edit_prediction::udiff::apply_diff(&example.spec.edit_history, project, cx).await } - -thread_local! { - static REPO_LOCKS: RefCell>>> = RefCell::new(HashMap::default()); -} - -#[must_use] -pub async fn lock_repo(path: impl AsRef) -> OwnedMutexGuard<()> { - REPO_LOCKS - .with(|cell| { - cell.borrow_mut() - .entry(path.as_ref().to_path_buf()) - .or_default() - .clone() - }) - .lock_owned() - .await -} - -async fn run_git(repo_path: &Path, args: &[&str]) -> Result { - let output = smol::process::Command::new("git") - .current_dir(repo_path) - .args(args) - .output() - .await?; - - anyhow::ensure!( - output.status.success(), - "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}", - args.join(" "), - repo_path.display(), - output.status, - String::from_utf8_lossy(&output.stderr), - String::from_utf8_lossy(&output.stdout), - ); - Ok(String::from_utf8(output.stdout)?.trim().to_string()) -} diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index dce0fbbed57dbc4b18faf93787cfb8f2341a126a..6074bed9b625fc7150442f51440bbb415560aa58 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -2,6 +2,7 @@ mod anthropic_client; mod distill; mod example; mod format_prompt; +mod git; mod headless; mod load_project; mod metrics; @@ -10,6 +11,7 @@ mod predict; mod progress; mod retrieve_context; mod score; +mod synthesize; use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; use edit_prediction::EditPredictionStore; @@ -28,6 +30,7 @@ use crate::predict::run_prediction; use crate::progress::Progress; use crate::retrieve_context::run_context_retrieval; use crate::score::run_scoring; +use crate::synthesize::{SynthesizeConfig, run_synthesize}; #[derive(Parser, Debug)] #[command(name = "ep")] @@ -67,6 +70,8 @@ enum Command { Distill, /// Print aggregated scores Eval(PredictArgs), + /// Generate eval examples by analyzing git commits from a repository + Synthesize(SynthesizeArgs), /// Remove git repositories and worktrees Clean, } @@ -118,6 +123,9 @@ impl Display for Command { .unwrap() .get_name() ), + Command::Synthesize(args) => { + write!(f, "synthesize --repo={}", args.repo) + } Command::Clean => write!(f, "clean"), } } @@ -143,7 +151,7 @@ struct PredictArgs { repetitions: usize, } -#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)] enum PredictionProvider { Sweep, Mercury, @@ -153,6 +161,25 @@ enum PredictionProvider { TeacherNonBatching, } +#[derive(Debug, Args)] +struct SynthesizeArgs { + /// Repository URL (git@github.com:owner/repo or https://...) + #[clap(long)] + repo: String, + + /// Number of examples to generate + #[clap(long, default_value_t = 5)] + count: usize, + + /// Maximum commits to scan before giving up + #[clap(long, default_value_t = 100)] + max_commits: usize, + + /// Ignore state file and reprocess all commits + #[clap(long)] + fresh: bool, +} + impl EpArgs { fn output_path(&self) -> Option { if self.in_place { @@ -189,6 +216,25 @@ fn main() { std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap(); return; } + Command::Synthesize(synth_args) => { + let Some(output_dir) = args.output else { + panic!("output dir is required"); + }; + let config = SynthesizeConfig { + repo_url: synth_args.repo.clone(), + count: synth_args.count, + max_commits: synth_args.max_commits, + output_dir, + fresh: synth_args.fresh, + }; + smol::block_on(async { + if let Err(e) = run_synthesize(config).await { + eprintln!("Error: {:?}", e); + std::process::exit(1); + } + }); + return; + } _ => {} } @@ -256,7 +302,7 @@ fn main() { run_scoring(example, &args, app_state.clone(), cx.clone()) .await?; } - Command::Clean => { + Command::Clean | Command::Synthesize(_) => { unreachable!() } } diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index b3e5eb8688724c821953a56c4fe82e67c75e13b6..2dc767d683fe3ebf15abd462b3f6ecc0f986742f 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -1,34 +1,17 @@ -use collections::{HashMap, HashSet}; -use edit_prediction::udiff::DiffLine; -use serde::{Deserialize, Serialize}; +use collections::HashMap; type Counts = HashMap; type CountsDelta = HashMap; -#[derive(Default, Debug, Clone, Serialize, Deserialize)] -pub struct ClassificationMetrics { - pub true_positives: usize, - pub false_positives: usize, - pub false_negatives: usize, +#[derive(Default, Debug, Clone)] +struct ClassificationMetrics { + true_positives: usize, + false_positives: usize, + false_negatives: usize, } impl ClassificationMetrics { - pub fn from_sets( - expected: &HashSet, - actual: &HashSet, - ) -> ClassificationMetrics { - let true_positives = expected.intersection(actual).count(); - let false_positives = actual.difference(expected).count(); - let false_negatives = expected.difference(actual).count(); - - ClassificationMetrics { - true_positives, - false_positives, - false_negatives, - } - } - - pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics { + fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics { let mut true_positives = 0; let mut false_positives = 0; let mut false_negatives = 0; @@ -56,27 +39,7 @@ impl ClassificationMetrics { } } - pub fn aggregate<'a>( - scores: impl Iterator, - ) -> ClassificationMetrics { - let mut true_positives = 0; - let mut false_positives = 0; - let mut false_negatives = 0; - - for score in scores { - true_positives += score.true_positives; - false_positives += score.false_positives; - false_negatives += score.false_negatives; - } - - ClassificationMetrics { - true_positives, - false_positives, - false_negatives, - } - } - - pub fn precision(&self) -> f64 { + fn precision(&self) -> f64 { if self.true_positives + self.false_positives == 0 { 0.0 } else { @@ -84,42 +47,13 @@ impl ClassificationMetrics { } } - pub fn recall(&self) -> f64 { + fn recall(&self) -> f64 { if self.true_positives + self.false_negatives == 0 { 0.0 } else { self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64 } } - - pub fn f1_score(&self) -> f64 { - let recall = self.recall(); - let precision = self.precision(); - if precision + recall == 0.0 { - 0.0 - } else { - 2.0 * precision * recall / (precision + recall) - } - } -} - -pub fn line_match_score( - expected_patch: &[DiffLine], - actual_patch: &[DiffLine], -) -> ClassificationMetrics { - let expected_change_lines = expected_patch - .iter() - .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) - .map(|line| line.to_string()) - .collect(); - - let actual_change_lines = actual_patch - .iter() - .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) - .map(|line| line.to_string()) - .collect(); - - ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines) } enum ChrfWhitespace { @@ -135,55 +69,26 @@ const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore; /// Computes a delta-chrF score that compares two sets of edits. /// /// This metric works by: -/// 1. Reconstructing original, golden (expected result), and actual texts from diffs -/// 2. Computing n-gram count differences (deltas) between original→golden and original→actual -/// 3. Comparing these deltas to measure how well actual edits match expected edits -pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 { - // Reconstruct texts from diffs - let mut original_text = String::new(); // state of the text before any edits - let mut golden_text = String::new(); // text after applying golden edits - let mut actual_text = String::new(); // text after applying actual edits - - for line in expected { - match line { - DiffLine::Context(s) => { - original_text.push_str(s); - golden_text.push_str(s); - } - DiffLine::Deletion(s) => { - original_text.push_str(s); - } - DiffLine::Addition(s) => { - golden_text.push_str(s); - } - _ => {} - } - } - - for line in actual { - match line { - DiffLine::Context(s) | DiffLine::Addition(s) => { - actual_text.push_str(s); - } - _ => {} - } - } - - // Edge case - if original_text == golden_text && golden_text == actual_text { +/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual +/// 2. Comparing these deltas to measure how well actual edits match expected edits +/// +/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match +/// the expected edits. +pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 { + // Edge case: if all texts are identical, the edits match perfectly + if original == expected && expected == actual { return 100.0; } - // Compute the metric - let original_ngrams = chr_f_ngram_counts(&original_text); - let golden_ngrams = chr_f_ngram_counts(&golden_text); - let actual_ngrams = chr_f_ngram_counts(&actual_text); + let original_ngrams = chr_f_ngram_counts(original); + let expected_ngrams = chr_f_ngram_counts(expected); + let actual_ngrams = chr_f_ngram_counts(actual); let mut total_precision = 0.0; let mut total_recall = 0.0; for order in 0..CHR_F_CHAR_ORDER { - let expected_delta = compute_ngram_delta(&golden_ngrams[order], &original_ngrams[order]); + let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]); let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]); if expected_delta.is_empty() && actual_delta.is_empty() { @@ -255,7 +160,7 @@ fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts { for (ngram, &delta) in delta { if delta > 0 { counts.insert(ngram.clone(), delta as usize); - } else { + } else if delta < 0 { counts.insert(format!("¬{ngram}"), delta.unsigned_abs()); } } @@ -278,94 +183,68 @@ fn count_ngrams(text: &str, n: usize) -> Counts { #[cfg(test)] mod test { use super::*; - use edit_prediction::udiff::DiffLine; #[test] fn test_delta_chr_f_perfect_match() { - let diff = vec![ - DiffLine::Context("fn main() {"), - DiffLine::Deletion(" println!(\"Hello\");"), - DiffLine::Addition(" println!(\"Hello, World!\");"), - DiffLine::Context("}"), - ]; - - let score = delta_chr_f(&diff, &diff); + let original = "fn main() { println!(\"Hello\");}"; + let expected = "fn main() { println!(\"Hello, World!\");}"; + + let score = delta_chr_f(original, expected, expected); assert!((score - 100.0).abs() < 1e-2); } #[test] fn test_delta_chr_f_wrong_edit() { // When the edit is wrong - let expected = vec![ - DiffLine::Context("one "), - DiffLine::Deletion("two "), - DiffLine::Context("three"), - ]; - - let actual = vec![ - DiffLine::Context("one "), - DiffLine::Context("two "), - DiffLine::Deletion("three"), - DiffLine::Addition("four"), - ]; + let original = "one two three"; + let expected = "one three"; // deleted "two " + let actual = "one two four"; // deleted "three", added "four" // Then the score should be low - let score = delta_chr_f(&expected, &actual); + let score = delta_chr_f(original, expected, actual); assert!(score > 20.0 && score < 40.0); } #[test] fn test_delta_chr_f_partial_match() { - let expected = vec![ - DiffLine::Deletion("let x = 42;"), - DiffLine::Addition("let x = 100;"), - ]; - - let actual = vec![ - DiffLine::Deletion("let x = 42;"), - DiffLine::Addition("let x = 99;"), - ]; + let original = "let x = 42;"; + let expected = "let x = 100;"; + let actual = "let x = 99;"; // We got the edit location right, but the replacement text is wrong. // Deleted ngrams will match, bringing the score somewhere in the middle. - let score = delta_chr_f(&expected, &actual); + let score = delta_chr_f(original, expected, actual); assert!(score > 40.0 && score < 60.0); } #[test] fn test_delta_chr_f_missed_edit() { // When predictions makes no changes - let expected = vec![ - DiffLine::Context("prefix "), - DiffLine::Deletion("old"), - DiffLine::Addition("new"), - DiffLine::Context(" suffix"), - ]; - - let actual = vec![ - DiffLine::Context("prefix "), - DiffLine::Context("old"), - DiffLine::Context(" suffix"), - ]; + let original = "prefix old suffix"; + let expected = "prefix new suffix"; + let actual = "prefix old suffix"; // no change // Then the score should be low (all expected changes are false negatives) - let score = delta_chr_f(&expected, &actual); + let score = delta_chr_f(original, expected, actual); assert!(score < 20.0); } #[test] fn test_delta_chr_f_extra_edit() { // When adding unexpected content - let expected = vec![DiffLine::Context("hello"), DiffLine::Context("world")]; - - let actual = vec![ - DiffLine::Context("hello"), - DiffLine::Addition("extra"), - DiffLine::Context("world"), - ]; + let original = "helloworld"; + let expected = "helloworld"; // no change expected + let actual = "helloextraworld"; // added "extra" // Then the score should be low (all actual changes are false positives) - let score = delta_chr_f(&expected, &actual); + let score = delta_chr_f(original, expected, actual); assert!(score < 20.0); } + + #[test] + fn test_delta_chr_f_no_changes() { + let text = "unchanged text"; + let score = delta_chr_f(text, text, text); + assert!((score - 100.0).abs() < 1e-2); + } } diff --git a/crates/edit_prediction_cli/src/paths.rs b/crates/edit_prediction_cli/src/paths.rs index e5d420d0e3dbeda9c50b8e5a3683238149dbc604..4ca9d787181febf39f4b2089dc20e307c20fc735 100644 --- a/crates/edit_prediction_cli/src/paths.rs +++ b/crates/edit_prediction_cli/src/paths.rs @@ -17,7 +17,11 @@ pub static RUN_DIR: LazyLock = LazyLock::new(|| { .join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string()) }); pub static LATEST_EXAMPLE_RUN_DIR: LazyLock = LazyLock::new(|| DATA_DIR.join("latest")); +pub static LATEST_FAILED_EXAMPLES_DIR: LazyLock = + LazyLock::new(|| DATA_DIR.join("latest_failed")); pub static LLM_CACHE_DB: LazyLock = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite")); +pub static SYNTHESIZE_STATE_FILE: LazyLock = + LazyLock::new(|| DATA_DIR.join("synthesize_state.json")); pub static FAILED_EXAMPLES_DIR: LazyLock = LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed"))); diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index aa93c5415dea091164a68b76a34242697aac70e3..51f4523605957f803d81dd51c6f0489f449da881 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -28,12 +28,16 @@ pub async fn run_prediction( app_state: Arc, mut cx: AsyncApp, ) -> anyhow::Result<()> { - if !example.predictions.is_empty() { - return Ok(()); - } - let provider = provider.context("provider is required")?; + if let Some(existing_prediction) = example.predictions.first() { + if existing_prediction.provider == provider { + return Ok(()); + } else { + example.predictions.clear(); + } + } + run_context_retrieval(example, app_state.clone(), cx.clone()).await?; if matches!( @@ -184,7 +188,9 @@ pub async fn run_prediction( let actual_patch = prediction .and_then(|prediction| { let prediction = prediction.prediction.ok()?; - prediction.edit_preview.as_unified_diff(&prediction.edits) + prediction + .edit_preview + .as_unified_diff(prediction.snapshot.file(), &prediction.edits) }) .unwrap_or_default(); diff --git a/crates/edit_prediction_cli/src/progress.rs b/crates/edit_prediction_cli/src/progress.rs index ddc710f202cc98e5932c234cb6bebcc93b28171c..c6157b1de9c8f09b1442ca9f3badf02c139b2b01 100644 --- a/crates/edit_prediction_cli/src/progress.rs +++ b/crates/edit_prediction_cli/src/progress.rs @@ -46,6 +46,7 @@ pub enum Step { FormatPrompt, Predict, Score, + Synthesize, } #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -62,6 +63,7 @@ impl Step { Step::FormatPrompt => "Format", Step::Predict => "Predict", Step::Score => "Score", + Step::Synthesize => "Synthesize", } } @@ -72,6 +74,7 @@ impl Step { Step::FormatPrompt => "\x1b[34m", Step::Predict => "\x1b[32m", Step::Score => "\x1b[31m", + Step::Synthesize => "\x1b[36m", } } } diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index 7b507e6d19c943de92eb0b22c7d24d4026789fed..4ea5a5b8792a6454a7dea3eeeb58ae401cec795a 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -2,11 +2,12 @@ use crate::{ PredictArgs, example::{Example, ExampleScore}, headless::EpAppState, - metrics::{self, ClassificationMetrics}, + metrics, predict::run_prediction, progress::{Progress, Step}, }; -use edit_prediction::udiff::DiffLine; +use anyhow::Context as _; +use edit_prediction::udiff::apply_diff_to_string; use gpui::AsyncApp; use std::sync::Arc; @@ -27,18 +28,32 @@ pub async fn run_scoring( let _progress = Progress::global().start(Step::Score, &example.spec.name); - let expected_patch = parse_patch(&example.spec.expected_patch); + let original_text = &example.buffer.as_ref().unwrap().content; + let expected_texts: Vec = example + .spec + .expected_patches + .iter() + .map(|patch| { + apply_diff_to_string(original_text, patch) + .with_context(|| format!("Expected patch did not apply for {}", example.spec.name)) + }) + .collect::, _>>()?; let mut scores = vec![]; - - for pred in &example.predictions { - let actual_patch = parse_patch(&pred.actual_patch); - let line_match = metrics::line_match_score(&expected_patch, &actual_patch); - let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32; - + for prediction in &example.predictions { + let actual_text = match apply_diff_to_string(original_text, &prediction.actual_patch) { + Ok(text) => text, + Err(_) => { + scores.push(ExampleScore { delta_chr_f: 0.0 }); + continue; + } + }; + let best_delta_chr_f = expected_texts + .iter() + .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32) + .fold(0.0, f32::max); scores.push(ExampleScore { - delta_chr_f, - line_match, + delta_chr_f: best_delta_chr_f, }); } @@ -46,42 +61,25 @@ pub async fn run_scoring( Ok(()) } -fn parse_patch(patch: &str) -> Vec> { - patch.lines().map(DiffLine::parse).collect() -} - pub fn print_report(examples: &[Example]) { eprintln!( "──────────────────────────────────────────────────────────────────────────────────────" ); - eprintln!( - "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}", - "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF" - ); + eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF"); eprintln!( "──────────────────────────────────────────────────────────────────────────────────────" ); - let mut all_line_match_scores = Vec::new(); let mut all_delta_chr_f_scores = Vec::new(); for example in examples { for score in example.score.iter() { - let line_match = &score.line_match; - eprintln!( - "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}", - truncate_name(&example.spec.name, 30), - line_match.true_positives, - line_match.false_positives, - line_match.false_negatives, - line_match.precision() * 100.0, - line_match.recall() * 100.0, - line_match.f1_score() * 100.0, + "{:<50} {:>9.2}", + truncate_name(&example.spec.name, 50), score.delta_chr_f ); - all_line_match_scores.push(line_match.clone()); all_delta_chr_f_scores.push(score.delta_chr_f); } } @@ -90,22 +88,11 @@ pub fn print_report(examples: &[Example]) { "──────────────────────────────────────────────────────────────────────────────────────" ); - if !all_line_match_scores.is_empty() { - let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter()); + if !all_delta_chr_f_scores.is_empty() { let avg_delta_chr_f: f32 = all_delta_chr_f_scores.iter().sum::() / all_delta_chr_f_scores.len() as f32; - eprintln!( - "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}", - "TOTAL", - total_line_match.true_positives, - total_line_match.false_positives, - total_line_match.false_negatives, - total_line_match.precision() * 100.0, - total_line_match.recall() * 100.0, - total_line_match.f1_score() * 100.0, - avg_delta_chr_f - ); + eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f); eprintln!( "──────────────────────────────────────────────────────────────────────────────────────" ); diff --git a/crates/edit_prediction_cli/src/synthesize.rs b/crates/edit_prediction_cli/src/synthesize.rs new file mode 100644 index 0000000000000000000000000000000000000000..b79f84b1c712867b01ed3e5a27b96bf0dd1b56e3 --- /dev/null +++ b/crates/edit_prediction_cli/src/synthesize.rs @@ -0,0 +1,892 @@ +use crate::{ + anthropic_client::PlainLlmClient, + git::{ensure_repo_cloned, run_git}, + paths::{FAILED_EXAMPLES_DIR, LATEST_FAILED_EXAMPLES_DIR, SYNTHESIZE_STATE_FILE}, + progress::{InfoStyle, Progress, Step, StepProgress}, +}; +use anthropic::ResponseContent; +use anyhow::{Context as _, Result}; +use chrono::Local; +use collections::{HashMap, HashSet}; +use edit_prediction::{ + example_spec::ExampleSpec, + udiff::{apply_diff_to_string, edits_for_diff}, +}; +use indoc::indoc; +use serde::{Deserialize, Serialize}; +use std::{ + path::{Path, PathBuf}, + sync::Arc, +}; + +#[derive(Debug, Clone)] +pub struct SynthesizeConfig { + pub repo_url: String, + pub count: usize, + pub max_commits: usize, + pub output_dir: PathBuf, + pub fresh: bool, +} + +#[derive(Debug, Default, Serialize, Deserialize)] +struct SynthesizeState { + repositories: HashMap, +} + +#[derive(Debug, Default, Serialize, Deserialize)] +struct RepoState { + processed_commits: HashSet, + examples_generated: usize, +} + +impl SynthesizeState { + fn load() -> Self { + if SYNTHESIZE_STATE_FILE.exists() { + std::fs::read_to_string(&*SYNTHESIZE_STATE_FILE) + .ok() + .and_then(|s| serde_json::from_str(&s).ok()) + .unwrap_or_default() + } else { + Self::default() + } + } + + fn save(&self) -> Result<()> { + let content = serde_json::to_string_pretty(self)?; + std::fs::write(&*SYNTHESIZE_STATE_FILE, content)?; + Ok(()) + } + + fn is_processed(&self, repo_url: &str, commit_sha: &str) -> bool { + self.repositories + .get(repo_url) + .is_some_and(|repo| repo.processed_commits.contains(commit_sha)) + } + + fn mark_processed(&mut self, repo_url: &str, commit_sha: &str, examples_count: usize) { + let repo = self.repositories.entry(repo_url.to_string()).or_default(); + repo.processed_commits.insert(commit_sha.to_string()); + repo.examples_generated += examples_count; + } +} + +#[derive(Debug)] +struct CommitInfo { + sha: String, + parent_sha: String, + message: String, + diff: String, + expanded_diff: String, +} + +/// Claude's response parsed into structured form +#[derive(Debug)] +struct ClaudeResponse { + name: String, + reasoning: String, + edit_history_hunks: Vec, + expected_patch_hunks: Vec, +} + +pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> { + let mut state = if config.fresh { + SynthesizeState::default() + } else { + SynthesizeState::load() + }; + + std::fs::create_dir_all(&config.output_dir)?; + std::fs::create_dir_all(&*FAILED_EXAMPLES_DIR)?; + + // Create "latest_failed" symlink pointing to this run's failed directory + if LATEST_FAILED_EXAMPLES_DIR.is_symlink() { + std::fs::remove_file(&*LATEST_FAILED_EXAMPLES_DIR)?; + } + #[cfg(unix)] + std::os::unix::fs::symlink(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?; + #[cfg(windows)] + std::os::windows::fs::symlink_dir(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?; + + let progress = Progress::global(); + progress.set_total_examples(config.count); + + let clone_progress = progress.start(Step::Synthesize, "clone"); + let repo_path = ensure_repo_cloned(&config.repo_url).await?; + drop(clone_progress); + + let client = PlainLlmClient::new()?; + let mut examples_generated = 0; + let mut commits_skipped = 0; + let batch_size = config.max_commits; + + 'outer: loop { + let list_progress = progress.start(Step::Synthesize, "list-commits"); + let commits = list_commits(&repo_path, batch_size, commits_skipped).await?; + drop(list_progress); + + if commits.is_empty() { + break; + } + + commits_skipped += commits.len(); + + for commit in commits { + if examples_generated >= config.count { + break 'outer; + } + + if !config.fresh && state.is_processed(&config.repo_url, &commit.sha) { + continue; + } + + if should_skip_commit(&commit) { + continue; + } + + let commit_label = format!( + "{} {}", + &commit.sha[..8], + truncate_message(&commit.message, 40) + ); + let step_progress = Arc::new(progress.start(Step::Synthesize, &commit_label)); + + // Single Claude call to identify and copy hunks + step_progress.set_substatus("analyzing..."); + let claude_response = + match analyze_commit(&client, &config, &commit, step_progress.clone()).await { + Ok(Some(response)) => response, + Ok(None) => { + step_progress.set_info("no pattern", InfoStyle::Normal); + state.mark_processed(&config.repo_url, &commit.sha, 0); + state.save()?; + continue; + } + Err(e) => { + step_progress.set_info(format!("error: {:?}", e), InfoStyle::Warning); + state.mark_processed(&config.repo_url, &commit.sha, 0); + state.save()?; + continue; + } + }; + + // Validate and build the example + step_progress.set_substatus("validating..."); + match build_example(&config, &commit, &repo_path, &claude_response).await { + Ok(spec) => { + let timestamp = Local::now().format("%Y-%m-%d--%H-%M-%S"); + let filename = format!("{}.md", timestamp); + let path = config.output_dir.join(&filename); + std::fs::write(&path, spec.to_markdown())?; + examples_generated += 1; + step_progress.set_info(filename, InfoStyle::Normal); + } + Err(rejection_reason) => { + log::debug!("Example rejected: {}", rejection_reason); + let timestamp = Local::now().format("%Y-%m-%d--%H-%M-%S%.3f"); + let filename = format!("{}.md", timestamp); + let path = FAILED_EXAMPLES_DIR.join(&filename); + let content = format_rejected_example(&claude_response, &rejection_reason); + if let Err(e) = std::fs::write(&path, content) { + log::warn!("Failed to write rejected example: {:?}", e); + } + step_progress.set_info(format!("rejected: {}", filename), InfoStyle::Warning); + } + } + + state.mark_processed(&config.repo_url, &commit.sha, 1); + state.save()?; + } + } + + progress.finalize(); + Ok(()) +} + +fn truncate_message(msg: &str, max_len: usize) -> String { + let first_line = msg.lines().next().unwrap_or(""); + if first_line.len() <= max_len { + first_line.to_string() + } else { + format!("{}...", &first_line[..max_len - 3]) + } +} + +fn should_skip_commit(commit: &CommitInfo) -> bool { + let lines_changed = commit + .diff + .lines() + .filter(|l| l.starts_with('+') || l.starts_with('-')) + .count(); + lines_changed < 10 + || lines_changed > 1000 + || is_non_code_commit(commit) + || is_rename_commit(commit) +} + +fn is_non_code_commit(commit: &CommitInfo) -> bool { + let non_code_extensions = [ + ".md", ".txt", ".json", ".yaml", ".yml", ".toml", ".lock", ".svg", ".png", ".jpg", ".gif", + ".ico", ".woff", ".ttf", ".eot", + ]; + + let diff_files: Vec<&str> = commit + .diff + .lines() + .filter(|l| l.starts_with("+++ b/") || l.starts_with("--- a/")) + .filter_map(|l| { + l.strip_prefix("+++ b/") + .or_else(|| l.strip_prefix("--- a/")) + }) + .collect(); + + if diff_files.is_empty() { + return false; + } + + diff_files + .iter() + .all(|f| non_code_extensions.iter().any(|ext| f.ends_with(ext))) +} + +fn is_rename_commit(commit: &CommitInfo) -> bool { + commit.diff.contains("similarity index") + || commit.diff.contains("rename from") + || commit.diff.contains("rename to") +} + +async fn list_commits( + repo_path: &Path, + max_commits: usize, + skip: usize, +) -> Result> { + let output = run_git( + repo_path, + &[ + "log", + "--no-merges", + &format!("--skip={}", skip), + &format!("-{}", max_commits), + "--format=%H|%P|%s", + ], + ) + .await?; + + let mut commits = Vec::new(); + for line in output.lines() { + let parts: Vec<&str> = line.splitn(3, '|').collect(); + if parts.len() < 3 { + continue; + } + let sha = parts[0].to_string(); + let parent_sha = parts[1].split_whitespace().next().unwrap_or("").to_string(); + if parent_sha.is_empty() { + continue; + } + + // Get standard diff (for skip checks) + let diff = run_git(repo_path, &["show", "--format=", &sha]) + .await + .unwrap_or_default(); + + // Get expanded diff with 30 lines of context + let expanded_diff = run_git(repo_path, &["show", "-U30", "--format=", &sha]) + .await + .unwrap_or_default(); + + commits.push(CommitInfo { + sha, + parent_sha, + message: parts[2].to_string(), + diff, + expanded_diff, + }); + } + + Ok(commits) +} + +fn build_prompt(config: &SynthesizeConfig, commit: &CommitInfo) -> String { + format!( + indoc! {r#" + You are analyzing a git commit to construct a realistic edit prediction example. + + Your goal is to tell the story of a programmer's editing session: what sequence of changes did they make, and what change logically comes next? We use these examples to train a model to predict edits, so the quality of the EDIT HISTORY is what matters most. + + An edit prediction example consists of: + 1. **Edit History**: 3-6 hunks showing what the programmer did BEFORE making the expected patch. This is the most important part - it must tell a coherent story of the changes leading up to the prediction. + 2. **Expected Patch**: One small hunk that logically follows from the edit history. + + Both single-file and multi-file patterns are acceptable. + + ## What Makes a Good Example + + The edit history should read like a story: "First the programmer changed X, then Y, then Z, and now they need to change W." + + GOOD examples (rich sequences with 3+ steps): + - Removing a parameter: docstring update → constructor change → field removal → (predict) usage site update + - Adding a feature: type definition → first usage → second usage → (predict) third usage + - Bug fix pattern: fix in file A → fix in file B → fix in file C → (predict) fix in file D + + BAD examples (respond NO_PATTERN): + - Commits where all changes are independent (no narrative thread) + - Simple find-and-replace (renaming, version bumps) + - Documentation-only or config-only changes + - Changes where you can only find 1-2 hunks for the edit history + + ## Commit Information + + Repository: {repo_url} + Commit: {sha} + Message: {message} + + ## Diff (30 lines context) + + ```diff + {expanded_diff} + ``` + + ## Your Task + + First, THINK through whether this commit can support a good example: + + 1. What is the high-level pattern in this commit? + 2. Can you identify at least 4 related hunks (3 for edit history + 1 for expected patch)? + 3. What would be the narrative? (First... then... then... finally predict...) + 4. Which specific hunk should be the expected patch (the "punchline")? + + If you cannot construct a coherent 3+ hunk story, respond with just: + NO_PATTERN: + + If you CAN construct a good example, respond in this format: + + ANALYSIS: + Pattern: + Steps: + 1. - + 2. - + 3. - + 4. [EXPECTED PATCH] - + + NAME: + + EDIT_HISTORY: + + Hunk 1: + ```diff + --- a/src/models/user.py + +++ b/src/models/user.py + @@ -15,7 +15,6 @@ class User: + """A user in the system. + + Attributes: + - email: The user's email address. + name: The user's display name. + """ + ``` + + Hunk 2: + ```diff + --- a/src/models/user.py + +++ b/src/models/user.py + @@ -25,10 +24,9 @@ class User: + def __init__( + self, + name: str, + - email: str, + created_at: datetime, + ): + self.name = name + - self.email = email + self.created_at = created_at + ``` + + Hunk 3: + ```diff + --- a/src/api/handlers.py + +++ b/src/api/handlers.py + @@ -42,7 +42,6 @@ def create_user(request): + data = request.json() + user = User( + name=data["name"], + - email=data["email"], + created_at=datetime.now(), + ) + return user.save() + ``` + + EXPECTED_PATCH: + ```diff + --- a/src/api/handlers.py + +++ b/src/api/handlers.py + @@ -58,7 +57,6 @@ def update_user(request, user_id): + user = User.get(user_id) + user.name = data.get("name", user.name) + - user.email = data.get("email", user.email) + user.save() + return user + ``` + + ## Requirements for the diffs + + Edit history: + - MUST have 3-6 hunks (if you cannot find 3+, respond NO_PATTERN instead) + - Each hunk needs file headers (--- a/path and +++ b/path) + - Hunks must be valid unified diffs that apply to the parent commit + - Order hunks as a programmer would naturally make the changes + + Expected patch: + - Must be a SINGLE hunk from a SINGLE file + - Must be SMALL: 1-15 changed lines (not counting context) + - Must be clearly predictable from the edit history narrative + "#}, + repo_url = config.repo_url, + sha = commit.sha, + message = commit.message, + expanded_diff = commit.expanded_diff, + ) +} + +async fn analyze_commit( + client: &PlainLlmClient, + config: &SynthesizeConfig, + commit: &CommitInfo, + step_progress: Arc, +) -> Result> { + use anthropic::{Message, RequestContent, Role}; + + let prompt = build_prompt(config, commit); + let messages = vec![Message { + role: Role::User, + content: vec![RequestContent::Text { + text: prompt, + cache_control: None, + }], + }]; + + let response = client + .generate_streaming("claude-sonnet-4-5", 8192, messages, |chars, _text| { + step_progress.set_substatus(format!("analyzing: {:.1}K", chars as f64 / 1000.0)); + }) + .await?; + + // Extract text content from response + let response_text: String = response + .content + .iter() + .filter_map(|block| { + if let ResponseContent::Text { text } = block { + Some(text.as_str()) + } else { + None + } + }) + .collect::>() + .join("\n"); + + parse_claude_response(&response_text) +} + +fn parse_claude_response(response: &str) -> Result> { + // Check for NO_PATTERN + if response.contains("NO_PATTERN:") { + return Ok(None); + } + + // Parse NAME + let name = response + .lines() + .find(|l| l.starts_with("NAME:")) + .map(|l| l.strip_prefix("NAME:").unwrap_or("").trim().to_string()) + .unwrap_or_else(|| "unnamed example".to_string()); + + // Parse ANALYSIS section (Claude's planning) - this is the primary reasoning + let reasoning = extract_section( + response, + "ANALYSIS:", + &["NAME:", "REASONING:", "EDIT_HISTORY:", "EXPECTED_PATCH:"], + ) + .unwrap_or_default(); + + // Parse EDIT_HISTORY diff block + let edit_history_hunks = extract_diff_block(response, "EDIT_HISTORY:")?; + + // Parse EXPECTED_PATCH diff block + let expected_patch_hunks = extract_diff_block(response, "EXPECTED_PATCH:")?; + + if edit_history_hunks.is_empty() { + anyhow::bail!("No edit history hunks found in response"); + } + if expected_patch_hunks.is_empty() { + anyhow::bail!("No expected patch hunks found in response"); + } + + Ok(Some(ClaudeResponse { + name, + reasoning, + edit_history_hunks, + expected_patch_hunks, + })) +} + +fn extract_section(text: &str, start_marker: &str, end_markers: &[&str]) -> Option { + let start_idx = text.find(start_marker)?; + let content_start = start_idx + start_marker.len(); + + let end_idx = end_markers + .iter() + .filter_map(|marker| text[content_start..].find(marker)) + .min() + .map(|idx| content_start + idx) + .unwrap_or(text.len()); + + Some(text[content_start..end_idx].trim().to_string()) +} + +fn extract_diff_block(text: &str, section_marker: &str) -> Result> { + let section_start = text + .find(section_marker) + .context(format!("Section {} not found", section_marker))?; + + let after_marker = &text[section_start + section_marker.len()..]; + + // Find where the next major section starts (to bound our search) + let section_end = ["EXPECTED_PATCH:", "## "] + .iter() + .filter(|&&m| m != section_marker) + .filter_map(|marker| after_marker.find(marker)) + .min() + .unwrap_or(after_marker.len()); + + let section_content = &after_marker[..section_end]; + + // Collect all ```diff blocks in this section + let mut hunks = Vec::new(); + let mut search_start = 0; + + while let Some(diff_start) = section_content[search_start..].find("```diff") { + let abs_diff_start = search_start + diff_start; + let block_content_start = section_content[abs_diff_start..] + .find('\n') + .map(|i| abs_diff_start + i + 1) + .unwrap_or(abs_diff_start); + + if let Some(block_end_rel) = section_content[block_content_start..].find("```") { + let block_end = block_content_start + block_end_rel; + let diff_content = section_content[block_content_start..block_end].trim(); + + // Split this block into hunks (in case multiple hunks in one block) + hunks.extend(split_into_hunks(diff_content)); + + search_start = block_end + 3; + } else { + break; + } + } + + if hunks.is_empty() { + anyhow::bail!("No diff blocks found in section {}", section_marker); + } + + Ok(hunks) +} + +/// Split a diff block into individual hunks, preserving file headers +fn split_into_hunks(diff: &str) -> Vec { + let mut hunks = Vec::new(); + let mut current_file_header: Option = None; + let mut current_hunk: Vec = Vec::new(); + let mut in_hunk = false; + + for line in diff.lines() { + if line.starts_with("--- a/") || line.starts_with("--- /") { + // Start of file header - flush previous hunk + if in_hunk && !current_hunk.is_empty() { + let mut hunk_text = String::new(); + if let Some(ref header) = current_file_header { + hunk_text.push_str(header); + hunk_text.push('\n'); + } + hunk_text.push_str(¤t_hunk.join("\n")); + hunks.push(hunk_text); + current_hunk.clear(); + } + current_file_header = Some(line.to_string()); + in_hunk = false; + } else if line.starts_with("+++ b/") || line.starts_with("+++ /") { + if let Some(ref mut header) = current_file_header { + header.push('\n'); + header.push_str(line); + } + } else if line.starts_with("@@ ") { + // New hunk - flush previous + if in_hunk && !current_hunk.is_empty() { + let mut hunk_text = String::new(); + if let Some(ref header) = current_file_header { + hunk_text.push_str(header); + hunk_text.push('\n'); + } + hunk_text.push_str(¤t_hunk.join("\n")); + hunks.push(hunk_text); + current_hunk.clear(); + } + current_hunk.push(line.to_string()); + in_hunk = true; + } else if in_hunk { + current_hunk.push(line.to_string()); + } + } + + // Flush final hunk + if !current_hunk.is_empty() { + let mut hunk_text = String::new(); + if let Some(ref header) = current_file_header { + hunk_text.push_str(header); + hunk_text.push('\n'); + } + hunk_text.push_str(¤t_hunk.join("\n")); + hunks.push(hunk_text); + } + + hunks +} + +/// Validate Claude's output by applying diffs and build the ExampleSpec +async fn build_example( + config: &SynthesizeConfig, + commit: &CommitInfo, + repo_path: &Path, + response: &ClaudeResponse, +) -> Result { + // Validate expected patch hunks + if response.expected_patch_hunks.len() != 1 { + return Err(format!( + "Expected exactly 1 expected patch hunk, got {}", + response.expected_patch_hunks.len() + )); + } + + // Parse the expected patch to determine cursor file + let expected_patch = &response.expected_patch_hunks[0]; + let cursor_file = extract_file_from_hunk(expected_patch) + .ok_or_else(|| "Could not determine file from expected patch".to_string())?; + + // Get the file content before the commit + let before_content = run_git( + repo_path, + &["show", &format!("{}^:{}", commit.sha, cursor_file)], + ) + .await + .map_err(|e| format!("Failed to get file content for {}: {}", cursor_file, e))?; + + // Build edit history diff from Claude's hunks + let edit_history = response.edit_history_hunks.join("\n"); + + // Apply edit history to get intermediate state (validates edit history) + let intermediate_state = + apply_edit_history_to_content(&before_content, &edit_history, &cursor_file)?; + + // Validate expected patch applies to intermediate state + let expected_patch_with_header = ensure_diff_header(expected_patch, &cursor_file); + apply_diff_to_string(&intermediate_state, &expected_patch_with_header) + .map_err(|e| format!("Expected patch failed to apply: {}", e))?; + + // Find where the expected patch edits would apply in the intermediate state + let edits = edits_for_diff(&intermediate_state, &expected_patch_with_header) + .map_err(|e| format!("Failed to parse expected patch: {}", e))?; + if edits.is_empty() { + return Err( + "Could not locate expected patch in file (context not found or ambiguous)".to_string(), + ); + } + + // Use the start of the first edit for cursor positioning + let cursor_byte_offset = edits[0].0.start; + + // Extract excerpt around the edit location + let (excerpt, cursor_offset) = extract_cursor_excerpt(&intermediate_state, cursor_byte_offset)?; + + // Build the ExampleSpec and use set_cursor_excerpt to format with comment marker + let comment_prefix = line_comment_prefix(&cursor_file); + let reasoning_with_source = format!( + "Source commit: {} ({})\n\n{}", + commit.sha, + truncate_message(&commit.message, 60), + response.reasoning + ); + let mut spec = ExampleSpec { + name: response.name.clone(), + repository_url: config.repo_url.clone(), + revision: commit.parent_sha.clone(), + tags: Vec::new(), + reasoning: Some(reasoning_with_source), + uncommitted_diff: String::new(), + cursor_path: Arc::from(Path::new(&cursor_file)), + cursor_position: String::new(), + edit_history, + expected_patches: vec![expected_patch_with_header], + }; + spec.set_cursor_excerpt(&excerpt, cursor_offset, comment_prefix); + + Ok(spec) +} + +/// Extract file path from a hunk (looks for --- a/path or +++ b/path) +fn extract_file_from_hunk(hunk: &str) -> Option { + for line in hunk.lines() { + if let Some(path) = line.strip_prefix("+++ b/") { + return Some(path.to_string()); + } + if let Some(path) = line.strip_prefix("--- a/") { + return Some(path.to_string()); + } + } + None +} + +/// Ensure a hunk has proper file headers +fn ensure_diff_header(hunk: &str, file_path: &str) -> String { + if hunk.contains("--- a/") || hunk.contains("+++ b/") { + return hunk.to_string(); + } + format!("--- a/{}\n+++ b/{}\n{}", file_path, file_path, hunk) +} + +/// Apply edit history to file content, only if hunks affect this file +fn apply_edit_history_to_content( + content: &str, + edit_history: &str, + cursor_file: &str, +) -> Result { + // Extract just the hunks for this file from the edit history + let file_diff = extract_file_diff_from_combined(edit_history, cursor_file); + + if file_diff.is_empty() { + return Ok(content.to_string()); + } + + apply_diff_to_string(content, &file_diff) + .map_err(|e| format!("Failed to apply edit history: {}", e)) +} + +/// Extract hunks for a specific file from a combined diff +fn extract_file_diff_from_combined(combined_diff: &str, target_file: &str) -> String { + let mut result = String::new(); + let mut in_target_file = false; + let mut found_header = false; + + for line in combined_diff.lines() { + if line.starts_with("--- a/") { + let file = line.strip_prefix("--- a/").unwrap_or(""); + in_target_file = file == target_file; + if in_target_file { + result.push_str(line); + result.push('\n'); + found_header = false; + } + } else if line.starts_with("+++ b/") && in_target_file { + result.push_str(line); + result.push('\n'); + found_header = true; + } else if in_target_file && found_header { + if line.starts_with("--- a/") { + break; + } + result.push_str(line); + result.push('\n'); + } + } + + result +} + +/// Extract a cursor position excerpt from content around a byte offset. +/// Returns the excerpt and the cursor offset within the excerpt. +fn extract_cursor_excerpt( + content: &str, + cursor_byte_offset: usize, +) -> Result<(String, usize), String> { + // Find the line containing the cursor + let line_start = content[..cursor_byte_offset] + .rfind('\n') + .map(|pos| pos + 1) + .unwrap_or(0); + let line_end = content[cursor_byte_offset..] + .find('\n') + .map(|pos| cursor_byte_offset + pos) + .unwrap_or(content.len()); + + // Get context lines before + let lines_before: Vec<&str> = content[..line_start].lines().collect(); + let context_before: Vec<&str> = lines_before.iter().rev().take(3).rev().cloned().collect(); + + // Get context lines after + let after_line_end = if line_end < content.len() { + line_end + 1 + } else { + line_end + }; + let context_after: Vec<&str> = content[after_line_end..].lines().take(4).collect(); + + // The line containing the cursor + let cursor_line = &content[line_start..line_end]; + let cursor_column = cursor_byte_offset - line_start; + + // Build the excerpt + let mut excerpt = String::new(); + for line in context_before { + excerpt.push_str(line); + excerpt.push('\n'); + } + // Track where cursor will be in the excerpt + let cursor_offset_in_excerpt = excerpt.len() + cursor_column; + // Line containing cursor + excerpt.push_str(cursor_line); + excerpt.push('\n'); + for line in context_after { + excerpt.push_str(line); + excerpt.push('\n'); + } + + // Trim trailing newline + if excerpt.ends_with('\n') { + excerpt.pop(); + } + + Ok((excerpt, cursor_offset_in_excerpt)) +} + +/// Get the line comment prefix for a file based on its extension +fn line_comment_prefix(file_path: &str) -> &'static str { + let extension = file_path.rsplit('.').next().unwrap_or(""); + match extension { + "rs" | "c" | "cpp" | "cc" | "h" | "hpp" | "js" | "ts" | "tsx" | "jsx" | "go" | "java" + | "swift" | "kt" | "kts" | "scala" | "cs" | "m" | "mm" | "zig" | "v" | "d" => "//", + "py" | "rb" | "sh" | "bash" | "zsh" | "pl" | "pm" | "r" | "jl" | "yaml" | "yml" + | "toml" | "coffee" | "cr" | "ex" | "exs" | "elixir" => "#", + "lua" | "hs" | "sql" => "--", + "lisp" | "clj" | "cljs" | "scm" | "rkt" | "el" => ";", + "erl" | "hrl" => "%", + _ => "//", + } +} + +fn format_rejected_example(response: &ClaudeResponse, rejection_reason: &str) -> String { + let mut content = String::new(); + content.push_str("# Rejected Example\n\n"); + content.push_str(&format!("## Name\n\n{}\n\n", response.name)); + content.push_str(&format!("## Reasoning\n\n{}\n\n", response.reasoning)); + content.push_str("## Edit History Hunks\n\n```diff\n"); + for hunk in &response.edit_history_hunks { + content.push_str(hunk); + content.push_str("\n\n"); + } + content.push_str("```\n\n"); + content.push_str("## Expected Patch Hunks\n\n```diff\n"); + for hunk in &response.expected_patch_hunks { + content.push_str(hunk); + content.push_str("\n\n"); + } + content.push_str("```\n\n"); + content.push_str(&format!("## Rejection Reason\n\n{}\n", rejection_reason)); + content +} diff --git a/crates/edit_prediction_ui/src/edit_prediction_ui.rs b/crates/edit_prediction_ui/src/edit_prediction_ui.rs index 9dc7623f79cab61350a06740440bae5cb4e3f40f..6f29bbaf2c89df295768b3db5dcde94a3a643195 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_ui.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_ui.rs @@ -150,7 +150,7 @@ fn capture_example_as_markdown( .buffer() .read(cx) .text_anchor_for_position(editor.selections.newest_anchor().head(), cx)?; - let example = capture_example(project.clone(), buffer, cursor_anchor, true, cx)?; + let example = capture_example(project.clone(), buffer, cursor_anchor, cx)?; let examples_dir = AllLanguageSettings::get_global(cx) .edit_predictions diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index c919aeca15714864e5a54f8b56d3f8517994cb56..e990a67ccd59983a526b41a38abcf59d1d2e8108 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -13,7 +13,7 @@ use crate::{ }, task_context::RunnableRange, text_diff::text_diff, - unified_diff, + unified_diff_with_offsets, }; pub use crate::{ Grammar, Language, LanguageRegistry, @@ -773,7 +773,11 @@ pub struct EditPreview { } impl EditPreview { - pub fn as_unified_diff(&self, edits: &[(Range, impl AsRef)]) -> Option { + pub fn as_unified_diff( + &self, + file: Option<&Arc>, + edits: &[(Range, impl AsRef)], + ) -> Option { let (first, _) = edits.first()?; let (last, _) = edits.last()?; @@ -788,7 +792,7 @@ impl EditPreview { let old_end = Point::new(old_end.row + 4, 0).min(self.old_snapshot.max_point()); let new_end = Point::new(new_end.row + 4, 0).min(self.applied_edits_snapshot.max_point()); - Some(unified_diff( + let diff_body = unified_diff_with_offsets( &self .old_snapshot .text_for_range(start..old_end) @@ -797,7 +801,17 @@ impl EditPreview { .applied_edits_snapshot .text_for_range(start..new_end) .collect::(), - )) + start.row, + start.row, + ); + + let path = file.map(|f| f.path().as_unix_str()); + let header = match path { + Some(p) => format!("--- a/{}\n+++ b/{}\n", p, p), + None => String::new(), + }; + + Some(format!("{}{}", header, diff_body)) } pub fn highlight_edits( diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 6755d5df7ee5f2f1244ea75ff90e5dbd9bf6543e..380c0689d4d7fca146ff773e371d0bd754a8408c 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -4564,13 +4564,13 @@ impl Project { for worktree in worktree_store.visible_worktrees(cx) { let worktree = worktree.read(cx); - if let Ok(path) = RelPath::new(path, path_style) - && let Some(entry) = worktree.entry_for_path(&path) - { - return Some(ProjectPath { - worktree_id: worktree.id(), - path: entry.path.clone(), - }); + if let Ok(rel_path) = RelPath::new(path, path_style) { + if let Some(entry) = worktree.entry_for_path(&rel_path) { + return Some(ProjectPath { + worktree_id: worktree.id(), + path: entry.path.clone(), + }); + } } } }