crates/cloud_llm_client/src/cloud_llm_client.rs 🔗
@@ -1,4 +1,5 @@
 pub mod predict_edits_v3;
+pub mod udiff;
 
 use std::str::FromStr;
 use std::sync::Arc;
  Agus Zubiaga created
crates/cloud_llm_client/src/cloud_llm_client.rs |   1 
crates/cloud_llm_client/src/udiff.rs            | 270 +++++++++++++++++++
crates/zeta2/src/zeta2.rs                       |  25 +
crates/zeta2_tools/src/zeta2_tools.rs           |   2 
crates/zeta_cli/src/example.rs                  | 150 +++++++++
crates/zeta_cli/src/main.rs                     |   8 
6 files changed, 438 insertions(+), 18 deletions(-)
@@ -1,4 +1,5 @@
 pub mod predict_edits_v3;
+pub mod udiff;
 
 use std::str::FromStr;
 use std::sync::Arc;
  @@ -0,0 +1,270 @@
+use std::borrow::Cow;
+
+#[derive(Debug, PartialEq)]
+pub enum DiffLine<'a> {
+    OldPath { path: Cow<'a, str> },
+    NewPath { path: Cow<'a, str> },
+    HunkHeader(Option<HunkLocation>),
+    Context(&'a str),
+    Deletion(&'a str),
+    Addition(&'a str),
+    Garbage,
+}
+
+#[derive(Debug, PartialEq)]
+pub struct HunkLocation {
+    start_line_old: u32,
+    count_old: u32,
+    start_line_new: u32,
+    count_new: u32,
+}
+
+impl<'a> DiffLine<'a> {
+    pub fn parse(line: &'a str) -> Self {
+        Self::try_parse(line).unwrap_or(Self::Garbage)
+    }
+
+    fn try_parse(line: &'a str) -> Option<Self> {
+        if let Some(header) = line.strip_prefix("---").and_then(eat_required_whitespace) {
+            let path = parse_header_path("a/", header);
+            Some(Self::OldPath { path })
+        } else if let Some(header) = line.strip_prefix("+++").and_then(eat_required_whitespace) {
+            Some(Self::NewPath {
+                path: parse_header_path("b/", header),
+            })
+        } else if let Some(header) = line.strip_prefix("@@").and_then(eat_required_whitespace) {
+            if header.starts_with("...") {
+                return Some(Self::HunkHeader(None));
+            }
+
+            let (start_line_old, header) = header.strip_prefix('-')?.split_once(',')?;
+            let mut parts = header.split_ascii_whitespace();
+            let count_old = parts.next()?;
+            let (start_line_new, count_new) = parts.next()?.strip_prefix('+')?.split_once(',')?;
+
+            Some(Self::HunkHeader(Some(HunkLocation {
+                start_line_old: start_line_old.parse::<u32>().ok()?.saturating_sub(1),
+                count_old: count_old.parse().ok()?,
+                start_line_new: start_line_new.parse::<u32>().ok()?.saturating_sub(1),
+                count_new: count_new.parse().ok()?,
+            })))
+        } else if let Some(deleted_header) = line.strip_prefix("-") {
+            Some(Self::Deletion(deleted_header))
+        } else if line.is_empty() {
+            Some(Self::Context(""))
+        } else if let Some(context) = line.strip_prefix(" ") {
+            Some(Self::Context(context))
+        } else {
+            Some(Self::Addition(line.strip_prefix("+")?))
+        }
+    }
+}
+
+fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> {
+    if !header.contains(['"', '\\']) {
+        let path = header.split_ascii_whitespace().next().unwrap_or(header);
+        return Cow::Borrowed(path.strip_prefix(strip_prefix).unwrap_or(path));
+    }
+
+    let mut path = String::with_capacity(header.len());
+    let mut in_quote = false;
+    let mut chars = header.chars().peekable();
+    let mut strip_prefix = Some(strip_prefix);
+
+    while let Some(char) = chars.next() {
+        if char == '"' {
+            in_quote = !in_quote;
+        } else if char == '\\' {
+            let Some(&next_char) = chars.peek() else {
+                break;
+            };
+            chars.next();
+            path.push(next_char);
+        } else if char.is_ascii_whitespace() && !in_quote {
+            break;
+        } else {
+            path.push(char);
+        }
+
+        if let Some(prefix) = strip_prefix
+            && path == prefix
+        {
+            strip_prefix.take();
+            path.clear();
+        }
+    }
+
+    Cow::Owned(path)
+}
+
+fn eat_required_whitespace(header: &str) -> Option<&str> {
+    let trimmed = header.trim_ascii_start();
+
+    if trimmed.len() == header.len() {
+        None
+    } else {
+        Some(trimmed)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use indoc::indoc;
+
+    #[test]
+    fn parse_lines_simple() {
+        let input = indoc! {"
+            diff --git a/text.txt b/text.txt
+            index 86c770d..a1fd855 100644
+            --- a/file.txt
+            +++ b/file.txt
+            @@ -1,2 +1,3 @@
+             context
+            -deleted
+            +inserted
+            garbage
+
+            --- b/file.txt
+            +++ a/file.txt
+        "};
+
+        let lines = input.lines().map(DiffLine::parse).collect::<Vec<_>>();
+
+        pretty_assertions::assert_eq!(
+            lines,
+            &[
+                DiffLine::Garbage,
+                DiffLine::Garbage,
+                DiffLine::OldPath {
+                    path: "file.txt".into()
+                },
+                DiffLine::NewPath {
+                    path: "file.txt".into()
+                },
+                DiffLine::HunkHeader(Some(HunkLocation {
+                    start_line_old: 0,
+                    count_old: 2,
+                    start_line_new: 0,
+                    count_new: 3
+                })),
+                DiffLine::Context("context"),
+                DiffLine::Deletion("deleted"),
+                DiffLine::Addition("inserted"),
+                DiffLine::Garbage,
+                DiffLine::Context(""),
+                DiffLine::OldPath {
+                    path: "b/file.txt".into()
+                },
+                DiffLine::NewPath {
+                    path: "a/file.txt".into()
+                },
+            ]
+        );
+    }
+
+    #[test]
+    fn file_header_extra_space() {
+        let options = ["--- file", "---   file", "---\tfile"];
+
+        for option in options {
+            pretty_assertions::assert_eq!(
+                DiffLine::parse(option),
+                DiffLine::OldPath {
+                    path: "file".into()
+                },
+                "{option}",
+            );
+        }
+    }
+
+    #[test]
+    fn hunk_header_extra_space() {
+        let options = [
+            "@@ -1,2 +1,3 @@",
+            "@@  -1,2  +1,3 @@",
+            "@@\t-1,2\t+1,3\t@@",
+            "@@ -1,2  +1,3 @@",
+            "@@ -1,2   +1,3 @@",
+            "@@ -1,2 +1,3   @@",
+            "@@ -1,2 +1,3 @@ garbage",
+        ];
+
+        for option in options {
+            pretty_assertions::assert_eq!(
+                DiffLine::parse(option),
+                DiffLine::HunkHeader(Some(HunkLocation {
+                    start_line_old: 0,
+                    count_old: 2,
+                    start_line_new: 0,
+                    count_new: 3
+                })),
+                "{option}",
+            );
+        }
+    }
+
+    #[test]
+    fn hunk_header_without_location() {
+        pretty_assertions::assert_eq!(DiffLine::parse("@@ ... @@"), DiffLine::HunkHeader(None));
+    }
+
+    #[test]
+    fn test_parse_path() {
+        assert_eq!(parse_header_path("a/", "foo.txt"), "foo.txt");
+        assert_eq!(
+            parse_header_path("a/", "foo/bar/baz.txt"),
+            "foo/bar/baz.txt"
+        );
+        assert_eq!(parse_header_path("a/", "a/foo.txt"), "foo.txt");
+        assert_eq!(
+            parse_header_path("a/", "a/foo/bar/baz.txt"),
+            "foo/bar/baz.txt"
+        );
+
+        // Extra
+        assert_eq!(
+            parse_header_path("a/", "a/foo/bar/baz.txt  2025"),
+            "foo/bar/baz.txt"
+        );
+        assert_eq!(
+            parse_header_path("a/", "a/foo/bar/baz.txt\t2025"),
+            "foo/bar/baz.txt"
+        );
+        assert_eq!(
+            parse_header_path("a/", "a/foo/bar/baz.txt \""),
+            "foo/bar/baz.txt"
+        );
+
+        // Quoted
+        assert_eq!(
+            parse_header_path("a/", "a/foo/bar/\"baz quox.txt\""),
+            "foo/bar/baz quox.txt"
+        );
+        assert_eq!(
+            parse_header_path("a/", "\"a/foo/bar/baz quox.txt\""),
+            "foo/bar/baz quox.txt"
+        );
+        assert_eq!(
+            parse_header_path("a/", "\"foo/bar/baz quox.txt\""),
+            "foo/bar/baz quox.txt"
+        );
+        assert_eq!(parse_header_path("a/", "\"whatever 🤷\""), "whatever 🤷");
+        assert_eq!(
+            parse_header_path("a/", "\"foo/bar/baz quox.txt\"  2025"),
+            "foo/bar/baz quox.txt"
+        );
+        // unescaped quotes are dropped
+        assert_eq!(parse_header_path("a/", "foo/\"bar\""), "foo/bar");
+
+        // Escaped
+        assert_eq!(
+            parse_header_path("a/", "\"foo/\\\"bar\\\"/baz.txt\""),
+            "foo/\"bar\"/baz.txt"
+        );
+        assert_eq!(
+            parse_header_path("a/", "\"C:\\\\Projects\\\\My App\\\\old file.txt\""),
+            "C:\\Projects\\My App\\old file.txt"
+        );
+    }
+}
  @@ -50,8 +50,6 @@ use crate::related_excerpts::find_related_excerpts;
 pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
 pub use provider::ZetaEditPredictionProvider;
 
-const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
-
 /// Maximum number of events to track.
 const MAX_EVENT_COUNT: usize = 16;
 
@@ -83,6 +81,7 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
     max_diagnostic_bytes: 2048,
     prompt_format: PromptFormat::DEFAULT,
     file_indexing_parallelism: 1,
+    buffer_change_grouping_interval: Duration::from_secs(1),
 };
 
 pub struct Zeta2FeatureFlag;
@@ -118,6 +117,7 @@ pub struct ZetaOptions {
     pub max_diagnostic_bytes: usize,
     pub prompt_format: predict_edits_v3::PromptFormat,
     pub file_indexing_parallelism: usize,
+    pub buffer_change_grouping_interval: Duration,
 }
 
 #[derive(Debug, Clone, PartialEq)]
@@ -460,6 +460,7 @@ impl Zeta {
         project: &Entity<Project>,
         cx: &mut Context<Self>,
     ) -> BufferSnapshot {
+        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
         let zeta_project = self.get_or_init_zeta_project(project, cx);
         let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 
@@ -469,6 +470,7 @@ impl Zeta {
                 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
             Self::push_event(
                 zeta_project,
+                buffer_change_grouping_interval,
                 Event::BufferChange {
                     old_snapshot,
                     new_snapshot: new_snapshot.clone(),
@@ -480,14 +482,19 @@ impl Zeta {
         new_snapshot
     }
 
-    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
+    fn push_event(
+        zeta_project: &mut ZetaProject,
+        buffer_change_grouping_interval: Duration,
+        event: Event,
+    ) {
         let events = &mut zeta_project.events;
 
-        if let Some(Event::BufferChange {
-            new_snapshot: last_new_snapshot,
-            timestamp: last_timestamp,
-            ..
-        }) = events.back_mut()
+        if buffer_change_grouping_interval > Duration::ZERO
+            && let Some(Event::BufferChange {
+                new_snapshot: last_new_snapshot,
+                timestamp: last_timestamp,
+                ..
+            }) = events.back_mut()
         {
             // Coalesce edits for the same buffer when they happen one after the other.
             let Event::BufferChange {
@@ -496,7 +503,7 @@ impl Zeta {
                 timestamp,
             } = &event;
 
-            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
+            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
                 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
                 && old_snapshot.version == last_new_snapshot.version
             {
  @@ -335,6 +335,8 @@ impl Zeta2Inspector {
                         max_diagnostic_bytes: zeta_options.max_diagnostic_bytes,
                         prompt_format: zeta_options.prompt_format,
                         file_indexing_parallelism: zeta_options.file_indexing_parallelism,
+                        buffer_change_grouping_interval: zeta_options
+                            .buffer_change_grouping_interval,
                     },
                     cx,
                 );
  @@ -10,8 +10,10 @@ use std::{
 
 use anyhow::{Context as _, Result};
 use clap::ValueEnum;
+use collections::HashSet;
 use futures::AsyncWriteExt as _;
-use gpui::http_client::Url;
+use gpui::{AsyncApp, Entity, http_client::Url};
+use project::{Project, ProjectPath};
 use pulldown_cmark::CowStr;
 use serde::{Deserialize, Serialize};
 
@@ -36,7 +38,7 @@ pub struct Example {
     pub uncommitted_diff: String,
     pub cursor_path: PathBuf,
     pub cursor_position: String,
-    pub edit_history: Vec<String>,
+    pub edit_history: String,
     pub expected_patch: String,
     pub expected_excerpts: Vec<ExpectedExcerpt>,
 }
@@ -94,7 +96,7 @@ impl NamedExample {
                 uncommitted_diff: String::new(),
                 cursor_path: PathBuf::new(),
                 cursor_position: String::new(),
-                edit_history: Vec::new(),
+                edit_history: String::new(),
                 expected_patch: String::new(),
                 expected_excerpts: Vec::new(),
             },
@@ -160,7 +162,7 @@ impl NamedExample {
                     if current_section.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
                         named.example.uncommitted_diff = mem::take(&mut text);
                     } else if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
-                        named.example.edit_history.push(mem::take(&mut text));
+                        named.example.edit_history.push_str(&mem::take(&mut text));
                     } else if current_section.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
                         named.example.cursor_path = block_info.into();
                         named.example.cursor_position = mem::take(&mut text);
@@ -328,6 +330,140 @@ impl NamedExample {
             Ok((owner.into(), repo.into()))
         }
     }
+
+    pub async fn apply_edit_history(
+        &self,
+        project: &Entity<Project>,
+        cx: &mut AsyncApp,
+    ) -> Result<()> {
+        use cloud_llm_client::udiff::DiffLine;
+        use std::fmt::Write;
+
+        #[derive(Default)]
+        struct Edit {
+            context: String,
+            deletion_start: Option<usize>,
+            addition: String,
+        }
+
+        let mut old_path = None;
+        let mut new_path = None;
+        let mut pending = Edit::default();
+        let mut diff_lines = self
+            .example
+            .edit_history
+            .lines()
+            .map(DiffLine::parse)
+            .peekable();
+        let mut open_buffers = HashSet::default();
+
+        while let Some(diff_line) = diff_lines.next() {
+            match diff_line {
+                DiffLine::OldPath { path } => old_path = Some(path),
+                DiffLine::NewPath { path } => {
+                    if old_path.is_none() {
+                        anyhow::bail!(
+                            "Found a new path header (`+++`) before an (`---`) old path header"
+                        );
+                    }
+                    new_path = Some(path)
+                }
+                DiffLine::Context(ctx) => {
+                    writeln!(&mut pending.context, "{ctx}")?;
+                }
+                DiffLine::Deletion(del) => {
+                    pending.deletion_start.get_or_insert(pending.context.len());
+                    writeln!(&mut pending.context, "{del}")?;
+                }
+                DiffLine::Addition(add) => {
+                    if pending.context.is_empty() {
+                        anyhow::bail!("Found an addition before any context or deletion lines");
+                    }
+
+                    writeln!(&mut pending.addition, "{add}")?;
+                }
+                DiffLine::HunkHeader(_) | DiffLine::Garbage => {}
+            }
+
+            let commit_pending = match diff_lines.peek() {
+                Some(DiffLine::OldPath { .. })
+                | Some(DiffLine::HunkHeader(_))
+                | Some(DiffLine::Context(_))
+                | None => {
+                    // commit pending edit cluster
+                    !pending.addition.is_empty() || pending.deletion_start.is_some()
+                }
+                Some(DiffLine::Deletion(_)) => {
+                    // start a new cluster if we have any additions specifically
+                    // if we only have deletions, we continue to aggregate them
+                    pending.addition.is_empty()
+                }
+                _ => false,
+            };
+
+            if commit_pending {
+                let edit = mem::take(&mut pending);
+
+                if edit.addition.is_empty() || edit.deletion_start.is_none() {
+                    return anyhow::Ok(());
+                }
+
+                let Some(old_path) = old_path.as_deref() else {
+                    anyhow::bail!("Missing old path (`---`) header")
+                };
+
+                let Some(new_path) = new_path.as_deref() else {
+                    anyhow::bail!("Missing new path (`+++`) header")
+                };
+
+                let buffer = project
+                    .update(cx, |project, cx| {
+                        let project_path = project
+                            .find_project_path(old_path, cx)
+                            .context("Failed to find old_path in project")?;
+
+                        anyhow::Ok(project.open_buffer(project_path, cx))
+                    })??
+                    .await?;
+                open_buffers.insert(buffer.clone());
+
+                if old_path != new_path {
+                    project
+                        .update(cx, |project, cx| {
+                            let project_file =
+                                project::File::from_dyn(buffer.read(cx).file()).unwrap();
+                            let new_path = ProjectPath {
+                                worktree_id: project_file.worktree_id(cx),
+                                path: project_file.path.clone(),
+                            };
+                            project.rename_entry(project_file.entry_id.unwrap(), new_path, cx)
+                        })?
+                        .await?;
+                }
+
+                // TODO is it worth using project search?
+                buffer.update(cx, |buffer, cx| {
+                    let text = buffer.text();
+                    if let Some(context_offset) = text.find(&edit.context) {
+                        let end = context_offset + edit.context.len();
+                        let start = if let Some(deletion_start) = edit.deletion_start {
+                            context_offset + deletion_start
+                        } else {
+                            end
+                        };
+
+                        buffer.edit([(start..end, edit.addition)], None, cx);
+
+                        anyhow::Ok(())
+                    } else {
+                        anyhow::bail!("Failed to match context");
+                    }
+                })??;
+            }
+        }
+
+        anyhow::Ok(())
+    }
 }
 
 async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
@@ -365,11 +501,7 @@ impl Display for NamedExample {
         write!(f, "`````\n")?;
 
         if !self.example.edit_history.is_empty() {
-            write!(f, "`````diff\n")?;
-            for item in &self.example.edit_history {
-                write!(f, "{item}")?;
-            }
-            write!(f, "`````\n")?;
+            write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
         }
 
         write!(
  @@ -26,6 +26,7 @@ use project::{Project, ProjectPath, Worktree};
 use reqwest_client::ReqwestClient;
 use serde_json::json;
 use std::io;
+use std::time::Duration;
 use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
 use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
 
@@ -176,6 +177,7 @@ fn syntax_args_to_options(
         max_prompt_bytes: zeta2_args.max_prompt_bytes,
         prompt_format: zeta2_args.prompt_format.clone().into(),
         file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
+        buffer_change_grouping_interval: Duration::ZERO,
     }
 }
 
@@ -414,6 +416,12 @@ async fn zeta2_predict(
 
     let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
 
+    zeta.update(cx, |zeta, cx| {
+        zeta.register_buffer(&cursor_buffer, &project, cx);
+    })?;
+
+    example.apply_edit_history(&project, cx).await?;
+
     let (prediction_task, mut debug_rx) = zeta.update(cx, |zeta, cx| {
         let receiver = zeta.debug_info();
         let prediction_task = zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx);