wip

Max Brunsfeld created

Change summary

Cargo.lock                     |   1 
crates/zeta_cli/Cargo.toml     |   8 
crates/zeta_cli/src/example.rs | 329 +++++++++++++++++++++--------------
3 files changed, 207 insertions(+), 131 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -21745,6 +21745,7 @@ dependencies = [
  "futures 0.3.31",
  "gpui",
  "gpui_tokio",
+ "indoc",
  "language",
  "language_extension",
  "language_model",

crates/zeta_cli/Cargo.toml 🔗

@@ -13,6 +13,7 @@ name = "zeta"
 path = "src/main.rs"
 
 [dependencies]
+
 anyhow.workspace = true
 chrono.workspace = true
 clap.workspace = true
@@ -42,7 +43,6 @@ prompt_store.workspace = true
 pulldown-cmark.workspace = true
 release_channel.workspace = true
 reqwest_client.workspace = true
-toml.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true
@@ -50,8 +50,14 @@ shellexpand.workspace = true
 smol.workspace = true
 soa-rs = "0.8.1"
 terminal_view.workspace = true
+toml.workspace = true
 util.workspace = true
 watch.workspace = true
 zeta.workspace = true
 zeta2.workspace = true
 zlog.workspace = true
+
+[dev-dependencies]
+indoc.workspace = true
+gpui = { workspace = true, features = ["test-support"] }
+project = { workspace = true, features = ["test-support"] }

crates/zeta_cli/src/example.rs 🔗

@@ -338,136 +338,7 @@ impl NamedExample {
         project: &Entity<Project>,
         cx: &mut AsyncApp,
     ) -> Result<HashSet<Entity<Buffer>>> {
-        use cloud_llm_client::udiff::DiffLine;
-        use std::fmt::Write;
-
-        #[derive(Debug, Default)]
-        struct Edit {
-            context: String,
-            deletion_start: Option<usize>,
-            addition: String,
-        }
-
-        let mut old_path = None;
-        let mut new_path = None;
-        let mut pending = Edit::default();
-        let mut diff_lines = self
-            .example
-            .edit_history
-            .lines()
-            .map(DiffLine::parse)
-            .peekable();
-        let mut open_buffers = HashSet::default();
-
-        while let Some(diff_line) = diff_lines.next() {
-            match diff_line {
-                DiffLine::OldPath { path } => {
-                    mem::take(&mut pending);
-                    old_path = Some(path)
-                }
-                DiffLine::HunkHeader(_) => {
-                    mem::take(&mut pending);
-                }
-                DiffLine::NewPath { path } => {
-                    if old_path.is_none() {
-                        anyhow::bail!(
-                            "Found a new path header (`+++`) before an (`---`) old path header"
-                        );
-                    }
-                    new_path = Some(path)
-                }
-                DiffLine::Context(ctx) => {
-                    writeln!(&mut pending.context, "{ctx}")?;
-                }
-                DiffLine::Deletion(del) => {
-                    pending.deletion_start.get_or_insert(pending.context.len());
-                    writeln!(&mut pending.context, "{del}")?;
-                }
-                DiffLine::Addition(add) => {
-                    if pending.context.is_empty() {
-                        anyhow::bail!("Found an addition before any context or deletion lines");
-                    }
-
-                    writeln!(&mut pending.addition, "{add}")?;
-                }
-                DiffLine::Garbage => {}
-            }
-
-            let commit_pending = match diff_lines.peek() {
-                Some(DiffLine::OldPath { .. })
-                | Some(DiffLine::HunkHeader(_))
-                | Some(DiffLine::Context(_))
-                | None => {
-                    // commit pending edit cluster
-                    !pending.addition.is_empty() || pending.deletion_start.is_some()
-                }
-                Some(DiffLine::Deletion(_)) => {
-                    // start a new cluster if we have any additions specifically
-                    // if we only have deletions, we continue to aggregate them
-                    !pending.addition.is_empty()
-                }
-                _ => false,
-            };
-
-            if commit_pending {
-                let edit = mem::take(&mut pending);
-
-                let Some(old_path) = old_path.as_deref() else {
-                    anyhow::bail!("Missing old path (`---`) header")
-                };
-
-                let Some(new_path) = new_path.as_deref() else {
-                    anyhow::bail!("Missing new path (`+++`) header")
-                };
-
-                let buffer = project
-                    .update(cx, |project, cx| {
-                        let project_path = project
-                            .find_project_path(old_path, cx)
-                            .context("Failed to find old_path in project")?;
-
-                        anyhow::Ok(project.open_buffer(project_path, cx))
-                    })??
-                    .await?;
-                open_buffers.insert(buffer.clone());
-
-                if old_path != new_path {
-                    project
-                        .update(cx, |project, cx| {
-                            let project_file =
-                                project::File::from_dyn(buffer.read(cx).file()).unwrap();
-                            let new_path = ProjectPath {
-                                worktree_id: project_file.worktree_id(cx),
-                                path: project_file.path.clone(),
-                            };
-                            project.rename_entry(project_file.entry_id.unwrap(), new_path, cx)
-                        })?
-                        .await?;
-                }
-
-                // TODO is it worth using project search?
-                buffer.update(cx, |buffer, cx| {
-                    let text = buffer.text();
-                    // todo! check there's only one
-                    if let Some(context_offset) = text.find(&edit.context) {
-                        let end = context_offset + edit.context.len();
-                        let start = if let Some(deletion_start) = edit.deletion_start {
-                            context_offset + deletion_start
-                        } else {
-                            end
-                        };
-
-                        buffer.edit([(start..end, edit.addition)], None, cx);
-
-                        anyhow::Ok(())
-                    } else {
-                        anyhow::bail!("Failed to match context:\n{}", edit.context);
-                    }
-                })??;
-            }
-        }
-
-        anyhow::Ok(open_buffers)
+        apply_diff(&self.example.edit_history, project, cx).await
     }
 }
 
