Cargo.lock 🔗
@@ -5253,6 +5253,7 @@ dependencies = [
"text",
"thiserror 2.0.17",
"time",
+ "toml 0.8.23",
"ui",
"util",
"uuid",
Max Brunsfeld created
* Fix some bugs in capture of EP examples from running app
* Tweak markdown format for EP examples
* Store repo and revision in TOML front matter
* Represent cursor position using a comment line
* Allow multiple expected patches in evals
* Remove line-based scoring criteria for evals
* Add a `synthesize` subcommand to the EP cli that generates examples
from git commits
Release Notes:
- N/A
Cargo.lock | 1
crates/edit_prediction/Cargo.toml | 1
crates/edit_prediction/src/capture_example.rs | 94
crates/edit_prediction/src/edit_prediction.rs | 6
crates/edit_prediction/src/example_spec.rs | 440 ++++++
crates/edit_prediction/src/udiff.rs | 329 ++++
crates/edit_prediction_cli/src/anthropic_client.rs | 103 +
crates/edit_prediction_cli/src/distill.rs | 17
crates/edit_prediction_cli/src/example.rs | 13
crates/edit_prediction_cli/src/format_prompt.rs | 29
crates/edit_prediction_cli/src/git.rs | 110 +
crates/edit_prediction_cli/src/load_project.rs | 154 --
crates/edit_prediction_cli/src/main.rs | 50
crates/edit_prediction_cli/src/metrics.rs | 219 --
crates/edit_prediction_cli/src/paths.rs | 4
crates/edit_prediction_cli/src/predict.rs | 16
crates/edit_prediction_cli/src/progress.rs | 3
crates/edit_prediction_cli/src/score.rs | 75
crates/edit_prediction_cli/src/synthesize.rs | 892 +++++++++++++++
crates/edit_prediction_ui/src/edit_prediction_ui.rs | 2
crates/language/src/buffer.rs | 22
crates/project/src/project.rs | 14
22 files changed, 2,087 insertions(+), 507 deletions(-)
@@ -5253,6 +5253,7 @@ dependencies = [
"text",
"thiserror 2.0.17",
"time",
+ "toml 0.8.23",
"ui",
"util",
"uuid",
@@ -56,6 +56,7 @@ telemetry_events.workspace = true
text.workspace = true
thiserror.workspace = true
time.workspace = true
+toml.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
@@ -7,15 +7,14 @@ use buffer_diff::BufferDiffSnapshot;
use collections::HashMap;
use gpui::{App, Entity, Task};
use language::{Buffer, ToPoint as _};
-use project::Project;
+use project::{Project, WorktreeId};
use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
-use text::{BufferSnapshot as TextBufferSnapshot, ToOffset as _};
+use text::BufferSnapshot as TextBufferSnapshot;
pub fn capture_example(
project: Entity<Project>,
buffer: Entity<Buffer>,
cursor_anchor: language::Anchor,
- last_event_is_expected_patch: bool,
cx: &mut App,
) -> Option<Task<Result<ExampleSpec>>> {
let ep_store = EditPredictionStore::try_global(cx)?;
@@ -43,8 +42,26 @@ pub fn capture_example(
let git_store = project.read(cx).git_store().clone();
Some(cx.spawn(async move |mut cx| {
- let snapshots_by_path = collect_snapshots(&project, &git_store, &events, &mut cx).await?;
- let cursor_excerpt = cx
+ let snapshots_by_path =
+ collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
+
+ events.retain(|stored_event| {
+ match stored_event.event.as_ref() {
+ zeta_prompt::Event::BufferChange { path, .. } => {
+ if !snapshots_by_path.contains_key(path) {
+ return false;
+ }
+ }
+ }
+ true
+ });
+
+ let line_comment_prefix = snapshot
+ .language()
+ .and_then(|lang| lang.config().line_comments.first())
+ .map(|s| s.to_string())
+ .unwrap_or_default();
+ let (cursor_excerpt, cursor_offset) = cx
.background_executor()
.spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
.await;
@@ -54,13 +71,6 @@ pub fn capture_example(
.await;
let mut edit_history = String::new();
- let mut expected_patch = String::new();
- if last_event_is_expected_patch {
- if let Some(stored_event) = events.pop() {
- zeta_prompt::write_event(&mut expected_patch, &stored_event.event);
- }
- }
-
for stored_event in &events {
zeta_prompt::write_event(&mut edit_history, &stored_event.event);
if !edit_history.ends_with('\n') {
@@ -68,57 +78,62 @@ pub fn capture_example(
}
}
- let name = generate_timestamp_name();
-
- Ok(ExampleSpec {
- name,
+ let mut spec = ExampleSpec {
+ name: generate_timestamp_name(),
repository_url,
revision,
+ tags: Vec::new(),
+ reasoning: None,
uncommitted_diff,
cursor_path: cursor_path.as_std_path().into(),
- cursor_position: cursor_excerpt,
+ cursor_position: String::new(),
edit_history,
- expected_patch,
- })
+ expected_patches: Vec::new(),
+ };
+ spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
+ Ok(spec)
}))
}
fn compute_cursor_excerpt(
snapshot: &language::BufferSnapshot,
cursor_anchor: language::Anchor,
-) -> String {
+) -> (String, usize) {
+ use text::ToOffset as _;
+
let cursor_point = cursor_anchor.to_point(snapshot);
let (_editable_range, context_range) =
editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
-
let context_start_offset = context_range.start.to_offset(snapshot);
let cursor_offset = cursor_anchor.to_offset(snapshot);
let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
- let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
- if cursor_offset_in_excerpt <= excerpt.len() {
- excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
- }
- excerpt
+ let excerpt = snapshot.text_for_range(context_range).collect::<String>();
+ (excerpt, cursor_offset_in_excerpt)
}
async fn collect_snapshots(
project: &Entity<Project>,
git_store: &Entity<project::git_store::GitStore>,
+ worktree_id: WorktreeId,
events: &[StoredEvent],
cx: &mut gpui::AsyncApp,
) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
let mut snapshots_by_path = HashMap::default();
+ let root_name = project.read_with(cx, |project, cx| {
+ project
+ .worktree_for_id(worktree_id, cx)
+ .unwrap()
+ .read(cx)
+ .root_name()
+ .to_owned()
+ })?;
for stored_event in events {
let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
- let project_path = project.find_project_path(path, cx)?;
- let full_path = project
- .worktree_for_id(project_path.worktree_id, cx)?
- .read(cx)
- .root_name()
- .join(&project_path.path)
- .as_std_path()
- .into();
+ let project_path = project
+ .find_project_path(path, cx)
+ .filter(|path| path.worktree_id == worktree_id)?;
+ let full_path = root_name.join(&project_path.path).as_std_path().into();
Some((project_path, full_path))
})? {
if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
@@ -289,9 +304,7 @@ mod tests {
cx.run_until_parked();
let mut example = cx
- .update(|cx| {
- capture_example(project.clone(), buffer.clone(), Anchor::MIN, false, cx).unwrap()
- })
+ .update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap())
.await
.unwrap();
example.name = "test".to_string();
@@ -302,6 +315,8 @@ mod tests {
name: "test".to_string(),
repository_url: "https://github.com/test/repo.git".to_string(),
revision: "abc123def456".to_string(),
+ tags: Vec::new(),
+ reasoning: None,
uncommitted_diff: indoc! {"
--- a/project/src/main.rs
+++ b/project/src/main.rs
@@ -322,7 +337,8 @@ mod tests {
.to_string(),
cursor_path: Path::new("project/src/main.rs").into(),
cursor_position: indoc! {"
- <|user_cursor|>fn main() {
+ fn main() {
+ ^[CURSOR_POSITION]
// comment 1
one();
two();
@@ -355,7 +371,7 @@ mod tests {
seven();
"}
.to_string(),
- expected_patch: "".to_string(),
+ expected_patches: Vec::new()
}
);
}
@@ -688,12 +688,14 @@ impl EditPredictionStore {
pub fn clear_history(&mut self) {
for project_state in self.projects.values_mut() {
project_state.events.clear();
+ project_state.last_event.take();
}
}
pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
project_state.events.clear();
+ project_state.last_event.take();
}
}
@@ -2044,7 +2046,9 @@ impl EditPredictionStore {
"Edit Prediction Rated",
rating,
inputs = prediction.inputs,
- output = prediction.edit_preview.as_unified_diff(&prediction.edits),
+ output = prediction
+ .edit_preview
+ .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
feedback
);
self.client.telemetry().flush_events().detach();
@@ -1,5 +1,9 @@
+use anyhow::{Context as _, Result};
use serde::{Deserialize, Serialize};
-use std::{fmt::Write as _, mem, path::Path, sync::Arc};
+use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc};
+
+pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
+pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ExampleSpec {
@@ -7,34 +11,81 @@ pub struct ExampleSpec {
pub name: String,
pub repository_url: String,
pub revision: String,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub tags: Vec<String>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub reasoning: Option<String>,
#[serde(default)]
pub uncommitted_diff: String,
pub cursor_path: Arc<Path>,
pub cursor_position: String,
pub edit_history: String,
- pub expected_patch: String,
+ pub expected_patches: Vec<String>,
}
+const REASONING_HEADING: &str = "Reasoning";
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
const EDIT_HISTORY_HEADING: &str = "Edit History";
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
-const REPOSITORY_URL_FIELD: &str = "repository_url";
-const REVISION_FIELD: &str = "revision";
+
+#[derive(Serialize, Deserialize)]
+struct FrontMatter<'a> {
+ repository_url: Cow<'a, str>,
+ revision: Cow<'a, str>,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ tags: Vec<String>,
+}
impl ExampleSpec {
+ /// Generate a sanitized filename for this example.
+ pub fn filename(&self) -> String {
+ self.name
+ .chars()
+ .map(|c| match c {
+ ' ' | ':' | '~' | '^' | '?' | '*' | '[' | '\\' | '@' | '{' | '/' | '<' | '>'
+ | '|' | '"' => '-',
+ c => c,
+ })
+ .collect()
+ }
+
/// Format this example spec as markdown.
pub fn to_markdown(&self) -> String {
+ use std::fmt::Write as _;
+
+ let front_matter = FrontMatter {
+ repository_url: Cow::Borrowed(&self.repository_url),
+ revision: Cow::Borrowed(&self.revision),
+ tags: self.tags.clone(),
+ };
+ let front_matter_toml =
+ toml::to_string_pretty(&front_matter).unwrap_or_else(|_| String::new());
+
let mut markdown = String::new();
- _ = writeln!(markdown, "# {}", self.name);
+ _ = writeln!(markdown, "+++");
+ markdown.push_str(&front_matter_toml);
+ if !markdown.ends_with('\n') {
+ markdown.push('\n');
+ }
+ _ = writeln!(markdown, "+++");
markdown.push('\n');
- _ = writeln!(markdown, "repository_url = {}", self.repository_url);
- _ = writeln!(markdown, "revision = {}", self.revision);
+ _ = writeln!(markdown, "# {}", self.name);
markdown.push('\n');
+ if let Some(reasoning) = &self.reasoning {
+ _ = writeln!(markdown, "## {}", REASONING_HEADING);
+ markdown.push('\n');
+ markdown.push_str(reasoning);
+ if !markdown.ends_with('\n') {
+ markdown.push('\n');
+ }
+ markdown.push('\n');
+ }
+
if !self.uncommitted_diff.is_empty() {
_ = writeln!(markdown, "## {}", UNCOMMITTED_DIFF_HEADING);
_ = writeln!(markdown);
@@ -75,34 +126,48 @@ impl ExampleSpec {
_ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING);
markdown.push('\n');
- _ = writeln!(markdown, "```diff");
- markdown.push_str(&self.expected_patch);
- if !markdown.ends_with('\n') {
+ for patch in &self.expected_patches {
+ _ = writeln!(markdown, "```diff");
+ markdown.push_str(patch);
+ if !markdown.ends_with('\n') {
+ markdown.push('\n');
+ }
+ _ = writeln!(markdown, "```");
markdown.push('\n');
}
- _ = writeln!(markdown, "```");
- markdown.push('\n');
markdown
}
/// Parse an example spec from markdown.
- pub fn from_markdown(name: String, input: &str) -> anyhow::Result<Self> {
+ pub fn from_markdown(mut input: &str) -> anyhow::Result<Self> {
use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
- let parser = Parser::new(input);
-
let mut spec = ExampleSpec {
- name,
+ name: String::new(),
repository_url: String::new(),
revision: String::new(),
+ tags: Vec::new(),
+ reasoning: None,
uncommitted_diff: String::new(),
cursor_path: Path::new("").into(),
cursor_position: String::new(),
edit_history: String::new(),
- expected_patch: String::new(),
+ expected_patches: Vec::new(),
};
+ if let Some(rest) = input.strip_prefix("+++\n")
+ && let Some((front_matter, rest)) = rest.split_once("+++\n")
+ {
+ if let Ok(data) = toml::from_str::<FrontMatter<'_>>(front_matter) {
+ spec.repository_url = data.repository_url.into_owned();
+ spec.revision = data.revision.into_owned();
+ spec.tags = data.tags;
+ }
+ input = rest.trim_start();
+ }
+
+ let parser = Parser::new(input);
let mut text = String::new();
let mut block_info: CowStr = "".into();
@@ -123,20 +188,9 @@ impl ExampleSpec {
match event {
Event::Text(line) => {
text.push_str(&line);
-
- if let Section::Start = current_section
- && let Some((field, value)) = line.split_once('=')
- {
- match field.trim() {
- REPOSITORY_URL_FIELD => {
- spec.repository_url = value.trim().to_string();
- }
- REVISION_FIELD => {
- spec.revision = value.trim().to_string();
- }
- _ => {}
- }
- }
+ }
+ Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
+ spec.name = mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
let title = mem::take(&mut text);
@@ -194,7 +248,7 @@ impl ExampleSpec {
mem::take(&mut text);
}
Section::ExpectedPatch => {
- spec.expected_patch = mem::take(&mut text);
+ spec.expected_patches.push(mem::take(&mut text));
}
Section::Start | Section::Other => {}
}
@@ -209,4 +263,326 @@ impl ExampleSpec {
Ok(spec)
}
+
+ /// Returns the excerpt of text around the cursor, and the offset of the cursor within that
+ /// excerpt.
+ ///
+ /// The cursor's position is marked with a special comment that appears
+ /// below the cursor line, which contains the string `[CURSOR_POSITION]`,
+ /// preceded by an arrow marking the cursor's column. The arrow can be
+ /// either:
+ /// - `^` - The cursor column is at the position of the `^` character (pointing up to the cursor)
+ /// - `<` - The cursor column is at the first non-whitespace character on that line.
+ pub fn cursor_excerpt(&self) -> Result<(String, usize)> {
+ let input = &self.cursor_position;
+
+ // Check for inline cursor marker first
+ if let Some(inline_offset) = input.find(INLINE_CURSOR_MARKER) {
+ let excerpt = input[..inline_offset].to_string()
+ + &input[inline_offset + INLINE_CURSOR_MARKER.len()..];
+ return Ok((excerpt, inline_offset));
+ }
+
+ let marker_offset = input
+ .find(CURSOR_POSITION_MARKER)
+ .context("missing [CURSOR_POSITION] marker")?;
+ let marker_line_start = input[..marker_offset]
+ .rfind('\n')
+ .map(|pos| pos + 1)
+ .unwrap_or(0);
+ let marker_line_end = input[marker_line_start..]
+ .find('\n')
+ .map(|pos| marker_line_start + pos + 1)
+ .unwrap_or(input.len());
+ let marker_line = &input[marker_line_start..marker_line_end].trim_end_matches('\n');
+
+ let cursor_column = if let Some(cursor_offset) = marker_line.find('^') {
+ cursor_offset
+ } else if let Some(less_than_pos) = marker_line.find('<') {
+ marker_line
+ .find(|c: char| !c.is_whitespace())
+ .unwrap_or(less_than_pos)
+ } else {
+ anyhow::bail!(
+ "cursor position marker line must contain '^' or '<' before [CURSOR_POSITION]"
+ );
+ };
+
+ let mut excerpt = input[..marker_line_start].to_string() + &input[marker_line_end..];
+ excerpt.truncate(excerpt.trim_end_matches('\n').len());
+
+ // The cursor is on the line above the marker line.
+ let cursor_line_end = marker_line_start.saturating_sub(1);
+ let cursor_line_start = excerpt[..cursor_line_end]
+ .rfind('\n')
+ .map(|pos| pos + 1)
+ .unwrap_or(0);
+ let cursor_offset = cursor_line_start + cursor_column;
+
+ Ok((excerpt, cursor_offset))
+ }
+
+ /// Sets the cursor position excerpt from a plain excerpt and cursor byte offset.
+ ///
+ /// The `line_comment_prefix` is used to format the marker line as a comment.
+ /// If the cursor column is less than the comment prefix length, the `<` format is used.
+ /// Otherwise, the `^` format is used.
+ pub fn set_cursor_excerpt(
+ &mut self,
+ excerpt: &str,
+ cursor_offset: usize,
+ line_comment_prefix: &str,
+ ) {
+ // Find which line the cursor is on and its column
+ let cursor_line_start = excerpt[..cursor_offset]
+ .rfind('\n')
+ .map(|pos| pos + 1)
+ .unwrap_or(0);
+ let cursor_line_end = excerpt[cursor_line_start..]
+ .find('\n')
+ .map(|pos| cursor_line_start + pos + 1)
+ .unwrap_or(excerpt.len());
+ let cursor_line = &excerpt[cursor_line_start..cursor_line_end];
+ let cursor_line_indent = &cursor_line[..cursor_line.len() - cursor_line.trim_start().len()];
+ let cursor_column = cursor_offset - cursor_line_start;
+
+ // Build the marker line
+ let mut marker_line = String::new();
+ if cursor_column < line_comment_prefix.len() {
+ for _ in 0..cursor_column {
+ marker_line.push(' ');
+ }
+ marker_line.push_str(line_comment_prefix);
+ write!(marker_line, " <{}", CURSOR_POSITION_MARKER).unwrap();
+ } else {
+ if cursor_column >= cursor_line_indent.len() + line_comment_prefix.len() {
+ marker_line.push_str(cursor_line_indent);
+ }
+ marker_line.push_str(line_comment_prefix);
+ while marker_line.len() < cursor_column {
+ marker_line.push(' ');
+ }
+ write!(marker_line, "^{}", CURSOR_POSITION_MARKER).unwrap();
+ }
+
+ // Build the final cursor_position string
+ let mut result = String::with_capacity(excerpt.len() + marker_line.len() + 2);
+ result.push_str(&excerpt[..cursor_line_end]);
+ if !result.ends_with('\n') {
+ result.push('\n');
+ }
+ result.push_str(&marker_line);
+ if cursor_line_end < excerpt.len() {
+ result.push('\n');
+ result.push_str(&excerpt[cursor_line_end..]);
+ }
+
+ self.cursor_position = result;
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use indoc::indoc;
+
+ #[test]
+ fn test_cursor_excerpt_with_caret() {
+ let mut spec = ExampleSpec {
+ name: String::new(),
+ repository_url: String::new(),
+ revision: String::new(),
+ tags: Vec::new(),
+ reasoning: None,
+ uncommitted_diff: String::new(),
+ cursor_path: Path::new("test.rs").into(),
+ cursor_position: String::new(),
+ edit_history: String::new(),
+ expected_patches: Vec::new(),
+ };
+
+ // Cursor before `42`
+ let excerpt = indoc! {"
+ fn main() {
+ let x = 42;
+ println!(\"{}\", x);
+ }"
+ };
+ let offset = excerpt.find("42").unwrap();
+ let position_string = indoc! {"
+ fn main() {
+ let x = 42;
+ // ^[CURSOR_POSITION]
+ println!(\"{}\", x);
+ }"
+ }
+ .to_string();
+
+ spec.set_cursor_excerpt(excerpt, offset, "//");
+ assert_eq!(spec.cursor_position, position_string);
+ assert_eq!(
+ spec.cursor_excerpt().unwrap(),
+ (excerpt.to_string(), offset)
+ );
+
+ // Cursor after `l` in `let`
+ let offset = excerpt.find("et x").unwrap();
+ let position_string = indoc! {"
+ fn main() {
+ let x = 42;
+ // ^[CURSOR_POSITION]
+ println!(\"{}\", x);
+ }"
+ }
+ .to_string();
+
+ spec.set_cursor_excerpt(excerpt, offset, "//");
+ assert_eq!(spec.cursor_position, position_string);
+ assert_eq!(
+ spec.cursor_excerpt().unwrap(),
+ (excerpt.to_string(), offset)
+ );
+
+ // Cursor before `let`
+ let offset = excerpt.find("let").unwrap();
+ let position_string = indoc! {"
+ fn main() {
+ let x = 42;
+ // ^[CURSOR_POSITION]
+ println!(\"{}\", x);
+ }"
+ }
+ .to_string();
+
+ spec.set_cursor_excerpt(excerpt, offset, "//");
+ assert_eq!(spec.cursor_position, position_string);
+ assert_eq!(
+ spec.cursor_excerpt().unwrap(),
+ (excerpt.to_string(), offset)
+ );
+
+ // Cursor at beginning of the line with `let`
+ let offset = excerpt.find(" let").unwrap();
+ let position_string = indoc! {"
+ fn main() {
+ let x = 42;
+ // <[CURSOR_POSITION]
+ println!(\"{}\", x);
+ }"
+ }
+ .to_string();
+
+ spec.set_cursor_excerpt(excerpt, offset, "//");
+ assert_eq!(spec.cursor_position, position_string);
+ assert_eq!(
+ spec.cursor_excerpt().unwrap(),
+ (excerpt.to_string(), offset)
+ );
+
+ // Cursor at end of line, after the semicolon
+ let offset = excerpt.find(';').unwrap() + 1;
+ let position_string = indoc! {"
+ fn main() {
+ let x = 42;
+ // ^[CURSOR_POSITION]
+ println!(\"{}\", x);
+ }"
+ }
+ .to_string();
+
+ spec.set_cursor_excerpt(excerpt, offset, "//");
+ assert_eq!(spec.cursor_position, position_string);
+ assert_eq!(
+ spec.cursor_excerpt().unwrap(),
+ (excerpt.to_string(), offset)
+ );
+
+ // Caret at end of file (no trailing newline)
+ let excerpt = indoc! {"
+ fn main() {
+ let x = 42;"
+ };
+ let offset = excerpt.find(';').unwrap() + 1;
+ let position_string = indoc! {"
+ fn main() {
+ let x = 42;
+ // ^[CURSOR_POSITION]"
+ }
+ .to_string();
+
+ spec.set_cursor_excerpt(excerpt, offset, "//");
+ assert_eq!(spec.cursor_position, position_string);
+ assert_eq!(
+ spec.cursor_excerpt().unwrap(),
+ (excerpt.to_string(), offset)
+ );
+ }
+
+ #[test]
+ fn test_cursor_excerpt_with_inline_marker() {
+ let mut spec = ExampleSpec {
+ name: String::new(),
+ repository_url: String::new(),
+ revision: String::new(),
+ tags: Vec::new(),
+ reasoning: None,
+ uncommitted_diff: String::new(),
+ cursor_path: Path::new("test.rs").into(),
+ cursor_position: String::new(),
+ edit_history: String::new(),
+ expected_patches: Vec::new(),
+ };
+
+ // Cursor before `42` using inline marker
+ spec.cursor_position = indoc! {"
+ fn main() {
+ let x = <|user_cursor|>42;
+ println!(\"{}\", x);
+ }"
+ }
+ .to_string();
+
+ let expected_excerpt = indoc! {"
+ fn main() {
+ let x = 42;
+ println!(\"{}\", x);
+ }"
+ };
+ let expected_offset = expected_excerpt.find("42").unwrap();
+
+ assert_eq!(
+ spec.cursor_excerpt().unwrap(),
+ (expected_excerpt.to_string(), expected_offset)
+ );
+
+ // Cursor at beginning of line
+ spec.cursor_position = indoc! {"
+ fn main() {
+ <|user_cursor|> let x = 42;
+ }"
+ }
+ .to_string();
+
+ let expected_excerpt = indoc! {"
+ fn main() {
+ let x = 42;
+ }"
+ };
+ let expected_offset = expected_excerpt.find(" let").unwrap();
+
+ assert_eq!(
+ spec.cursor_excerpt().unwrap(),
+ (expected_excerpt.to_string(), expected_offset)
+ );
+
+ // Cursor at end of file
+ spec.cursor_position = "fn main() {}<|user_cursor|>".to_string();
+ let expected_excerpt = "fn main() {}";
+ let expected_offset = expected_excerpt.len();
+
+ assert_eq!(
+ spec.cursor_excerpt().unwrap(),
+ (expected_excerpt.to_string(), expected_offset)
+ );
+ }
}
@@ -1,23 +1,19 @@
-use std::borrow::Cow;
-use std::fmt::Display;
-use std::sync::Arc;
use std::{
- fmt::{Debug, Write},
+ borrow::Cow,
+ fmt::{Debug, Display, Write},
mem,
ops::Range,
path::Path,
+ sync::Arc,
};
-use anyhow::Context as _;
-use anyhow::Result;
-use anyhow::anyhow;
+use anyhow::{Context as _, Result, anyhow};
use collections::HashMap;
-use gpui::AsyncApp;
-use gpui::Entity;
-use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
-use project::{Project, ProjectPath};
-use util::paths::PathStyle;
-use util::rel_path::RelPath;
+use gpui::{AsyncApp, Entity};
+use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot, text_diff};
+use postage::stream::Stream as _;
+use project::Project;
+use util::{paths::PathStyle, rel_path::RelPath};
#[derive(Clone, Debug)]
pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
@@ -28,56 +24,50 @@ pub async fn apply_diff(
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers> {
- let mut included_files = HashMap::default();
-
- let worktree_id = project.read_with(cx, |project, cx| {
- anyhow::Ok(
- project
- .visible_worktrees(cx)
- .next()
- .context("no worktrees")?
- .read(cx)
- .id(),
- )
- })??;
+ let worktree = project
+ .read_with(cx, |project, cx| project.visible_worktrees(cx).next())?
+ .context("project has no worktree")?;
+ // Ensure the files in the diff are loaded, since worktree scanning is disabled in
+ // edit prediction CLI.
+ let mut paths = Vec::new();
for line in diff_str.lines() {
- let diff_line = DiffLine::parse(line);
-
- if let DiffLine::OldPath { path } = diff_line {
- let buffer = project
- .update(cx, |project, cx| {
- let project_path = ProjectPath {
- worktree_id,
- path: RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?.into_arc(),
- };
- anyhow::Ok(project.open_buffer(project_path, cx))
- })??
- .await?;
-
- included_files.insert(path.to_string(), buffer);
+ if let DiffLine::OldPath { path } = DiffLine::parse(line) {
+ paths.push(RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?.into_arc());
}
}
+ worktree
+ .update(cx, |worktree, _| {
+ worktree
+ .as_local()
+ .unwrap()
+ .refresh_entries_for_paths(paths)
+ })?
+ .recv()
+ .await;
- let ranges = [Anchor::MIN..Anchor::MAX];
+ let mut included_files = HashMap::default();
+ let ranges = [Anchor::MIN..Anchor::MAX];
let mut diff = DiffParser::new(diff_str);
let mut current_file = None;
let mut edits = vec![];
while let Some(event) = diff.next()? {
match event {
- DiffEvent::Hunk {
- path: file_path,
- hunk,
- } => {
- let (buffer, ranges) = match current_file {
+ DiffEvent::Hunk { path, hunk } => {
+ let buffer = match current_file {
None => {
- let buffer = included_files
- .get_mut(file_path.as_ref())
- .expect("Opened all files in diff");
-
- current_file = Some((buffer, ranges.as_slice()));
+ let buffer = project
+ .update(cx, |project, cx| {
+ let project_path = project
+ .find_project_path(path.as_ref(), cx)
+ .context("no such path")?;
+ anyhow::Ok(project.open_buffer(project_path, cx))
+ })??
+ .await?;
+ included_files.insert(path.to_string(), buffer.clone());
+ current_file = Some(buffer);
current_file.as_ref().unwrap()
}
Some(ref current) => current,
@@ -85,14 +75,14 @@ pub async fn apply_diff(
buffer.read_with(cx, |buffer, _| {
edits.extend(
- resolve_hunk_edits_in_buffer(hunk, buffer, ranges)
+ resolve_hunk_edits_in_buffer(hunk, buffer, ranges.as_slice())
.with_context(|| format!("Diff:\n{diff_str}"))?,
);
anyhow::Ok(())
})??;
}
DiffEvent::FileEnd { renamed_to } => {
- let (buffer, _) = current_file
+ let buffer = current_file
.take()
.context("Got a FileEnd event before an Hunk event")?;
@@ -128,6 +118,65 @@ pub async fn apply_diff(
Ok(OpenedBuffers(included_files))
}
+/// Extract the diff for a specific file from a multi-file diff.
+/// Returns an error if the file is not found in the diff.
+pub fn extract_file_diff(full_diff: &str, file_path: &str) -> Result<String> {
+ let mut result = String::new();
+ let mut in_target_file = false;
+ let mut found_file = false;
+
+ for line in full_diff.lines() {
+ if line.starts_with("diff --git") {
+ if in_target_file {
+ break;
+ }
+ in_target_file = line.contains(&format!("a/{}", file_path))
+ || line.contains(&format!("b/{}", file_path));
+ if in_target_file {
+ found_file = true;
+ }
+ }
+
+ if in_target_file {
+ result.push_str(line);
+ result.push('\n');
+ }
+ }
+
+ if !found_file {
+ anyhow::bail!("File '{}' not found in diff", file_path);
+ }
+
+ Ok(result)
+}
+
+/// Strip unnecessary git metadata lines from a diff, keeping only the lines
+/// needed for patch application: path headers (--- and +++), hunk headers (@@),
+/// and content lines (+, -, space).
+pub fn strip_diff_metadata(diff: &str) -> String {
+ let mut result = String::new();
+
+ for line in diff.lines() {
+ let dominated = DiffLine::parse(line);
+ match dominated {
+ // Keep path headers, hunk headers, and content lines
+ DiffLine::OldPath { .. }
+ | DiffLine::NewPath { .. }
+ | DiffLine::HunkHeader(_)
+ | DiffLine::Context(_)
+ | DiffLine::Deletion(_)
+ | DiffLine::Addition(_) => {
+ result.push_str(line);
+ result.push('\n');
+ }
+ // Skip garbage lines (diff --git, index, etc.)
+ DiffLine::Garbage(_) => {}
+ }
+ }
+
+ result
+}
+
pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
let mut diff = DiffParser::new(diff_str);
@@ -151,6 +200,51 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
Ok(text)
}
+/// Returns the individual edits that would be applied by a diff to the given content.
+/// Each edit is a tuple of (byte_range_in_content, replacement_text).
+/// Uses sub-line diffing to find the precise character positions of changes.
+/// Returns an empty vec if the hunk context is not found or is ambiguous.
+pub fn edits_for_diff(content: &str, diff_str: &str) -> Result<Vec<(Range<usize>, String)>> {
+ let mut diff = DiffParser::new(diff_str);
+ let mut result = Vec::new();
+
+ while let Some(event) = diff.next()? {
+ match event {
+ DiffEvent::Hunk { hunk, .. } => {
+ if hunk.context.is_empty() {
+ return Ok(Vec::new());
+ }
+
+ // Find the context in the content
+ let first_match = content.find(&hunk.context);
+ let Some(context_offset) = first_match else {
+ return Ok(Vec::new());
+ };
+
+ // Check for ambiguity - if context appears more than once, reject
+ if content[context_offset + 1..].contains(&hunk.context) {
+ return Ok(Vec::new());
+ }
+
+ // Use sub-line diffing to find precise edit positions
+ for edit in &hunk.edits {
+ let old_text = &content
+ [context_offset + edit.range.start..context_offset + edit.range.end];
+ let edits_within_hunk = text_diff(old_text, &edit.text);
+ for (inner_range, inner_text) in edits_within_hunk {
+ let absolute_start = context_offset + edit.range.start + inner_range.start;
+ let absolute_end = context_offset + edit.range.start + inner_range.end;
+ result.push((absolute_start..absolute_end, inner_text.to_string()));
+ }
+ }
+ }
+ DiffEvent::FileEnd { .. } => {}
+ }
+ }
+
+ Ok(result)
+}
+
struct PatchFile<'a> {
old_path: Cow<'a, str>,
new_path: Cow<'a, str>,
@@ -873,4 +967,135 @@ mod tests {
FakeFs::new(cx.background_executor.clone())
}
+
+ #[test]
+ fn test_extract_file_diff() {
+ let multi_file_diff = indoc! {r#"
+ diff --git a/file1.txt b/file1.txt
+ index 1234567..abcdefg 100644
+ --- a/file1.txt
+ +++ b/file1.txt
+ @@ -1,3 +1,4 @@
+ line1
+ +added line
+ line2
+ line3
+ diff --git a/file2.txt b/file2.txt
+ index 2345678..bcdefgh 100644
+ --- a/file2.txt
+ +++ b/file2.txt
+ @@ -1,2 +1,2 @@
+ -old line
+ +new line
+ unchanged
+ "#};
+
+ let file1_diff = extract_file_diff(multi_file_diff, "file1.txt").unwrap();
+ assert_eq!(
+ file1_diff,
+ indoc! {r#"
+ diff --git a/file1.txt b/file1.txt
+ index 1234567..abcdefg 100644
+ --- a/file1.txt
+ +++ b/file1.txt
+ @@ -1,3 +1,4 @@
+ line1
+ +added line
+ line2
+ line3
+ "#}
+ );
+
+ let file2_diff = extract_file_diff(multi_file_diff, "file2.txt").unwrap();
+ assert_eq!(
+ file2_diff,
+ indoc! {r#"
+ diff --git a/file2.txt b/file2.txt
+ index 2345678..bcdefgh 100644
+ --- a/file2.txt
+ +++ b/file2.txt
+ @@ -1,2 +1,2 @@
+ -old line
+ +new line
+ unchanged
+ "#}
+ );
+
+ let result = extract_file_diff(multi_file_diff, "nonexistent.txt");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_edits_for_diff() {
+ let content = indoc! {"
+ fn main() {
+ let x = 1;
+ let y = 2;
+ println!(\"{} {}\", x, y);
+ }
+ "};
+
+ let diff = indoc! {"
+ --- a/file.rs
+ +++ b/file.rs
+ @@ -1,5 +1,5 @@
+ fn main() {
+ - let x = 1;
+ + let x = 42;
+ let y = 2;
+ println!(\"{} {}\", x, y);
+ }
+ "};
+
+ let edits = edits_for_diff(content, diff).unwrap();
+ assert_eq!(edits.len(), 1);
+
+ let (range, replacement) = &edits[0];
+ // With sub-line diffing, the edit should start at "1" (the actual changed character)
+ let expected_start = content.find("let x = 1;").unwrap() + "let x = ".len();
+ assert_eq!(range.start, expected_start);
+ // The deleted text is just "1"
+ assert_eq!(range.end, expected_start + "1".len());
+ // The replacement text
+ assert_eq!(replacement, "42");
+
+ // Verify the cursor would be positioned at the column of "1"
+ let line_start = content[..range.start]
+ .rfind('\n')
+ .map(|p| p + 1)
+ .unwrap_or(0);
+ let cursor_column = range.start - line_start;
+ // " let x = " is 12 characters, so column 12
+ assert_eq!(cursor_column, " let x = ".len());
+ }
+
+ #[test]
+ fn test_strip_diff_metadata() {
+ let diff_with_metadata = indoc! {r#"
+ diff --git a/file.txt b/file.txt
+ index 1234567..abcdefg 100644
+ --- a/file.txt
+ +++ b/file.txt
+ @@ -1,3 +1,4 @@
+ context line
+ -removed line
+ +added line
+ more context
+ "#};
+
+ let stripped = strip_diff_metadata(diff_with_metadata);
+
+ assert_eq!(
+ stripped,
+ indoc! {r#"
+ --- a/file.txt
+ +++ b/file.txt
+ @@ -1,3 +1,4 @@
+ context line
+ -removed line
+ +added line
+ more context
+ "#}
+ );
+ }
}
@@ -1,8 +1,10 @@
use anthropic::{
- ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent,
- Response as AnthropicResponse, Role, non_streaming_completion,
+ ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
+ Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
+ stream_completion,
};
use anyhow::Result;
+use futures::StreamExt as _;
use http_client::HttpClient;
use indoc::indoc;
use reqwest_client::ReqwestClient;
@@ -15,12 +17,12 @@ use std::path::Path;
use std::sync::Arc;
pub struct PlainLlmClient {
- http_client: Arc<dyn HttpClient>,
- api_key: String,
+ pub http_client: Arc<dyn HttpClient>,
+ pub api_key: String,
}
impl PlainLlmClient {
- fn new() -> Result<Self> {
+ pub fn new() -> Result<Self> {
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
@@ -30,7 +32,7 @@ impl PlainLlmClient {
})
}
- async fn generate(
+ pub async fn generate(
&self,
model: &str,
max_tokens: u64,
@@ -63,6 +65,72 @@ impl PlainLlmClient {
Ok(response)
}
+
+ pub async fn generate_streaming<F>(
+ &self,
+ model: &str,
+ max_tokens: u64,
+ messages: Vec<Message>,
+ mut on_progress: F,
+ ) -> Result<AnthropicResponse>
+ where
+ F: FnMut(usize, &str),
+ {
+ let request = AnthropicRequest {
+ model: model.to_string(),
+ max_tokens,
+ messages,
+ tools: Vec::new(),
+ thinking: None,
+ tool_choice: None,
+ system: None,
+ metadata: None,
+ stop_sequences: Vec::new(),
+ temperature: None,
+ top_k: None,
+ top_p: None,
+ };
+
+ let mut stream = stream_completion(
+ self.http_client.as_ref(),
+ ANTHROPIC_API_URL,
+ &self.api_key,
+ request,
+ None,
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ let mut response: Option<AnthropicResponse> = None;
+ let mut text_content = String::new();
+
+ while let Some(event_result) = stream.next().await {
+ let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ match event {
+ Event::MessageStart { message } => {
+ response = Some(message);
+ }
+ Event::ContentBlockDelta { delta, .. } => {
+ if let anthropic::ContentDelta::TextDelta { text } = delta {
+ text_content.push_str(&text);
+ on_progress(text_content.len(), &text_content);
+ }
+ }
+ _ => {}
+ }
+ }
+
+ let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
+
+ if response.content.is_empty() && !text_content.is_empty() {
+ response
+ .content
+ .push(ResponseContent::Text { text: text_content });
+ }
+
+ Ok(response)
+ }
}
pub struct BatchingLlmClient {
@@ -408,6 +476,29 @@ impl AnthropicClient {
}
}
+ #[allow(dead_code)]
+ pub async fn generate_streaming<F>(
+ &self,
+ model: &str,
+ max_tokens: u64,
+ messages: Vec<Message>,
+ on_progress: F,
+ ) -> Result<Option<AnthropicResponse>>
+ where
+ F: FnMut(usize, &str),
+ {
+ match self {
+ AnthropicClient::Plain(plain_llm_client) => plain_llm_client
+ .generate_streaming(model, max_tokens, messages, on_progress)
+ .await
+ .map(Some),
+ AnthropicClient::Batch(_) => {
+ anyhow::bail!("Streaming not supported with batching client")
+ }
+ AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+ }
+ }
+
pub async fn sync_batches(&self) -> Result<()> {
match self {
AnthropicClient::Plain(_) => Ok(()),
@@ -1,20 +1,15 @@
-use anyhow::{Result, anyhow};
+use anyhow::Result;
use std::mem;
use crate::example::Example;
pub async fn run_distill(example: &mut Example) -> Result<()> {
- let [prediction]: [_; 1] =
- mem::take(&mut example.predictions)
- .try_into()
- .map_err(|preds: Vec<_>| {
- anyhow!(
- "Example has {} predictions, but it should have exactly one",
- preds.len()
- )
- })?;
+ let predictions = mem::take(&mut example.predictions)
+ .into_iter()
+ .map(|p| p.actual_patch)
+ .collect();
- example.spec.expected_patch = prediction.actual_patch;
+ example.spec.expected_patches = predictions;
example.prompt = None;
example.predictions = Vec::new();
example.score = Vec::new();
@@ -1,4 +1,4 @@
-use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
+use crate::{PredictionProvider, PromptFormat};
use anyhow::{Context as _, Result};
use collections::HashMap;
use edit_prediction::example_spec::ExampleSpec;
@@ -87,7 +87,6 @@ pub struct ExamplePrediction {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleScore {
pub delta_chr_f: f32,
- pub line_match: ClassificationMetrics,
}
impl Example {
@@ -190,7 +189,11 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
.collect::<Vec<Example>>(),
),
"md" => {
- examples.push(parse_markdown_example(filename, &content).unwrap());
+ let mut example = parse_markdown_example(&content).unwrap();
+ if example.spec.name.is_empty() {
+ example.spec.name = filename;
+ }
+ examples.push(example);
}
ext => {
panic!("{} has invalid example extension `{ext}`", path.display())
@@ -236,8 +239,8 @@ pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>
examples_by_repo.into_values().collect()
}
-fn parse_markdown_example(name: String, input: &str) -> Result<Example> {
- let spec = ExampleSpec::from_markdown(name, input)?;
+fn parse_markdown_example(input: &str) -> Result<Example> {
+ let spec = ExampleSpec::from_markdown(input)?;
Ok(Example {
spec,
buffer: None,
@@ -30,7 +30,12 @@ pub async fn run_format_prompt(
let prompt = TeacherPrompt::format_prompt(example);
example.prompt = Some(ExamplePrompt {
input: prompt,
- expected_output: example.spec.expected_patch.clone(), // TODO
+ expected_output: example
+ .spec
+ .expected_patches
+ .first()
+ .cloned()
+ .unwrap_or_default(),
format: prompt_format,
});
}
@@ -68,8 +73,15 @@ pub async fn run_format_prompt(
))
})??;
let prompt = format_zeta_prompt(&input);
- let expected_output =
- zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?;
+ let expected_output = zeta2_output_for_patch(
+ &input,
+ &example
+ .spec
+ .expected_patches
+ .first()
+ .context("expected patches is empty")?
+ .clone(),
+ )?;
example.prompt = Some(ExamplePrompt {
input: prompt,
expected_output,
@@ -86,6 +98,7 @@ impl TeacherPrompt {
const PROMPT: &str = include_str!("teacher.prompt.md");
pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
+ pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
/// Truncate edit history to this number of last lines
const MAX_HISTORY_LINES: usize = 128;
@@ -181,13 +194,15 @@ impl TeacherPrompt {
result.push_str(Self::EDITABLE_REGION_START);
// TODO: control number of lines around cursor
- result.push_str(&example.spec.cursor_position);
- if !example.spec.cursor_position.ends_with('\n') {
+ let (mut excerpt, offset) = example.spec.cursor_excerpt().unwrap();
+ excerpt.insert_str(offset, Self::USER_CURSOR_MARKER);
+ result.push_str(&excerpt);
+ if !result.ends_with('\n') {
result.push('\n');
}
- result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
- result.push_str("`````");
+ result.push_str(Self::EDITABLE_REGION_END);
+ result.push_str("\n`````");
result
}
@@ -0,0 +1,110 @@
+use anyhow::{Context as _, Result};
+use collections::HashMap;
+use futures::lock::{Mutex, OwnedMutexGuard};
+use std::{
+ cell::RefCell,
+ path::{Path, PathBuf},
+ sync::Arc,
+};
+
+use crate::paths::REPOS_DIR;
+
+thread_local! {
+ static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
+}
+
+#[must_use]
+pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
+ REPO_LOCKS
+ .with(|cell| {
+ cell.borrow_mut()
+ .entry(path.as_ref().to_path_buf())
+ .or_default()
+ .clone()
+ })
+ .lock_owned()
+ .await
+}
+
+pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
+ let output = smol::process::Command::new("git")
+ .current_dir(repo_path)
+ .args(args)
+ .output()
+ .await?;
+
+ anyhow::ensure!(
+ output.status.success(),
+ "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
+ args.join(" "),
+ repo_path.display(),
+ output.status,
+ String::from_utf8_lossy(&output.stderr),
+ String::from_utf8_lossy(&output.stdout),
+ );
+ Ok(String::from_utf8(output.stdout)?.trim().to_string())
+}
+
+pub fn parse_repo_url(url: &str) -> Result<(String, String)> {
+ if url.contains('@') {
+ let (_, path) = url.split_once(':').context("expected : in git url")?;
+ let (owner, repo) = path.split_once('/').context("expected / in git url")?;
+ Ok((owner.to_string(), repo.trim_end_matches(".git").to_string()))
+ } else {
+ let parsed = http_client::Url::parse(url)?;
+ let mut segments = parsed.path_segments().context("empty http url")?;
+ let owner = segments.next().context("expected owner")?;
+ let repo = segments.next().context("expected repo")?;
+ Ok((owner.to_string(), repo.trim_end_matches(".git").to_string()))
+ }
+}
+
+pub fn repo_path_for_url(url: &str) -> Result<PathBuf> {
+ let (owner, name) = parse_repo_url(url)?;
+ Ok(REPOS_DIR.join(&owner).join(&name))
+}
+
+pub async fn ensure_repo_cloned(repo_url: &str) -> Result<PathBuf> {
+ let repo_path = repo_path_for_url(repo_url)?;
+ let _lock = lock_repo(&repo_path).await;
+
+ if !repo_path.is_dir() {
+ log::info!("Cloning {} into {:?}", repo_url, repo_path);
+ std::fs::create_dir_all(&repo_path)?;
+ run_git(&repo_path, &["init"]).await?;
+ run_git(&repo_path, &["remote", "add", "origin", repo_url]).await?;
+ }
+
+ // Always fetch to get latest commits
+ run_git(&repo_path, &["fetch", "origin"]).await?;
+
+ // Check if we have a valid HEAD, if not checkout FETCH_HEAD
+ let has_head = run_git(&repo_path, &["rev-parse", "HEAD"]).await.is_ok();
+ if !has_head {
+ // Use reset to set HEAD without needing a branch
+ run_git(&repo_path, &["reset", "--hard", "FETCH_HEAD"]).await?;
+ }
+
+ Ok(repo_path)
+}
+
+pub async fn fetch_if_needed(repo_path: &Path, revision: &str) -> Result<String> {
+ let resolved = run_git(
+ repo_path,
+ &["rev-parse", &format!("{}^{{commit}}", revision)],
+ )
+ .await;
+
+ if let Ok(sha) = resolved {
+ return Ok(sha);
+ }
+
+ if run_git(repo_path, &["fetch", "--depth", "1", "origin", revision])
+ .await
+ .is_err()
+ {
+ run_git(repo_path, &["fetch", "origin"]).await?;
+ }
+
+ run_git(repo_path, &["rev-parse", "FETCH_HEAD"]).await
+}
@@ -1,29 +1,19 @@
use crate::{
example::{Example, ExampleBuffer, ExampleState},
+ git,
headless::EpAppState,
- paths::{REPOS_DIR, WORKTREES_DIR},
+ paths::WORKTREES_DIR,
progress::{InfoStyle, Progress, Step, StepProgress},
};
use anyhow::{Context as _, Result};
-use collections::HashMap;
use edit_prediction::EditPredictionStore;
use edit_prediction::udiff::OpenedBuffers;
-use futures::{
- AsyncWriteExt as _,
- lock::{Mutex, OwnedMutexGuard},
-};
+use futures::AsyncWriteExt as _;
use gpui::{AsyncApp, Entity};
use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint};
+use project::Project;
use project::buffer_store::BufferStoreEvent;
-use project::{Project, ProjectPath};
-use std::{
- cell::RefCell,
- fs,
- path::{Path, PathBuf},
- sync::Arc,
-};
-use util::{paths::PathStyle, rel_path::RelPath};
-use zeta_prompt::CURSOR_MARKER;
+use std::{fs, path::PathBuf, sync::Arc};
pub async fn run_load_project(
example: &mut Example,
@@ -86,37 +76,22 @@ async fn cursor_position(
return Err(error);
}
- let worktree = project.read_with(cx, |project, cx| {
- project
- .visible_worktrees(cx)
- .next()
- .context("No visible worktrees")
- })??;
-
- let cursor_path = RelPath::new(&example.spec.cursor_path, PathStyle::Posix)
- .context("Failed to create RelPath")?
- .into_arc();
- let cursor_buffer = project
- .update(cx, |project, cx| {
- project.open_buffer(
- ProjectPath {
- worktree_id: worktree.read(cx).id(),
- path: cursor_path,
- },
- cx,
- )
+ let cursor_path = project
+ .read_with(cx, |project, cx| {
+ project.find_project_path(&example.spec.cursor_path, cx)
})?
+ .with_context(|| {
+ format!(
+ "failed to find cursor path {}",
+ example.spec.cursor_path.display()
+ )
+ })?;
+ let cursor_buffer = project
+ .update(cx, |project, cx| project.open_buffer(cursor_path, cx))?
.await?;
- let cursor_offset_within_excerpt = example
- .spec
- .cursor_position
- .find(CURSOR_MARKER)
- .context("missing cursor marker")?;
- let mut cursor_excerpt = example.spec.cursor_position.clone();
- cursor_excerpt.replace_range(
- cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
- "",
- );
+
+ let (cursor_excerpt, cursor_offset_within_excerpt) = example.spec.cursor_excerpt()?;
+
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
@@ -212,17 +187,17 @@ async fn setup_project(
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
- let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
+ let repo_dir = git::repo_path_for_url(&example.spec.repository_url)?;
let worktree_path = WORKTREES_DIR
.join(repo_owner.as_ref())
.join(repo_name.as_ref());
- let repo_lock = lock_repo(&repo_dir).await;
+ let repo_lock = git::lock_repo(&repo_dir).await;
if !repo_dir.is_dir() {
step_progress.set_substatus(format!("cloning {}", repo_name));
fs::create_dir_all(&repo_dir)?;
- run_git(&repo_dir, &["init"]).await?;
- run_git(
+ git::run_git(&repo_dir, &["init"]).await?;
+ git::run_git(
&repo_dir,
&["remote", "add", "origin", &example.spec.repository_url],
)
@@ -230,53 +205,26 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
}
// Resolve the example to a revision, fetching it if needed.
- let revision = run_git(
- &repo_dir,
- &[
- "rev-parse",
- &format!("{}^{{commit}}", example.spec.revision),
- ],
- )
- .await;
- let revision = if let Ok(revision) = revision {
- revision
- } else {
- step_progress.set_substatus("fetching");
- if run_git(
- &repo_dir,
- &["fetch", "--depth", "1", "origin", &example.spec.revision],
- )
- .await
- .is_err()
- {
- run_git(&repo_dir, &["fetch", "origin"]).await?;
- }
- let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
- revision
- };
+ step_progress.set_substatus("fetching");
+ let revision = git::fetch_if_needed(&repo_dir, &example.spec.revision).await?;
// Create the worktree for this example if needed.
step_progress.set_substatus("preparing worktree");
if worktree_path.is_dir() {
- run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
- run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
- run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
+ git::run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
+ git::run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
+ git::run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
} else {
let worktree_path_string = worktree_path.to_string_lossy();
- run_git(
+ let branch_name = example.spec.filename();
+ git::run_git(
&repo_dir,
- &["branch", "-f", &example.spec.name, revision.as_str()],
+ &["branch", "-f", &branch_name, revision.as_str()],
)
.await?;
- run_git(
+ git::run_git(
&repo_dir,
- &[
- "worktree",
- "add",
- "-f",
- &worktree_path_string,
- &example.spec.name,
- ],
+ &["worktree", "add", "-f", &worktree_path_string, &branch_name],
)
.await?;
}
@@ -319,39 +267,3 @@ async fn apply_edit_history(
) -> Result<OpenedBuffers> {
edit_prediction::udiff::apply_diff(&example.spec.edit_history, project, cx).await
}
-
-thread_local! {
- static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
-}
-
-#[must_use]
-pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
- REPO_LOCKS
- .with(|cell| {
- cell.borrow_mut()
- .entry(path.as_ref().to_path_buf())
- .or_default()
- .clone()
- })
- .lock_owned()
- .await
-}
-
-async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
- let output = smol::process::Command::new("git")
- .current_dir(repo_path)
- .args(args)
- .output()
- .await?;
-
- anyhow::ensure!(
- output.status.success(),
- "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
- args.join(" "),
- repo_path.display(),
- output.status,
- String::from_utf8_lossy(&output.stderr),
- String::from_utf8_lossy(&output.stdout),
- );
- Ok(String::from_utf8(output.stdout)?.trim().to_string())
-}
@@ -2,6 +2,7 @@ mod anthropic_client;
mod distill;
mod example;
mod format_prompt;
+mod git;
mod headless;
mod load_project;
mod metrics;
@@ -10,6 +11,7 @@ mod predict;
mod progress;
mod retrieve_context;
mod score;
+mod synthesize;
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
use edit_prediction::EditPredictionStore;
@@ -28,6 +30,7 @@ use crate::predict::run_prediction;
use crate::progress::Progress;
use crate::retrieve_context::run_context_retrieval;
use crate::score::run_scoring;
+use crate::synthesize::{SynthesizeConfig, run_synthesize};
#[derive(Parser, Debug)]
#[command(name = "ep")]
@@ -67,6 +70,8 @@ enum Command {
Distill,
/// Print aggregated scores
Eval(PredictArgs),
+ /// Generate eval examples by analyzing git commits from a repository
+ Synthesize(SynthesizeArgs),
/// Remove git repositories and worktrees
Clean,
}
@@ -118,6 +123,9 @@ impl Display for Command {
.unwrap()
.get_name()
),
+ Command::Synthesize(args) => {
+ write!(f, "synthesize --repo={}", args.repo)
+ }
Command::Clean => write!(f, "clean"),
}
}
@@ -143,7 +151,7 @@ struct PredictArgs {
repetitions: usize,
}
-#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
+#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
enum PredictionProvider {
Sweep,
Mercury,
@@ -153,6 +161,25 @@ enum PredictionProvider {
TeacherNonBatching,
}
+#[derive(Debug, Args)]
+struct SynthesizeArgs {
+ /// Repository URL (git@github.com:owner/repo or https://...)
+ #[clap(long)]
+ repo: String,
+
+ /// Number of examples to generate
+ #[clap(long, default_value_t = 5)]
+ count: usize,
+
+ /// Maximum commits to scan before giving up
+ #[clap(long, default_value_t = 100)]
+ max_commits: usize,
+
+ /// Ignore state file and reprocess all commits
+ #[clap(long)]
+ fresh: bool,
+}
+
impl EpArgs {
fn output_path(&self) -> Option<PathBuf> {
if self.in_place {
@@ -189,6 +216,25 @@ fn main() {
std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
return;
}
+ Command::Synthesize(synth_args) => {
+ let Some(output_dir) = args.output else {
+ panic!("output dir is required");
+ };
+ let config = SynthesizeConfig {
+ repo_url: synth_args.repo.clone(),
+ count: synth_args.count,
+ max_commits: synth_args.max_commits,
+ output_dir,
+ fresh: synth_args.fresh,
+ };
+ smol::block_on(async {
+ if let Err(e) = run_synthesize(config).await {
+ eprintln!("Error: {:?}", e);
+ std::process::exit(1);
+ }
+ });
+ return;
+ }
_ => {}
}
@@ -256,7 +302,7 @@ fn main() {
run_scoring(example, &args, app_state.clone(), cx.clone())
.await?;
}
- Command::Clean => {
+ Command::Clean | Command::Synthesize(_) => {
unreachable!()
}
}
@@ -1,34 +1,17 @@
-use collections::{HashMap, HashSet};
-use edit_prediction::udiff::DiffLine;
-use serde::{Deserialize, Serialize};
+use collections::HashMap;
type Counts = HashMap<String, usize>;
type CountsDelta = HashMap<String, isize>;
-#[derive(Default, Debug, Clone, Serialize, Deserialize)]
-pub struct ClassificationMetrics {
- pub true_positives: usize,
- pub false_positives: usize,
- pub false_negatives: usize,
+#[derive(Default, Debug, Clone)]
+struct ClassificationMetrics {
+ true_positives: usize,
+ false_positives: usize,
+ false_negatives: usize,
}
impl ClassificationMetrics {
- pub fn from_sets(
- expected: &HashSet<String>,
- actual: &HashSet<String>,
- ) -> ClassificationMetrics {
- let true_positives = expected.intersection(actual).count();
- let false_positives = actual.difference(expected).count();
- let false_negatives = expected.difference(actual).count();
-
- ClassificationMetrics {
- true_positives,
- false_positives,
- false_negatives,
- }
- }
-
- pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
+ fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -56,27 +39,7 @@ impl ClassificationMetrics {
}
}
- pub fn aggregate<'a>(
- scores: impl Iterator<Item = &'a ClassificationMetrics>,
- ) -> ClassificationMetrics {
- let mut true_positives = 0;
- let mut false_positives = 0;
- let mut false_negatives = 0;
-
- for score in scores {
- true_positives += score.true_positives;
- false_positives += score.false_positives;
- false_negatives += score.false_negatives;
- }
-
- ClassificationMetrics {
- true_positives,
- false_positives,
- false_negatives,
- }
- }
-
- pub fn precision(&self) -> f64 {
+ fn precision(&self) -> f64 {
if self.true_positives + self.false_positives == 0 {
0.0
} else {
@@ -84,42 +47,13 @@ impl ClassificationMetrics {
}
}
- pub fn recall(&self) -> f64 {
+ fn recall(&self) -> f64 {
if self.true_positives + self.false_negatives == 0 {
0.0
} else {
self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
}
}
-
- pub fn f1_score(&self) -> f64 {
- let recall = self.recall();
- let precision = self.precision();
- if precision + recall == 0.0 {
- 0.0
- } else {
- 2.0 * precision * recall / (precision + recall)
- }
- }
-}
-
-pub fn line_match_score(
- expected_patch: &[DiffLine],
- actual_patch: &[DiffLine],
-) -> ClassificationMetrics {
- let expected_change_lines = expected_patch
- .iter()
- .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
- .map(|line| line.to_string())
- .collect();
-
- let actual_change_lines = actual_patch
- .iter()
- .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
- .map(|line| line.to_string())
- .collect();
-
- ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
}
enum ChrfWhitespace {
@@ -135,55 +69,26 @@ const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore;
/// Computes a delta-chrF score that compares two sets of edits.
///
/// This metric works by:
-/// 1. Reconstructing original, golden (expected result), and actual texts from diffs
-/// 2. Computing n-gram count differences (deltas) between original→golden and original→actual
-/// 3. Comparing these deltas to measure how well actual edits match expected edits
-pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
- // Reconstruct texts from diffs
- let mut original_text = String::new(); // state of the text before any edits
- let mut golden_text = String::new(); // text after applying golden edits
- let mut actual_text = String::new(); // text after applying actual edits
-
- for line in expected {
- match line {
- DiffLine::Context(s) => {
- original_text.push_str(s);
- golden_text.push_str(s);
- }
- DiffLine::Deletion(s) => {
- original_text.push_str(s);
- }
- DiffLine::Addition(s) => {
- golden_text.push_str(s);
- }
- _ => {}
- }
- }
-
- for line in actual {
- match line {
- DiffLine::Context(s) | DiffLine::Addition(s) => {
- actual_text.push_str(s);
- }
- _ => {}
- }
- }
-
- // Edge case
- if original_text == golden_text && golden_text == actual_text {
+/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual
+/// 2. Comparing these deltas to measure how well actual edits match expected edits
+///
+/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match
+/// the expected edits.
+pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 {
+ // Edge case: if all texts are identical, the edits match perfectly
+ if original == expected && expected == actual {
return 100.0;
}
- // Compute the metric
- let original_ngrams = chr_f_ngram_counts(&original_text);
- let golden_ngrams = chr_f_ngram_counts(&golden_text);
- let actual_ngrams = chr_f_ngram_counts(&actual_text);
+ let original_ngrams = chr_f_ngram_counts(original);
+ let expected_ngrams = chr_f_ngram_counts(expected);
+ let actual_ngrams = chr_f_ngram_counts(actual);
let mut total_precision = 0.0;
let mut total_recall = 0.0;
for order in 0..CHR_F_CHAR_ORDER {
- let expected_delta = compute_ngram_delta(&golden_ngrams[order], &original_ngrams[order]);
+ let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]);
let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
if expected_delta.is_empty() && actual_delta.is_empty() {
@@ -255,7 +160,7 @@ fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts {
for (ngram, &delta) in delta {
if delta > 0 {
counts.insert(ngram.clone(), delta as usize);
- } else {
+ } else if delta < 0 {
counts.insert(format!("¬{ngram}"), delta.unsigned_abs());
}
}
@@ -278,94 +183,68 @@ fn count_ngrams(text: &str, n: usize) -> Counts {
#[cfg(test)]
mod test {
use super::*;
- use edit_prediction::udiff::DiffLine;
#[test]
fn test_delta_chr_f_perfect_match() {
- let diff = vec![
- DiffLine::Context("fn main() {"),
- DiffLine::Deletion(" println!(\"Hello\");"),
- DiffLine::Addition(" println!(\"Hello, World!\");"),
- DiffLine::Context("}"),
- ];
-
- let score = delta_chr_f(&diff, &diff);
+ let original = "fn main() { println!(\"Hello\");}";
+ let expected = "fn main() { println!(\"Hello, World!\");}";
+
+ let score = delta_chr_f(original, expected, expected);
assert!((score - 100.0).abs() < 1e-2);
}
#[test]
fn test_delta_chr_f_wrong_edit() {
// When the edit is wrong
- let expected = vec![
- DiffLine::Context("one "),
- DiffLine::Deletion("two "),
- DiffLine::Context("three"),
- ];
-
- let actual = vec![
- DiffLine::Context("one "),
- DiffLine::Context("two "),
- DiffLine::Deletion("three"),
- DiffLine::Addition("four"),
- ];
+ let original = "one two three";
+ let expected = "one three"; // deleted "two "
+ let actual = "one two four"; // deleted "three", added "four"
// Then the score should be low
- let score = delta_chr_f(&expected, &actual);
+ let score = delta_chr_f(original, expected, actual);
assert!(score > 20.0 && score < 40.0);
}
#[test]
fn test_delta_chr_f_partial_match() {
- let expected = vec![
- DiffLine::Deletion("let x = 42;"),
- DiffLine::Addition("let x = 100;"),
- ];
-
- let actual = vec![
- DiffLine::Deletion("let x = 42;"),
- DiffLine::Addition("let x = 99;"),
- ];
+ let original = "let x = 42;";
+ let expected = "let x = 100;";
+ let actual = "let x = 99;";
// We got the edit location right, but the replacement text is wrong.
// Deleted ngrams will match, bringing the score somewhere in the middle.
- let score = delta_chr_f(&expected, &actual);
+ let score = delta_chr_f(original, expected, actual);
assert!(score > 40.0 && score < 60.0);
}
#[test]
fn test_delta_chr_f_missed_edit() {
// When predictions makes no changes
- let expected = vec![
- DiffLine::Context("prefix "),
- DiffLine::Deletion("old"),
- DiffLine::Addition("new"),
- DiffLine::Context(" suffix"),
- ];
-
- let actual = vec![
- DiffLine::Context("prefix "),
- DiffLine::Context("old"),
- DiffLine::Context(" suffix"),
- ];
+ let original = "prefix old suffix";
+ let expected = "prefix new suffix";
+ let actual = "prefix old suffix"; // no change
// Then the score should be low (all expected changes are false negatives)
- let score = delta_chr_f(&expected, &actual);
+ let score = delta_chr_f(original, expected, actual);
assert!(score < 20.0);
}
#[test]
fn test_delta_chr_f_extra_edit() {
// When adding unexpected content
- let expected = vec![DiffLine::Context("hello"), DiffLine::Context("world")];
-
- let actual = vec![
- DiffLine::Context("hello"),
- DiffLine::Addition("extra"),
- DiffLine::Context("world"),
- ];
+ let original = "helloworld";
+ let expected = "helloworld"; // no change expected
+ let actual = "helloextraworld"; // added "extra"
// Then the score should be low (all actual changes are false positives)
- let score = delta_chr_f(&expected, &actual);
+ let score = delta_chr_f(original, expected, actual);
assert!(score < 20.0);
}
+
+ #[test]
+ fn test_delta_chr_f_no_changes() {
+ let text = "unchanged text";
+ let score = delta_chr_f(text, text, text);
+ assert!((score - 100.0).abs() < 1e-2);
+ }
}
@@ -17,7 +17,11 @@ pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
.join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
});
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
+pub static LATEST_FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
+ LazyLock::new(|| DATA_DIR.join("latest_failed"));
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
+pub static SYNTHESIZE_STATE_FILE: LazyLock<PathBuf> =
+ LazyLock::new(|| DATA_DIR.join("synthesize_state.json"));
pub static FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed")));
@@ -28,12 +28,16 @@ pub async fn run_prediction(
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) -> anyhow::Result<()> {
- if !example.predictions.is_empty() {
- return Ok(());
- }
-
let provider = provider.context("provider is required")?;
+ if let Some(existing_prediction) = example.predictions.first() {
+ if existing_prediction.provider == provider {
+ return Ok(());
+ } else {
+ example.predictions.clear();
+ }
+ }
+
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
if matches!(
@@ -184,7 +188,9 @@ pub async fn run_prediction(
let actual_patch = prediction
.and_then(|prediction| {
let prediction = prediction.prediction.ok()?;
- prediction.edit_preview.as_unified_diff(&prediction.edits)
+ prediction
+ .edit_preview
+ .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
})
.unwrap_or_default();
@@ -46,6 +46,7 @@ pub enum Step {
FormatPrompt,
Predict,
Score,
+ Synthesize,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -62,6 +63,7 @@ impl Step {
Step::FormatPrompt => "Format",
Step::Predict => "Predict",
Step::Score => "Score",
+ Step::Synthesize => "Synthesize",
}
}
@@ -72,6 +74,7 @@ impl Step {
Step::FormatPrompt => "\x1b[34m",
Step::Predict => "\x1b[32m",
Step::Score => "\x1b[31m",
+ Step::Synthesize => "\x1b[36m",
}
}
}
@@ -2,11 +2,12 @@ use crate::{
PredictArgs,
example::{Example, ExampleScore},
headless::EpAppState,
- metrics::{self, ClassificationMetrics},
+ metrics,
predict::run_prediction,
progress::{Progress, Step},
};
-use edit_prediction::udiff::DiffLine;
+use anyhow::Context as _;
+use edit_prediction::udiff::apply_diff_to_string;
use gpui::AsyncApp;
use std::sync::Arc;
@@ -27,18 +28,32 @@ pub async fn run_scoring(
let _progress = Progress::global().start(Step::Score, &example.spec.name);
- let expected_patch = parse_patch(&example.spec.expected_patch);
+ let original_text = &example.buffer.as_ref().unwrap().content;
+ let expected_texts: Vec<String> = example
+ .spec
+ .expected_patches
+ .iter()
+ .map(|patch| {
+ apply_diff_to_string(original_text, patch)
+ .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
+ })
+ .collect::<Result<Vec<_>, _>>()?;
let mut scores = vec![];
-
- for pred in &example.predictions {
- let actual_patch = parse_patch(&pred.actual_patch);
- let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
- let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
-
+ for prediction in &example.predictions {
+ let actual_text = match apply_diff_to_string(original_text, &prediction.actual_patch) {
+ Ok(text) => text,
+ Err(_) => {
+ scores.push(ExampleScore { delta_chr_f: 0.0 });
+ continue;
+ }
+ };
+ let best_delta_chr_f = expected_texts
+ .iter()
+ .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
+ .fold(0.0, f32::max);
scores.push(ExampleScore {
- delta_chr_f,
- line_match,
+ delta_chr_f: best_delta_chr_f,
});
}
@@ -46,42 +61,25 @@ pub async fn run_scoring(
Ok(())
}
-fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
- patch.lines().map(DiffLine::parse).collect()
-}
-
pub fn print_report(examples: &[Example]) {
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
- eprintln!(
- "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
- "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
- );
+ eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
- let mut all_line_match_scores = Vec::new();
let mut all_delta_chr_f_scores = Vec::new();
for example in examples {
for score in example.score.iter() {
- let line_match = &score.line_match;
-
eprintln!(
- "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
- truncate_name(&example.spec.name, 30),
- line_match.true_positives,
- line_match.false_positives,
- line_match.false_negatives,
- line_match.precision() * 100.0,
- line_match.recall() * 100.0,
- line_match.f1_score() * 100.0,
+ "{:<50} {:>9.2}",
+ truncate_name(&example.spec.name, 50),
score.delta_chr_f
);
- all_line_match_scores.push(line_match.clone());
all_delta_chr_f_scores.push(score.delta_chr_f);
}
}
@@ -90,22 +88,11 @@ pub fn print_report(examples: &[Example]) {
"──────────────────────────────────────────────────────────────────────────────────────"
);
- if !all_line_match_scores.is_empty() {
- let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
+ if !all_delta_chr_f_scores.is_empty() {
let avg_delta_chr_f: f32 =
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
- eprintln!(
- "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
- "TOTAL",
- total_line_match.true_positives,
- total_line_match.false_positives,
- total_line_match.false_negatives,
- total_line_match.precision() * 100.0,
- total_line_match.recall() * 100.0,
- total_line_match.f1_score() * 100.0,
- avg_delta_chr_f
- );
+ eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
@@ -0,0 +1,892 @@
+use crate::{
+ anthropic_client::PlainLlmClient,
+ git::{ensure_repo_cloned, run_git},
+ paths::{FAILED_EXAMPLES_DIR, LATEST_FAILED_EXAMPLES_DIR, SYNTHESIZE_STATE_FILE},
+ progress::{InfoStyle, Progress, Step, StepProgress},
+};
+use anthropic::ResponseContent;
+use anyhow::{Context as _, Result};
+use chrono::Local;
+use collections::{HashMap, HashSet};
+use edit_prediction::{
+ example_spec::ExampleSpec,
+ udiff::{apply_diff_to_string, edits_for_diff},
+};
+use indoc::indoc;
+use serde::{Deserialize, Serialize};
+use std::{
+ path::{Path, PathBuf},
+ sync::Arc,
+};
+
+#[derive(Debug, Clone)]
+pub struct SynthesizeConfig {
+ pub repo_url: String,
+ pub count: usize,
+ pub max_commits: usize,
+ pub output_dir: PathBuf,
+ pub fresh: bool,
+}
+
+#[derive(Debug, Default, Serialize, Deserialize)]
+struct SynthesizeState {
+ repositories: HashMap<String, RepoState>,
+}
+
+#[derive(Debug, Default, Serialize, Deserialize)]
+struct RepoState {
+ processed_commits: HashSet<String>,
+ examples_generated: usize,
+}
+
+impl SynthesizeState {
+ fn load() -> Self {
+ if SYNTHESIZE_STATE_FILE.exists() {
+ std::fs::read_to_string(&*SYNTHESIZE_STATE_FILE)
+ .ok()
+ .and_then(|s| serde_json::from_str(&s).ok())
+ .unwrap_or_default()
+ } else {
+ Self::default()
+ }
+ }
+
+ fn save(&self) -> Result<()> {
+ let content = serde_json::to_string_pretty(self)?;
+ std::fs::write(&*SYNTHESIZE_STATE_FILE, content)?;
+ Ok(())
+ }
+
+ fn is_processed(&self, repo_url: &str, commit_sha: &str) -> bool {
+ self.repositories
+ .get(repo_url)
+ .is_some_and(|repo| repo.processed_commits.contains(commit_sha))
+ }
+
+ fn mark_processed(&mut self, repo_url: &str, commit_sha: &str, examples_count: usize) {
+ let repo = self.repositories.entry(repo_url.to_string()).or_default();
+ repo.processed_commits.insert(commit_sha.to_string());
+ repo.examples_generated += examples_count;
+ }
+}
+
+#[derive(Debug)]
+struct CommitInfo {
+ sha: String,
+ parent_sha: String,
+ message: String,
+ diff: String,
+ expanded_diff: String,
+}
+
+/// Claude's response parsed into structured form
+#[derive(Debug)]
+struct ClaudeResponse {
+ name: String,
+ reasoning: String,
+ edit_history_hunks: Vec<String>,
+ expected_patch_hunks: Vec<String>,
+}
+
+pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
+ let mut state = if config.fresh {
+ SynthesizeState::default()
+ } else {
+ SynthesizeState::load()
+ };
+
+ std::fs::create_dir_all(&config.output_dir)?;
+ std::fs::create_dir_all(&*FAILED_EXAMPLES_DIR)?;
+
+ // Create "latest_failed" symlink pointing to this run's failed directory
+ if LATEST_FAILED_EXAMPLES_DIR.is_symlink() {
+ std::fs::remove_file(&*LATEST_FAILED_EXAMPLES_DIR)?;
+ }
+ #[cfg(unix)]
+ std::os::unix::fs::symlink(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?;
+ #[cfg(windows)]
+ std::os::windows::fs::symlink_dir(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?;
+
+ let progress = Progress::global();
+ progress.set_total_examples(config.count);
+
+ let clone_progress = progress.start(Step::Synthesize, "clone");
+ let repo_path = ensure_repo_cloned(&config.repo_url).await?;
+ drop(clone_progress);
+
+ let client = PlainLlmClient::new()?;
+ let mut examples_generated = 0;
+ let mut commits_skipped = 0;
+ let batch_size = config.max_commits;
+
+ 'outer: loop {
+ let list_progress = progress.start(Step::Synthesize, "list-commits");
+ let commits = list_commits(&repo_path, batch_size, commits_skipped).await?;
+ drop(list_progress);
+
+ if commits.is_empty() {
+ break;
+ }
+
+ commits_skipped += commits.len();
+
+ for commit in commits {
+ if examples_generated >= config.count {
+ break 'outer;
+ }
+
+ if !config.fresh && state.is_processed(&config.repo_url, &commit.sha) {
+ continue;
+ }
+
+ if should_skip_commit(&commit) {
+ continue;
+ }
+
+ let commit_label = format!(
+ "{} {}",
+ &commit.sha[..8],
+ truncate_message(&commit.message, 40)
+ );
+ let step_progress = Arc::new(progress.start(Step::Synthesize, &commit_label));
+
+ // Single Claude call to identify and copy hunks
+ step_progress.set_substatus("analyzing...");
+ let claude_response =
+ match analyze_commit(&client, &config, &commit, step_progress.clone()).await {
+ Ok(Some(response)) => response,
+ Ok(None) => {
+ step_progress.set_info("no pattern", InfoStyle::Normal);
+ state.mark_processed(&config.repo_url, &commit.sha, 0);
+ state.save()?;
+ continue;
+ }
+ Err(e) => {
+ step_progress.set_info(format!("error: {:?}", e), InfoStyle::Warning);
+ state.mark_processed(&config.repo_url, &commit.sha, 0);
+ state.save()?;
+ continue;
+ }
+ };
+
+ // Validate and build the example
+ step_progress.set_substatus("validating...");
+ match build_example(&config, &commit, &repo_path, &claude_response).await {
+ Ok(spec) => {
+ let timestamp = Local::now().format("%Y-%m-%d--%H-%M-%S");
+ let filename = format!("{}.md", timestamp);
+ let path = config.output_dir.join(&filename);
+ std::fs::write(&path, spec.to_markdown())?;
+ examples_generated += 1;
+ step_progress.set_info(filename, InfoStyle::Normal);
+ }
+ Err(rejection_reason) => {
+ log::debug!("Example rejected: {}", rejection_reason);
+ let timestamp = Local::now().format("%Y-%m-%d--%H-%M-%S%.3f");
+ let filename = format!("{}.md", timestamp);
+ let path = FAILED_EXAMPLES_DIR.join(&filename);
+ let content = format_rejected_example(&claude_response, &rejection_reason);
+ if let Err(e) = std::fs::write(&path, content) {
+ log::warn!("Failed to write rejected example: {:?}", e);
+ }
+ step_progress.set_info(format!("rejected: {}", filename), InfoStyle::Warning);
+ }
+ }
+
+ state.mark_processed(&config.repo_url, &commit.sha, 1);
+ state.save()?;
+ }
+ }
+
+ progress.finalize();
+ Ok(())
+}
+
+fn truncate_message(msg: &str, max_len: usize) -> String {
+ let first_line = msg.lines().next().unwrap_or("");
+ if first_line.len() <= max_len {
+ first_line.to_string()
+ } else {
+ format!("{}...", &first_line[..max_len - 3])
+ }
+}
+
+fn should_skip_commit(commit: &CommitInfo) -> bool {
+ let lines_changed = commit
+ .diff
+ .lines()
+ .filter(|l| l.starts_with('+') || l.starts_with('-'))
+ .count();
+ lines_changed < 10
+ || lines_changed > 1000
+ || is_non_code_commit(commit)
+ || is_rename_commit(commit)
+}
+
+fn is_non_code_commit(commit: &CommitInfo) -> bool {
+ let non_code_extensions = [
+ ".md", ".txt", ".json", ".yaml", ".yml", ".toml", ".lock", ".svg", ".png", ".jpg", ".gif",
+ ".ico", ".woff", ".ttf", ".eot",
+ ];
+
+ let diff_files: Vec<&str> = commit
+ .diff
+ .lines()
+ .filter(|l| l.starts_with("+++ b/") || l.starts_with("--- a/"))
+ .filter_map(|l| {
+ l.strip_prefix("+++ b/")
+ .or_else(|| l.strip_prefix("--- a/"))
+ })
+ .collect();
+
+ if diff_files.is_empty() {
+ return false;
+ }
+
+ diff_files
+ .iter()
+ .all(|f| non_code_extensions.iter().any(|ext| f.ends_with(ext)))
+}
+
+fn is_rename_commit(commit: &CommitInfo) -> bool {
+ commit.diff.contains("similarity index")
+ || commit.diff.contains("rename from")
+ || commit.diff.contains("rename to")
+}
+
+async fn list_commits(
+ repo_path: &Path,
+ max_commits: usize,
+ skip: usize,
+) -> Result<Vec<CommitInfo>> {
+ let output = run_git(
+ repo_path,
+ &[
+ "log",
+ "--no-merges",
+ &format!("--skip={}", skip),
+ &format!("-{}", max_commits),
+ "--format=%H|%P|%s",
+ ],
+ )
+ .await?;
+
+ let mut commits = Vec::new();
+ for line in output.lines() {
+ let parts: Vec<&str> = line.splitn(3, '|').collect();
+ if parts.len() < 3 {
+ continue;
+ }
+ let sha = parts[0].to_string();
+ let parent_sha = parts[1].split_whitespace().next().unwrap_or("").to_string();
+ if parent_sha.is_empty() {
+ continue;
+ }
+
+ // Get standard diff (for skip checks)
+ let diff = run_git(repo_path, &["show", "--format=", &sha])
+ .await
+ .unwrap_or_default();
+
+ // Get expanded diff with 30 lines of context
+ let expanded_diff = run_git(repo_path, &["show", "-U30", "--format=", &sha])
+ .await
+ .unwrap_or_default();
+
+ commits.push(CommitInfo {
+ sha,
+ parent_sha,
+ message: parts[2].to_string(),
+ diff,
+ expanded_diff,
+ });
+ }
+
+ Ok(commits)
+}
+
+fn build_prompt(config: &SynthesizeConfig, commit: &CommitInfo) -> String {
+ format!(
+ indoc! {r#"
+ You are analyzing a git commit to construct a realistic edit prediction example.
+
+ Your goal is to tell the story of a programmer's editing session: what sequence of changes did they make, and what change logically comes next? We use these examples to train a model to predict edits, so the quality of the EDIT HISTORY is what matters most.
+
+ An edit prediction example consists of:
+ 1. **Edit History**: 3-6 hunks showing what the programmer did BEFORE making the expected patch. This is the most important part - it must tell a coherent story of the changes leading up to the prediction.
+ 2. **Expected Patch**: One small hunk that logically follows from the edit history.
+
+ Both single-file and multi-file patterns are acceptable.
+
+ ## What Makes a Good Example
+
+ The edit history should read like a story: "First the programmer changed X, then Y, then Z, and now they need to change W."
+
+ GOOD examples (rich sequences with 3+ steps):
+ - Removing a parameter: docstring update → constructor change → field removal → (predict) usage site update
+ - Adding a feature: type definition → first usage → second usage → (predict) third usage
+ - Bug fix pattern: fix in file A → fix in file B → fix in file C → (predict) fix in file D
+
+ BAD examples (respond NO_PATTERN):
+ - Commits where all changes are independent (no narrative thread)
+ - Simple find-and-replace (renaming, version bumps)
+ - Documentation-only or config-only changes
+ - Changes where you can only find 1-2 hunks for the edit history
+
+ ## Commit Information
+
+ Repository: {repo_url}
+ Commit: {sha}
+ Message: {message}
+
+ ## Diff (30 lines context)
+
+ ```diff
+ {expanded_diff}
+ ```
+
+ ## Your Task
+
+ First, THINK through whether this commit can support a good example:
+
+ 1. What is the high-level pattern in this commit?
+ 2. Can you identify at least 4 related hunks (3 for edit history + 1 for expected patch)?
+ 3. What would be the narrative? (First... then... then... finally predict...)
+ 4. Which specific hunk should be the expected patch (the "punchline")?
+
+ If you cannot construct a coherent 3+ hunk story, respond with just:
+ NO_PATTERN: <brief reason>
+
+ If you CAN construct a good example, respond in this format:
+
+ ANALYSIS:
+ Pattern: <one sentence describing the pattern>
+ Steps:
+ 1. <file:line-range> - <what this hunk does>
+ 2. <file:line-range> - <what this hunk does>
+ 3. <file:line-range> - <what this hunk does>
+ 4. [EXPECTED PATCH] <file:line-range> - <what this hunk does>
+
+ NAME: <short description, like a commit message, under 60 chars>
+
+ EDIT_HISTORY:
+
+ Hunk 1:
+ ```diff
+ --- a/src/models/user.py
+ +++ b/src/models/user.py
+ @@ -15,7 +15,6 @@ class User:
+ """A user in the system.
+
+ Attributes:
+ - email: The user's email address.
+ name: The user's display name.
+ """
+ ```
+
+ Hunk 2:
+ ```diff
+ --- a/src/models/user.py
+ +++ b/src/models/user.py
+ @@ -25,10 +24,9 @@ class User:
+ def __init__(
+ self,
+ name: str,
+ - email: str,
+ created_at: datetime,
+ ):
+ self.name = name
+ - self.email = email
+ self.created_at = created_at
+ ```
+
+ Hunk 3:
+ ```diff
+ --- a/src/api/handlers.py
+ +++ b/src/api/handlers.py
+ @@ -42,7 +42,6 @@ def create_user(request):
+ data = request.json()
+ user = User(
+ name=data["name"],
+ - email=data["email"],
+ created_at=datetime.now(),
+ )
+ return user.save()
+ ```
+
+ EXPECTED_PATCH:
+ ```diff
+ --- a/src/api/handlers.py
+ +++ b/src/api/handlers.py
+ @@ -58,7 +57,6 @@ def update_user(request, user_id):
+ user = User.get(user_id)
+ user.name = data.get("name", user.name)
+ - user.email = data.get("email", user.email)
+ user.save()
+ return user
+ ```
+
+ ## Requirements for the diffs
+
+ Edit history:
+ - MUST have 3-6 hunks (if you cannot find 3+, respond NO_PATTERN instead)
+ - Each hunk needs file headers (--- a/path and +++ b/path)
+ - Hunks must be valid unified diffs that apply to the parent commit
+ - Order hunks as a programmer would naturally make the changes
+
+ Expected patch:
+ - Must be a SINGLE hunk from a SINGLE file
+ - Must be SMALL: 1-15 changed lines (not counting context)
+ - Must be clearly predictable from the edit history narrative
+ "#},
+ repo_url = config.repo_url,
+ sha = commit.sha,
+ message = commit.message,
+ expanded_diff = commit.expanded_diff,
+ )
+}
+
+async fn analyze_commit(
+ client: &PlainLlmClient,
+ config: &SynthesizeConfig,
+ commit: &CommitInfo,
+ step_progress: Arc<StepProgress>,
+) -> Result<Option<ClaudeResponse>> {
+ use anthropic::{Message, RequestContent, Role};
+
+ let prompt = build_prompt(config, commit);
+ let messages = vec![Message {
+ role: Role::User,
+ content: vec![RequestContent::Text {
+ text: prompt,
+ cache_control: None,
+ }],
+ }];
+
+ let response = client
+ .generate_streaming("claude-sonnet-4-5", 8192, messages, |chars, _text| {
+ step_progress.set_substatus(format!("analyzing: {:.1}K", chars as f64 / 1000.0));
+ })
+ .await?;
+
+ // Extract text content from response
+ let response_text: String = response
+ .content
+ .iter()
+ .filter_map(|block| {
+ if let ResponseContent::Text { text } = block {
+ Some(text.as_str())
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<_>>()
+ .join("\n");
+
+ parse_claude_response(&response_text)
+}
+
+fn parse_claude_response(response: &str) -> Result<Option<ClaudeResponse>> {
+ // Check for NO_PATTERN
+ if response.contains("NO_PATTERN:") {
+ return Ok(None);
+ }
+
+ // Parse NAME
+ let name = response
+ .lines()
+ .find(|l| l.starts_with("NAME:"))
+ .map(|l| l.strip_prefix("NAME:").unwrap_or("").trim().to_string())
+ .unwrap_or_else(|| "unnamed example".to_string());
+
+ // Parse ANALYSIS section (Claude's planning) - this is the primary reasoning
+ let reasoning = extract_section(
+ response,
+ "ANALYSIS:",
+ &["NAME:", "REASONING:", "EDIT_HISTORY:", "EXPECTED_PATCH:"],
+ )
+ .unwrap_or_default();
+
+ // Parse EDIT_HISTORY diff block
+ let edit_history_hunks = extract_diff_block(response, "EDIT_HISTORY:")?;
+
+ // Parse EXPECTED_PATCH diff block
+ let expected_patch_hunks = extract_diff_block(response, "EXPECTED_PATCH:")?;
+
+ if edit_history_hunks.is_empty() {
+ anyhow::bail!("No edit history hunks found in response");
+ }
+ if expected_patch_hunks.is_empty() {
+ anyhow::bail!("No expected patch hunks found in response");
+ }
+
+ Ok(Some(ClaudeResponse {
+ name,
+ reasoning,
+ edit_history_hunks,
+ expected_patch_hunks,
+ }))
+}
+
+fn extract_section(text: &str, start_marker: &str, end_markers: &[&str]) -> Option<String> {
+ let start_idx = text.find(start_marker)?;
+ let content_start = start_idx + start_marker.len();
+
+ let end_idx = end_markers
+ .iter()
+ .filter_map(|marker| text[content_start..].find(marker))
+ .min()
+ .map(|idx| content_start + idx)
+ .unwrap_or(text.len());
+
+ Some(text[content_start..end_idx].trim().to_string())
+}
+
+fn extract_diff_block(text: &str, section_marker: &str) -> Result<Vec<String>> {
+ let section_start = text
+ .find(section_marker)
+ .context(format!("Section {} not found", section_marker))?;
+
+ let after_marker = &text[section_start + section_marker.len()..];
+
+ // Find where the next major section starts (to bound our search)
+ let section_end = ["EXPECTED_PATCH:", "## "]
+ .iter()
+ .filter(|&&m| m != section_marker)
+ .filter_map(|marker| after_marker.find(marker))
+ .min()
+ .unwrap_or(after_marker.len());
+
+ let section_content = &after_marker[..section_end];
+
+ // Collect all ```diff blocks in this section
+ let mut hunks = Vec::new();
+ let mut search_start = 0;
+
+ while let Some(diff_start) = section_content[search_start..].find("```diff") {
+ let abs_diff_start = search_start + diff_start;
+ let block_content_start = section_content[abs_diff_start..]
+ .find('\n')
+ .map(|i| abs_diff_start + i + 1)
+ .unwrap_or(abs_diff_start);
+
+ if let Some(block_end_rel) = section_content[block_content_start..].find("```") {
+ let block_end = block_content_start + block_end_rel;
+ let diff_content = section_content[block_content_start..block_end].trim();
+
+ // Split this block into hunks (in case multiple hunks in one block)
+ hunks.extend(split_into_hunks(diff_content));
+
+ search_start = block_end + 3;
+ } else {
+ break;
+ }
+ }
+
+ if hunks.is_empty() {
+ anyhow::bail!("No diff blocks found in section {}", section_marker);
+ }
+
+ Ok(hunks)
+}
+
+/// Split a diff block into individual hunks, preserving file headers
+fn split_into_hunks(diff: &str) -> Vec<String> {
+ let mut hunks = Vec::new();
+ let mut current_file_header: Option<String> = None;
+ let mut current_hunk: Vec<String> = Vec::new();
+ let mut in_hunk = false;
+
+ for line in diff.lines() {
+ if line.starts_with("--- a/") || line.starts_with("--- /") {
+ // Start of file header - flush previous hunk
+ if in_hunk && !current_hunk.is_empty() {
+ let mut hunk_text = String::new();
+ if let Some(ref header) = current_file_header {
+ hunk_text.push_str(header);
+ hunk_text.push('\n');
+ }
+ hunk_text.push_str(¤t_hunk.join("\n"));
+ hunks.push(hunk_text);
+ current_hunk.clear();
+ }
+ current_file_header = Some(line.to_string());
+ in_hunk = false;
+ } else if line.starts_with("+++ b/") || line.starts_with("+++ /") {
+ if let Some(ref mut header) = current_file_header {
+ header.push('\n');
+ header.push_str(line);
+ }
+ } else if line.starts_with("@@ ") {
+ // New hunk - flush previous
+ if in_hunk && !current_hunk.is_empty() {
+ let mut hunk_text = String::new();
+ if let Some(ref header) = current_file_header {
+ hunk_text.push_str(header);
+ hunk_text.push('\n');
+ }
+ hunk_text.push_str(¤t_hunk.join("\n"));
+ hunks.push(hunk_text);
+ current_hunk.clear();
+ }
+ current_hunk.push(line.to_string());
+ in_hunk = true;
+ } else if in_hunk {
+ current_hunk.push(line.to_string());
+ }
+ }
+
+ // Flush final hunk
+ if !current_hunk.is_empty() {
+ let mut hunk_text = String::new();
+ if let Some(ref header) = current_file_header {
+ hunk_text.push_str(header);
+ hunk_text.push('\n');
+ }
+ hunk_text.push_str(¤t_hunk.join("\n"));
+ hunks.push(hunk_text);
+ }
+
+ hunks
+}
+
+/// Validate Claude's output by applying diffs and build the ExampleSpec
+async fn build_example(
+ config: &SynthesizeConfig,
+ commit: &CommitInfo,
+ repo_path: &Path,
+ response: &ClaudeResponse,
+) -> Result<ExampleSpec, String> {
+ // Validate expected patch hunks
+ if response.expected_patch_hunks.len() != 1 {
+ return Err(format!(
+ "Expected exactly 1 expected patch hunk, got {}",
+ response.expected_patch_hunks.len()
+ ));
+ }
+
+ // Parse the expected patch to determine cursor file
+ let expected_patch = &response.expected_patch_hunks[0];
+ let cursor_file = extract_file_from_hunk(expected_patch)
+ .ok_or_else(|| "Could not determine file from expected patch".to_string())?;
+
+ // Get the file content before the commit
+ let before_content = run_git(
+ repo_path,
+ &["show", &format!("{}^:{}", commit.sha, cursor_file)],
+ )
+ .await
+ .map_err(|e| format!("Failed to get file content for {}: {}", cursor_file, e))?;
+
+ // Build edit history diff from Claude's hunks
+ let edit_history = response.edit_history_hunks.join("\n");
+
+ // Apply edit history to get intermediate state (validates edit history)
+ let intermediate_state =
+ apply_edit_history_to_content(&before_content, &edit_history, &cursor_file)?;
+
+ // Validate expected patch applies to intermediate state
+ let expected_patch_with_header = ensure_diff_header(expected_patch, &cursor_file);
+ apply_diff_to_string(&intermediate_state, &expected_patch_with_header)
+ .map_err(|e| format!("Expected patch failed to apply: {}", e))?;
+
+ // Find where the expected patch edits would apply in the intermediate state
+ let edits = edits_for_diff(&intermediate_state, &expected_patch_with_header)
+ .map_err(|e| format!("Failed to parse expected patch: {}", e))?;
+ if edits.is_empty() {
+ return Err(
+ "Could not locate expected patch in file (context not found or ambiguous)".to_string(),
+ );
+ }
+
+ // Use the start of the first edit for cursor positioning
+ let cursor_byte_offset = edits[0].0.start;
+
+ // Extract excerpt around the edit location
+ let (excerpt, cursor_offset) = extract_cursor_excerpt(&intermediate_state, cursor_byte_offset)?;
+
+ // Build the ExampleSpec and use set_cursor_excerpt to format with comment marker
+ let comment_prefix = line_comment_prefix(&cursor_file);
+ let reasoning_with_source = format!(
+ "Source commit: {} ({})\n\n{}",
+ commit.sha,
+ truncate_message(&commit.message, 60),
+ response.reasoning
+ );
+ let mut spec = ExampleSpec {
+ name: response.name.clone(),
+ repository_url: config.repo_url.clone(),
+ revision: commit.parent_sha.clone(),
+ tags: Vec::new(),
+ reasoning: Some(reasoning_with_source),
+ uncommitted_diff: String::new(),
+ cursor_path: Arc::from(Path::new(&cursor_file)),
+ cursor_position: String::new(),
+ edit_history,
+ expected_patches: vec![expected_patch_with_header],
+ };
+ spec.set_cursor_excerpt(&excerpt, cursor_offset, comment_prefix);
+
+ Ok(spec)
+}
+
+/// Extract file path from a hunk (looks for --- a/path or +++ b/path)
+fn extract_file_from_hunk(hunk: &str) -> Option<String> {
+ for line in hunk.lines() {
+ if let Some(path) = line.strip_prefix("+++ b/") {
+ return Some(path.to_string());
+ }
+ if let Some(path) = line.strip_prefix("--- a/") {
+ return Some(path.to_string());
+ }
+ }
+ None
+}
+
+/// Ensure a hunk has proper file headers
+fn ensure_diff_header(hunk: &str, file_path: &str) -> String {
+ if hunk.contains("--- a/") || hunk.contains("+++ b/") {
+ return hunk.to_string();
+ }
+ format!("--- a/{}\n+++ b/{}\n{}", file_path, file_path, hunk)
+}
+
+/// Apply edit history to file content, only if hunks affect this file
+fn apply_edit_history_to_content(
+ content: &str,
+ edit_history: &str,
+ cursor_file: &str,
+) -> Result<String, String> {
+ // Extract just the hunks for this file from the edit history
+ let file_diff = extract_file_diff_from_combined(edit_history, cursor_file);
+
+ if file_diff.is_empty() {
+ return Ok(content.to_string());
+ }
+
+ apply_diff_to_string(content, &file_diff)
+ .map_err(|e| format!("Failed to apply edit history: {}", e))
+}
+
+/// Extract hunks for a specific file from a combined diff
+fn extract_file_diff_from_combined(combined_diff: &str, target_file: &str) -> String {
+ let mut result = String::new();
+ let mut in_target_file = false;
+ let mut found_header = false;
+
+ for line in combined_diff.lines() {
+ if line.starts_with("--- a/") {
+ let file = line.strip_prefix("--- a/").unwrap_or("");
+ in_target_file = file == target_file;
+ if in_target_file {
+ result.push_str(line);
+ result.push('\n');
+ found_header = false;
+ }
+ } else if line.starts_with("+++ b/") && in_target_file {
+ result.push_str(line);
+ result.push('\n');
+ found_header = true;
+ } else if in_target_file && found_header {
+ if line.starts_with("--- a/") {
+ break;
+ }
+ result.push_str(line);
+ result.push('\n');
+ }
+ }
+
+ result
+}
+
+/// Extract a cursor position excerpt from content around a byte offset.
+/// Returns the excerpt and the cursor offset within the excerpt.
+fn extract_cursor_excerpt(
+ content: &str,
+ cursor_byte_offset: usize,
+) -> Result<(String, usize), String> {
+ // Find the line containing the cursor
+ let line_start = content[..cursor_byte_offset]
+ .rfind('\n')
+ .map(|pos| pos + 1)
+ .unwrap_or(0);
+ let line_end = content[cursor_byte_offset..]
+ .find('\n')
+ .map(|pos| cursor_byte_offset + pos)
+ .unwrap_or(content.len());
+
+ // Get context lines before
+ let lines_before: Vec<&str> = content[..line_start].lines().collect();
+ let context_before: Vec<&str> = lines_before.iter().rev().take(3).rev().cloned().collect();
+
+ // Get context lines after
+ let after_line_end = if line_end < content.len() {
+ line_end + 1
+ } else {
+ line_end
+ };
+ let context_after: Vec<&str> = content[after_line_end..].lines().take(4).collect();
+
+ // The line containing the cursor
+ let cursor_line = &content[line_start..line_end];
+ let cursor_column = cursor_byte_offset - line_start;
+
+ // Build the excerpt
+ let mut excerpt = String::new();
+ for line in context_before {
+ excerpt.push_str(line);
+ excerpt.push('\n');
+ }
+ // Track where cursor will be in the excerpt
+ let cursor_offset_in_excerpt = excerpt.len() + cursor_column;
+ // Line containing cursor
+ excerpt.push_str(cursor_line);
+ excerpt.push('\n');
+ for line in context_after {
+ excerpt.push_str(line);
+ excerpt.push('\n');
+ }
+
+ // Trim trailing newline
+ if excerpt.ends_with('\n') {
+ excerpt.pop();
+ }
+
+ Ok((excerpt, cursor_offset_in_excerpt))
+}
+
+/// Get the line comment prefix for a file based on its extension
+fn line_comment_prefix(file_path: &str) -> &'static str {
+ let extension = file_path.rsplit('.').next().unwrap_or("");
+ match extension {
+ "rs" | "c" | "cpp" | "cc" | "h" | "hpp" | "js" | "ts" | "tsx" | "jsx" | "go" | "java"
+ | "swift" | "kt" | "kts" | "scala" | "cs" | "m" | "mm" | "zig" | "v" | "d" => "//",
+ "py" | "rb" | "sh" | "bash" | "zsh" | "pl" | "pm" | "r" | "jl" | "yaml" | "yml"
+ | "toml" | "coffee" | "cr" | "ex" | "exs" | "elixir" => "#",
+ "lua" | "hs" | "sql" => "--",
+ "lisp" | "clj" | "cljs" | "scm" | "rkt" | "el" => ";",
+ "erl" | "hrl" => "%",
+ _ => "//",
+ }
+}
+
+fn format_rejected_example(response: &ClaudeResponse, rejection_reason: &str) -> String {
+ let mut content = String::new();
+ content.push_str("# Rejected Example\n\n");
+ content.push_str(&format!("## Name\n\n{}\n\n", response.name));
+ content.push_str(&format!("## Reasoning\n\n{}\n\n", response.reasoning));
+ content.push_str("## Edit History Hunks\n\n```diff\n");
+ for hunk in &response.edit_history_hunks {
+ content.push_str(hunk);
+ content.push_str("\n\n");
+ }
+ content.push_str("```\n\n");
+ content.push_str("## Expected Patch Hunks\n\n```diff\n");
+ for hunk in &response.expected_patch_hunks {
+ content.push_str(hunk);
+ content.push_str("\n\n");
+ }
+ content.push_str("```\n\n");
+ content.push_str(&format!("## Rejection Reason\n\n{}\n", rejection_reason));
+ content
+}
@@ -150,7 +150,7 @@ fn capture_example_as_markdown(
.buffer()
.read(cx)
.text_anchor_for_position(editor.selections.newest_anchor().head(), cx)?;
- let example = capture_example(project.clone(), buffer, cursor_anchor, true, cx)?;
+ let example = capture_example(project.clone(), buffer, cursor_anchor, cx)?;
let examples_dir = AllLanguageSettings::get_global(cx)
.edit_predictions
@@ -13,7 +13,7 @@ use crate::{
},
task_context::RunnableRange,
text_diff::text_diff,
- unified_diff,
+ unified_diff_with_offsets,
};
pub use crate::{
Grammar, Language, LanguageRegistry,
@@ -773,7 +773,11 @@ pub struct EditPreview {
}
impl EditPreview {
- pub fn as_unified_diff(&self, edits: &[(Range<Anchor>, impl AsRef<str>)]) -> Option<String> {
+ pub fn as_unified_diff(
+ &self,
+ file: Option<&Arc<dyn File>>,
+ edits: &[(Range<Anchor>, impl AsRef<str>)],
+ ) -> Option<String> {
let (first, _) = edits.first()?;
let (last, _) = edits.last()?;
@@ -788,7 +792,7 @@ impl EditPreview {
let old_end = Point::new(old_end.row + 4, 0).min(self.old_snapshot.max_point());
let new_end = Point::new(new_end.row + 4, 0).min(self.applied_edits_snapshot.max_point());
- Some(unified_diff(
+ let diff_body = unified_diff_with_offsets(
&self
.old_snapshot
.text_for_range(start..old_end)
@@ -797,7 +801,17 @@ impl EditPreview {
.applied_edits_snapshot
.text_for_range(start..new_end)
.collect::<String>(),
- ))
+ start.row,
+ start.row,
+ );
+
+ let path = file.map(|f| f.path().as_unix_str());
+ let header = match path {
+ Some(p) => format!("--- a/{}\n+++ b/{}\n", p, p),
+ None => String::new(),
+ };
+
+ Some(format!("{}{}", header, diff_body))
}
pub fn highlight_edits(
@@ -4564,13 +4564,13 @@ impl Project {
for worktree in worktree_store.visible_worktrees(cx) {
let worktree = worktree.read(cx);
- if let Ok(path) = RelPath::new(path, path_style)
- && let Some(entry) = worktree.entry_for_path(&path)
- {
- return Some(ProjectPath {
- worktree_id: worktree.id(),
- path: entry.path.clone(),
- });
+ if let Ok(rel_path) = RelPath::new(path, path_style) {
+ if let Some(entry) = worktree.entry_for_path(&rel_path) {
+ return Some(ProjectPath {
+ worktree_id: worktree.id(),
+ path: entry.path.clone(),
+ });
+ }
}
}
}