diff --git a/Cargo.lock b/Cargo.lock index c0eea670a77f03c4dbb5afdb7d1197b6d9b76159..fd8c81b2ceda42cc2e5421b9c80c96bd6e62b7c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21745,6 +21745,7 @@ dependencies = [ "futures 0.3.31", "gpui", "gpui_tokio", + "indoc", "language", "language_extension", "language_model", diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml index a54298366614c3633cf527cc5746480e66c6caae..739960a0a32c21640944aebf13372f3ace5ece63 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/zeta_cli/Cargo.toml @@ -13,6 +13,7 @@ name = "zeta" path = "src/main.rs" [dependencies] + anyhow.workspace = true chrono.workspace = true clap.workspace = true @@ -42,7 +43,6 @@ prompt_store.workspace = true pulldown-cmark.workspace = true release_channel.workspace = true reqwest_client.workspace = true -toml.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true @@ -50,8 +50,14 @@ shellexpand.workspace = true smol.workspace = true soa-rs = "0.8.1" terminal_view.workspace = true +toml.workspace = true util.workspace = true watch.workspace = true zeta.workspace = true zeta2.workspace = true zlog.workspace = true + +[dev-dependencies] +indoc.workspace = true +gpui = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } diff --git a/crates/zeta_cli/src/example.rs b/crates/zeta_cli/src/example.rs index 083b021a8da9d09fb134483ac36b543398576f5f..6a6ce9e57e91a3284cfc1464366cd868f601a813 100644 --- a/crates/zeta_cli/src/example.rs +++ b/crates/zeta_cli/src/example.rs @@ -338,136 +338,7 @@ impl NamedExample { project: &Entity, cx: &mut AsyncApp, ) -> Result>> { - use cloud_llm_client::udiff::DiffLine; - use std::fmt::Write; - - #[derive(Debug, Default)] - struct Edit { - context: String, - deletion_start: Option, - addition: String, - } - - let mut old_path = None; - let mut new_path = None; - let mut pending = Edit::default(); - let mut diff_lines = self - .example - .edit_history - .lines() - .map(DiffLine::parse) - .peekable(); - let mut open_buffers = HashSet::default(); - - while let Some(diff_line) = diff_lines.next() { - match diff_line { - DiffLine::OldPath { path } => { - mem::take(&mut pending); - old_path = Some(path) - } - DiffLine::HunkHeader(_) => { - mem::take(&mut pending); - } - DiffLine::NewPath { path } => { - if old_path.is_none() { - anyhow::bail!( - "Found a new path header (`+++`) before an (`---`) old path header" - ); - } - new_path = Some(path) - } - DiffLine::Context(ctx) => { - writeln!(&mut pending.context, "{ctx}")?; - } - DiffLine::Deletion(del) => { - pending.deletion_start.get_or_insert(pending.context.len()); - writeln!(&mut pending.context, "{del}")?; - } - DiffLine::Addition(add) => { - if pending.context.is_empty() { - anyhow::bail!("Found an addition before any context or deletion lines"); - } - - writeln!(&mut pending.addition, "{add}")?; - } - DiffLine::Garbage => {} - } - - let commit_pending = match diff_lines.peek() { - Some(DiffLine::OldPath { .. }) - | Some(DiffLine::HunkHeader(_)) - | Some(DiffLine::Context(_)) - | None => { - // commit pending edit cluster - !pending.addition.is_empty() || pending.deletion_start.is_some() - } - Some(DiffLine::Deletion(_)) => { - // start a new cluster if we have any additions specifically - // if we only have deletions, we continue to aggregate them - !pending.addition.is_empty() - } - _ => false, - }; - - if commit_pending { - let edit = mem::take(&mut pending); - - let Some(old_path) = old_path.as_deref() else { - anyhow::bail!("Missing old path (`---`) header") - }; - - let Some(new_path) = new_path.as_deref() else { - anyhow::bail!("Missing new path (`+++`) header") - }; - - let buffer = project - .update(cx, |project, cx| { - let project_path = project - .find_project_path(old_path, cx) - .context("Failed to find old_path in project")?; - - anyhow::Ok(project.open_buffer(project_path, cx)) - })?? - .await?; - open_buffers.insert(buffer.clone()); - - if old_path != new_path { - project - .update(cx, |project, cx| { - let project_file = - project::File::from_dyn(buffer.read(cx).file()).unwrap(); - let new_path = ProjectPath { - worktree_id: project_file.worktree_id(cx), - path: project_file.path.clone(), - }; - project.rename_entry(project_file.entry_id.unwrap(), new_path, cx) - })? - .await?; - } - - // TODO is it worth using project search? - buffer.update(cx, |buffer, cx| { - let text = buffer.text(); - // todo! check there's only one - if let Some(context_offset) = text.find(&edit.context) { - let end = context_offset + edit.context.len(); - let start = if let Some(deletion_start) = edit.deletion_start { - context_offset + deletion_start - } else { - end - }; - - buffer.edit([(start..end, edit.addition)], None, cx); - - anyhow::Ok(()) - } else { - anyhow::bail!("Failed to match context:\n{}", edit.context); - } - })??; - } - } - - anyhow::Ok(open_buffers) + apply_diff(&self.example.edit_history, project, cx).await } } @@ -546,3 +417,201 @@ impl Display for NamedExample { Ok(()) } } + +#[must_use] +pub async fn apply_diff( + diff: &str, + project: &Entity, + cx: &mut AsyncApp, +) -> Result>> { + use cloud_llm_client::udiff::DiffLine; + use std::fmt::Write; + + #[derive(Debug, Default)] + struct Edit { + context: String, + deletion_start: Option, + addition: String, + } + + let mut old_path = None; + let mut new_path = None; + let mut pending = Edit::default(); + let mut diff_lines = diff.lines().map(DiffLine::parse).peekable(); + let mut open_buffers = HashSet::default(); + + while let Some(diff_line) = diff_lines.next() { + match diff_line { + DiffLine::OldPath { path } => { + mem::take(&mut pending); + old_path = Some(path) + } + DiffLine::HunkHeader(_) => { + mem::take(&mut pending); + } + DiffLine::NewPath { path } => { + if old_path.is_none() { + anyhow::bail!( + "Found a new path header (`+++`) before an (`---`) old path header" + ); + } + new_path = Some(path) + } + DiffLine::Context(ctx) => { + writeln!(&mut pending.context, "{ctx}")?; + } + DiffLine::Deletion(del) => { + pending.deletion_start.get_or_insert(pending.context.len()); + writeln!(&mut pending.context, "{del}")?; + } + DiffLine::Addition(add) => { + if pending.context.is_empty() { + anyhow::bail!("Found an addition before any context or deletion lines"); + } + + writeln!(&mut pending.addition, "{add}")?; + } + DiffLine::Garbage => {} + } + + let commit_pending = match diff_lines.peek() { + Some(DiffLine::OldPath { .. }) + | Some(DiffLine::HunkHeader(_)) + | Some(DiffLine::Context(_)) + | None => { + // commit pending edit cluster + !pending.addition.is_empty() || pending.deletion_start.is_some() + } + Some(DiffLine::Deletion(_)) => { + // start a new cluster if we have any additions specifically + // if we only have deletions, we continue to aggregate them + !pending.addition.is_empty() + } + _ => false, + }; + + if commit_pending { + let edit = mem::take(&mut pending); + + let Some(old_path) = old_path.as_deref() else { + anyhow::bail!("Missing old path (`---`) header") + }; + + let Some(new_path) = new_path.as_deref() else { + anyhow::bail!("Missing new path (`+++`) header") + }; + + let buffer = project + .update(cx, |project, cx| { + let project_path = project + .find_project_path(old_path, cx) + .context("Failed to find old_path in project")?; + + anyhow::Ok(project.open_buffer(project_path, cx)) + })?? + .await?; + open_buffers.insert(buffer.clone()); + + if old_path != new_path { + project + .update(cx, |project, cx| { + let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap(); + let new_path = ProjectPath { + worktree_id: project_file.worktree_id(cx), + path: project_file.path.clone(), + }; + project.rename_entry(project_file.entry_id.unwrap(), new_path, cx) + })? + .await?; + } + + // TODO is it worth using project search? + buffer.update(cx, |buffer, cx| { + let text = buffer.text(); + // todo! check there's only one + if let Some(context_offset) = text.find(&edit.context) { + let end = context_offset + edit.context.len(); + let start = if let Some(deletion_start) = edit.deletion_start { + context_offset + deletion_start + } else { + end + }; + + buffer.edit([(start..end, edit.addition)], None, cx); + + anyhow::Ok(()) + } else { + anyhow::bail!("Failed to match context:\n{}", edit.context); + } + })??; + } + } + + anyhow::Ok(open_buffers) +} + +#[cfg(test)] +mod tests { + use super::*; + use fs::FakeFs; + use gpui::TestAppContext; + use indoc::indoc; + use project::Project; + use serde_json::json; + + #[gpui::test] + async fn test_apply_diff(cx: &mut TestAppContext) { + let buffer_1_text = indoc! {r#" + one + two + three + four + five + "# }; + + let buffer_2_text = indoc! {r#" + six + seven + eight + nine + ten + "# }; + + let fs = FakeFs::new(cx.background_executor().clone()); + fs.insert_tree( + "/root", + json!({ + "file1": buffer_1_text, + "file2": buffer_2_text, + }), + ) + .await; + + let project = Project::test(fs, ["/root".as_ref()], cx).await; + + let diff = indoc! {r#" + --- a/root/file1 + +++ b/root/file1 + one + two + -three + +3 + four + five + "#}; + + let _buffers = apply_diff(diff, &project, &mut cx.to_async()) + .await + .unwrap(); + let buffer_1 = project + .update(cx, |project, cx| { + let project_path = project.find_project_path("/root/file1", cx).unwrap(); + project.open_buffer(project_path, cx) + })? + .await?; + + buffer_1.read_with(cx, |buffer, cx| { + pretty_assertions::assert_eq!(buffer.text()) + }) + } +}