@@ -546,3 +417,201 @@ impl Display for NamedExample {
         Ok(())
     }
 }
+
+#[must_use]
+pub async fn apply_diff(
+    diff: &str,
+    project: &Entity<Project>,
+    cx: &mut AsyncApp,
+) -> Result<HashSet<Entity<Buffer>>> {
+    use cloud_llm_client::udiff::DiffLine;
+    use std::fmt::Write;
+
+    #[derive(Debug, Default)]
+    struct Edit {
+        context: String,
+        deletion_start: Option<usize>,
+        addition: String,
+    }
+
+    let mut old_path = None;
+    let mut new_path = None;
+    let mut pending = Edit::default();
+    let mut diff_lines = diff.lines().map(DiffLine::parse).peekable();
+    let mut open_buffers = HashSet::default();
+
+    while let Some(diff_line) = diff_lines.next() {
+        match diff_line {
+            DiffLine::OldPath { path } => {
+                mem::take(&mut pending);
+                old_path = Some(path)
+            }
+            DiffLine::HunkHeader(_) => {
+                mem::take(&mut pending);
+            }
+            DiffLine::NewPath { path } => {
+                if old_path.is_none() {
+                    anyhow::bail!(
+                        "Found a new path header (`+++`) before an (`---`) old path header"
+                    );
+                }
+                new_path = Some(path)
+            }
+            DiffLine::Context(ctx) => {
+                writeln!(&mut pending.context, "{ctx}")?;
+            }
+            DiffLine::Deletion(del) => {
+                pending.deletion_start.get_or_insert(pending.context.len());
+                writeln!(&mut pending.context, "{del}")?;
+            }
+            DiffLine::Addition(add) => {
+                if pending.context.is_empty() {
+                    anyhow::bail!("Found an addition before any context or deletion lines");
+                }
+
+                writeln!(&mut pending.addition, "{add}")?;
+            }
+            DiffLine::Garbage => {}
+        }
+
+        let commit_pending = match diff_lines.peek() {
+            Some(DiffLine::OldPath { .. })
+            | Some(DiffLine::HunkHeader(_))
+            | Some(DiffLine::Context(_))
+            | None => {
+                // commit pending edit cluster
+                !pending.addition.is_empty() || pending.deletion_start.is_some()
+            }
+            Some(DiffLine::Deletion(_)) => {
+                // start a new cluster if we have any additions specifically
+                // if we only have deletions, we continue to aggregate them
+                !pending.addition.is_empty()
+            }
+            _ => false,
+        };
+
+        if commit_pending {
+            let edit = mem::take(&mut pending);
+
+            let Some(old_path) = old_path.as_deref() else {
+                anyhow::bail!("Missing old path (`---`) header")
+            };
+
+            let Some(new_path) = new_path.as_deref() else {
+                anyhow::bail!("Missing new path (`+++`) header")
+            };
+
+            let buffer = project
+                .update(cx, |project, cx| {
+                    let project_path = project
+                        .find_project_path(old_path, cx)
+                        .context("Failed to find old_path in project")?;
+
+                    anyhow::Ok(project.open_buffer(project_path, cx))
+                })??
+                .await?;
+            open_buffers.insert(buffer.clone());
+
+            if old_path != new_path {
+                project
+                    .update(cx, |project, cx| {
+                        let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap();
+                        let new_path = ProjectPath {
+                            worktree_id: project_file.worktree_id(cx),
+                            path: project_file.path.clone(),
+                        };
+                        project.rename_entry(project_file.entry_id.unwrap(), new_path, cx)
+                    })?
+                    .await?;
+            }
+
+            // TODO is it worth using project search?
+            buffer.update(cx, |buffer, cx| {
+                let text = buffer.text();
+                // todo! check there's only one
+                if let Some(context_offset) = text.find(&edit.context) {
+                    let end = context_offset + edit.context.len();
+                    let start = if let Some(deletion_start) = edit.deletion_start {
+                        context_offset + deletion_start
+                    } else {
+                        end
+                    };
+
+                    buffer.edit([(start..end, edit.addition)], None, cx);
+
+                    anyhow::Ok(())
+                } else {
+                    anyhow::bail!("Failed to match context:\n{}", edit.context);
+                }
+            })??;
+        }
+    }
+
+    anyhow::Ok(open_buffers)
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use fs::FakeFs;
+    use gpui::TestAppContext;
+    use indoc::indoc;
+    use project::Project;
+    use serde_json::json;
+
+    #[gpui::test]
+    async fn test_apply_diff(cx: &mut TestAppContext) {
+        let buffer_1_text = indoc! {r#"
+            one
+            two
+            three
+            four
+            five
+        "# };
+
+        let buffer_2_text = indoc! {r#"
+            six
+            seven
+            eight
+            nine
+            ten
+        "# };
+
+        let fs = FakeFs::new(cx.background_executor().clone());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "file1": buffer_1_text,
+                "file2": buffer_2_text,
+            }),
+        )
+        .await;
+
+        let project = Project::test(fs, ["/root".as_ref()], cx).await;
+
+        let diff = indoc! {r#"
+            --- a/root/file1
+            +++ b/root/file1
+             one
+             two
+            -three
+            +3
+             four
+             five
+        "#};
+
+        let _buffers = apply_diff(diff, &project, &mut cx.to_async())
+            .await
+            .unwrap();
+        let buffer_1 = project
+            .update(cx, |project, cx| {
+                let project_path = project.find_project_path("/root/file1", cx).unwrap();
+                project.open_buffer(project_path, cx)
+            })?
+            .await?;
+
+        buffer_1.read_with(cx, |buffer, cx| {
+            pretty_assertions::assert_eq!(buffer.text())
+        })
+    }
+}