From 4c261b2cd84d04f7c617a5b7637fc98b1d5568de Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 3 Nov 2025 15:06:32 -0300 Subject: [PATCH] Checkpoint: Applying edit history (untested) --- .../cloud_llm_client/src/cloud_llm_client.rs | 1 + crates/cloud_llm_client/src/udiff.rs | 270 ++++++++++++++++++ crates/zeta2/src/zeta2.rs | 25 +- crates/zeta2_tools/src/zeta2_tools.rs | 2 + crates/zeta_cli/src/example.rs | 150 +++++++++- crates/zeta_cli/src/main.rs | 8 + 6 files changed, 438 insertions(+), 18 deletions(-) create mode 100644 crates/cloud_llm_client/src/udiff.rs diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index bb77c3a5b7f8009093cbf7bc427160ed535e6c62..afa72665f168e7ec341d92df0a094f7880368087 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -1,4 +1,5 @@ pub mod predict_edits_v3; +pub mod udiff; use std::str::FromStr; use std::sync::Arc; diff --git a/crates/cloud_llm_client/src/udiff.rs b/crates/cloud_llm_client/src/udiff.rs new file mode 100644 index 0000000000000000000000000000000000000000..444452e6b7350de1680d51b5b9a34eab69685fa3 --- /dev/null +++ b/crates/cloud_llm_client/src/udiff.rs @@ -0,0 +1,270 @@ +use std::borrow::Cow; + +#[derive(Debug, PartialEq)] +pub enum DiffLine<'a> { + OldPath { path: Cow<'a, str> }, + NewPath { path: Cow<'a, str> }, + HunkHeader(Option), + Context(&'a str), + Deletion(&'a str), + Addition(&'a str), + Garbage, +} + +#[derive(Debug, PartialEq)] +pub struct HunkLocation { + start_line_old: u32, + count_old: u32, + start_line_new: u32, + count_new: u32, +} + +impl<'a> DiffLine<'a> { + pub fn parse(line: &'a str) -> Self { + Self::try_parse(line).unwrap_or(Self::Garbage) + } + + fn try_parse(line: &'a str) -> Option { + if let Some(header) = line.strip_prefix("---").and_then(eat_required_whitespace) { + let path = parse_header_path("a/", header); + Some(Self::OldPath { path }) + } else if let Some(header) = line.strip_prefix("+++").and_then(eat_required_whitespace) { + Some(Self::NewPath { + path: parse_header_path("b/", header), + }) + } else if let Some(header) = line.strip_prefix("@@").and_then(eat_required_whitespace) { + if header.starts_with("...") { + return Some(Self::HunkHeader(None)); + } + + let (start_line_old, header) = header.strip_prefix('-')?.split_once(',')?; + let mut parts = header.split_ascii_whitespace(); + let count_old = parts.next()?; + let (start_line_new, count_new) = parts.next()?.strip_prefix('+')?.split_once(',')?; + + Some(Self::HunkHeader(Some(HunkLocation { + start_line_old: start_line_old.parse::().ok()?.saturating_sub(1), + count_old: count_old.parse().ok()?, + start_line_new: start_line_new.parse::().ok()?.saturating_sub(1), + count_new: count_new.parse().ok()?, + }))) + } else if let Some(deleted_header) = line.strip_prefix("-") { + Some(Self::Deletion(deleted_header)) + } else if line.is_empty() { + Some(Self::Context("")) + } else if let Some(context) = line.strip_prefix(" ") { + Some(Self::Context(context)) + } else { + Some(Self::Addition(line.strip_prefix("+")?)) + } + } +} + +fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> { + if !header.contains(['"', '\\']) { + let path = header.split_ascii_whitespace().next().unwrap_or(header); + return Cow::Borrowed(path.strip_prefix(strip_prefix).unwrap_or(path)); + } + + let mut path = String::with_capacity(header.len()); + let mut in_quote = false; + let mut chars = header.chars().peekable(); + let mut strip_prefix = Some(strip_prefix); + + while let Some(char) = chars.next() { + if char == '"' { + in_quote = !in_quote; + } else if char == '\\' { + let Some(&next_char) = chars.peek() else { + break; + }; + chars.next(); + path.push(next_char); + } else if char.is_ascii_whitespace() && !in_quote { + break; + } else { + path.push(char); + } + + if let Some(prefix) = strip_prefix + && path == prefix + { + strip_prefix.take(); + path.clear(); + } + } + + Cow::Owned(path) +} + +fn eat_required_whitespace(header: &str) -> Option<&str> { + let trimmed = header.trim_ascii_start(); + + if trimmed.len() == header.len() { + None + } else { + Some(trimmed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use indoc::indoc; + + #[test] + fn parse_lines_simple() { + let input = indoc! {" + diff --git a/text.txt b/text.txt + index 86c770d..a1fd855 100644 + --- a/file.txt + +++ b/file.txt + @@ -1,2 +1,3 @@ + context + -deleted + +inserted + garbage + + --- b/file.txt + +++ a/file.txt + "}; + + let lines = input.lines().map(DiffLine::parse).collect::>(); + + pretty_assertions::assert_eq!( + lines, + &[ + DiffLine::Garbage, + DiffLine::Garbage, + DiffLine::OldPath { + path: "file.txt".into() + }, + DiffLine::NewPath { + path: "file.txt".into() + }, + DiffLine::HunkHeader(Some(HunkLocation { + start_line_old: 0, + count_old: 2, + start_line_new: 0, + count_new: 3 + })), + DiffLine::Context("context"), + DiffLine::Deletion("deleted"), + DiffLine::Addition("inserted"), + DiffLine::Garbage, + DiffLine::Context(""), + DiffLine::OldPath { + path: "b/file.txt".into() + }, + DiffLine::NewPath { + path: "a/file.txt".into() + }, + ] + ); + } + + #[test] + fn file_header_extra_space() { + let options = ["--- file", "--- file", "---\tfile"]; + + for option in options { + pretty_assertions::assert_eq!( + DiffLine::parse(option), + DiffLine::OldPath { + path: "file".into() + }, + "{option}", + ); + } + } + + #[test] + fn hunk_header_extra_space() { + let options = [ + "@@ -1,2 +1,3 @@", + "@@ -1,2 +1,3 @@", + "@@\t-1,2\t+1,3\t@@", + "@@ -1,2 +1,3 @@", + "@@ -1,2 +1,3 @@", + "@@ -1,2 +1,3 @@", + "@@ -1,2 +1,3 @@ garbage", + ]; + + for option in options { + pretty_assertions::assert_eq!( + DiffLine::parse(option), + DiffLine::HunkHeader(Some(HunkLocation { + start_line_old: 0, + count_old: 2, + start_line_new: 0, + count_new: 3 + })), + "{option}", + ); + } + } + + #[test] + fn hunk_header_without_location() { + pretty_assertions::assert_eq!(DiffLine::parse("@@ ... @@"), DiffLine::HunkHeader(None)); + } + + #[test] + fn test_parse_path() { + assert_eq!(parse_header_path("a/", "foo.txt"), "foo.txt"); + assert_eq!( + parse_header_path("a/", "foo/bar/baz.txt"), + "foo/bar/baz.txt" + ); + assert_eq!(parse_header_path("a/", "a/foo.txt"), "foo.txt"); + assert_eq!( + parse_header_path("a/", "a/foo/bar/baz.txt"), + "foo/bar/baz.txt" + ); + + // Extra + assert_eq!( + parse_header_path("a/", "a/foo/bar/baz.txt 2025"), + "foo/bar/baz.txt" + ); + assert_eq!( + parse_header_path("a/", "a/foo/bar/baz.txt\t2025"), + "foo/bar/baz.txt" + ); + assert_eq!( + parse_header_path("a/", "a/foo/bar/baz.txt \""), + "foo/bar/baz.txt" + ); + + // Quoted + assert_eq!( + parse_header_path("a/", "a/foo/bar/\"baz quox.txt\""), + "foo/bar/baz quox.txt" + ); + assert_eq!( + parse_header_path("a/", "\"a/foo/bar/baz quox.txt\""), + "foo/bar/baz quox.txt" + ); + assert_eq!( + parse_header_path("a/", "\"foo/bar/baz quox.txt\""), + "foo/bar/baz quox.txt" + ); + assert_eq!(parse_header_path("a/", "\"whatever 🤷\""), "whatever 🤷"); + assert_eq!( + parse_header_path("a/", "\"foo/bar/baz quox.txt\" 2025"), + "foo/bar/baz quox.txt" + ); + // unescaped quotes are dropped + assert_eq!(parse_header_path("a/", "foo/\"bar\""), "foo/bar"); + + // Escaped + assert_eq!( + parse_header_path("a/", "\"foo/\\\"bar\\\"/baz.txt\""), + "foo/\"bar\"/baz.txt" + ); + assert_eq!( + parse_header_path("a/", "\"C:\\\\Projects\\\\My App\\\\old file.txt\""), + "C:\\Projects\\My App\\old file.txt" + ); + } +} diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 65bd16ef598bafc4f92329a4699cb513d2220bc0..ed8f0b12c79374aaac4f115bed77e03132f63889 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -50,8 +50,6 @@ use crate::related_excerpts::find_related_excerpts; pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery}; pub use provider::ZetaEditPredictionProvider; -const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); - /// Maximum number of events to track. const MAX_EVENT_COUNT: usize = 16; @@ -83,6 +81,7 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { max_diagnostic_bytes: 2048, prompt_format: PromptFormat::DEFAULT, file_indexing_parallelism: 1, + buffer_change_grouping_interval: Duration::from_secs(1), }; pub struct Zeta2FeatureFlag; @@ -118,6 +117,7 @@ pub struct ZetaOptions { pub max_diagnostic_bytes: usize, pub prompt_format: predict_edits_v3::PromptFormat, pub file_indexing_parallelism: usize, + pub buffer_change_grouping_interval: Duration, } #[derive(Debug, Clone, PartialEq)] @@ -460,6 +460,7 @@ impl Zeta { project: &Entity, cx: &mut Context, ) -> BufferSnapshot { + let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval; let zeta_project = self.get_or_init_zeta_project(project, cx); let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx); @@ -469,6 +470,7 @@ impl Zeta { std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); Self::push_event( zeta_project, + buffer_change_grouping_interval, Event::BufferChange { old_snapshot, new_snapshot: new_snapshot.clone(), @@ -480,14 +482,19 @@ impl Zeta { new_snapshot } - fn push_event(zeta_project: &mut ZetaProject, event: Event) { + fn push_event( + zeta_project: &mut ZetaProject, + buffer_change_grouping_interval: Duration, + event: Event, + ) { let events = &mut zeta_project.events; - if let Some(Event::BufferChange { - new_snapshot: last_new_snapshot, - timestamp: last_timestamp, - .. - }) = events.back_mut() + if buffer_change_grouping_interval > Duration::ZERO + && let Some(Event::BufferChange { + new_snapshot: last_new_snapshot, + timestamp: last_timestamp, + .. + }) = events.back_mut() { // Coalesce edits for the same buffer when they happen one after the other. let Event::BufferChange { @@ -496,7 +503,7 @@ impl Zeta { timestamp, } = &event; - if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL + if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval && old_snapshot.remote_id() == last_new_snapshot.remote_id() && old_snapshot.version == last_new_snapshot.version { diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 0b4a59844d7b4a02c2f41ff7654c7df0c4292f7a..89f9dcd5e318c5c21d0121a52b1f39a4f1bd8848 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -335,6 +335,8 @@ impl Zeta2Inspector { max_diagnostic_bytes: zeta_options.max_diagnostic_bytes, prompt_format: zeta_options.prompt_format, file_indexing_parallelism: zeta_options.file_indexing_parallelism, + buffer_change_grouping_interval: zeta_options + .buffer_change_grouping_interval, }, cx, ); diff --git a/crates/zeta_cli/src/example.rs b/crates/zeta_cli/src/example.rs index 3eec3bc0f508744c4c1e3b6552a82b9fdbfe1704..e0f9e51c0c68234d48d03c5e598b0ddf8497b2f8 100644 --- a/crates/zeta_cli/src/example.rs +++ b/crates/zeta_cli/src/example.rs @@ -10,8 +10,10 @@ use std::{ use anyhow::{Context as _, Result}; use clap::ValueEnum; +use collections::HashSet; use futures::AsyncWriteExt as _; -use gpui::http_client::Url; +use gpui::{AsyncApp, Entity, http_client::Url}; +use project::{Project, ProjectPath}; use pulldown_cmark::CowStr; use serde::{Deserialize, Serialize}; @@ -36,7 +38,7 @@ pub struct Example { pub uncommitted_diff: String, pub cursor_path: PathBuf, pub cursor_position: String, - pub edit_history: Vec, + pub edit_history: String, pub expected_patch: String, pub expected_excerpts: Vec, } @@ -94,7 +96,7 @@ impl NamedExample { uncommitted_diff: String::new(), cursor_path: PathBuf::new(), cursor_position: String::new(), - edit_history: Vec::new(), + edit_history: String::new(), expected_patch: String::new(), expected_excerpts: Vec::new(), }, @@ -160,7 +162,7 @@ impl NamedExample { if current_section.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) { named.example.uncommitted_diff = mem::take(&mut text); } else if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) { - named.example.edit_history.push(mem::take(&mut text)); + named.example.edit_history.push_str(&mem::take(&mut text)); } else if current_section.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) { named.example.cursor_path = block_info.into(); named.example.cursor_position = mem::take(&mut text); @@ -328,6 +330,140 @@ impl NamedExample { Ok((owner.into(), repo.into())) } } + + pub async fn apply_edit_history( + &self, + project: &Entity, + cx: &mut AsyncApp, + ) -> Result<()> { + use cloud_llm_client::udiff::DiffLine; + use std::fmt::Write; + + #[derive(Default)] + struct Edit { + context: String, + deletion_start: Option, + addition: String, + } + + let mut old_path = None; + let mut new_path = None; + let mut pending = Edit::default(); + let mut diff_lines = self + .example + .edit_history + .lines() + .map(DiffLine::parse) + .peekable(); + let mut open_buffers = HashSet::default(); + + while let Some(diff_line) = diff_lines.next() { + match diff_line { + DiffLine::OldPath { path } => old_path = Some(path), + DiffLine::NewPath { path } => { + if old_path.is_none() { + anyhow::bail!( + "Found a new path header (`+++`) before an (`---`) old path header" + ); + } + new_path = Some(path) + } + DiffLine::Context(ctx) => { + writeln!(&mut pending.context, "{ctx}")?; + } + DiffLine::Deletion(del) => { + pending.deletion_start.get_or_insert(pending.context.len()); + writeln!(&mut pending.context, "{del}")?; + } + DiffLine::Addition(add) => { + if pending.context.is_empty() { + anyhow::bail!("Found an addition before any context or deletion lines"); + } + + writeln!(&mut pending.addition, "{add}")?; + } + DiffLine::HunkHeader(_) | DiffLine::Garbage => {} + } + + let commit_pending = match diff_lines.peek() { + Some(DiffLine::OldPath { .. }) + | Some(DiffLine::HunkHeader(_)) + | Some(DiffLine::Context(_)) + | None => { + // commit pending edit cluster + !pending.addition.is_empty() || pending.deletion_start.is_some() + } + Some(DiffLine::Deletion(_)) => { + // start a new cluster if we have any additions specifically + // if we only have deletions, we continue to aggregate them + pending.addition.is_empty() + } + _ => false, + }; + + if commit_pending { + let edit = mem::take(&mut pending); + + if edit.addition.is_empty() || edit.deletion_start.is_none() { + return anyhow::Ok(()); + } + + let Some(old_path) = old_path.as_deref() else { + anyhow::bail!("Missing old path (`---`) header") + }; + + let Some(new_path) = new_path.as_deref() else { + anyhow::bail!("Missing new path (`+++`) header") + }; + + let buffer = project + .update(cx, |project, cx| { + let project_path = project + .find_project_path(old_path, cx) + .context("Failed to find old_path in project")?; + + anyhow::Ok(project.open_buffer(project_path, cx)) + })?? + .await?; + open_buffers.insert(buffer.clone()); + + if old_path != new_path { + project + .update(cx, |project, cx| { + let project_file = + project::File::from_dyn(buffer.read(cx).file()).unwrap(); + let new_path = ProjectPath { + worktree_id: project_file.worktree_id(cx), + path: project_file.path.clone(), + }; + project.rename_entry(project_file.entry_id.unwrap(), new_path, cx) + })? + .await?; + } + + // TODO is it worth using project search? + buffer.update(cx, |buffer, cx| { + let text = buffer.text(); + if let Some(context_offset) = text.find(&edit.context) { + let end = context_offset + edit.context.len(); + let start = if let Some(deletion_start) = edit.deletion_start { + context_offset + deletion_start + } else { + end + }; + + buffer.edit([(start..end, edit.addition)], None, cx); + + anyhow::Ok(()) + } else { + anyhow::bail!("Failed to match context"); + } + })??; + } + } + + anyhow::Ok(()) + } } async fn run_git(repo_path: &Path, args: &[&str]) -> Result { @@ -365,11 +501,7 @@ impl Display for NamedExample { write!(f, "`````\n")?; if !self.example.edit_history.is_empty() { - write!(f, "`````diff\n")?; - for item in &self.example.edit_history { - write!(f, "{item}")?; - } - write!(f, "`````\n")?; + write!(f, "`````diff\n{}`````\n", self.example.edit_history)?; } write!( diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 3f2b51203fbb587a43936d49d14c0af83182e16e..168264f8f845522b5e29315f21496cbdb8a65dd8 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -26,6 +26,7 @@ use project::{Project, ProjectPath, Worktree}; use reqwest_client::ReqwestClient; use serde_json::json; use std::io; +use std::time::Duration; use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc}; use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery}; @@ -176,6 +177,7 @@ fn syntax_args_to_options( max_prompt_bytes: zeta2_args.max_prompt_bytes, prompt_format: zeta2_args.prompt_format.clone().into(), file_indexing_parallelism: zeta2_args.file_indexing_parallelism, + buffer_change_grouping_interval: Duration::ZERO, } } @@ -414,6 +416,12 @@ async fn zeta2_predict( let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?; + zeta.update(cx, |zeta, cx| { + zeta.register_buffer(&cursor_buffer, &project, cx); + })?; + + example.apply_edit_history(&project, cx).await?; + let (prediction_task, mut debug_rx) = zeta.update(cx, |zeta, cx| { let receiver = zeta.debug_info(); let prediction_task = zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx);