diff --git a/Cargo.lock b/Cargo.lock index 13eaf18ce12b15afbaf301a175afa4317939e65b..0ae2b1697a9f6ddbb76e0b26a60199b6af538610 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21741,6 +21741,7 @@ dependencies = [ "futures 0.3.31", "gpui", "gpui_tokio", + "indoc", "language", "language_extension", "language_model", @@ -21751,6 +21752,7 @@ dependencies = [ "ordered-float 2.10.1", "paths", "polars", + "pretty_assertions", "project", "prompt_store", "pulldown-cmark 0.12.2", diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index bb77c3a5b7f8009093cbf7bc427160ed535e6c62..afa72665f168e7ec341d92df0a094f7880368087 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -1,4 +1,5 @@ pub mod predict_edits_v3; +pub mod udiff; use std::str::FromStr; use std::sync::Arc; diff --git a/crates/cloud_llm_client/src/udiff.rs b/crates/cloud_llm_client/src/udiff.rs new file mode 100644 index 0000000000000000000000000000000000000000..444452e6b7350de1680d51b5b9a34eab69685fa3 --- /dev/null +++ b/crates/cloud_llm_client/src/udiff.rs @@ -0,0 +1,270 @@ +use std::borrow::Cow; + +#[derive(Debug, PartialEq)] +pub enum DiffLine<'a> { + OldPath { path: Cow<'a, str> }, + NewPath { path: Cow<'a, str> }, + HunkHeader(Option), + Context(&'a str), + Deletion(&'a str), + Addition(&'a str), + Garbage, +} + +#[derive(Debug, PartialEq)] +pub struct HunkLocation { + start_line_old: u32, + count_old: u32, + start_line_new: u32, + count_new: u32, +} + +impl<'a> DiffLine<'a> { + pub fn parse(line: &'a str) -> Self { + Self::try_parse(line).unwrap_or(Self::Garbage) + } + + fn try_parse(line: &'a str) -> Option { + if let Some(header) = line.strip_prefix("---").and_then(eat_required_whitespace) { + let path = parse_header_path("a/", header); + Some(Self::OldPath { path }) + } else if let Some(header) = line.strip_prefix("+++").and_then(eat_required_whitespace) { + Some(Self::NewPath { + path: parse_header_path("b/", header), + }) + } else if let Some(header) = line.strip_prefix("@@").and_then(eat_required_whitespace) { + if header.starts_with("...") { + return Some(Self::HunkHeader(None)); + } + + let (start_line_old, header) = header.strip_prefix('-')?.split_once(',')?; + let mut parts = header.split_ascii_whitespace(); + let count_old = parts.next()?; + let (start_line_new, count_new) = parts.next()?.strip_prefix('+')?.split_once(',')?; + + Some(Self::HunkHeader(Some(HunkLocation { + start_line_old: start_line_old.parse::().ok()?.saturating_sub(1), + count_old: count_old.parse().ok()?, + start_line_new: start_line_new.parse::().ok()?.saturating_sub(1), + count_new: count_new.parse().ok()?, + }))) + } else if let Some(deleted_header) = line.strip_prefix("-") { + Some(Self::Deletion(deleted_header)) + } else if line.is_empty() { + Some(Self::Context("")) + } else if let Some(context) = line.strip_prefix(" ") { + Some(Self::Context(context)) + } else { + Some(Self::Addition(line.strip_prefix("+")?)) + } + } +} + +fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> { + if !header.contains(['"', '\\']) { + let path = header.split_ascii_whitespace().next().unwrap_or(header); + return Cow::Borrowed(path.strip_prefix(strip_prefix).unwrap_or(path)); + } + + let mut path = String::with_capacity(header.len()); + let mut in_quote = false; + let mut chars = header.chars().peekable(); + let mut strip_prefix = Some(strip_prefix); + + while let Some(char) = chars.next() { + if char == '"' { + in_quote = !in_quote; + } else if char == '\\' { + let Some(&next_char) = chars.peek() else { + break; + }; + chars.next(); + path.push(next_char); + } else if char.is_ascii_whitespace() && !in_quote { + break; + } else { + path.push(char); + } + + if let Some(prefix) = strip_prefix + && path == prefix + { + strip_prefix.take(); + path.clear(); + } + } + + Cow::Owned(path) +} + +fn eat_required_whitespace(header: &str) -> Option<&str> { + let trimmed = header.trim_ascii_start(); + + if trimmed.len() == header.len() { + None + } else { + Some(trimmed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use indoc::indoc; + + #[test] + fn parse_lines_simple() { + let input = indoc! {" + diff --git a/text.txt b/text.txt + index 86c770d..a1fd855 100644 + --- a/file.txt + +++ b/file.txt + @@ -1,2 +1,3 @@ + context + -deleted + +inserted + garbage + + --- b/file.txt + +++ a/file.txt + "}; + + let lines = input.lines().map(DiffLine::parse).collect::>(); + + pretty_assertions::assert_eq!( + lines, + &[ + DiffLine::Garbage, + DiffLine::Garbage, + DiffLine::OldPath { + path: "file.txt".into() + }, + DiffLine::NewPath { + path: "file.txt".into() + }, + DiffLine::HunkHeader(Some(HunkLocation { + start_line_old: 0, + count_old: 2, + start_line_new: 0, + count_new: 3 + })), + DiffLine::Context("context"), + DiffLine::Deletion("deleted"), + DiffLine::Addition("inserted"), + DiffLine::Garbage, + DiffLine::Context(""), + DiffLine::OldPath { + path: "b/file.txt".into() + }, + DiffLine::NewPath { + path: "a/file.txt".into() + }, + ] + ); + } + + #[test] + fn file_header_extra_space() { + let options = ["--- file", "--- file", "---\tfile"]; + + for option in options { + pretty_assertions::assert_eq!( + DiffLine::parse(option), + DiffLine::OldPath { + path: "file".into() + }, + "{option}", + ); + } + } + + #[test] + fn hunk_header_extra_space() { + let options = [ + "@@ -1,2 +1,3 @@", + "@@ -1,2 +1,3 @@", + "@@\t-1,2\t+1,3\t@@", + "@@ -1,2 +1,3 @@", + "@@ -1,2 +1,3 @@", + "@@ -1,2 +1,3 @@", + "@@ -1,2 +1,3 @@ garbage", + ]; + + for option in options { + pretty_assertions::assert_eq!( + DiffLine::parse(option), + DiffLine::HunkHeader(Some(HunkLocation { + start_line_old: 0, + count_old: 2, + start_line_new: 0, + count_new: 3 + })), + "{option}", + ); + } + } + + #[test] + fn hunk_header_without_location() { + pretty_assertions::assert_eq!(DiffLine::parse("@@ ... @@"), DiffLine::HunkHeader(None)); + } + + #[test] + fn test_parse_path() { + assert_eq!(parse_header_path("a/", "foo.txt"), "foo.txt"); + assert_eq!( + parse_header_path("a/", "foo/bar/baz.txt"), + "foo/bar/baz.txt" + ); + assert_eq!(parse_header_path("a/", "a/foo.txt"), "foo.txt"); + assert_eq!( + parse_header_path("a/", "a/foo/bar/baz.txt"), + "foo/bar/baz.txt" + ); + + // Extra + assert_eq!( + parse_header_path("a/", "a/foo/bar/baz.txt 2025"), + "foo/bar/baz.txt" + ); + assert_eq!( + parse_header_path("a/", "a/foo/bar/baz.txt\t2025"), + "foo/bar/baz.txt" + ); + assert_eq!( + parse_header_path("a/", "a/foo/bar/baz.txt \""), + "foo/bar/baz.txt" + ); + + // Quoted + assert_eq!( + parse_header_path("a/", "a/foo/bar/\"baz quox.txt\""), + "foo/bar/baz quox.txt" + ); + assert_eq!( + parse_header_path("a/", "\"a/foo/bar/baz quox.txt\""), + "foo/bar/baz quox.txt" + ); + assert_eq!( + parse_header_path("a/", "\"foo/bar/baz quox.txt\""), + "foo/bar/baz quox.txt" + ); + assert_eq!(parse_header_path("a/", "\"whatever 🤷\""), "whatever 🤷"); + assert_eq!( + parse_header_path("a/", "\"foo/bar/baz quox.txt\" 2025"), + "foo/bar/baz quox.txt" + ); + // unescaped quotes are dropped + assert_eq!(parse_header_path("a/", "foo/\"bar\""), "foo/bar"); + + // Escaped + assert_eq!( + parse_header_path("a/", "\"foo/\\\"bar\\\"/baz.txt\""), + "foo/\"bar\"/baz.txt" + ); + assert_eq!( + parse_header_path("a/", "\"C:\\\\Projects\\\\My App\\\\old file.txt\""), + "C:\\Projects\\My App\\old file.txt" + ); + } +} diff --git a/crates/zeta2/src/related_excerpts.rs b/crates/zeta2/src/related_excerpts.rs index 44388251e32678ff8d1b3ce594ab35996b235759..f1721020d000ec9b7ec308eaa3bac4951c45c3f8 100644 --- a/crates/zeta2/src/related_excerpts.rs +++ b/crates/zeta2/src/related_excerpts.rs @@ -149,6 +149,9 @@ pub fn find_related_excerpts( .find(|model| { model.provider_id() == MODEL_PROVIDER_ID && model.id() == LanguageModelId("claude-haiku-4-5-latest".into()) + // model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b") + // model.provider_id() == LanguageModelProviderId::new("ollama") + // && model.id() == LanguageModelId("gpt-oss:20b".into()) }) else { return Task::ready(Err(anyhow!("could not find context model"))); diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index bff091b6f0cd5a37c19ee015f8a0383c8b138b40..92e64f7f332accddbca46ee631f64e5b14be376d 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -35,8 +35,8 @@ use std::str::FromStr as _; use std::sync::Arc; use std::time::{Duration, Instant}; use thiserror::Error; -use util::ResultExt as _; use util::rel_path::RelPathBuf; +use util::{LogErrorFuture, TryFutureExt}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; pub mod merge_excerpts; @@ -50,8 +50,6 @@ use crate::related_excerpts::find_related_excerpts; pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery}; pub use provider::ZetaEditPredictionProvider; -const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); - /// Maximum number of events to track. const MAX_EVENT_COUNT: usize = 16; @@ -83,6 +81,7 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { max_diagnostic_bytes: 2048, prompt_format: PromptFormat::DEFAULT, file_indexing_parallelism: 1, + buffer_change_grouping_interval: Duration::from_secs(1), }; pub struct Zeta2FeatureFlag; @@ -118,6 +117,7 @@ pub struct ZetaOptions { pub max_diagnostic_bytes: usize, pub prompt_format: predict_edits_v3::PromptFormat, pub file_indexing_parallelism: usize, + pub buffer_change_grouping_interval: Duration, } #[derive(Debug, Clone, PartialEq)] @@ -135,6 +135,7 @@ impl ContextMode { } } +#[derive(Debug)] pub enum ZetaDebugInfo { ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo), SearchQueriesGenerated(ZetaSearchQueryDebugInfo), @@ -144,17 +145,20 @@ pub enum ZetaDebugInfo { EditPredicted(ZetaEditPredictionDebugInfo), } +#[derive(Debug)] pub struct ZetaContextRetrievalStartedDebugInfo { pub project: Entity, pub timestamp: Instant, pub search_prompt: String, } +#[derive(Debug)] pub struct ZetaContextRetrievalDebugInfo { pub project: Entity, pub timestamp: Instant, } +#[derive(Debug)] pub struct ZetaEditPredictionDebugInfo { pub request: predict_edits_v3::PredictEditsRequest, pub retrieval_time: TimeDelta, @@ -164,6 +168,7 @@ pub struct ZetaEditPredictionDebugInfo { pub response_rx: oneshot::Receiver>, } +#[derive(Debug)] pub struct ZetaSearchQueryDebugInfo { pub project: Entity, pub timestamp: Instant, @@ -178,7 +183,7 @@ struct ZetaProject { registered_buffers: HashMap, current_prediction: Option, context: Option, Vec>>>, - refresh_context_task: Option>>, + refresh_context_task: Option>>>, refresh_context_debounce_task: Option>>, refresh_context_timestamp: Option, } @@ -460,6 +465,7 @@ impl Zeta { project: &Entity, cx: &mut Context, ) -> BufferSnapshot { + let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval; let zeta_project = self.get_or_init_zeta_project(project, cx); let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx); @@ -469,6 +475,7 @@ impl Zeta { std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); Self::push_event( zeta_project, + buffer_change_grouping_interval, Event::BufferChange { old_snapshot, new_snapshot: new_snapshot.clone(), @@ -480,14 +487,19 @@ impl Zeta { new_snapshot } - fn push_event(zeta_project: &mut ZetaProject, event: Event) { + fn push_event( + zeta_project: &mut ZetaProject, + buffer_change_grouping_interval: Duration, + event: Event, + ) { let events = &mut zeta_project.events; - if let Some(Event::BufferChange { - new_snapshot: last_new_snapshot, - timestamp: last_timestamp, - .. - }) = events.back_mut() + if buffer_change_grouping_interval > Duration::ZERO + && let Some(Event::BufferChange { + new_snapshot: last_new_snapshot, + timestamp: last_timestamp, + .. + }) = events.back_mut() { // Coalesce edits for the same buffer when they happen one after the other. let Event::BufferChange { @@ -496,7 +508,7 @@ impl Zeta { timestamp, } = &event; - if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL + if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval && old_snapshot.remote_id() == last_new_snapshot.remote_id() && old_snapshot.version == last_new_snapshot.version { @@ -624,7 +636,7 @@ impl Zeta { }) } - fn request_prediction( + pub fn request_prediction( &mut self, project: &Entity, buffer: &Entity, @@ -1068,7 +1080,11 @@ impl Zeta { log::debug!("refetching edit prediction context after pause"); } this.update(cx, |this, cx| { - this.refresh_context(project, buffer, cursor_position, cx); + let task = this.refresh_context(project.clone(), buffer, cursor_position, cx); + + if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { + zeta_project.refresh_context_task = Some(task.log_err()); + }; }) .ok() } @@ -1077,73 +1093,68 @@ impl Zeta { // Refresh the related excerpts asynchronously. Ensure the task runs to completion, // and avoid spawning more than one concurrent task. - fn refresh_context( + pub fn refresh_context( &mut self, project: Entity, buffer: Entity, cursor_position: language::Anchor, cx: &mut Context, - ) { - let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { - return; - }; - - let debug_tx = self.debug_tx.clone(); - - zeta_project - .refresh_context_task - .get_or_insert(cx.spawn(async move |this, cx| { - let related_excerpts = this - .update(cx, |this, cx| { - let Some(zeta_project) = this.projects.get(&project.entity_id()) else { - return Task::ready(anyhow::Ok(HashMap::default())); - }; + ) -> Task> { + cx.spawn(async move |this, cx| { + let related_excerpts_result = this + .update(cx, |this, cx| { + let Some(zeta_project) = this.projects.get(&project.entity_id()) else { + return Task::ready(anyhow::Ok(HashMap::default())); + }; - let ContextMode::Llm(options) = &this.options().context else { - return Task::ready(anyhow::Ok(HashMap::default())); - }; + let ContextMode::Llm(options) = &this.options().context else { + return Task::ready(anyhow::Ok(HashMap::default())); + }; - let mut edit_history_unified_diff = String::new(); + let mut edit_history_unified_diff = String::new(); - for event in zeta_project.events.iter() { - if let Some(event) = event.to_request_event(cx) { - writeln!(&mut edit_history_unified_diff, "{event}").ok(); - } + for event in zeta_project.events.iter() { + if let Some(event) = event.to_request_event(cx) { + writeln!(&mut edit_history_unified_diff, "{event}").ok(); } + } - find_related_excerpts( - buffer.clone(), - cursor_position, - &project, - edit_history_unified_diff, - options, - debug_tx, - cx, - ) - }) - .ok()? - .await - .log_err() - .unwrap_or_default(); - this.update(cx, |this, _cx| { - let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else { - return; - }; - zeta_project.context = Some(related_excerpts); - zeta_project.refresh_context_task.take(); - if let Some(debug_tx) = &this.debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished( - ZetaContextRetrievalDebugInfo { - project, - timestamp: Instant::now(), - }, - )) - .ok(); + find_related_excerpts( + buffer.clone(), + cursor_position, + &project, + edit_history_unified_diff, + options, + this.debug_tx.clone(), + cx, + ) + })? + .await; + + this.update(cx, |this, _cx| { + let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else { + return Ok(()); + }; + zeta_project.refresh_context_task.take(); + if let Some(debug_tx) = &this.debug_tx { + debug_tx + .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished( + ZetaContextRetrievalDebugInfo { + project, + timestamp: Instant::now(), + }, + )) + .ok(); + } + match related_excerpts_result { + Ok(excerpts) => { + zeta_project.context = Some(excerpts); + Ok(()) } - }) - .ok() - })); + Err(error) => Err(error), + } + })? + }) } fn gather_nearby_diagnostics( diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 0b4a59844d7b4a02c2f41ff7654c7df0c4292f7a..89f9dcd5e318c5c21d0121a52b1f39a4f1bd8848 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -335,6 +335,8 @@ impl Zeta2Inspector { max_diagnostic_bytes: zeta_options.max_diagnostic_bytes, prompt_format: zeta_options.prompt_format, file_indexing_parallelism: zeta_options.file_indexing_parallelism, + buffer_change_grouping_interval: zeta_options + .buffer_change_grouping_interval, }, cx, ); diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml index a54298366614c3633cf527cc5746480e66c6caae..5bf90910f18f085db42d5f7934d13601e1c691a2 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/zeta_cli/Cargo.toml @@ -13,6 +13,7 @@ name = "zeta" path = "src/main.rs" [dependencies] + anyhow.workspace = true chrono.workspace = true clap.workspace = true @@ -42,7 +43,6 @@ prompt_store.workspace = true pulldown-cmark.workspace = true release_channel.workspace = true reqwest_client.workspace = true -toml.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true @@ -50,8 +50,15 @@ shellexpand.workspace = true smol.workspace = true soa-rs = "0.8.1" terminal_view.workspace = true +toml.workspace = true util.workspace = true watch.workspace = true zeta.workspace = true zeta2.workspace = true zlog.workspace = true + +[dev-dependencies] +indoc.workspace = true +gpui = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +pretty_assertions.workspace = true diff --git a/crates/zeta_cli/src/example.rs b/crates/zeta_cli/src/example.rs index de95bbe8d0c97df7c12ce04f75de35ed41a660e4..e742241787cbc714deb6ab934f07bc01218dce10 100644 --- a/crates/zeta_cli/src/example.rs +++ b/crates/zeta_cli/src/example.rs @@ -5,17 +5,23 @@ use std::{ fs, io::Write, mem, + ops::Range, path::{Path, PathBuf}, }; use anyhow::{Context as _, Result}; use clap::ValueEnum; -use gpui::http_client::Url; +use collections::HashSet; +use futures::AsyncWriteExt as _; +use gpui::{AsyncApp, Entity, http_client::Url}; +use language::Buffer; +use project::{Project, ProjectPath}; use pulldown_cmark::CowStr; use serde::{Deserialize, Serialize}; -const CURSOR_POSITION_HEADING: &str = "Cursor Position"; +const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff"; const EDIT_HISTORY_HEADING: &str = "Edit History"; +const CURSOR_POSITION_HEADING: &str = "Cursor Position"; const EXPECTED_PATCH_HEADING: &str = "Expected Patch"; const EXPECTED_EXCERPTS_HEADING: &str = "Expected Excerpts"; const REPOSITORY_URL_FIELD: &str = "repository_url"; @@ -31,9 +37,10 @@ pub struct NamedExample { pub struct Example { pub repository_url: String, pub revision: String, + pub uncommitted_diff: String, pub cursor_path: PathBuf, pub cursor_position: String, - pub edit_history: Vec, + pub edit_history: String, pub expected_patch: String, pub expected_excerpts: Vec, } @@ -59,11 +66,11 @@ impl NamedExample { match ext.and_then(|s| s.to_str()) { Some("json") => Ok(Self { - name: path.file_name().unwrap_or_default().display().to_string(), + name: path.file_stem().unwrap_or_default().display().to_string(), example: serde_json::from_str(&content)?, }), Some("toml") => Ok(Self { - name: path.file_name().unwrap_or_default().display().to_string(), + name: path.file_stem().unwrap_or_default().display().to_string(), example: toml::from_str(&content)?, }), Some("md") => Self::parse_md(&content), @@ -88,9 +95,10 @@ impl NamedExample { example: Example { repository_url: String::new(), revision: String::new(), + uncommitted_diff: String::new(), cursor_path: PathBuf::new(), cursor_position: String::new(), - edit_history: Vec::new(), + edit_history: String::new(), expected_patch: String::new(), expected_excerpts: Vec::new(), }, @@ -152,18 +160,19 @@ impl NamedExample { block_info = "".into(); } Event::End(TagEnd::CodeBlock) => { - if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) { - named.example.edit_history.push(mem::take(&mut text)); + let block_info = block_info.trim(); + if current_section.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) { + named.example.uncommitted_diff = mem::take(&mut text); + } else if current_section.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) { + named.example.edit_history.push_str(&mem::take(&mut text)); } else if current_section.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) { - let path = PathBuf::from(block_info.trim()); - named.example.cursor_path = path; + named.example.cursor_path = block_info.into(); named.example.cursor_position = mem::take(&mut text); } else if current_section.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) { named.example.expected_patch = mem::take(&mut text); } else if current_section.eq_ignore_ascii_case(EXPECTED_EXCERPTS_HEADING) { - let path = PathBuf::from(block_info.trim()); named.example.expected_excerpts.push(ExpectedExcerpt { - path, + path: block_info.into(), text: mem::take(&mut text), }); } else { @@ -195,13 +204,14 @@ impl NamedExample { #[allow(unused)] pub async fn setup_worktree(&self) -> Result { + let (repo_owner, repo_name) = self.repo_name()?; + let file_name = self.file_name(); + let worktrees_dir = env::current_dir()?.join("target").join("zeta-worktrees"); let repos_dir = env::current_dir()?.join("target").join("zeta-repos"); fs::create_dir_all(&repos_dir)?; fs::create_dir_all(&worktrees_dir)?; - let (repo_owner, repo_name) = self.repo_name()?; - let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref()); if !repo_dir.is_dir() { fs::create_dir_all(&repo_dir)?; @@ -213,36 +223,81 @@ impl NamedExample { .await?; } - run_git( - &repo_dir, - &["fetch", "--depth", "1", "origin", &self.example.revision], - ) - .await?; - - let worktree_path = worktrees_dir.join(&self.name); + // Resolve the example to a revision, fetching it if needed. + let revision = run_git(&repo_dir, &["rev-parse", &self.example.revision]).await; + let revision = if let Ok(revision) = revision { + revision + } else { + run_git( + &repo_dir, + &["fetch", "--depth", "1", "origin", &self.example.revision], + ) + .await?; + let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?; + if revision != self.example.revision { + run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?; + } + revision + }; + // Create the worktree for this example if needed. + let worktree_path = worktrees_dir.join(&file_name); if worktree_path.is_dir() { run_git(&worktree_path, &["clean", "--force", "-d"]).await?; run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; - run_git(&worktree_path, &["checkout", &self.example.revision]).await?; + run_git(&worktree_path, &["checkout", revision.as_str()]).await?; } else { let worktree_path_string = worktree_path.to_string_lossy(); + run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?; run_git( &repo_dir, - &[ - "worktree", - "add", - "-f", - &worktree_path_string, - &self.example.revision, - ], + &["worktree", "add", "-f", &worktree_path_string, &file_name], ) .await?; } + // Apply the uncommitted diff for this example. + if !self.example.uncommitted_diff.is_empty() { + let mut apply_process = smol::process::Command::new("git") + .current_dir(&worktree_path) + .args(&["apply", "-"]) + .stdin(std::process::Stdio::piped()) + .spawn()?; + + let mut stdin = apply_process.stdin.take().unwrap(); + stdin + .write_all(self.example.uncommitted_diff.as_bytes()) + .await?; + stdin.close().await?; + drop(stdin); + + let apply_result = apply_process.output().await?; + if !apply_result.status.success() { + anyhow::bail!( + "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}", + apply_result.status, + String::from_utf8_lossy(&apply_result.stderr), + String::from_utf8_lossy(&apply_result.stdout), + ); + } + } + Ok(worktree_path) } + fn file_name(&self) -> String { + self.name + .chars() + .map(|c| { + if c.is_whitespace() { + '-' + } else { + c.to_ascii_lowercase() + } + }) + .collect() + } + #[allow(unused)] fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> { // git@github.com:owner/repo.git @@ -277,6 +332,15 @@ impl NamedExample { Ok((owner.into(), repo.into())) } } + + #[must_use] + pub async fn apply_edit_history( + &self, + project: &Entity, + cx: &mut AsyncApp, + ) -> Result>> { + apply_diff(&self.example.edit_history, project, cx).await + } } async fn run_git(repo_path: &Path, args: &[&str]) -> Result { @@ -308,6 +372,15 @@ impl Display for NamedExample { )?; write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?; + write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?; + write!(f, "`````diff\n")?; + write!(f, "{}", self.example.uncommitted_diff)?; + write!(f, "`````\n")?; + + if !self.example.edit_history.is_empty() { + write!(f, "`````diff\n{}`````\n", self.example.edit_history)?; + } + write!( f, "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n", @@ -316,14 +389,6 @@ impl Display for NamedExample { )?; write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?; - if !self.example.edit_history.is_empty() { - write!(f, "`````diff\n")?; - for item in &self.example.edit_history { - write!(f, "{item}")?; - } - write!(f, "`````\n")?; - } - if !self.example.expected_patch.is_empty() { write!( f, @@ -353,3 +418,404 @@ impl Display for NamedExample { Ok(()) } } + +#[must_use] +pub async fn apply_diff( + diff: &str, + project: &Entity, + cx: &mut AsyncApp, +) -> Result>> { + use cloud_llm_client::udiff::DiffLine; + use std::fmt::Write; + + #[derive(Debug, Default)] + struct HunkState { + context: String, + edits: Vec, + } + + #[derive(Debug)] + struct Edit { + range: Range, + text: String, + } + + let mut old_path = None; + let mut new_path = None; + let mut hunk = HunkState::default(); + let mut diff_lines = diff.lines().map(DiffLine::parse).peekable(); + let mut open_buffers = HashSet::default(); + + while let Some(diff_line) = diff_lines.next() { + match diff_line { + DiffLine::OldPath { path } => old_path = Some(path), + DiffLine::NewPath { path } => { + if old_path.is_none() { + anyhow::bail!( + "Found a new path header (`+++`) before an (`---`) old path header" + ); + } + new_path = Some(path) + } + DiffLine::Context(ctx) => { + writeln!(&mut hunk.context, "{ctx}")?; + } + DiffLine::Deletion(del) => { + let range = hunk.context.len()..hunk.context.len() + del.len() + '\n'.len_utf8(); + if let Some(last_edit) = hunk.edits.last_mut() + && last_edit.range.end == range.start + { + last_edit.range.end = range.end; + } else { + hunk.edits.push(Edit { + range, + text: String::new(), + }); + } + writeln!(&mut hunk.context, "{del}")?; + } + DiffLine::Addition(add) => { + let range = hunk.context.len()..hunk.context.len(); + if let Some(last_edit) = hunk.edits.last_mut() + && last_edit.range.end == range.start + { + writeln!(&mut last_edit.text, "{add}").unwrap(); + } else { + hunk.edits.push(Edit { + range, + text: format!("{add}\n"), + }); + } + } + DiffLine::HunkHeader(_) | DiffLine::Garbage => {} + } + + let at_hunk_end = match diff_lines.peek() { + Some(DiffLine::OldPath { .. }) | Some(DiffLine::HunkHeader(_)) | None => true, + _ => false, + }; + + if at_hunk_end { + let hunk = mem::take(&mut hunk); + + let Some(old_path) = old_path.as_deref() else { + anyhow::bail!("Missing old path (`---`) header") + }; + + let Some(new_path) = new_path.as_deref() else { + anyhow::bail!("Missing new path (`+++`) header") + }; + + let buffer = project + .update(cx, |project, cx| { + let project_path = project + .find_project_path(old_path, cx) + .context("Failed to find old_path in project")?; + + anyhow::Ok(project.open_buffer(project_path, cx)) + })?? + .await?; + open_buffers.insert(buffer.clone()); + + if old_path != new_path { + project + .update(cx, |project, cx| { + let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap(); + let new_path = ProjectPath { + worktree_id: project_file.worktree_id(cx), + path: project_file.path.clone(), + }; + project.rename_entry(project_file.entry_id.unwrap(), new_path, cx) + })? + .await?; + } + + // TODO is it worth using project search? + buffer.update(cx, |buffer, cx| { + let context_offset = if hunk.context.is_empty() { + 0 + } else { + let text = buffer.text(); + if let Some(offset) = text.find(&hunk.context) { + if text[offset + 1..].contains(&hunk.context) { + anyhow::bail!("Context is not unique enough:\n{}", hunk.context); + } + offset + } else { + anyhow::bail!( + "Failed to match context:\n{}\n\nBuffer:\n{}", + hunk.context, + text + ); + } + }; + + buffer.edit( + hunk.edits.into_iter().map(|edit| { + ( + context_offset + edit.range.start..context_offset + edit.range.end, + edit.text, + ) + }), + None, + cx, + ); + + anyhow::Ok(()) + })??; + } + } + + anyhow::Ok(open_buffers) +} + +#[cfg(test)] +mod tests { + use super::*; + use ::fs::FakeFs; + use gpui::TestAppContext; + use indoc::indoc; + use pretty_assertions::assert_eq; + use project::Project; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + #[gpui::test] + async fn test_apply_diff_successful(cx: &mut TestAppContext) { + let buffer_1_text = indoc! {r#" + one + two + three + four + five + "# }; + + let buffer_1_text_final = indoc! {r#" + 3 + 4 + 5 + "# }; + + let buffer_2_text = indoc! {r#" + six + seven + eight + nine + ten + "# }; + + let buffer_2_text_final = indoc! {r#" + 5 + six + seven + 7.5 + eight + nine + ten + 11 + "# }; + + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + }); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/root"), + json!({ + "file1": buffer_1_text, + "file2": buffer_2_text, + }), + ) + .await; + + let project = Project::test(fs, [path!("/root").as_ref()], cx).await; + + let diff = indoc! {r#" + --- a/root/file1 + +++ b/root/file1 + one + two + -three + +3 + four + five + --- a/root/file1 + +++ b/root/file1 + 3 + -four + -five + +4 + +5 + --- a/root/file1 + +++ b/root/file1 + -one + -two + 3 + 4 + --- a/root/file2 + +++ b/root/file2 + +5 + six + --- a/root/file2 + +++ b/root/file2 + seven + +7.5 + eight + --- a/root/file2 + +++ b/root/file2 + ten + +11 + "#}; + + let _buffers = apply_diff(diff, &project, &mut cx.to_async()) + .await + .unwrap(); + let buffer_1 = project + .update(cx, |project, cx| { + let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + buffer_1.read_with(cx, |buffer, _cx| { + assert_eq!(buffer.text(), buffer_1_text_final); + }); + let buffer_2 = project + .update(cx, |project, cx| { + let project_path = project.find_project_path(path!("/root/file2"), cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + buffer_2.read_with(cx, |buffer, _cx| { + assert_eq!(buffer.text(), buffer_2_text_final); + }); + } + + #[gpui::test] + async fn test_apply_diff_non_unique(cx: &mut TestAppContext) { + let buffer_1_text = indoc! {r#" + one + two + three + four + five + one + two + three + four + five + "# }; + + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + }); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/root"), + json!({ + "file1": buffer_1_text, + }), + ) + .await; + + let project = Project::test(fs, [path!("/root").as_ref()], cx).await; + + let diff = indoc! {r#" + --- a/root/file1 + +++ b/root/file1 + one + two + -three + +3 + four + five + "#}; + + apply_diff(diff, &project, &mut cx.to_async()) + .await + .expect_err("Non-unique edits should fail"); + } + + #[gpui::test] + async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) { + let start = indoc! {r#" + one + two + three + four + five + + four + five + "# }; + + let end = indoc! {r#" + one + two + 3 + four + 5 + + four + five + "# }; + + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + language::init(cx); + }); + + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/root"), + json!({ + "file1": start, + }), + ) + .await; + + let project = Project::test(fs, [path!("/root").as_ref()], cx).await; + + let diff = indoc! {r#" + --- a/root/file1 + +++ b/root/file1 + one + two + -three + +3 + four + -five + +5 + "#}; + + let _buffers = apply_diff(diff, &project, &mut cx.to_async()) + .await + .unwrap(); + + let buffer_1 = project + .update(cx, |project, cx| { + let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + buffer_1.read_with(cx, |buffer, _cx| { + assert_eq!(buffer.text(), end); + }); + } +} diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 8f19287744697e9f0d2ffd520be8a814790b8345..f0d1cb3fd445d841c2f237c2f828c65c326836ea 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -8,6 +8,7 @@ use crate::example::{ExampleFormat, NamedExample}; use crate::syntax_retrieval_stats::retrieval_stats; use ::serde::Serialize; use ::util::paths::PathStyle; +use ::util::rel_path::RelPath; use anyhow::{Context as _, Result, anyhow}; use clap::{Args, Parser, Subcommand}; use cloud_llm_client::predict_edits_v3::{self, Excerpt}; @@ -21,10 +22,11 @@ use futures::channel::mpsc; use gpui::{Application, AsyncApp, Entity, prelude::*}; use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point}; use language_model::LanguageModelRegistry; -use project::{Project, Worktree}; +use project::{Project, ProjectPath, Worktree}; use reqwest_client::ReqwestClient; use serde_json::json; use std::io; +use std::time::{Duration, Instant}; use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc}; use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery}; @@ -46,8 +48,6 @@ enum Command { command: Zeta1Command, }, Zeta2 { - #[clap(flatten)] - args: Zeta2Args, #[command(subcommand)] command: Zeta2Command, }, @@ -69,15 +69,22 @@ enum Zeta1Command { #[derive(Subcommand, Debug)] enum Zeta2Command { Syntax { + #[clap(flatten)] + args: Zeta2Args, #[clap(flatten)] syntax_args: Zeta2SyntaxArgs, #[command(subcommand)] command: Zeta2SyntaxCommand, }, Llm { + #[clap(flatten)] + args: Zeta2Args, #[command(subcommand)] command: Zeta2LlmCommand, }, + Predict { + example_path: PathBuf, + }, } #[derive(Subcommand, Debug)] @@ -170,6 +177,7 @@ fn syntax_args_to_options( max_prompt_bytes: zeta2_args.max_prompt_bytes, prompt_format: zeta2_args.prompt_format.clone().into(), file_indexing_parallelism: zeta2_args.file_indexing_parallelism, + buffer_change_grouping_interval: Duration::ZERO, } } @@ -319,6 +327,208 @@ async fn load_context( }) } +async fn zeta2_predict( + example: NamedExample, + app_state: &Arc, + cx: &mut AsyncApp, +) -> Result<()> { + let worktree_path = example.setup_worktree().await?; + + cx.update(|cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry + .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID) + .unwrap() + .authenticate(cx) + }) + })? + .await?; + + app_state + .client + .sign_in_with_optional_connect(true, cx) + .await?; + + let project = cx.update(|cx| { + Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + cx, + ) + })?; + + let worktree = project + .update(cx, |project, cx| { + project.create_worktree(&worktree_path, true, cx) + })? + .await?; + worktree + .read_with(cx, |worktree, _cx| { + worktree.as_local().unwrap().scan_complete() + })? + .await; + + let _edited_buffers = example.apply_edit_history(&project, cx).await?; + + let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc(); + + let cursor_buffer = project + .update(cx, |project, cx| { + project.open_buffer( + ProjectPath { + worktree_id: worktree.read(cx).id(), + path: cursor_path, + }, + cx, + ) + })? + .await?; + + let cursor_offset_within_excerpt = example + .example + .cursor_position + .find(CURSOR_MARKER) + .ok_or_else(|| anyhow!("missing cursor marker"))?; + let mut cursor_excerpt = example.example.cursor_position.clone(); + cursor_excerpt.replace_range( + cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()), + "", + ); + let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| { + let text = buffer.text(); + + let mut matches = text.match_indices(&cursor_excerpt); + let Some((excerpt_offset, _)) = matches.next() else { + anyhow::bail!( + "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n" + ); + }; + assert!(matches.next().is_none()); + + Ok(excerpt_offset) + })??; + + let cursor_offset = excerpt_offset + cursor_offset_within_excerpt; + let cursor_anchor = + cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?; + + let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?; + + let refresh_task = zeta.update(cx, |zeta, cx| { + zeta.register_buffer(&cursor_buffer, &project, cx); + zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) + })?; + + let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; + let mut context_retrieval_started_at = None; + let mut context_retrieval_finished_at = None; + let mut search_queries_generated_at = None; + let mut search_queries_executed_at = None; + let mut prediction_started_at = None; + let mut prediction_finished_at = None; + let mut excerpts_text = String::new(); + let mut prediction_task = None; + while let Some(event) = debug_rx.next().await { + match event { + zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { + context_retrieval_started_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { + search_queries_generated_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { + search_queries_executed_at = Some(info.timestamp); + } + zeta2::ZetaDebugInfo::ContextRetrievalFinished(info) => { + context_retrieval_finished_at = Some(info.timestamp); + + prediction_task = Some(zeta.update(cx, |zeta, cx| { + zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx) + })?); + } + zeta2::ZetaDebugInfo::EditPredicted(request) => { + prediction_started_at = Some(Instant::now()); + request.response_rx.await?.map_err(|err| anyhow!(err))?; + prediction_finished_at = Some(Instant::now()); + + for included_file in request.request.included_files { + let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)]; + write_codeblock( + &included_file.path, + included_file.excerpts.iter(), + if included_file.path == request.request.excerpt_path { + &insertions + } else { + &[] + }, + included_file.max_row, + false, + &mut excerpts_text, + ); + } + break; + } + _ => {} + } + } + + refresh_task.await.context("context retrieval failed")?; + let prediction = prediction_task.unwrap().await?.context("No prediction")?; + + println!("## Excerpts\n"); + println!("{excerpts_text}"); + + let old_text = prediction.snapshot.text(); + let new_text = prediction.buffer.update(cx, |buffer, cx| { + buffer.edit(prediction.edits.iter().cloned(), None, cx); + buffer.text() + })?; + let diff = language::unified_diff(&old_text, &new_text); + + println!("## Prediction\n"); + println!("{diff}"); + + println!("## Time\n"); + + let planning_search_time = + search_queries_generated_at.unwrap() - context_retrieval_started_at.unwrap(); + + println!("Planning searches: {}ms", planning_search_time.as_millis()); + println!( + "Running searches: {}ms", + (search_queries_executed_at.unwrap() - search_queries_generated_at.unwrap()).as_millis() + ); + + let filtering_search_time = + context_retrieval_finished_at.unwrap() - search_queries_executed_at.unwrap(); + println!( + "Filtering context results: {}ms", + filtering_search_time.as_millis() + ); + + let prediction_time = prediction_finished_at.unwrap() - prediction_started_at.unwrap(); + println!("Making Prediction: {}ms", prediction_time.as_millis()); + + println!("-------------------"); + let total_time = + (prediction_finished_at.unwrap() - context_retrieval_started_at.unwrap()).as_millis(); + println!("Total: {}ms", total_time); + + let inference_time = + (planning_search_time + filtering_search_time + prediction_time).as_millis(); + println!( + "Inference: {}ms ({:.2}%)", + inference_time, + (inference_time as f64 / total_time as f64) * 100. + ); + + anyhow::Ok(()) +} + async fn zeta2_syntax_context( zeta2_args: Zeta2Args, syntax_args: Zeta2SyntaxArgs, @@ -616,8 +826,15 @@ fn main() { let context = zeta1_context(context_args, &app_state, cx).await.unwrap(); serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err)) } - Command::Zeta2 { args, command } => match command { + Command::Zeta2 { command } => match command { + Zeta2Command::Predict { example_path } => { + let example = NamedExample::load(example_path).unwrap(); + zeta2_predict(example, &app_state, cx).await.unwrap(); + let _ = cx.update(|cx| cx.quit()); + return; + } Zeta2Command::Syntax { + args, syntax_args, command, } => match command { @@ -643,7 +860,7 @@ fn main() { .await } }, - Zeta2Command::Llm { command } => match command { + Zeta2Command::Llm { args, command } => match command { Zeta2LlmCommand::Context { context_args } => { zeta2_llm_context(args, context_args, &app_state, cx).await }