diff --git a/Cargo.lock b/Cargo.lock index 63734b552d7475eacdb2ee3eac66371f7c029d28..93961b4181aa1ad721ba8d740736d86c2ae32ca2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5309,7 +5309,6 @@ dependencies = [ "workspace", "zed_actions", "zeta", - "zeta2", ] [[package]] @@ -21316,7 +21315,6 @@ dependencies = [ "zed_actions", "zed_env_vars", "zeta", - "zeta2", "zeta2_tools", "zlog", "zlog_settings", @@ -21636,48 +21634,52 @@ dependencies = [ "ai_onboarding", "anyhow", "arrayvec", - "call", + "brotli", + "buffer_diff", "client", "clock", "cloud_api_types", "cloud_llm_client", + "cloud_zeta2_prompt", "collections", "command_palette_hooks", "copilot", "ctor", "db", "edit_prediction", + "edit_prediction_context", "editor", "feature_flags", "fs", "futures 0.3.31", "gpui", - "http_client", "indoc", "itertools 0.14.0", "language", "language_model", "log", + "lsp", + "markdown", "menu", + "open_ai", "parking_lot", "postage", + "pretty_assertions", "project", "rand 0.9.2", "regex", "release_channel", - "reqwest_client", - "rpc", "semver", "serde", "serde_json", "settings", + "smol", + "strsim", "strum 0.27.2", "telemetry", "telemetry_events", "theme", "thiserror 2.0.17", - "tree-sitter-go", - "tree-sitter-rust", "ui", "util", "uuid", @@ -21687,53 +21689,11 @@ dependencies = [ "zlog", ] -[[package]] -name = "zeta2" -version = "0.1.0" -dependencies = [ - "anyhow", - "arrayvec", - "brotli", - "chrono", - "client", - "clock", - "cloud_llm_client", - "cloud_zeta2_prompt", - "collections", - "edit_prediction", - "edit_prediction_context", - "feature_flags", - "futures 0.3.31", - "gpui", - "indoc", - "language", - "language_model", - "log", - "lsp", - "open_ai", - "pretty_assertions", - "project", - "release_channel", - "semver", - "serde", - "serde_json", - "settings", - "smol", - "strsim", - "thiserror 2.0.17", - "util", - "uuid", - "workspace", - "worktree", - "zlog", -] - [[package]] name = "zeta2_tools" version = "0.1.0" dependencies = [ "anyhow", - "chrono", "clap", "client", "cloud_llm_client", @@ -21746,9 +21706,7 @@ dependencies = [ "gpui", "indoc", "language", - "log", "multi_buffer", - "ordered-float 2.10.1", "pretty_assertions", "project", "serde", @@ -21760,7 +21718,7 @@ dependencies = [ "ui_input", "util", "workspace", - "zeta2", + "zeta", "zlog", ] @@ -21810,7 +21768,6 @@ dependencies = [ "util", "watch", "zeta", - "zeta2", "zlog", ] diff --git a/Cargo.toml b/Cargo.toml index e3ba2cb817357f5733179864bc23161d01aa1123..ab18418939e1b7100684e3c0acec277e7ec75a88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -201,7 +201,6 @@ members = [ "crates/zed_actions", "crates/zed_env_vars", "crates/zeta", - "crates/zeta2", "crates/zeta_cli", "crates/zlog", "crates/zlog_settings", @@ -433,7 +432,6 @@ zed = { path = "crates/zed" } zed_actions = { path = "crates/zed_actions" } zed_env_vars = { path = "crates/zed_env_vars" } zeta = { path = "crates/zeta" } -zeta2 = { path = "crates/zeta2" } zlog = { path = "crates/zlog" } zlog_settings = { path = "crates/zlog_settings" } diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 2f7c25a3560e09bccb9f45c64df38048eefdddd6..a298db28e63fd761f2f6d58827a7bcf5c8b39962 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -1218,23 +1218,23 @@ } }, { - "context": "RateCompletionModal", + "context": "RatePredictionsModal", "use_key_equivalents": true, "bindings": { - "cmd-shift-enter": "zeta::ThumbsUpActiveCompletion", - "cmd-shift-backspace": "zeta::ThumbsDownActiveCompletion", + "cmd-shift-enter": "zeta::ThumbsUpActivePrediction", + "cmd-shift-backspace": "zeta::ThumbsDownActivePrediction", "shift-down": "zeta::NextEdit", "shift-up": "zeta::PreviousEdit", - "right": "zeta::PreviewCompletion" + "right": "zeta::PreviewPrediction" } }, { - "context": "RateCompletionModal > Editor", + "context": "RatePredictionsModal > Editor", "use_key_equivalents": true, "bindings": { - "escape": "zeta::FocusCompletions", - "cmd-shift-enter": "zeta::ThumbsUpActiveCompletion", - "cmd-shift-backspace": "zeta::ThumbsDownActiveCompletion" + "escape": "zeta::FocusPredictions", + "cmd-shift-enter": "zeta::ThumbsUpActivePrediction", + "cmd-shift-backspace": "zeta::ThumbsDownActivePrediction" } }, { diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 32a5a34d9d3b63332008a9f7df84a1990f87f17c..47e5e71589c806f71725ee4f218ca4a86bee62d0 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use std::{ fmt::{Display, Write as _}, ops::{Add, Range, Sub}, - path::{Path, PathBuf}, + path::Path, sync::Arc, }; use strum::EnumIter; @@ -17,7 +17,7 @@ pub struct PlanContextRetrievalRequest { pub excerpt_path: Arc, pub excerpt_line_range: Range, pub cursor_file_max_row: Line, - pub events: Vec, + pub events: Vec>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -36,7 +36,7 @@ pub struct PredictEditsRequest { pub signatures: Vec, #[serde(skip_serializing_if = "Vec::is_empty", default)] pub referenced_declarations: Vec, - pub events: Vec, + pub events: Vec>, #[serde(default)] pub can_collect_data: bool, #[serde(skip_serializing_if = "Vec::is_empty", default)] @@ -120,10 +120,11 @@ impl std::fmt::Display for PromptFormat { #[serde(tag = "event")] pub enum Event { BufferChange { - path: Option, - old_path: Option, + path: Arc, + old_path: Arc, diff: String, predicted: bool, + in_open_source_repo: bool, }, } @@ -135,23 +136,21 @@ impl Display for Event { old_path, diff, predicted, + .. } => { - let new_path = path.as_deref().unwrap_or(Path::new("untitled")); - let old_path = old_path.as_deref().unwrap_or(new_path); - if *predicted { write!( f, "// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}", DiffPathFmt(old_path), - DiffPathFmt(new_path) + DiffPathFmt(path) ) } else { write!( f, "--- a/{}\n+++ b/{}\n{diff}", DiffPathFmt(old_path), - DiffPathFmt(new_path) + DiffPathFmt(path) ) } } @@ -300,10 +299,11 @@ mod tests { #[test] fn test_event_display() { let ev = Event::BufferChange { - path: None, - old_path: None, + path: Path::new("untitled").into(), + old_path: Path::new("untitled").into(), diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(), predicted: false, + in_open_source_repo: true, }; assert_eq!( ev.to_string(), @@ -317,10 +317,11 @@ mod tests { ); let ev = Event::BufferChange { - path: Some(PathBuf::from("foo/bar.txt")), - old_path: Some(PathBuf::from("foo/bar.txt")), + path: Path::new("foo/bar.txt").into(), + old_path: Path::new("foo/bar.txt").into(), diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(), predicted: false, + in_open_source_repo: true, }; assert_eq!( ev.to_string(), @@ -334,10 +335,11 @@ mod tests { ); let ev = Event::BufferChange { - path: Some(PathBuf::from("abc.txt")), - old_path: Some(PathBuf::from("123.txt")), + path: Path::new("abc.txt").into(), + old_path: Path::new("123.txt").into(), diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(), predicted: false, + in_open_source_repo: true, }; assert_eq!( ev.to_string(), @@ -351,10 +353,11 @@ mod tests { ); let ev = Event::BufferChange { - path: Some(PathBuf::from("abc.txt")), - old_path: Some(PathBuf::from("123.txt")), + path: Path::new("abc.txt").into(), + old_path: Path::new("123.txt").into(), diff: "@@ -1,2 +1,2 @@\n-a\n-b\n".into(), predicted: true, + in_open_source_repo: true, }; assert_eq!( ev.to_string(), diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index 2ddabf750be763542bfc10b794afcb034ff08443..d67190c17556c5eb8b901e9baad73cc2691a9c78 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -432,7 +432,7 @@ pub fn write_excerpts<'a>( } } -pub fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) { +pub fn push_events(output: &mut String, events: &[Arc]) { if events.is_empty() { return; }; @@ -910,7 +910,7 @@ fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle } struct PromptData { - events: Vec, + events: Vec>, cursor_point: Point, cursor_path: Arc, // TODO: make a common struct with cursor_point included_files: Vec, diff --git a/crates/edit_prediction_button/Cargo.toml b/crates/edit_prediction_button/Cargo.toml index 9877b70161b3fdd16a0f667d85085520c9fe4f86..9062aca3c56f527385aecb000ebcd625f588eb9a 100644 --- a/crates/edit_prediction_button/Cargo.toml +++ b/crates/edit_prediction_button/Cargo.toml @@ -35,7 +35,6 @@ ui.workspace = true workspace.workspace = true zed_actions.workspace = true zeta.workspace = true -zeta2.workspace = true [dev-dependencies] copilot = { workspace = true, features = ["test-support"] } diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index 051ca6e85ccb985ba6b325cda725f83029aa3193..254caa698aa05214f73a749e540233952db4978b 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -21,7 +21,9 @@ use language::{ use project::DisableAiSettings; use regex::Regex; use settings::{ - EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore, update_settings_file, + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, + EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore, + update_settings_file, }; use std::{ sync::{Arc, LazyLock}, @@ -38,7 +40,7 @@ use workspace::{ }; use zed_actions::OpenBrowser; use zeta::RateCompletions; -use zeta2::SweepFeatureFlag; +use zeta::{SweepFeatureFlag, Zeta2FeatureFlag}; actions!( edit_prediction, @@ -300,10 +302,7 @@ impl Render for EditPredictionButton { .with_handle(self.popover_menu_handle.clone()), ) } - provider @ (EditPredictionProvider::Experimental( - EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, - ) - | EditPredictionProvider::Zed) => { + provider @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => { let enabled = self.editor_enabled.unwrap_or(true); let is_sweep = matches!( @@ -430,9 +429,7 @@ impl Render for EditPredictionButton { div().child(popover_menu.into_any_element()) } - EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => { - div().hidden() - } + EditPredictionProvider::None => div().hidden(), } } } @@ -497,6 +494,12 @@ impl EditPredictionButton { )); } + if cx.has_flag::() { + providers.push(EditPredictionProvider::Experimental( + EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, + )); + } + providers } @@ -554,7 +557,7 @@ impl EditPredictionButton { EditPredictionProvider::Experimental( EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, ) => { - let has_api_token = zeta2::Zeta::try_global(cx) + let has_api_token = zeta::Zeta::try_global(cx) .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token()); let entry = ContextMenuEntry::new("Sweep") @@ -571,6 +574,11 @@ impl EditPredictionButton { menu.item(entry) } + EditPredictionProvider::Experimental( + EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, + ) => menu.entry("Zeta2", None, move |_, cx| { + set_completion_provider(fs.clone(), cx, provider); + }), EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => { continue; } diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index fd5e6fcaf6435a2836ab1ad828933a9d0763f5b9..c599a4751b60f150e31b7ddf6e32a6234a510c74 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -13,6 +13,7 @@ use crate::{ }, task_context::RunnableRange, text_diff::text_diff, + unified_diff, }; pub use crate::{ Grammar, Language, LanguageRegistry, @@ -745,6 +746,33 @@ pub struct EditPreview { } impl EditPreview { + pub fn as_unified_diff(&self, edits: &[(Range, impl AsRef)]) -> Option { + let (first, _) = edits.first()?; + let (last, _) = edits.last()?; + + let start = first.start.to_point(&self.old_snapshot); + let old_end = last.end.to_point(&self.old_snapshot); + let new_end = last + .end + .bias_right(&self.old_snapshot) + .to_point(&self.applied_edits_snapshot); + + let start = Point::new(start.row.saturating_sub(3), 0); + let old_end = Point::new(old_end.row + 3, 0).min(self.old_snapshot.max_point()); + let new_end = Point::new(new_end.row + 3, 0).min(self.applied_edits_snapshot.max_point()); + + Some(unified_diff( + &self + .old_snapshot + .text_for_range(start..old_end) + .collect::(), + &self + .applied_edits_snapshot + .text_for_range(start..new_end) + .collect::(), + )) + } + pub fn highlight_edits( &self, current_snapshot: &BufferSnapshot, @@ -758,6 +786,8 @@ impl EditPreview { let mut highlighted_text = HighlightedTextBuilder::default(); + let visible_range_in_preview_snapshot = + visible_range_in_preview_snapshot.to_offset(&self.applied_edits_snapshot); let mut offset_in_preview_snapshot = visible_range_in_preview_snapshot.start; let insertion_highlight_style = HighlightStyle { @@ -825,7 +855,19 @@ impl EditPreview { highlighted_text.build() } - fn compute_visible_range(&self, edits: &[(Range, T)]) -> Option> { + pub fn build_result_buffer(&self, cx: &mut App) -> Entity { + cx.new(|cx| { + let mut buffer = Buffer::local_normalized( + self.applied_edits_snapshot.as_rope().clone(), + self.applied_edits_snapshot.line_ending(), + cx, + ); + buffer.set_language(self.syntax_snapshot.root_language(), cx); + buffer + }) + } + + pub fn compute_visible_range(&self, edits: &[(Range, T)]) -> Option> { let (first, _) = edits.first()?; let (last, _) = edits.last()?; @@ -842,7 +884,7 @@ impl EditPreview { let range = Point::new(start.row, 0) ..Point::new(end.row, self.applied_edits_snapshot.line_len(end.row)); - Some(range.to_offset(&self.applied_edits_snapshot)) + Some(range) } } diff --git a/crates/language/src/syntax_map.rs b/crates/language/src/syntax_map.rs index a9ac2faad9da9d5e07261ec826dda138921717a6..33a652b6fdeb32a2adbc1743cf8a70fe453518f5 100644 --- a/crates/language/src/syntax_map.rs +++ b/crates/language/src/syntax_map.rs @@ -279,6 +279,13 @@ impl SyntaxSnapshot { self.layers.is_empty() } + pub fn root_language(&self) -> Option> { + match &self.layers.first()?.content { + SyntaxLayerContent::Parsed { language, .. } => Some(language.clone()), + SyntaxLayerContent::Pending { .. } => None, + } + } + pub fn update_count(&self) -> usize { self.update_count } diff --git a/crates/settings/src/settings_content/language.rs b/crates/settings/src/settings_content/language.rs index 78ecc270166483b13af7e169b2390ad9f76d595d..166444c44b28133cfe20933c5b12acc42edb2399 100644 --- a/crates/settings/src/settings_content/language.rs +++ b/crates/settings/src/settings_content/language.rs @@ -78,6 +78,7 @@ pub enum EditPredictionProvider { } pub const EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME: &str = "sweep"; +pub const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2"; impl<'de> Deserialize<'de> for EditPredictionProvider { fn deserialize(deserializer: D) -> Result @@ -101,17 +102,25 @@ impl<'de> Deserialize<'de> for EditPredictionProvider { Content::Supermaven => EditPredictionProvider::Supermaven, Content::Zed => EditPredictionProvider::Zed, Content::Codestral => EditPredictionProvider::Codestral, + Content::Experimental(name) + if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME => + { + EditPredictionProvider::Experimental( + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, + ) + } + Content::Experimental(name) + if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME => + { + EditPredictionProvider::Experimental( + EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, + ) + } Content::Experimental(name) => { - if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME { - EditPredictionProvider::Experimental( - EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, - ) - } else { - return Err(D::Error::custom(format!( - "Unknown experimental edit prediction provider: {}", - name - ))); - } + return Err(D::Error::custom(format!( + "Unknown experimental edit prediction provider: {}", + name + ))); } }) } diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 68ba338102202f1803ab97746ec8372adb45a66a..470f1ea28a3663838080b7e7bf98f58215a0a8fc 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -161,7 +161,6 @@ workspace.workspace = true zed_actions.workspace = true zed_env_vars.workspace = true zeta.workspace = true -zeta2.workspace = true zlog.workspace = true zlog_settings.workspace = true chrono.workspace = true diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 577e81c6a9b36bc29a4b1d1f0cda63170c75d5a2..f413fd94cb1a48adb213120364ed2f59c4cf58e0 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -7,13 +7,14 @@ use feature_flags::FeatureFlagAppExt; use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; use language::language_settings::{EditPredictionProvider, all_language_settings}; use language_models::MistralLanguageModelProvider; -use settings::{EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore}; +use settings::{ + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, + EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore, +}; use std::{cell::RefCell, rc::Rc, sync::Arc}; use supermaven::{Supermaven, SupermavenCompletionProvider}; use ui::Window; -use zeta::ZetaEditPredictionProvider; -use zeta2::SweepFeatureFlag; -use zeta2::Zeta2FeatureFlag; +use zeta::{SweepFeatureFlag, Zeta2FeatureFlag, ZetaEditPredictionProvider}; pub fn init(client: Arc, user_store: Entity, cx: &mut App) { let editors: Rc, AnyWindowHandle>>> = Rc::default(); @@ -100,9 +101,7 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { } fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) { - if let Some(zeta) = zeta::Zeta::global(cx) { - zeta.update(cx, |zeta, _| zeta.clear_history()); - } else if let Some(zeta) = zeta2::Zeta::try_global(cx) { + if let Some(zeta) = zeta::Zeta::try_global(cx) { zeta.update(cx, |zeta, _| zeta.clear_history()); } } @@ -204,86 +203,41 @@ fn assign_edit_prediction_provider( editor.set_edit_prediction_provider(Some(provider), window, cx); } value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => { - let zeta2 = zeta2::Zeta::global(client, &user_store, cx); - - if let Some(project) = editor.project() { - let mut worktree = None; - if let Some(buffer) = &singleton_buffer - && let Some(file) = buffer.read(cx).file() - { - let id = file.worktree_id(cx); - worktree = project.read(cx).worktree_for_id(id, cx); - } - - if let EditPredictionProvider::Experimental(name) = value - && name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME - && cx.has_flag::() - { - let provider = cx.new(|cx| { - zeta2::ZetaEditPredictionProvider::new( - project.clone(), - &client, - &user_store, - cx, - ) - }); - - if let Some(buffer) = &singleton_buffer - && buffer.read(cx).file().is_some() - { - zeta2.update(cx, |zeta, cx| { - zeta.set_edit_prediction_model(zeta2::ZetaEditPredictionModel::Sweep); - zeta.register_buffer(buffer, project, cx); - }); - } - - editor.set_edit_prediction_provider(Some(provider), window, cx); - } else if user_store.read(cx).current_user().is_some() { - if cx.has_flag::() { - let zeta = zeta2::Zeta::global(client, &user_store, cx); - let provider = cx.new(|cx| { - zeta2::ZetaEditPredictionProvider::new( - project.clone(), - &client, - &user_store, - cx, - ) - }); - - // TODO [zeta2] handle multibuffers - if let Some(buffer) = &singleton_buffer - && buffer.read(cx).file().is_some() + let zeta = zeta::Zeta::global(client, &user_store, cx); + + if let Some(project) = editor.project() + && let Some(buffer) = &singleton_buffer + && buffer.read(cx).file().is_some() + { + let has_model = zeta.update(cx, |zeta, cx| { + let model = if let EditPredictionProvider::Experimental(name) = value { + if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME + && cx.has_flag::() + { + zeta::ZetaEditPredictionModel::Sweep + } else if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME + && cx.has_flag::() { - zeta.update(cx, |zeta, cx| { - zeta.set_edit_prediction_model( - zeta2::ZetaEditPredictionModel::ZedCloud, - ); - zeta.register_buffer(buffer, project, cx); - }); + zeta::ZetaEditPredictionModel::Zeta2 + } else { + return false; } - - editor.set_edit_prediction_provider(Some(provider), window, cx); + } else if user_store.read(cx).current_user().is_some() { + zeta::ZetaEditPredictionModel::Zeta1 } else { - let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx); + return false; + }; - if let Some(buffer) = &singleton_buffer - && buffer.read(cx).file().is_some() - { - zeta.update(cx, |zeta, cx| { - zeta.register_buffer(buffer, project, cx); - }); - } + zeta.set_edit_prediction_model(model); + zeta.register_buffer(buffer, project, cx); + true + }); - let provider = cx.new(|cx| { - zeta::ZetaEditPredictionProvider::new( - zeta, - project.clone(), - singleton_buffer, - cx, - ) - }); - editor.set_edit_prediction_provider(Some(provider), window, cx); - } + if has_model { + let provider = cx.new(|cx| { + ZetaEditPredictionProvider::new(project.clone(), &client, &user_store, cx) + }); + editor.set_edit_prediction_provider(Some(provider), window, cx); } } } diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index df569c7bc39655d99ee01b464a05e0ef3873f8d6..61eeab16229d82dc01d800f37bf729aa11469afd 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -4,81 +4,80 @@ version = "0.1.0" edition.workspace = true publish.workspace = true license = "GPL-3.0-or-later" -exclude = ["fixtures"] [lints] workspace = true [lib] path = "src/zeta.rs" -doctest = false [features] -test-support = [] +eval-support = [] [dependencies] ai_onboarding.workspace = true anyhow.workspace = true arrayvec.workspace = true +brotli.workspace = true +buffer_diff.workspace = true client.workspace = true cloud_llm_client.workspace = true +cloud_zeta2_prompt.workspace = true +copilot.workspace = true collections.workspace = true command_palette_hooks.workspace = true -copilot.workspace = true db.workspace = true edit_prediction.workspace = true +edit_prediction_context.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true futures.workspace = true gpui.workspace = true -http_client.workspace = true indoc.workspace = true itertools.workspace = true language.workspace = true language_model.workspace = true log.workspace = true +lsp.workspace = true +markdown.workspace = true menu.workspace = true +open_ai.workspace = true +pretty_assertions.workspace = true postage.workspace = true project.workspace = true rand.workspace = true -regex.workspace = true release_channel.workspace = true +regex.workspace = true semver.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true +smol.workspace = true +strsim.workspace = true strum.workspace = true telemetry.workspace = true telemetry_events.workspace = true theme.workspace = true thiserror.workspace = true -ui.workspace = true util.workspace = true +ui.workspace = true uuid.workspace = true workspace.workspace = true worktree.workspace = true zed_actions.workspace = true [dev-dependencies] -call = { workspace = true, features = ["test-support"] } -client = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } cloud_api_types.workspace = true -collections = { workspace = true, features = ["test-support"] } +cloud_llm_client = { workspace = true, features = ["test-support"] } ctor.workspace = true -editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } -http_client = { workspace = true, features = ["test-support"] } indoc.workspace = true language = { workspace = true, features = ["test-support"] } +language_model = { workspace = true, features = ["test-support"] } +lsp.workspace = true parking_lot.workspace = true -reqwest_client = { workspace = true, features = ["test-support"] } -rpc = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } -theme = { workspace = true, features = ["test-support"] } -tree-sitter-go.workspace = true -tree-sitter-rust.workspace = true -workspace = { workspace = true, features = ["test-support"] } -worktree = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/zeta2/src/assemble_excerpts.rs b/crates/zeta/src/assemble_excerpts.rs similarity index 100% rename from crates/zeta2/src/assemble_excerpts.rs rename to crates/zeta/src/assemble_excerpts.rs diff --git a/crates/zeta/src/completion_diff_element.rs b/crates/zeta/src/completion_diff_element.rs deleted file mode 100644 index 73c3cb20cd7de5da92fbf6e5a32a8ca8d42a5933..0000000000000000000000000000000000000000 --- a/crates/zeta/src/completion_diff_element.rs +++ /dev/null @@ -1,173 +0,0 @@ -use std::cmp; - -use crate::EditPrediction; -use gpui::{ - AnyElement, App, BorderStyle, Bounds, Corners, Edges, HighlightStyle, Hsla, StyledText, - TextLayout, TextStyle, point, prelude::*, quad, size, -}; -use language::OffsetRangeExt; -use settings::Settings; -use theme::ThemeSettings; -use ui::prelude::*; - -pub struct CompletionDiffElement { - element: AnyElement, - text_layout: TextLayout, - cursor_offset: usize, -} - -impl CompletionDiffElement { - pub fn new(completion: &EditPrediction, cx: &App) -> Self { - let mut diff = completion - .snapshot - .text_for_range(completion.excerpt_range.clone()) - .collect::(); - - let mut cursor_offset_in_diff = None; - let mut delta = 0; - let mut diff_highlights = Vec::new(); - for (old_range, new_text) in completion.edits.iter() { - let old_range = old_range.to_offset(&completion.snapshot); - - if cursor_offset_in_diff.is_none() && completion.cursor_offset <= old_range.end { - cursor_offset_in_diff = - Some(completion.cursor_offset - completion.excerpt_range.start + delta); - } - - let old_start_in_diff = old_range.start - completion.excerpt_range.start + delta; - let old_end_in_diff = old_range.end - completion.excerpt_range.start + delta; - if old_start_in_diff < old_end_in_diff { - diff_highlights.push(( - old_start_in_diff..old_end_in_diff, - HighlightStyle { - background_color: Some(cx.theme().status().deleted_background), - strikethrough: Some(gpui::StrikethroughStyle { - thickness: px(1.), - color: Some(cx.theme().colors().text_muted), - }), - ..Default::default() - }, - )); - } - - if !new_text.is_empty() { - diff.insert_str(old_end_in_diff, new_text); - diff_highlights.push(( - old_end_in_diff..old_end_in_diff + new_text.len(), - HighlightStyle { - background_color: Some(cx.theme().status().created_background), - ..Default::default() - }, - )); - delta += new_text.len(); - } - } - - let cursor_offset_in_diff = cursor_offset_in_diff - .unwrap_or_else(|| completion.cursor_offset - completion.excerpt_range.start + delta); - - let settings = ThemeSettings::get_global(cx).clone(); - let text_style = TextStyle { - color: cx.theme().colors().editor_foreground, - font_size: settings.buffer_font_size(cx).into(), - font_family: settings.buffer_font.family, - font_features: settings.buffer_font.features, - font_fallbacks: settings.buffer_font.fallbacks, - line_height: relative(settings.buffer_line_height.value()), - font_weight: settings.buffer_font.weight, - font_style: settings.buffer_font.style, - ..Default::default() - }; - let element = StyledText::new(diff).with_default_highlights(&text_style, diff_highlights); - let text_layout = element.layout().clone(); - - CompletionDiffElement { - element: element.into_any_element(), - text_layout, - cursor_offset: cursor_offset_in_diff, - } - } -} - -impl IntoElement for CompletionDiffElement { - type Element = Self; - - fn into_element(self) -> Self { - self - } -} - -impl Element for CompletionDiffElement { - type RequestLayoutState = (); - type PrepaintState = (); - - fn id(&self) -> Option { - None - } - - fn source_location(&self) -> Option<&'static core::panic::Location<'static>> { - None - } - - fn request_layout( - &mut self, - _id: Option<&gpui::GlobalElementId>, - _inspector_id: Option<&gpui::InspectorElementId>, - window: &mut Window, - cx: &mut App, - ) -> (gpui::LayoutId, Self::RequestLayoutState) { - (self.element.request_layout(window, cx), ()) - } - - fn prepaint( - &mut self, - _id: Option<&gpui::GlobalElementId>, - _inspector_id: Option<&gpui::InspectorElementId>, - _bounds: gpui::Bounds, - _request_layout: &mut Self::RequestLayoutState, - window: &mut Window, - cx: &mut App, - ) -> Self::PrepaintState { - self.element.prepaint(window, cx); - } - - fn paint( - &mut self, - _id: Option<&gpui::GlobalElementId>, - _inspector_id: Option<&gpui::InspectorElementId>, - _bounds: gpui::Bounds, - _request_layout: &mut Self::RequestLayoutState, - _prepaint: &mut Self::PrepaintState, - window: &mut Window, - cx: &mut App, - ) { - if let Some(position) = self.text_layout.position_for_index(self.cursor_offset) { - let bounds = self.text_layout.bounds(); - let line_height = self.text_layout.line_height(); - let line_width = self - .text_layout - .line_layout_for_index(self.cursor_offset) - .map_or(bounds.size.width, |layout| layout.width()); - window.paint_quad(quad( - Bounds::new( - point(bounds.origin.x, position.y), - size(cmp::max(bounds.size.width, line_width), line_height), - ), - Corners::default(), - cx.theme().colors().editor_active_line_background, - Edges::default(), - Hsla::transparent_black(), - BorderStyle::default(), - )); - self.element.paint(window, cx); - window.paint_quad(quad( - Bounds::new(position, size(px(2.), line_height)), - Corners::default(), - cx.theme().players().local().cursor, - Edges::default(), - Hsla::transparent_black(), - BorderStyle::default(), - )); - } - } -} diff --git a/crates/zeta/src/init.rs b/crates/zeta/src/init.rs deleted file mode 100644 index 0167d878fa34976d7175a64269d9dfe29d18d8fe..0000000000000000000000000000000000000000 --- a/crates/zeta/src/init.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::any::{Any, TypeId}; - -use command_palette_hooks::CommandPaletteFilter; -use feature_flags::{FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag}; -use gpui::actions; -use language::language_settings::EditPredictionProvider; -use project::DisableAiSettings; -use settings::{Settings, SettingsStore, update_settings_file}; -use ui::App; -use workspace::Workspace; - -use crate::{RateCompletionModal, onboarding_modal::ZedPredictModal}; - -actions!( - edit_prediction, - [ - /// Resets the edit prediction onboarding state. - ResetOnboarding, - /// Opens the rate completions modal. - RateCompletions - ] -); - -pub fn init(cx: &mut App) { - feature_gate_predict_edits_actions(cx); - - cx.observe_new(move |workspace: &mut Workspace, _, _cx| { - workspace.register_action(|workspace, _: &RateCompletions, window, cx| { - if cx.has_flag::() { - RateCompletionModal::toggle(workspace, window, cx); - } - }); - - workspace.register_action( - move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| { - ZedPredictModal::toggle( - workspace, - workspace.user_store().clone(), - workspace.client().clone(), - window, - cx, - ) - }, - ); - - workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| { - update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| { - settings - .project - .all_languages - .features - .get_or_insert_default() - .edit_prediction_provider = Some(EditPredictionProvider::None) - }); - }); - }) - .detach(); -} - -fn feature_gate_predict_edits_actions(cx: &mut App) { - let rate_completion_action_types = [TypeId::of::()]; - let reset_onboarding_action_types = [TypeId::of::()]; - let zeta_all_action_types = [ - TypeId::of::(), - TypeId::of::(), - zed_actions::OpenZedPredictOnboarding.type_id(), - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - TypeId::of::(), - ]; - - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_action_types(&rate_completion_action_types); - filter.hide_action_types(&reset_onboarding_action_types); - filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]); - }); - - cx.observe_global::(move |cx| { - let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; - let has_feature_flag = cx.has_flag::(); - - CommandPaletteFilter::update_global(cx, |filter, _cx| { - if is_ai_disabled { - filter.hide_action_types(&zeta_all_action_types); - } else if has_feature_flag { - filter.show_action_types(&rate_completion_action_types); - } else { - filter.hide_action_types(&rate_completion_action_types); - } - }); - }) - .detach(); - - cx.observe_flag::(move |is_enabled, cx| { - if !DisableAiSettings::get_global(cx).disable_ai { - if is_enabled { - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.show_action_types(&rate_completion_action_types); - }); - } else { - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_action_types(&rate_completion_action_types); - }); - } - } - }) - .detach(); -} diff --git a/crates/zeta/src/onboarding_modal.rs b/crates/zeta/src/onboarding_modal.rs index 94480add3053bece5017cf478e9f74065491639b..ed7adfc75476afb07f9c56b9c9c03abbbcef1134 100644 --- a/crates/zeta/src/onboarding_modal.rs +++ b/crates/zeta/src/onboarding_modal.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::{ZedPredictUpsell, onboarding_event}; +use crate::ZedPredictUpsell; use ai_onboarding::EditPredictionOnboarding; use client::{Client, UserStore}; use db::kvp::Dismissable; @@ -14,6 +14,16 @@ use settings::update_settings_file; use ui::{Vector, VectorName, prelude::*}; use workspace::{ModalView, Workspace}; +#[macro_export] +macro_rules! onboarding_event { + ($name:expr) => { + telemetry::event!($name, source = "Edit Prediction Onboarding"); + }; + ($name:expr, $($key:ident $(= $value:expr)?),+ $(,)?) => { + telemetry::event!($name, source = "Edit Prediction Onboarding", $($key $(= $value)?),+); + }; +} + /// Introduces user to Zed's Edit Prediction feature pub struct ZedPredictModal { onboarding: Entity, diff --git a/crates/zeta/src/onboarding_telemetry.rs b/crates/zeta/src/onboarding_telemetry.rs deleted file mode 100644 index 3c7d5e1442947c3e8cea446ebf37597a3cce1f80..0000000000000000000000000000000000000000 --- a/crates/zeta/src/onboarding_telemetry.rs +++ /dev/null @@ -1,9 +0,0 @@ -#[macro_export] -macro_rules! onboarding_event { - ($name:expr) => { - telemetry::event!($name, source = "Edit Prediction Onboarding"); - }; - ($name:expr, $($key:ident $(= $value:expr)?),+ $(,)?) => { - telemetry::event!($name, source = "Edit Prediction Onboarding", $($key $(= $value)?),+); - }; -} diff --git a/crates/zeta2/src/prediction.rs b/crates/zeta/src/prediction.rs similarity index 86% rename from crates/zeta2/src/prediction.rs rename to crates/zeta/src/prediction.rs index e9f726ce00c36b5235919c0e185876996f4fda03..0125e739f335fc133cbff84dcd8b4c4bac3e6e7b 100644 --- a/crates/zeta2/src/prediction.rs +++ b/crates/zeta/src/prediction.rs @@ -1,7 +1,13 @@ -use std::{ops::Range, sync::Arc}; +use std::{ + ops::Range, + path::Path, + sync::Arc, + time::{Duration, Instant}, +}; use gpui::{AsyncApp, Entity, SharedString}; use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot}; +use serde::Serialize; #[derive(Clone, Default, Debug, PartialEq, Eq, Hash)] pub struct EditPredictionId(pub SharedString); @@ -26,6 +32,17 @@ pub struct EditPrediction { pub edit_preview: EditPreview, // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction. pub buffer: Entity, + pub buffer_snapshotted_at: Instant, + pub response_received_at: Instant, + pub inputs: EditPredictionInputs, +} + +#[derive(Debug, Clone, Serialize)] +pub struct EditPredictionInputs { + pub events: Vec>, + pub included_files: Vec, + pub cursor_point: cloud_llm_client::predict_edits_v3::Point, + pub cursor_path: Arc, } impl EditPrediction { @@ -33,14 +50,17 @@ impl EditPrediction { id: EditPredictionId, edited_buffer: &Entity, edited_buffer_snapshot: &BufferSnapshot, - edits: Vec<(Range, Arc)>, + edits: Arc<[(Range, Arc)]>, + buffer_snapshotted_at: Instant, + response_received_at: Instant, + inputs: EditPredictionInputs, cx: &mut AsyncApp, ) -> Option { let (edits, snapshot, edit_preview_task) = edited_buffer .read_with(cx, |buffer, cx| { let new_snapshot = buffer.snapshot(); let edits: Arc<[_]> = - interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits.into())?.into(); + interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits)?.into(); Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx))) }) @@ -53,7 +73,10 @@ impl EditPrediction { edits, snapshot, edit_preview, + inputs, buffer: edited_buffer.clone(), + buffer_snapshotted_at, + response_received_at, }) } @@ -67,6 +90,10 @@ impl EditPrediction { pub fn targets_buffer(&self, buffer: &Buffer) -> bool { self.snapshot.remote_id() == buffer.remote_id() } + + pub fn latency(&self) -> Duration { + self.response_received_at - self.buffer_snapshotted_at + } } impl std::fmt::Debug for EditPrediction { @@ -147,6 +174,17 @@ mod tests { snapshot: cx.read(|cx| buffer.read(cx).snapshot()), buffer: buffer.clone(), edit_preview, + inputs: EditPredictionInputs { + events: vec![], + included_files: vec![], + cursor_point: cloud_llm_client::predict_edits_v3::Point { + line: cloud_llm_client::predict_edits_v3::Line(0), + column: 0, + }, + cursor_path: Path::new("path.txt").into(), + }, + buffer_snapshotted_at: Instant::now(), + response_received_at: Instant::now(), }; cx.update(|cx| { diff --git a/crates/zeta2/src/provider.rs b/crates/zeta/src/provider.rs similarity index 93% rename from crates/zeta2/src/provider.rs rename to crates/zeta/src/provider.rs index 768af6253fe1a2aa60ef9cb0a10fcee0035dc3e2..a2b3eed1b5efe953ebdf5a2448ca06e7866bea86 100644 --- a/crates/zeta2/src/provider.rs +++ b/crates/zeta/src/provider.rs @@ -131,8 +131,14 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { } fn discard(&mut self, cx: &mut Context) { - self.zeta.update(cx, |zeta, _cx| { - zeta.discard_current_prediction(&self.project); + self.zeta.update(cx, |zeta, cx| { + zeta.discard_current_prediction(&self.project, cx); + }); + } + + fn did_show(&mut self, cx: &mut Context) { + self.zeta.update(cx, |zeta, cx| { + zeta.did_show_current_prediction(&self.project, cx); }); } @@ -162,8 +168,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { let snapshot = buffer.snapshot(); let Some(edits) = prediction.interpolate(&snapshot) else { - self.zeta.update(cx, |zeta, _cx| { - zeta.discard_current_prediction(&self.project); + self.zeta.update(cx, |zeta, cx| { + zeta.discard_current_prediction(&self.project, cx); }); return None; }; diff --git a/crates/zeta/src/rate_completion_modal.rs b/crates/zeta/src/rate_prediction_modal.rs similarity index 60% rename from crates/zeta/src/rate_completion_modal.rs rename to crates/zeta/src/rate_prediction_modal.rs index a081538f5528946ea5b959981b7bd70d44b8b11b..0cceb86608ed609122c81d406c71280894789e88 100644 --- a/crates/zeta/src/rate_completion_modal.rs +++ b/crates/zeta/src/rate_prediction_modal.rs @@ -1,8 +1,18 @@ -use crate::{CompletionDiffElement, EditPrediction, EditPredictionRating, Zeta}; -use editor::Editor; -use gpui::{App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, actions, prelude::*}; -use language::language_settings; +use crate::{EditPrediction, EditPredictionRating, Zeta}; +use buffer_diff::{BufferDiff, BufferDiffSnapshot}; +use cloud_zeta2_prompt::write_codeblock; +use editor::{Editor, ExcerptRange, MultiBuffer}; +use gpui::{ + App, BorderStyle, DismissEvent, EdgesRefinement, Entity, EventEmitter, FocusHandle, Focusable, + Length, StyleRefinement, TextStyleRefinement, Window, actions, prelude::*, +}; +use language::{LanguageRegistry, Point, language_settings}; +use markdown::{Markdown, MarkdownStyle}; +use settings::Settings as _; +use std::fmt::Write; +use std::sync::Arc; use std::time::Duration; +use theme::ThemeSettings; use ui::{KeyBinding, List, ListItem, ListItemSpacing, Tooltip, prelude::*}; use workspace::{ModalView, Workspace}; @@ -10,41 +20,44 @@ actions!( zeta, [ /// Rates the active completion with a thumbs up. - ThumbsUpActiveCompletion, + ThumbsUpActivePrediction, /// Rates the active completion with a thumbs down. - ThumbsDownActiveCompletion, + ThumbsDownActivePrediction, /// Navigates to the next edit in the completion history. NextEdit, /// Navigates to the previous edit in the completion history. PreviousEdit, /// Focuses on the completions list. - FocusCompletions, + FocusPredictions, /// Previews the selected completion. - PreviewCompletion, + PreviewPrediction, ] ); -pub struct RateCompletionModal { +pub struct RatePredictionsModal { zeta: Entity, - active_completion: Option, + language_registry: Arc, + active_prediction: Option, selected_index: usize, + diff_editor: Entity, focus_handle: FocusHandle, _subscription: gpui::Subscription, - current_view: RateCompletionView, + current_view: RatePredictionView, } -struct ActiveCompletion { - completion: EditPrediction, +struct ActivePrediction { + prediction: EditPrediction, feedback_editor: Entity, + formatted_inputs: Entity, } #[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] -enum RateCompletionView { +enum RatePredictionView { SuggestedEdits, RawInput, } -impl RateCompletionView { +impl RatePredictionView { pub fn name(&self) -> &'static str { match self { Self::SuggestedEdits => "Suggested Edits", @@ -53,25 +66,42 @@ impl RateCompletionView { } } -impl RateCompletionModal { +impl RatePredictionsModal { pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context) { - if let Some(zeta) = Zeta::global(cx) { - workspace.toggle_modal(window, cx, |_window, cx| RateCompletionModal::new(zeta, cx)); + if let Some(zeta) = Zeta::try_global(cx) { + let language_registry = workspace.app_state().languages.clone(); + workspace.toggle_modal(window, cx, |window, cx| { + RatePredictionsModal::new(zeta, language_registry, window, cx) + }); - telemetry::event!("Rate Completion Modal Open", source = "Edit Prediction"); + telemetry::event!("Rate Prediction Modal Open", source = "Edit Prediction"); } } - pub fn new(zeta: Entity, cx: &mut Context) -> Self { + pub fn new( + zeta: Entity, + language_registry: Arc, + window: &mut Window, + cx: &mut Context, + ) -> Self { let subscription = cx.observe(&zeta, |_, _, cx| cx.notify()); Self { zeta, + language_registry, selected_index: 0, focus_handle: cx.focus_handle(), - active_completion: None, + active_prediction: None, _subscription: subscription, - current_view: RateCompletionView::SuggestedEdits, + diff_editor: cx.new(|cx| { + let multibuffer = cx.new(|_| MultiBuffer::new(language::Capability::ReadOnly)); + let mut editor = Editor::for_multibuffer(multibuffer, None, window, cx); + editor.disable_inline_diagnostics(); + editor.set_expand_all_diff_hunks(cx); + editor.set_show_git_diff_gutter(false, cx); + editor + }), + current_view: RatePredictionView::SuggestedEdits, } } @@ -83,7 +113,7 @@ impl RateCompletionModal { self.selected_index += 1; self.selected_index = usize::min( self.selected_index, - self.zeta.read(cx).shown_completions().count(), + self.zeta.read(cx).shown_predictions().count(), ); cx.notify(); } @@ -102,7 +132,7 @@ impl RateCompletionModal { let next_index = self .zeta .read(cx) - .shown_completions() + .shown_predictions() .skip(self.selected_index) .enumerate() .skip(1) // Skip straight to the next item @@ -122,7 +152,7 @@ impl RateCompletionModal { let prev_index = self .zeta .read(cx) - .shown_completions() + .shown_predictions() .rev() .skip((completions_len - 1) - self.selected_index) .enumerate() @@ -149,14 +179,14 @@ impl RateCompletionModal { pub fn thumbs_up_active( &mut self, - _: &ThumbsUpActiveCompletion, + _: &ThumbsUpActivePrediction, window: &mut Window, cx: &mut Context, ) { self.zeta.update(cx, |zeta, cx| { - if let Some(active) = &self.active_completion { - zeta.rate_completion( - &active.completion, + if let Some(active) = &self.active_prediction { + zeta.rate_prediction( + &active.prediction, EditPredictionRating::Positive, active.feedback_editor.read(cx).text(cx), cx, @@ -165,9 +195,9 @@ impl RateCompletionModal { }); let current_completion = self - .active_completion + .active_prediction .as_ref() - .map(|completion| completion.completion.clone()); + .map(|completion| completion.prediction.clone()); self.select_completion(current_completion, false, window, cx); self.select_next_edit(&Default::default(), window, cx); self.confirm(&Default::default(), window, cx); @@ -177,18 +207,18 @@ impl RateCompletionModal { pub fn thumbs_down_active( &mut self, - _: &ThumbsDownActiveCompletion, + _: &ThumbsDownActivePrediction, window: &mut Window, cx: &mut Context, ) { - if let Some(active) = &self.active_completion { + if let Some(active) = &self.active_prediction { if active.feedback_editor.read(cx).text(cx).is_empty() { return; } self.zeta.update(cx, |zeta, cx| { - zeta.rate_completion( - &active.completion, + zeta.rate_prediction( + &active.prediction, EditPredictionRating::Negative, active.feedback_editor.read(cx).text(cx), cx, @@ -197,9 +227,9 @@ impl RateCompletionModal { } let current_completion = self - .active_completion + .active_prediction .as_ref() - .map(|completion| completion.completion.clone()); + .map(|completion| completion.prediction.clone()); self.select_completion(current_completion, false, window, cx); self.select_next_edit(&Default::default(), window, cx); self.confirm(&Default::default(), window, cx); @@ -209,7 +239,7 @@ impl RateCompletionModal { fn focus_completions( &mut self, - _: &FocusCompletions, + _: &FocusPredictions, window: &mut Window, cx: &mut Context, ) { @@ -219,14 +249,14 @@ impl RateCompletionModal { fn preview_completion( &mut self, - _: &PreviewCompletion, + _: &PreviewPrediction, window: &mut Window, cx: &mut Context, ) { let completion = self .zeta .read(cx) - .shown_completions() + .shown_predictions() .skip(self.selected_index) .take(1) .next() @@ -239,7 +269,7 @@ impl RateCompletionModal { let completion = self .zeta .read(cx) - .shown_completions() + .shown_predictions() .skip(self.selected_index) .take(1) .next() @@ -250,54 +280,145 @@ impl RateCompletionModal { pub fn select_completion( &mut self, - completion: Option, + prediction: Option, focus: bool, window: &mut Window, cx: &mut Context, ) { // Avoid resetting completion rating if it's already selected. - if let Some(completion) = completion.as_ref() { + if let Some(prediction) = prediction { self.selected_index = self .zeta .read(cx) - .shown_completions() + .shown_predictions() .enumerate() - .find(|(_, completion_b)| completion.id == completion_b.id) + .find(|(_, completion_b)| prediction.id == completion_b.id) .map(|(ix, _)| ix) .unwrap_or(self.selected_index); cx.notify(); - if let Some(prev_completion) = self.active_completion.as_ref() - && completion.id == prev_completion.completion.id + if let Some(prev_prediction) = self.active_prediction.as_ref() + && prediction.id == prev_prediction.prediction.id { if focus { - window.focus(&prev_completion.feedback_editor.focus_handle(cx)); + window.focus(&prev_prediction.feedback_editor.focus_handle(cx)); } return; } + + self.diff_editor.update(cx, |editor, cx| { + let new_buffer = prediction.edit_preview.build_result_buffer(cx); + let new_buffer_snapshot = new_buffer.read(cx).snapshot(); + let old_buffer_snapshot = prediction.snapshot.clone(); + let new_buffer_id = new_buffer_snapshot.remote_id(); + + let range = prediction + .edit_preview + .compute_visible_range(&prediction.edits) + .unwrap_or(Point::zero()..Point::zero()); + let start = Point::new(range.start.row.saturating_sub(5), 0); + let end = Point::new(range.end.row + 5, 0).min(new_buffer_snapshot.max_point()); + + let diff = cx.new::(|cx| { + let diff_snapshot = BufferDiffSnapshot::new_with_base_buffer( + new_buffer_snapshot.text.clone(), + Some(old_buffer_snapshot.text().into()), + old_buffer_snapshot.clone(), + cx, + ); + let diff = BufferDiff::new(&new_buffer_snapshot, cx); + cx.spawn(async move |diff, cx| { + let diff_snapshot = diff_snapshot.await; + diff.update(cx, |diff, cx| { + diff.set_snapshot(diff_snapshot, &new_buffer_snapshot.text, cx); + }) + }) + .detach(); + diff + }); + + editor.disable_header_for_buffer(new_buffer_id, cx); + editor.buffer().update(cx, |multibuffer, cx| { + multibuffer.clear(cx); + multibuffer.push_excerpts( + new_buffer, + vec![ExcerptRange { + context: start..end, + primary: start..end, + }], + cx, + ); + multibuffer.add_diff(diff, cx); + }); + }); + + let mut formatted_inputs = String::new(); + + write!(&mut formatted_inputs, "## Events\n\n").unwrap(); + + for event in &prediction.inputs.events { + write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap(); + } + + write!(&mut formatted_inputs, "## Included files\n\n").unwrap(); + + for included_file in &prediction.inputs.included_files { + let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")]; + + write!( + &mut formatted_inputs, + "### {}\n\n", + included_file.path.display() + ) + .unwrap(); + + write_codeblock( + &included_file.path, + &included_file.excerpts, + if included_file.path == prediction.inputs.cursor_path { + cursor_insertions + } else { + &[] + }, + included_file.max_row, + false, + &mut formatted_inputs, + ); + } + + self.active_prediction = Some(ActivePrediction { + prediction, + feedback_editor: cx.new(|cx| { + let mut editor = Editor::multi_line(window, cx); + editor.disable_scrollbars_and_minimap(window, cx); + editor.set_soft_wrap_mode(language_settings::SoftWrap::EditorWidth, cx); + editor.set_show_line_numbers(false, cx); + editor.set_show_git_diff_gutter(false, cx); + editor.set_show_code_actions(false, cx); + editor.set_show_runnables(false, cx); + editor.set_show_breakpoints(false, cx); + editor.set_show_wrap_guides(false, cx); + editor.set_show_indent_guides(false, cx); + editor.set_show_edit_predictions(Some(false), window, cx); + editor.set_placeholder_text("Add your feedback…", window, cx); + if focus { + cx.focus_self(window); + } + editor + }), + formatted_inputs: cx.new(|cx| { + Markdown::new( + formatted_inputs.into(), + Some(self.language_registry.clone()), + None, + cx, + ) + }), + }); + } else { + self.active_prediction = None; } - self.active_completion = completion.map(|completion| ActiveCompletion { - completion, - feedback_editor: cx.new(|cx| { - let mut editor = Editor::multi_line(window, cx); - editor.disable_scrollbars_and_minimap(window, cx); - editor.set_soft_wrap_mode(language_settings::SoftWrap::EditorWidth, cx); - editor.set_show_line_numbers(false, cx); - editor.set_show_git_diff_gutter(false, cx); - editor.set_show_code_actions(false, cx); - editor.set_show_runnables(false, cx); - editor.set_show_breakpoints(false, cx); - editor.set_show_wrap_guides(false, cx); - editor.set_show_indent_guides(false, cx); - editor.set_show_edit_predictions(Some(false), window, cx); - editor.set_placeholder_text("Add your feedback…", window, cx); - if focus { - cx.focus_self(window); - } - editor - }), - }); cx.notify(); } @@ -312,33 +433,31 @@ impl RateCompletionModal { .child( Button::new( ElementId::Name("suggested-edits".into()), - RateCompletionView::SuggestedEdits.name(), + RatePredictionView::SuggestedEdits.name(), ) .label_size(LabelSize::Small) .on_click(cx.listener(move |this, _, _window, cx| { - this.current_view = RateCompletionView::SuggestedEdits; + this.current_view = RatePredictionView::SuggestedEdits; cx.notify(); })) - .toggle_state(self.current_view == RateCompletionView::SuggestedEdits), + .toggle_state(self.current_view == RatePredictionView::SuggestedEdits), ) .child( Button::new( ElementId::Name("raw-input".into()), - RateCompletionView::RawInput.name(), + RatePredictionView::RawInput.name(), ) .label_size(LabelSize::Small) .on_click(cx.listener(move |this, _, _window, cx| { - this.current_view = RateCompletionView::RawInput; + this.current_view = RatePredictionView::RawInput; cx.notify(); })) - .toggle_state(self.current_view == RateCompletionView::RawInput), + .toggle_state(self.current_view == RatePredictionView::RawInput), ) } fn render_suggested_edits(&self, cx: &mut Context) -> Option> { - let active_completion = self.active_completion.as_ref()?; let bg_color = cx.theme().colors().editor_background; - Some( div() .id("diff") @@ -347,14 +466,18 @@ impl RateCompletionModal { .bg(bg_color) .overflow_scroll() .whitespace_nowrap() - .child(CompletionDiffElement::new( - &active_completion.completion, - cx, - )), + .child(self.diff_editor.clone()), ) } - fn render_raw_input(&self, cx: &mut Context) -> Option> { + fn render_raw_input( + &self, + window: &mut Window, + cx: &mut Context, + ) -> Option> { + let theme_settings = ThemeSettings::get_global(cx); + let buffer_font_size = theme_settings.buffer_font_size(cx); + Some( v_flex() .size_full() @@ -368,30 +491,81 @@ impl RateCompletionModal { .size_full() .bg(cx.theme().colors().editor_background) .overflow_scroll() - .child(if let Some(active_completion) = &self.active_completion { - format!( - "{}\n{}", - active_completion.completion.input_events, - active_completion.completion.input_excerpt + .child(if let Some(active_prediction) = &self.active_prediction { + markdown::MarkdownElement::new( + active_prediction.formatted_inputs.clone(), + MarkdownStyle { + base_text_style: window.text_style(), + syntax: cx.theme().syntax().clone(), + code_block: StyleRefinement { + text: Some(TextStyleRefinement { + font_family: Some( + theme_settings.buffer_font.family.clone(), + ), + font_size: Some(buffer_font_size.into()), + ..Default::default() + }), + padding: EdgesRefinement { + top: Some(DefiniteLength::Absolute( + AbsoluteLength::Pixels(px(8.)), + )), + left: Some(DefiniteLength::Absolute( + AbsoluteLength::Pixels(px(8.)), + )), + right: Some(DefiniteLength::Absolute( + AbsoluteLength::Pixels(px(8.)), + )), + bottom: Some(DefiniteLength::Absolute( + AbsoluteLength::Pixels(px(8.)), + )), + }, + margin: EdgesRefinement { + top: Some(Length::Definite(px(8.).into())), + left: Some(Length::Definite(px(0.).into())), + right: Some(Length::Definite(px(0.).into())), + bottom: Some(Length::Definite(px(12.).into())), + }, + border_style: Some(BorderStyle::Solid), + border_widths: EdgesRefinement { + top: Some(AbsoluteLength::Pixels(px(1.))), + left: Some(AbsoluteLength::Pixels(px(1.))), + right: Some(AbsoluteLength::Pixels(px(1.))), + bottom: Some(AbsoluteLength::Pixels(px(1.))), + }, + border_color: Some(cx.theme().colors().border_variant), + background: Some( + cx.theme().colors().editor_background.into(), + ), + ..Default::default() + }, + ..Default::default() + }, ) + .into_any_element() } else { - "No active completion".to_string() + div() + .child("No active completion".to_string()) + .into_any_element() }), ) .id("raw-input-view"), ) } - fn render_active_completion(&mut self, cx: &mut Context) -> Option { - let active_completion = self.active_completion.as_ref()?; - let completion_id = active_completion.completion.id; + fn render_active_completion( + &mut self, + window: &mut Window, + cx: &mut Context, + ) -> Option { + let active_prediction = self.active_prediction.as_ref()?; + let completion_id = active_prediction.prediction.id.clone(); let focus_handle = &self.focus_handle(cx); let border_color = cx.theme().colors().border; let bg_color = cx.theme().colors().editor_background; - let rated = self.zeta.read(cx).is_completion_rated(completion_id); - let feedback_empty = active_completion + let rated = self.zeta.read(cx).is_prediction_rated(&completion_id); + let feedback_empty = active_prediction .feedback_editor .read(cx) .text(cx) @@ -412,10 +586,10 @@ impl RateCompletionModal { .child(self.render_view_nav(cx)) .when_some( match self.current_view { - RateCompletionView::SuggestedEdits => { + RatePredictionView::SuggestedEdits => { self.render_suggested_edits(cx) } - RateCompletionView::RawInput => self.render_raw_input(cx), + RatePredictionView::RawInput => self.render_raw_input(window, cx), }, |this, element| this.child(element), ), @@ -450,7 +624,7 @@ impl RateCompletionModal { .h_40() .pt_1() .bg(bg_color) - .child(active_completion.feedback_editor.clone()), + .child(active_prediction.feedback_editor.clone()), ) }) .child( @@ -472,7 +646,7 @@ impl RateCompletionModal { ) .child(Label::new("Rated completion.").color(Color::Muted)), ) - } else if active_completion.completion.edits.is_empty() { + } else if active_prediction.prediction.edits.is_empty() { Some( label_container .child( @@ -489,7 +663,7 @@ impl RateCompletionModal { h_flex() .gap_1() .child( - Button::new("bad", "Bad Completion") + Button::new("bad", "Bad Prediction") .icon(IconName::ThumbsDown) .icon_size(IconSize::Small) .icon_position(IconPosition::Start) @@ -500,14 +674,14 @@ impl RateCompletionModal { )) }) .key_binding(KeyBinding::for_action_in( - &ThumbsDownActiveCompletion, + &ThumbsDownActivePrediction, focus_handle, cx, )) .on_click(cx.listener(move |this, _, window, cx| { - if this.active_completion.is_some() { + if this.active_prediction.is_some() { this.thumbs_down_active( - &ThumbsDownActiveCompletion, + &ThumbsDownActivePrediction, window, cx, ); @@ -515,20 +689,20 @@ impl RateCompletionModal { })), ) .child( - Button::new("good", "Good Completion") + Button::new("good", "Good Prediction") .icon(IconName::ThumbsUp) .icon_size(IconSize::Small) .icon_position(IconPosition::Start) .disabled(rated) .key_binding(KeyBinding::for_action_in( - &ThumbsUpActiveCompletion, + &ThumbsUpActivePrediction, focus_handle, cx, )) .on_click(cx.listener(move |this, _, window, cx| { - if this.active_completion.is_some() { + if this.active_prediction.is_some() { this.thumbs_up_active( - &ThumbsUpActiveCompletion, + &ThumbsUpActivePrediction, window, cx, ); @@ -543,34 +717,32 @@ impl RateCompletionModal { fn render_shown_completions(&self, cx: &Context) -> impl Iterator { self.zeta .read(cx) - .shown_completions() + .shown_predictions() .cloned() .enumerate() .map(|(index, completion)| { let selected = self - .active_completion + .active_prediction .as_ref() - .is_some_and(|selected| selected.completion.id == completion.id); - let rated = self.zeta.read(cx).is_completion_rated(completion.id); + .is_some_and(|selected| selected.prediction.id == completion.id); + let rated = self.zeta.read(cx).is_prediction_rated(&completion.id); let (icon_name, icon_color, tooltip_text) = match (rated, completion.edits.is_empty()) { - (true, _) => (IconName::Check, Color::Success, "Rated Completion"), + (true, _) => (IconName::Check, Color::Success, "Rated Prediction"), (false, true) => (IconName::File, Color::Muted, "No Edits Produced"), (false, false) => (IconName::FileDiff, Color::Accent, "Edits Available"), }; - let file_name = completion - .path - .file_name() - .map(|f| f.to_string_lossy().into_owned()) - .unwrap_or("untitled".to_string()); - let file_path = completion - .path - .parent() - .map(|p| p.to_string_lossy().into_owned()); - - ListItem::new(completion.id) + let file = completion.buffer.read(cx).file(); + let file_name = file + .as_ref() + .map_or(SharedString::new_static("untitled"), |file| { + file.file_name(cx).to_string().into() + }); + let file_path = file.map(|file| file.path().as_unix_str().to_string()); + + ListItem::new(completion.id.clone()) .inset(true) .spacing(ListItemSpacing::Sparse) .focused(index == self.selected_index) @@ -615,12 +787,12 @@ impl RateCompletionModal { } } -impl Render for RateCompletionModal { +impl Render for RatePredictionsModal { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let border_color = cx.theme().colors().border; h_flex() - .key_context("RateCompletionModal") + .key_context("RatePredictionModal") .track_focus(&self.focus_handle) .on_action(cx.listener(Self::dismiss)) .on_action(cx.listener(Self::confirm)) @@ -688,20 +860,20 @@ impl Render for RateCompletionModal { ), ), ) - .children(self.render_active_completion(cx)) + .children(self.render_active_completion(window, cx)) .on_mouse_down_out(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))) } } -impl EventEmitter for RateCompletionModal {} +impl EventEmitter for RatePredictionsModal {} -impl Focusable for RateCompletionModal { +impl Focusable for RatePredictionsModal { fn focus_handle(&self, _cx: &App) -> FocusHandle { self.focus_handle.clone() } } -impl ModalView for RateCompletionModal {} +impl ModalView for RatePredictionsModal {} fn format_time_ago(elapsed: Duration) -> String { let seconds = elapsed.as_secs(); diff --git a/crates/zeta2/src/retrieval_search.rs b/crates/zeta/src/retrieval_search.rs similarity index 100% rename from crates/zeta2/src/retrieval_search.rs rename to crates/zeta/src/retrieval_search.rs diff --git a/crates/zeta2/src/sweep_ai.rs b/crates/zeta/src/sweep_ai.rs similarity index 77% rename from crates/zeta2/src/sweep_ai.rs rename to crates/zeta/src/sweep_ai.rs index c56d7409fa212734c5f5a73a6b24319c27c7494f..0e226ab9df26ffc945a2d8e810790d0b00d0f198 100644 --- a/crates/zeta2/src/sweep_ai.rs +++ b/crates/zeta/src/sweep_ai.rs @@ -2,7 +2,6 @@ use std::fmt; use std::{path::Path, sync::Arc}; use serde::{Deserialize, Serialize}; -use util::rel_path::RelPath; #[derive(Debug, Clone, Serialize)] pub struct AutocompleteRequest { @@ -91,34 +90,24 @@ pub struct AdditionalCompletion { pub finish_reason: Option, } -pub(crate) fn write_event(event: crate::Event, f: &mut impl fmt::Write) -> fmt::Result { +pub(crate) fn write_event( + event: &cloud_llm_client::predict_edits_v3::Event, + f: &mut impl fmt::Write, +) -> fmt::Result { match event { - crate::Event::BufferChange { - old_snapshot, - new_snapshot, + cloud_llm_client::predict_edits_v3::Event::BufferChange { + old_path, + path, + diff, .. } => { - let old_path = old_snapshot - .file() - .map(|f| f.path().as_ref()) - .unwrap_or(RelPath::unix("untitled").unwrap()); - let new_path = new_snapshot - .file() - .map(|f| f.path().as_ref()) - .unwrap_or(RelPath::unix("untitled").unwrap()); - if old_path != new_path { + if old_path != path { // TODO confirm how to do this for sweep // writeln!(f, "User renamed {:?} to {:?}\n", old_path, new_path)?; } - let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text()); if !diff.is_empty() { - write!( - f, - "File: {}:\n{}\n", - new_path.display(util::paths::PathStyle::Posix), - diff - )? + write!(f, "File: {}:\n{}\n", path.display(), diff)? } fmt::Result::Ok(()) diff --git a/crates/zeta2/src/udiff.rs b/crates/zeta/src/udiff.rs similarity index 100% rename from crates/zeta2/src/udiff.rs rename to crates/zeta/src/udiff.rs diff --git a/crates/zeta2/src/xml_edits.rs b/crates/zeta/src/xml_edits.rs similarity index 100% rename from crates/zeta2/src/xml_edits.rs rename to crates/zeta/src/xml_edits.rs diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 5b2c3856eda2cd984e6675d671f8c99aa183e883..6464ce19ebaf1f95ad58e2954fb68e934600dac4 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -1,130 +1,178 @@ -mod completion_diff_element; -mod init; -mod input_excerpt; -mod license_detection; -mod onboarding_modal; -mod onboarding_telemetry; -mod rate_completion_modal; - -pub(crate) use completion_diff_element::*; -use db::kvp::{Dismissable, KEY_VALUE_STORE}; -use db::smol::stream::StreamExt as _; -use edit_prediction::DataCollectionState; -use futures::channel::mpsc; -pub use init::*; -use license_detection::LicenseDetectionWatcher; -pub use rate_completion_modal::*; - -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Context as _, Result, anyhow, bail}; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; +use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature}; use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME, - PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, RejectEditPredictionsBody, - ZED_VERSION_HEADER_NAME, + RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME, }; -use collections::{HashMap, HashSet, VecDeque}; -use futures::AsyncReadExt; +use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; +use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES}; +use collections::{HashMap, HashSet}; +use command_palette_hooks::CommandPaletteFilter; +use db::kvp::{Dismissable, KEY_VALUE_STORE}; +use edit_prediction_context::{ + DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, + EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line, + SyntaxIndex, SyntaxIndexState, +}; +use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag}; +use futures::channel::{mpsc, oneshot}; +use futures::{AsyncReadExt as _, StreamExt as _}; use gpui::{ - App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SharedString, Subscription, - Task, actions, + App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions, + http_client::{self, AsyncBody, Method}, + prelude::*, }; -use http_client::{AsyncBody, HttpClient, Method, Request, Response}; -use input_excerpt::excerpt_for_cursor_position; use language::{ - Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff, + Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint, }; +use language::{BufferSnapshot, OffsetRangeExt}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; -use project::{Project, ProjectPath}; +use lsp::DiagnosticSeverity; +use open_ai::FunctionDefinition; +use project::{DisableAiSettings, Project, ProjectPath, WorktreeId}; use release_channel::AppVersion; use semver::Version; -use settings::WorktreeId; -use std::collections::hash_map; -use std::mem; -use std::str::FromStr; -use std::{ - cmp, - fmt::Write, - future::Future, - ops::Range, - path::Path, - rc::Rc, - sync::Arc, - time::{Duration, Instant}, -}; +use serde::de::DeserializeOwned; +use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file}; +use std::any::{Any as _, TypeId}; +use std::collections::{VecDeque, hash_map}; use telemetry_events::EditPredictionRating; +use workspace::Workspace; + +use std::fmt::Write as _; +use std::ops::Range; +use std::path::Path; +use std::rc::Rc; +use std::str::FromStr as _; +use std::sync::{Arc, LazyLock}; +use std::time::{Duration, Instant}; +use std::{env, mem}; use thiserror::Error; -use util::ResultExt; -use util::rel_path::RelPath; -use uuid::Uuid; +use util::rel_path::RelPathBuf; +use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; -use worktree::Worktree; - -const CURSOR_MARKER: &str = "<|user_cursor_is_here|>"; -const START_OF_FILE_MARKER: &str = "<|start_of_file|>"; -const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>"; -const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>"; -const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); -const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice"; -const MAX_CONTEXT_TOKENS: usize = 150; -const MAX_REWRITE_TOKENS: usize = 350; -const MAX_EVENT_TOKENS: usize = 500; +pub mod assemble_excerpts; +mod license_detection; +mod onboarding_modal; +mod prediction; +mod provider; +mod rate_prediction_modal; +pub mod retrieval_search; +mod sweep_ai; +pub mod udiff; +mod xml_edits; +pub mod zeta1; -/// Maximum number of events to track. -const MAX_EVENT_COUNT: usize = 16; +#[cfg(test)] +mod zeta_tests; + +use crate::assemble_excerpts::assemble_excerpts; +use crate::license_detection::LicenseDetectionWatcher; +use crate::onboarding_modal::ZedPredictModal; +pub use crate::prediction::EditPrediction; +pub use crate::prediction::EditPredictionId; +pub use crate::prediction::EditPredictionInputs; +use crate::rate_prediction_modal::{ + NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction, + ThumbsUpActivePrediction, +}; +use crate::zeta1::request_prediction_with_zeta1; +pub use provider::ZetaEditPredictionProvider; actions!( edit_prediction, [ + /// Resets the edit prediction onboarding state. + ResetOnboarding, + /// Opens the rate completions modal. + RateCompletions, /// Clears the edit prediction history. - ClearHistory + ClearHistory, ] ); -#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] -pub struct EditPredictionId(Uuid); +/// Maximum number of events to track. +const EVENT_COUNT_MAX: usize = 6; +const CHANGE_GROUPING_LINE_SPAN: u32 = 8; +const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice"; -impl From for gpui::ElementId { - fn from(value: EditPredictionId) -> Self { - gpui::ElementId::Uuid(value.0) - } -} +pub struct SweepFeatureFlag; -impl std::fmt::Display for EditPredictionId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } +impl FeatureFlag for SweepFeatureFlag { + const NAME: &str = "sweep-ai"; } +pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions { + max_bytes: 512, + min_bytes: 128, + target_before_cursor_over_total_bytes: 0.5, +}; -struct ZedPredictUpsell; +pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = + ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS); -impl Dismissable for ZedPredictUpsell { - const KEY: &'static str = "dismissed-edit-predict-upsell"; +pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions { + excerpt: DEFAULT_EXCERPT_OPTIONS, +}; - fn dismissed() -> bool { - // To make this backwards compatible with older versions of Zed, we - // check if the user has seen the previous Edit Prediction Onboarding - // before, by checking the data collection choice which was written to - // the database once the user clicked on "Accept and Enable" - if KEY_VALUE_STORE - .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) - .log_err() - .is_some_and(|s| s.is_some()) - { - return true; +pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions = + EditPredictionContextOptions { + use_imports: true, + max_retrieved_declarations: 0, + excerpt: DEFAULT_EXCERPT_OPTIONS, + score: EditPredictionScoreOptions { + omit_excerpt_overlaps: true, + }, + }; + +pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { + context: DEFAULT_CONTEXT_OPTIONS, + max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES, + max_diagnostic_bytes: 2048, + prompt_format: PromptFormat::DEFAULT, + file_indexing_parallelism: 1, + buffer_change_grouping_interval: Duration::from_secs(1), +}; + +static USE_OLLAMA: LazyLock = + LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty())); +static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock = LazyLock::new(|| { + env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA { + "qwen3-coder:30b".to_string() + } else { + "yqvev8r3".to_string() + }) +}); +static EDIT_PREDICTIONS_MODEL_ID: LazyLock = LazyLock::new(|| { + match env::var("ZED_ZETA2_MODEL").as_deref() { + Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten + Ok(model) => model, + Err(_) if *USE_OLLAMA => "qwen3-coder:30b", + Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten + } + .to_string() +}); +static PREDICT_EDITS_URL: LazyLock> = LazyLock::new(|| { + env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| { + if *USE_OLLAMA { + Some("http://localhost:11434/v1/chat/completions".into()) + } else { + None } + }) +}); - KEY_VALUE_STORE - .read_kvp(Self::KEY) - .log_err() - .is_some_and(|s| s.is_some()) - } -} +pub struct Zeta2FeatureFlag; -pub fn should_show_upsell_modal() -> bool { - !ZedPredictUpsell::dismissed() +impl FeatureFlag for Zeta2FeatureFlag { + const NAME: &'static str = "zeta2"; + + fn enabled_for_staff() -> bool { + false + } } #[derive(Clone)] @@ -132,108 +180,291 @@ struct ZetaGlobal(Entity); impl Global for ZetaGlobal {} -#[derive(Clone)] -pub struct EditPrediction { - id: EditPredictionId, - path: Arc, - excerpt_range: Range, - cursor_offset: usize, - edits: Arc<[(Range, Arc)]>, - snapshot: BufferSnapshot, - edit_preview: EditPreview, - input_outline: Arc, - input_events: Arc, - input_excerpt: Arc, - output_excerpt: Arc, - buffer_snapshotted_at: Instant, - response_received_at: Instant, +pub struct Zeta { + client: Arc, + user_store: Entity, + llm_token: LlmApiToken, + _llm_token_subscription: Subscription, + projects: HashMap, + options: ZetaOptions, + update_required: bool, + debug_tx: Option>, + #[cfg(feature = "eval-support")] + eval_cache: Option>, + edit_prediction_model: ZetaEditPredictionModel, + sweep_api_token: Option, + sweep_ai_debug_info: Arc, + data_collection_choice: DataCollectionChoice, + rejected_predictions: Vec, + reject_predictions_tx: mpsc::UnboundedSender<()>, + reject_predictions_debounce_task: Option>, + shown_predictions: VecDeque, + rated_predictions: HashSet, } -impl EditPrediction { - fn latency(&self) -> Duration { - self.response_received_at - .duration_since(self.buffer_snapshotted_at) - } +#[derive(Default, PartialEq, Eq)] +pub enum ZetaEditPredictionModel { + #[default] + Zeta1, + Zeta2, + Sweep, +} - fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, Arc)>> { - edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits) - } +#[derive(Debug, Clone, PartialEq)] +pub struct ZetaOptions { + pub context: ContextMode, + pub max_prompt_bytes: usize, + pub max_diagnostic_bytes: usize, + pub prompt_format: predict_edits_v3::PromptFormat, + pub file_indexing_parallelism: usize, + pub buffer_change_grouping_interval: Duration, } -impl std::fmt::Debug for EditPrediction { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("EditPrediction") - .field("id", &self.id) - .field("path", &self.path) - .field("edits", &self.edits) - .finish_non_exhaustive() +#[derive(Debug, Clone, PartialEq)] +pub enum ContextMode { + Agentic(AgenticContextOptions), + Syntax(EditPredictionContextOptions), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct AgenticContextOptions { + pub excerpt: EditPredictionExcerptOptions, +} + +impl ContextMode { + pub fn excerpt(&self) -> &EditPredictionExcerptOptions { + match self { + ContextMode::Agentic(options) => &options.excerpt, + ContextMode::Syntax(options) => &options.excerpt, + } } } -pub struct Zeta { - projects: HashMap, - client: Arc, - shown_completions: VecDeque, - rated_completions: HashSet, - data_collection_choice: DataCollectionChoice, - discarded_completions: Vec, - llm_token: LlmApiToken, - _llm_token_subscription: Subscription, - /// Whether an update to a newer version of Zed is required to continue using Zeta. - update_required: bool, - user_store: Entity, - license_detection_watchers: HashMap>, - discard_completions_debounce_task: Option>, - discard_completions_tx: mpsc::UnboundedSender<()>, +#[derive(Debug)] +pub enum ZetaDebugInfo { + ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo), + SearchQueriesGenerated(ZetaSearchQueryDebugInfo), + SearchQueriesExecuted(ZetaContextRetrievalDebugInfo), + ContextRetrievalFinished(ZetaContextRetrievalDebugInfo), + EditPredictionRequested(ZetaEditPredictionDebugInfo), +} + +#[derive(Debug)] +pub struct ZetaContextRetrievalStartedDebugInfo { + pub project: Entity, + pub timestamp: Instant, + pub search_prompt: String, +} + +#[derive(Debug)] +pub struct ZetaContextRetrievalDebugInfo { + pub project: Entity, + pub timestamp: Instant, +} + +#[derive(Debug)] +pub struct ZetaEditPredictionDebugInfo { + pub inputs: EditPredictionInputs, + pub retrieval_time: Duration, + pub buffer: WeakEntity, + pub position: language::Anchor, + pub local_prompt: Result, + pub response_rx: oneshot::Receiver<(Result, Duration)>, +} + +#[derive(Debug)] +pub struct ZetaSearchQueryDebugInfo { + pub project: Entity, + pub timestamp: Instant, + pub search_queries: Vec, } +pub type RequestDebugInfo = predict_edits_v3::DebugInfo; + struct ZetaProject { - events: VecDeque, + syntax_index: Option>, + events: VecDeque>, + last_event: Option, + recent_paths: VecDeque, registered_buffers: HashMap, + current_prediction: Option, + next_pending_prediction_id: usize, + pending_predictions: ArrayVec, + last_prediction_refresh: Option<(EntityId, Instant)>, + context: Option, Vec>>>, + refresh_context_task: Option>>>, + refresh_context_debounce_task: Option>>, + refresh_context_timestamp: Option, + license_detection_watchers: HashMap>, + _subscription: gpui::Subscription, } -impl Zeta { - pub fn global(cx: &mut App) -> Option> { - cx.try_global::().map(|global| global.0.clone()) +impl ZetaProject { + pub fn events(&self, cx: &App) -> Vec> { + self.events + .iter() + .cloned() + .chain( + self.last_event + .as_ref() + .and_then(|event| event.finalize(&self.license_detection_watchers, cx)), + ) + .collect() } +} - pub fn register( - worktree: Option>, - client: Arc, - user_store: Entity, - cx: &mut App, - ) -> Entity { - let this = Self::global(cx).unwrap_or_else(|| { - let entity = cx.new(|cx| Self::new(client, user_store, cx)); - cx.set_global(ZetaGlobal(entity.clone())); - entity - }); +#[derive(Debug, Clone)] +struct CurrentEditPrediction { + pub requested_by: PredictionRequestedBy, + pub prediction: EditPrediction, + pub was_shown: bool, +} - this.update(cx, move |this, cx| { - if let Some(worktree) = worktree { - let worktree_id = worktree.read(cx).id(); - this.license_detection_watchers - .entry(worktree_id) - .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx))); - } - }); +impl CurrentEditPrediction { + fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool { + let Some(new_edits) = self + .prediction + .interpolate(&self.prediction.buffer.read(cx)) + else { + return false; + }; + + if self.prediction.buffer != old_prediction.prediction.buffer { + return true; + } + + let Some(old_edits) = old_prediction + .prediction + .interpolate(&old_prediction.prediction.buffer.read(cx)) + else { + return true; + }; - this + let requested_by_buffer_id = self.requested_by.buffer_id(); + + // This reduces the occurrence of UI thrash from replacing edits + // + // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits. + if requested_by_buffer_id == Some(self.prediction.buffer.entity_id()) + && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id()) + && old_edits.len() == 1 + && new_edits.len() == 1 + { + let (old_range, old_text) = &old_edits[0]; + let (new_range, new_text) = &new_edits[0]; + new_range == old_range && new_text.starts_with(old_text.as_ref()) + } else { + true + } } +} - pub fn clear_history(&mut self) { - for zeta_project in self.projects.values_mut() { - zeta_project.events.clear(); +#[derive(Debug, Clone)] +enum PredictionRequestedBy { + DiagnosticsUpdate, + Buffer(EntityId), +} + +impl PredictionRequestedBy { + pub fn buffer_id(&self) -> Option { + match self { + PredictionRequestedBy::DiagnosticsUpdate => None, + PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id), } } +} - pub fn usage(&self, cx: &App) -> Option { - self.user_store.read(cx).edit_prediction_usage() +struct PendingPrediction { + id: usize, + task: Task>, +} + +/// A prediction from the perspective of a buffer. +#[derive(Debug)] +enum BufferEditPrediction<'a> { + Local { prediction: &'a EditPrediction }, + Jump { prediction: &'a EditPrediction }, +} + +struct RegisteredBuffer { + snapshot: BufferSnapshot, + _subscriptions: [gpui::Subscription; 2], +} + +struct LastEvent { + old_snapshot: BufferSnapshot, + new_snapshot: BufferSnapshot, + end_edit_anchor: Option, +} + +impl LastEvent { + pub fn finalize( + &self, + license_detection_watchers: &HashMap>, + cx: &App, + ) -> Option> { + let path = buffer_path_with_id_fallback(&self.new_snapshot, cx); + let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx); + + let file = self.new_snapshot.file(); + let old_file = self.old_snapshot.file(); + + let in_open_source_repo = [file, old_file].iter().all(|file| { + file.is_some_and(|file| { + license_detection_watchers + .get(&file.worktree_id(cx)) + .is_some_and(|watcher| watcher.is_project_open_source()) + }) + }); + + let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text()); + + if path == old_path && diff.is_empty() { + None + } else { + Some(Arc::new(predict_edits_v3::Event::BufferChange { + old_path, + path, + diff, + in_open_source_repo, + // TODO: Actually detect if this edit was predicted or not + predicted: false, + })) + } + } +} + +fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc { + if let Some(file) = snapshot.file() { + file.full_path(cx).into() + } else { + Path::new(&format!("untitled-{}", snapshot.remote_id())).into() + } +} + +impl Zeta { + pub fn try_global(cx: &App) -> Option> { + cx.try_global::().map(|global| global.0.clone()) + } + + pub fn global( + client: &Arc, + user_store: &Entity, + cx: &mut App, + ) -> Entity { + cx.try_global::() + .map(|global| global.0.clone()) + .unwrap_or_else(|| { + let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx)); + cx.set_global(ZetaGlobal(zeta.clone())); + zeta + }) } - fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { + pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); let data_collection_choice = Self::load_data_collection_choice(); + let (reject_tx, mut reject_rx) = mpsc::unbounded(); cx.spawn(async move |this, cx| { while let Some(()) = reject_rx.next().await { @@ -248,12 +479,8 @@ impl Zeta { Self { projects: HashMap::default(), client, - shown_completions: VecDeque::new(), - rated_completions: HashSet::default(), - discarded_completions: Vec::new(), - discard_completions_debounce_task: None, - discard_completions_tx: reject_tx, - data_collection_choice, + user_store, + options: DEFAULT_OPTIONS, llm_token: LlmApiToken::default(), _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, @@ -268,64 +495,85 @@ impl Zeta { }, ), update_required: false, - license_detection_watchers: HashMap::default(), - user_store, + debug_tx: None, + #[cfg(feature = "eval-support")] + eval_cache: None, + edit_prediction_model: ZetaEditPredictionModel::Zeta2, + sweep_api_token: std::env::var("SWEEP_AI_TOKEN") + .context("No SWEEP_AI_TOKEN environment variable set") + .log_err(), + data_collection_choice, + sweep_ai_debug_info: sweep_ai::debug_info(cx), + rejected_predictions: Vec::new(), + reject_predictions_debounce_task: None, + reject_predictions_tx: reject_tx, + rated_predictions: Default::default(), + shown_predictions: Default::default(), } } - fn get_or_init_zeta_project( - &mut self, - project: &Entity, - cx: &mut Context, - ) -> &mut ZetaProject { - let project_id = project.entity_id(); - match self.projects.entry(project_id) { - hash_map::Entry::Occupied(entry) => entry.into_mut(), - hash_map::Entry::Vacant(entry) => { - cx.observe_release(project, move |this, _, _cx| { - this.projects.remove(&project_id); - }) - .detach(); - entry.insert(ZetaProject { - events: VecDeque::with_capacity(MAX_EVENT_COUNT), - registered_buffers: HashMap::default(), - }) - } - } + pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) { + self.edit_prediction_model = model; } - fn push_event(zeta_project: &mut ZetaProject, event: Event) { - let events = &mut zeta_project.events; + pub fn has_sweep_api_token(&self) -> bool { + self.sweep_api_token.is_some() + } - if let Some(Event::BufferChange { - new_snapshot: last_new_snapshot, - timestamp: last_timestamp, - .. - }) = events.back_mut() - { - // Coalesce edits for the same buffer when they happen one after the other. - let Event::BufferChange { - old_snapshot, - new_snapshot, - timestamp, - } = &event; - - if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL - && old_snapshot.remote_id() == last_new_snapshot.remote_id() - && old_snapshot.version == last_new_snapshot.version - { - *last_new_snapshot = new_snapshot.clone(); - *last_timestamp = *timestamp; - return; - } + #[cfg(feature = "eval-support")] + pub fn with_eval_cache(&mut self, cache: Arc) { + self.eval_cache = Some(cache); + } + + pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver { + let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded(); + self.debug_tx = Some(debug_watch_tx); + debug_watch_rx + } + + pub fn options(&self) -> &ZetaOptions { + &self.options + } + + pub fn set_options(&mut self, options: ZetaOptions) { + self.options = options; + } + + pub fn clear_history(&mut self) { + for zeta_project in self.projects.values_mut() { + zeta_project.events.clear(); } + } + + pub fn context_for_project( + &self, + project: &Entity, + ) -> impl Iterator, &[Range])> { + self.projects + .get(&project.entity_id()) + .and_then(|project| { + Some( + project + .context + .as_ref()? + .iter() + .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())), + ) + }) + .into_iter() + .flatten() + } - if events.len() >= MAX_EVENT_COUNT { - // These are halved instead of popping to improve prompt caching. - events.drain(..MAX_EVENT_COUNT / 2); + pub fn usage(&self, cx: &App) -> Option { + if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 { + self.user_store.read(cx).edit_prediction_usage() + } else { + None } + } - events.push_back(event); + pub fn register_project(&mut self, project: &Entity, cx: &mut Context) { + self.get_or_init_zeta_project(project, cx); } pub fn register_buffer( @@ -338,6 +586,69 @@ impl Zeta { Self::register_buffer_impl(zeta_project, buffer, project, cx); } + fn get_or_init_zeta_project( + &mut self, + project: &Entity, + cx: &mut Context, + ) -> &mut ZetaProject { + self.projects + .entry(project.entity_id()) + .or_insert_with(|| ZetaProject { + syntax_index: if let ContextMode::Syntax(_) = &self.options.context { + Some(cx.new(|cx| { + SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx) + })) + } else { + None + }, + events: VecDeque::new(), + last_event: None, + recent_paths: VecDeque::new(), + registered_buffers: HashMap::default(), + current_prediction: None, + pending_predictions: ArrayVec::new(), + next_pending_prediction_id: 0, + last_prediction_refresh: None, + context: None, + refresh_context_task: None, + refresh_context_debounce_task: None, + refresh_context_timestamp: None, + license_detection_watchers: HashMap::default(), + _subscription: cx.subscribe(&project, Self::handle_project_event), + }) + } + + fn handle_project_event( + &mut self, + project: Entity, + event: &project::Event, + cx: &mut Context, + ) { + // TODO [zeta2] init with recent paths + match event { + project::Event::ActiveEntryChanged(Some(active_entry_id)) => { + let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + let path = project.read(cx).path_for_entry(*active_entry_id, cx); + if let Some(path) = path { + if let Some(ix) = zeta_project + .recent_paths + .iter() + .position(|probe| probe == &path) + { + zeta_project.recent_paths.remove(ix); + } + zeta_project.recent_paths.push_front(path); + } + } + project::Event::DiagnosticsUpdated { .. } => { + self.refresh_prediction_from_diagnostics(project, cx); + } + _ => (), + } + } + fn register_buffer_impl<'a>( zeta_project: &'a mut ZetaProject, buffer: &Entity, @@ -345,6 +656,28 @@ impl Zeta { cx: &mut Context, ) -> &'a mut RegisteredBuffer { let buffer_id = buffer.entity_id(); + + if let Some(file) = buffer.read(cx).file() { + let worktree_id = file.worktree_id(cx); + if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) { + zeta_project + .license_detection_watchers + .entry(worktree_id) + .or_insert_with(|| { + let project_entity_id = project.entity_id(); + cx.observe_release(&worktree, move |this, _worktree, _cx| { + let Some(zeta_project) = this.projects.get_mut(&project_entity_id) + else { + return; + }; + zeta_project.license_detection_watchers.remove(&worktree_id); + }) + .detach(); + Rc::new(LicenseDetectionWatcher::new(&worktree, cx)) + }); + } + } + match zeta_project.registered_buffers.entry(buffer_id) { hash_map::Entry::Occupied(entry) => entry.into_mut(), hash_map::Entry::Vacant(entry) => { @@ -376,2037 +709,2755 @@ impl Zeta { } } - fn request_completion_impl( + fn report_changes_for_buffer( &mut self, - project: &Entity, buffer: &Entity, - cursor: language::Anchor, + project: &Entity, cx: &mut Context, - perform_predict_edits: F, - ) -> Task>> - where - F: FnOnce(PerformPredictEditsParams) -> R + 'static, - R: Future)>> - + Send - + 'static, - { - let buffer = buffer.clone(); - let buffer_snapshotted_at = Instant::now(); - let snapshot = self.report_changes_for_buffer(&buffer, project, cx); - let zeta = cx.entity(); - let client = self.client.clone(); - let llm_token = self.llm_token.clone(); - let app_version = AppVersion::global(cx); - - let zeta_project = self.get_or_init_zeta_project(project, cx); - let mut events = Vec::with_capacity(zeta_project.events.len()); - events.extend(zeta_project.events.iter().cloned()); - let events = Arc::new(events); - - let (git_info, can_collect_file) = if let Some(file) = snapshot.file() { - let can_collect_file = self.can_collect_file(file, cx); - let git_info = if can_collect_file { - git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) - } else { - None - }; - (git_info, can_collect_file) - } else { - (None, false) - }; - - let full_path: Arc = snapshot - .file() - .map(|f| Arc::from(f.full_path(cx).as_path())) - .unwrap_or_else(|| Arc::from(Path::new("untitled"))); - let full_path_str = full_path.to_string_lossy().into_owned(); - let cursor_point = cursor.to_point(&snapshot); - let cursor_offset = cursor_point.to_offset(&snapshot); - let prompt_for_events = { - let events = events.clone(); - move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS) - }; - let gather_task = gather_context( - full_path_str, - &snapshot, - cursor_point, - prompt_for_events, - cx, - ); - - cx.spawn(async move |this, cx| { - let GatherContextOutput { - mut body, - editable_range, - included_events_count, - } = gather_task.await?; - let done_gathering_context_at = Instant::now(); - - let included_events = &events[events.len() - included_events_count..events.len()]; - body.can_collect_data = can_collect_file - && this - .read_with(cx, |this, cx| this.can_collect_events(included_events, cx)) - .unwrap_or(false); - if body.can_collect_data { - body.git_info = git_info; - } - - log::debug!( - "Events:\n{}\nExcerpt:\n{:?}", - body.input_events, - body.input_excerpt - ); - - let input_outline = body.outline.clone().unwrap_or_default(); - let input_events = body.input_events.clone(); - let input_excerpt = body.input_excerpt.clone(); - - let response = perform_predict_edits(PerformPredictEditsParams { - client, - llm_token, - app_version, - body, - }) - .await; - let (response, usage) = match response { - Ok(response) => response, - Err(err) => { - if err.is::() { - cx.update(|cx| { - zeta.update(cx, |zeta, _cx| { - zeta.update_required = true; - }); - - let error_message: SharedString = err.to_string().into(); - show_app_notification( - NotificationId::unique::(), - cx, - move |cx| { - cx.new(|cx| { - ErrorMessagePrompt::new(error_message.clone(), cx) - .with_link_button( - "Update Zed", - "https://zed.dev/releases", - ) - }) - }, - ); - }) - .ok(); - } + ) { + let project_state = self.get_or_init_zeta_project(project, cx); + let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx); - return Err(err); - } - }; + let new_snapshot = buffer.read(cx).snapshot(); + if new_snapshot.version == registered_buffer.snapshot.version { + return; + } - let received_response_at = Instant::now(); - log::debug!("completion response: {}", &response.output_excerpt); + let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); + let end_edit_anchor = new_snapshot + .anchored_edits_since::(&old_snapshot.version) + .last() + .map(|(_, range)| range.end); + let events = &mut project_state.events; - if let Some(usage) = usage { - this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); + if let Some(LastEvent { + new_snapshot: last_new_snapshot, + end_edit_anchor: last_end_edit_anchor, + .. + }) = project_state.last_event.as_mut() + { + let is_next_snapshot_of_same_buffer = old_snapshot.remote_id() + == last_new_snapshot.remote_id() + && old_snapshot.version == last_new_snapshot.version; + + let should_coalesce = is_next_snapshot_of_same_buffer + && end_edit_anchor + .as_ref() + .zip(last_end_edit_anchor.as_ref()) + .is_some_and(|(a, b)| { + let a = a.to_point(&new_snapshot); + let b = b.to_point(&new_snapshot); + a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN }); - }) - .ok(); + + if should_coalesce { + *last_end_edit_anchor = end_edit_anchor; + *last_new_snapshot = new_snapshot; + return; } + } - let edit_prediction = Self::process_completion_response( - response, - buffer, - &snapshot, - editable_range, - cursor_offset, - full_path, - input_outline, - input_events, - input_excerpt, - buffer_snapshotted_at, - cx, - ) - .await; + if events.len() + 1 >= EVENT_COUNT_MAX { + events.pop_front(); + } - let finished_at = Instant::now(); - - // record latency for ~1% of requests - if rand::random::() <= 2 { - telemetry::event!( - "Edit Prediction Request", - context_latency = done_gathering_context_at - .duration_since(buffer_snapshotted_at) - .as_millis(), - request_latency = received_response_at - .duration_since(done_gathering_context_at) - .as_millis(), - process_latency = finished_at.duration_since(received_response_at).as_millis() - ); - } + if let Some(event) = project_state.last_event.take() { + events.extend(event.finalize(&project_state.license_detection_watchers, cx)); + } - edit_prediction - }) + project_state.last_event = Some(LastEvent { + old_snapshot, + new_snapshot, + end_edit_anchor, + }); } - #[cfg(any(test, feature = "test-support"))] - pub fn fake_completion( - &mut self, - project: &Entity, + fn current_prediction_for_buffer( + &self, buffer: &Entity, - position: language::Anchor, - response: PredictEditsResponse, - cx: &mut Context, - ) -> Task>> { - self.request_completion_impl(project, buffer, position, cx, |_params| { - std::future::ready(Ok((response, None))) - }) - } - - pub fn request_completion( - &mut self, project: &Entity, - buffer: &Entity, - position: language::Anchor, - cx: &mut Context, - ) -> Task>> { - self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits) - } - - pub fn perform_predict_edits( - params: PerformPredictEditsParams, - ) -> impl Future)>> { - async move { - let PerformPredictEditsParams { - client, - llm_token, - app_version, - body, - .. - } = params; - - let http_client = client.http_client(); - let mut token = llm_token.acquire(&client).await?; - let mut did_retry = false; - - loop { - let request_builder = http_client::Request::builder().method(Method::POST); - let request_builder = - if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") { - request_builder.uri(predict_edits_url) - } else { - request_builder.uri( - http_client - .build_zed_llm_url("/predict_edits/v2", &[])? - .as_ref(), - ) - }; - let request = request_builder - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", token)) - .header(ZED_VERSION_HEADER_NAME, app_version.to_string()) - .body(serde_json::to_string(&body)?.into())?; + cx: &App, + ) -> Option> { + let project_state = self.projects.get(&project.entity_id())?; - let mut response = http_client.send(request).await?; + let CurrentEditPrediction { + requested_by, + prediction, + .. + } = project_state.current_prediction.as_ref()?; - if let Some(minimum_required_version) = response - .headers() - .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) - .and_then(|version| Version::from_str(version.to_str().ok()?).ok()) - { - anyhow::ensure!( - app_version >= minimum_required_version, - ZedUpdateRequiredError { - minimum_version: minimum_required_version - } - ); + if prediction.targets_buffer(buffer.read(cx)) { + Some(BufferEditPrediction::Local { prediction }) + } else { + let show_jump = match requested_by { + PredictionRequestedBy::Buffer(requested_by_buffer_id) => { + requested_by_buffer_id == &buffer.entity_id() } + PredictionRequestedBy::DiagnosticsUpdate => true, + }; - if response.status().is_success() { - let usage = EditPredictionUsage::from_headers(response.headers()).ok(); - - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - return Ok((serde_json::from_str(&body)?, usage)); - } else if !did_retry - && response - .headers() - .get(EXPIRED_LLM_TOKEN_HEADER_NAME) - .is_some() - { - did_retry = true; - token = llm_token.refresh(&client).await?; - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - anyhow::bail!( - "error predicting edits.\nStatus: {:?}\nBody: {}", - response.status(), - body - ); - } + if show_jump { + Some(BufferEditPrediction::Jump { prediction }) + } else { + None } } } - fn accept_edit_prediction( - &mut self, - request_id: EditPredictionId, - cx: &mut Context, - ) -> Task> { + fn accept_current_prediction(&mut self, project: &Entity, cx: &mut Context) { + match self.edit_prediction_model { + ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {} + ZetaEditPredictionModel::Sweep => return, + } + + let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + + let Some(prediction) = project_state.current_prediction.take() else { + return; + }; + let request_id = prediction.prediction.id.to_string(); + for pending_prediction in mem::take(&mut project_state.pending_predictions) { + self.cancel_pending_prediction(pending_prediction, cx); + } + let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); cx.spawn(async move |this, cx| { - let http_client = client.http_client(); - let mut response = llm_token_retry(&llm_token, &client, |token| { - let request_builder = http_client::Request::builder().method(Method::POST); - let request_builder = - if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") { - request_builder.uri(accept_prediction_url) - } else { - request_builder.uri( - http_client - .build_zed_llm_url("/predict_edits/accept", &[])? - .as_ref(), - ) - }; - Ok(request_builder - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", token)) - .header(ZED_VERSION_HEADER_NAME, app_version.to_string()) - .body( - serde_json::to_string(&AcceptEditPredictionBody { - request_id: request_id.0.to_string(), - })? - .into(), - )?) - }) - .await?; - - if let Some(minimum_required_version) = response - .headers() - .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) - .and_then(|version| Version::from_str(version.to_str().ok()?).ok()) - && app_version < minimum_required_version - { - return Err(anyhow!(ZedUpdateRequiredError { - minimum_version: minimum_required_version - })); - } - - if response.status().is_success() { - if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() { - this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); - }); - })?; - } - - Ok(()) + let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") { + http_client::Url::parse(&predict_edits_url)? } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - Err(anyhow!( - "error accepting edit prediction.\nStatus: {:?}\nBody: {}", - response.status(), - body + client + .http_client() + .build_zed_llm_url("/predict_edits/accept", &[])? + }; + + let response = cx + .background_spawn(Self::send_api_request::<()>( + move |builder| { + let req = builder.uri(url.as_ref()).body( + serde_json::to_string(&AcceptEditPredictionBody { + request_id: request_id.clone(), + })? + .into(), + ); + Ok(req?) + }, + client, + llm_token, + app_version, )) - } + .await; + + Self::handle_api_response(&this, response, cx)?; + anyhow::Ok(()) }) + .detach_and_log_err(cx); } fn reject_edit_predictions(&mut self, cx: &mut Context) -> Task> { + match self.edit_prediction_model { + ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {} + ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())), + } + let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let last_rejection = self.discarded_completions.last().cloned(); - let body = serde_json::to_string(&RejectEditPredictionsBody { - rejections: self.discarded_completions.clone(), - }) - .ok(); - + let last_rejection = self.rejected_predictions.last().cloned(); let Some(last_rejection) = last_rejection else { return Task::ready(anyhow::Ok(())); }; + let body = serde_json::to_string(&RejectEditPredictionsBody { + rejections: self.rejected_predictions.clone(), + }) + .ok(); + cx.spawn(async move |this, cx| { - let http_client = client.http_client(); - let mut response = llm_token_retry(&llm_token, &client, |token| { - let request_builder = http_client::Request::builder().method(Method::POST); - let request_builder = request_builder.uri( - http_client - .build_zed_llm_url("/predict_edits/reject", &[])? - .as_ref(), - ); - Ok(request_builder - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", token)) - .header(ZED_VERSION_HEADER_NAME, app_version.to_string()) - .body( - body.as_ref() - .context("failed to serialize body")? - .clone() - .into(), - )?) + let url = client + .http_client() + .build_zed_llm_url("/predict_edits/reject", &[])?; + + cx.background_spawn(Self::send_api_request::<()>( + move |builder| { + let req = builder.uri(url.as_ref()).body(body.clone().into()); + Ok(req?) + }, + client, + llm_token, + app_version, + )) + .await + .context("Failed to reject edit predictions")?; + + this.update(cx, |this, _| { + if let Some(ix) = this + .rejected_predictions + .iter() + .position(|rejection| rejection.request_id == last_rejection.request_id) + { + this.rejected_predictions.drain(..ix + 1); + } }) - .await?; + }) + } - if let Some(minimum_required_version) = response - .headers() - .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) - .and_then(|version| Version::from_str(version.to_str().ok()?).ok()) - && app_version < minimum_required_version - { - return Err(anyhow!(ZedUpdateRequiredError { - minimum_version: minimum_required_version - })); + fn discard_current_prediction(&mut self, project: &Entity, cx: &mut Context) { + if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { + project_state.pending_predictions.clear(); + if let Some(prediction) = project_state.current_prediction.take() { + self.discard_prediction(prediction.prediction.id, prediction.was_shown, cx); } + }; + } - if response.status().is_success() { - this.update(cx, |this, _| { - if let Some(ix) = this - .discarded_completions - .iter() - .position(|rejection| rejection.request_id == last_rejection.request_id) - { - this.discarded_completions.drain(..ix + 1); + fn did_show_current_prediction(&mut self, project: &Entity, _cx: &mut Context) { + if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { + if let Some(current_prediction) = project_state.current_prediction.as_mut() { + if !current_prediction.was_shown { + current_prediction.was_shown = true; + self.shown_predictions + .push_front(current_prediction.prediction.clone()); + if self.shown_predictions.len() > 50 { + let completion = self.shown_predictions.pop_back().unwrap(); + self.rated_predictions.remove(&completion.id); } - }) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - Err(anyhow!( - "error rejecting edit predictions.\nStatus: {:?}\nBody: {}", - response.status(), - body - )) + } + } + } + } + + fn discard_prediction( + &mut self, + prediction_id: EditPredictionId, + was_shown: bool, + cx: &mut Context, + ) { + self.rejected_predictions.push(EditPredictionRejection { + request_id: prediction_id.to_string(), + was_shown, + }); + + let reached_request_limit = + self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST; + let reject_tx = self.reject_predictions_tx.clone(); + self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| { + const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15); + if !reached_request_limit { + cx.background_executor() + .timer(DISCARD_COMPLETIONS_DEBOUNCE) + .await; } + reject_tx.unbounded_send(()).log_err(); + })); + } + + fn cancel_pending_prediction( + &self, + pending_prediction: PendingPrediction, + cx: &mut Context, + ) { + cx.spawn(async move |this, cx| { + let Some(prediction_id) = pending_prediction.task.await else { + return; + }; + + this.update(cx, |this, cx| { + this.discard_prediction(prediction_id, false, cx); + }) + .ok(); }) + .detach() + } + + fn is_refreshing(&self, project: &Entity) -> bool { + self.projects + .get(&project.entity_id()) + .is_some_and(|project_state| !project_state.pending_predictions.is_empty()) } - fn process_completion_response( - prediction_response: PredictEditsResponse, + pub fn refresh_prediction_from_buffer( + &mut self, + project: Entity, buffer: Entity, - snapshot: &BufferSnapshot, - editable_range: Range, - cursor_offset: usize, - path: Arc, - input_outline: String, - input_events: String, - input_excerpt: String, - buffer_snapshotted_at: Instant, - cx: &AsyncApp, - ) -> Task>> { - let snapshot = snapshot.clone(); - let request_id = prediction_response.request_id; - let output_excerpt = prediction_response.output_excerpt; - cx.spawn(async move |cx| { - let output_excerpt: Arc = output_excerpt.into(); - - let edits: Arc<[(Range, Arc)]> = cx - .background_spawn({ - let output_excerpt = output_excerpt.clone(); - let editable_range = editable_range.clone(); - let snapshot = snapshot.clone(); - async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) } + position: language::Anchor, + cx: &mut Context, + ) { + self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| { + let Some(request_task) = this + .update(cx, |this, cx| { + this.request_prediction(&project, &buffer, position, cx) }) - .await? - .into(); - - let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, { - let edits = edits.clone(); - move |buffer, cx| { - let new_snapshot = buffer.snapshot(); - let edits: Arc<[(Range, Arc)]> = - edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)? - .into(); - Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx))) - } - })? + .log_err() else { - return anyhow::Ok(None); + return Task::ready(anyhow::Ok(None)); }; - let request_id = Uuid::from_str(&request_id).context("failed to parse request id")?; - - let edit_preview = edit_preview.await; - - Ok(Some(EditPrediction { - id: EditPredictionId(request_id), - path, - excerpt_range: editable_range, - cursor_offset, - edits, - edit_preview, - snapshot, - input_outline: input_outline.into(), - input_events: input_events.into(), - input_excerpt: input_excerpt.into(), - output_excerpt, - buffer_snapshotted_at, - response_received_at: Instant::now(), - })) + let project = project.clone(); + cx.spawn(async move |cx| { + if let Some(prediction) = request_task.await? { + let id = prediction.id.clone(); + this.update(cx, |this, cx| { + let project_state = this + .projects + .get_mut(&project.entity_id()) + .context("Project not found")?; + + let new_prediction = CurrentEditPrediction { + requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()), + prediction: prediction, + was_shown: false, + }; + + if project_state + .current_prediction + .as_ref() + .is_none_or(|old_prediction| { + new_prediction.should_replace_prediction(&old_prediction, cx) + }) + { + project_state.current_prediction = Some(new_prediction); + cx.notify(); + } + anyhow::Ok(()) + })??; + Ok(Some(id)) + } else { + Ok(None) + } + }) }) } - fn parse_edits( - output_excerpt: Arc, - editable_range: Range, - snapshot: &BufferSnapshot, - ) -> Result, Arc)>> { - let content = output_excerpt.replace(CURSOR_MARKER, ""); - - let start_markers = content - .match_indices(EDITABLE_REGION_START_MARKER) - .collect::>(); - anyhow::ensure!( - start_markers.len() == 1, - "expected exactly one start marker, found {}", - start_markers.len() - ); + pub fn refresh_prediction_from_diagnostics( + &mut self, + project: Entity, + cx: &mut Context, + ) { + let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { + return; + }; - let end_markers = content - .match_indices(EDITABLE_REGION_END_MARKER) - .collect::>(); - anyhow::ensure!( - end_markers.len() == 1, - "expected exactly one end marker, found {}", - end_markers.len() - ); - - let sof_markers = content - .match_indices(START_OF_FILE_MARKER) - .collect::>(); - anyhow::ensure!( - sof_markers.len() <= 1, - "expected at most one start-of-file marker, found {}", - sof_markers.len() - ); + // Prefer predictions from buffer + if zeta_project.current_prediction.is_some() { + return; + }; - let codefence_start = start_markers[0].0; - let content = &content[codefence_start..]; + self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| { + let Some(open_buffer_task) = project + .update(cx, |project, cx| { + project + .active_entry() + .and_then(|entry| project.path_for_entry(entry, cx)) + .map(|path| project.open_buffer(path, cx)) + }) + .log_err() + .flatten() + else { + return Task::ready(anyhow::Ok(None)); + }; - let newline_ix = content.find('\n').context("could not find newline")?; - let content = &content[newline_ix + 1..]; + cx.spawn(async move |cx| { + let active_buffer = open_buffer_task.await?; + let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( + active_buffer, + &snapshot, + Default::default(), + Default::default(), + &project, + cx, + ) + .await? + else { + return anyhow::Ok(None); + }; - let codefence_end = content - .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}")) - .context("could not find end marker")?; - let new_text = &content[..codefence_end]; + let Some(prediction) = this + .update(cx, |this, cx| { + this.request_prediction(&project, &jump_buffer, jump_position, cx) + })? + .await? + else { + return anyhow::Ok(None); + }; - let old_text = snapshot - .text_for_range(editable_range.clone()) - .collect::(); + let id = prediction.id.clone(); + this.update(cx, |this, cx| { + if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { + zeta_project.current_prediction.get_or_insert_with(|| { + cx.notify(); + CurrentEditPrediction { + requested_by: PredictionRequestedBy::DiagnosticsUpdate, + prediction, + was_shown: false, + } + }); + } + })?; - Ok(Self::compute_edits( - old_text, - new_text, - editable_range.start, - snapshot, - )) + anyhow::Ok(Some(id)) + }) + }); } - pub fn compute_edits( - old_text: String, - new_text: &str, - offset: usize, - snapshot: &BufferSnapshot, - ) -> Vec<(Range, Arc)> { - text_diff(&old_text, new_text) - .into_iter() - .map(|(mut old_range, new_text)| { - old_range.start += offset; - old_range.end += offset; + #[cfg(not(test))] + pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); + #[cfg(test)] + pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO; - let prefix_len = common_prefix( - snapshot.chars_for_range(old_range.clone()), - new_text.chars(), - ); - old_range.start += prefix_len; + fn queue_prediction_refresh( + &mut self, + project: Entity, + throttle_entity: EntityId, + cx: &mut Context, + do_refresh: impl FnOnce( + WeakEntity, + &mut AsyncApp, + ) -> Task>> + + 'static, + ) { + let zeta_project = self.get_or_init_zeta_project(&project, cx); + let pending_prediction_id = zeta_project.next_pending_prediction_id; + zeta_project.next_pending_prediction_id += 1; + let last_request = zeta_project.last_prediction_refresh; - let suffix_len = common_prefix( - snapshot.reversed_chars_for_range(old_range.clone()), - new_text[prefix_len..].chars().rev(), - ); - old_range.end = old_range.end.saturating_sub(suffix_len); + // TODO report cancelled requests like in zeta1 + let task = cx.spawn(async move |this, cx| { + if let Some((last_entity, last_timestamp)) = last_request + && throttle_entity == last_entity + && let Some(timeout) = + (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now()) + { + cx.background_executor().timer(timeout).await; + } - let new_text = new_text[prefix_len..new_text.len() - suffix_len].into(); - let range = if old_range.is_empty() { - let anchor = snapshot.anchor_after(old_range.start); - anchor..anchor - } else { - snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end) - }; - (range, new_text) + let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten(); + + // When a prediction completes, remove it from the pending list, and cancel + // any pending predictions that were enqueued before it. + this.update(cx, |this, cx| { + let zeta_project = this.get_or_init_zeta_project(&project, cx); + let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions); + for (ix, pending_prediction) in pending_predictions.iter().enumerate() { + if pending_prediction.id == pending_prediction_id { + pending_predictions.remove(ix); + for pending_prediction in pending_predictions.drain(0..ix) { + this.cancel_pending_prediction(pending_prediction, cx) + } + break; + } + } + this.get_or_init_zeta_project(&project, cx) + .pending_predictions = pending_predictions; + cx.notify(); }) - .collect() - } + .ok(); - pub fn is_completion_rated(&self, completion_id: EditPredictionId) -> bool { - self.rated_completions.contains(&completion_id) - } + edit_prediction_id + }); - pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context) { - self.shown_completions.push_front(completion.clone()); - if self.shown_completions.len() > 50 { - let completion = self.shown_completions.pop_back().unwrap(); - self.rated_completions.remove(&completion.id); + if zeta_project.pending_predictions.len() <= 1 { + zeta_project.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + task, + }); + } else if zeta_project.pending_predictions.len() == 2 { + let pending_prediction = zeta_project.pending_predictions.pop().unwrap(); + zeta_project.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + task, + }); + self.cancel_pending_prediction(pending_prediction, cx); } - cx.notify(); } - pub fn rate_completion( + pub fn request_prediction( &mut self, - completion: &EditPrediction, - rating: EditPredictionRating, - feedback: String, + project: &Entity, + active_buffer: &Entity, + position: language::Anchor, cx: &mut Context, - ) { - self.rated_completions.insert(completion.id); - telemetry::event!( - "Edit Prediction Rated", - rating, - input_events = completion.input_events, - input_excerpt = completion.input_excerpt, - input_outline = completion.input_outline, - output_excerpt = completion.output_excerpt, - feedback - ); - self.client.telemetry().flush_events().detach(); - cx.notify(); - } - - pub fn shown_completions(&self) -> impl DoubleEndedIterator { - self.shown_completions.iter() - } - - pub fn shown_completions_len(&self) -> usize { - self.shown_completions.len() + ) -> Task>> { + match self.edit_prediction_model { + ZetaEditPredictionModel::Zeta1 => { + request_prediction_with_zeta1(self, project, active_buffer, position, cx) + } + ZetaEditPredictionModel::Zeta2 => { + self.request_prediction_with_zeta2(project, active_buffer, position, cx) + } + ZetaEditPredictionModel::Sweep => { + self.request_prediction_with_sweep(project, active_buffer, position, true, cx) + } + } } - fn report_changes_for_buffer( + fn request_prediction_with_sweep( &mut self, - buffer: &Entity, project: &Entity, + active_buffer: &Entity, + position: language::Anchor, + allow_jump: bool, cx: &mut Context, - ) -> BufferSnapshot { - let zeta_project = self.get_or_init_zeta_project(project, cx); - let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx); + ) -> Task>> { + let snapshot = active_buffer.read(cx).snapshot(); + let debug_info = self.sweep_ai_debug_info.clone(); + let Some(api_token) = self.sweep_api_token.clone() else { + return Task::ready(Ok(None)); + }; + let full_path: Arc = snapshot + .file() + .map(|file| file.full_path(cx)) + .unwrap_or_else(|| "untitled".into()) + .into(); + + let project_file = project::File::from_dyn(snapshot.file()); + let repo_name = project_file + .map(|file| file.worktree.read(cx).root_name_str()) + .unwrap_or("untitled") + .into(); + let offset = position.to_offset(&snapshot); + + let project_state = self.get_or_init_zeta_project(project, cx); + let events = project_state.events(cx); + let has_events = !events.is_empty(); + let recent_buffers = project_state.recent_paths.iter().cloned(); + let http_client = cx.http_client(); + + let recent_buffer_snapshots = recent_buffers + .filter_map(|project_path| { + let buffer = project.read(cx).get_open_buffer(&project_path, cx)?; + if active_buffer == &buffer { + None + } else { + Some(buffer.read(cx).snapshot()) + } + }) + .take(3) + .collect::>(); - let new_snapshot = buffer.read(cx).snapshot(); - if new_snapshot.version != registered_buffer.snapshot.version { - let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); - Self::push_event( - zeta_project, - Event::BufferChange { - old_snapshot, - new_snapshot: new_snapshot.clone(), - timestamp: Instant::now(), - }, - ); - } + const DIAGNOSTIC_LINES_RANGE: u32 = 20; - new_snapshot - } + let cursor_point = position.to_point(&snapshot); + let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE); + let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE; + let diagnostic_search_range = + Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0); + let buffer_snapshotted_at = Instant::now(); - fn can_collect_file(&self, file: &Arc, cx: &App) -> bool { - self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx) - } + let result = cx.background_spawn({ + let snapshot = snapshot.clone(); + let diagnostic_search_range = diagnostic_search_range.clone(); + async move { + let text = snapshot.text(); - fn can_collect_events(&self, events: &[Event], cx: &App) -> bool { - if !self.data_collection_choice.is_enabled() { - return false; - } - let mut last_checked_file = None; - for event in events { - match event { - Event::BufferChange { - old_snapshot, - new_snapshot, - .. - } => { - if let Some(old_file) = old_snapshot.file() - && let Some(new_file) = new_snapshot.file() - { - if let Some(last_checked_file) = last_checked_file - && Arc::ptr_eq(last_checked_file, old_file) - && Arc::ptr_eq(last_checked_file, new_file) - { - continue; - } - if !self.can_collect_file(old_file, cx) { - return false; - } - if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx) - { - return false; + let mut recent_changes = String::new(); + for event in &events { + sweep_ai::write_event(event.as_ref(), &mut recent_changes).unwrap(); + } + + let mut file_chunks = recent_buffer_snapshots + .into_iter() + .map(|snapshot| { + let end_point = Point::new(30, 0).min(snapshot.max_point()); + sweep_ai::FileChunk { + content: snapshot.text_for_range(Point::zero()..end_point).collect(), + file_path: snapshot + .file() + .map(|f| f.path().as_unix_str()) + .unwrap_or("untitled") + .to_string(), + start_line: 0, + end_line: end_point.row as usize, + timestamp: snapshot.file().and_then(|file| { + Some( + file.disk_state() + .mtime()? + .to_seconds_and_nanos_for_persistence()? + .0, + ) + }), } - last_checked_file = Some(new_file); - } else { - return false; - } + }) + .collect::>(); + + let diagnostic_entries = + snapshot.diagnostics_in_range(diagnostic_search_range, false); + let mut diagnostic_content = String::new(); + let mut diagnostic_count = 0; + + for entry in diagnostic_entries { + let start_point: Point = entry.range.start; + + let severity = match entry.diagnostic.severity { + DiagnosticSeverity::ERROR => "error", + DiagnosticSeverity::WARNING => "warning", + DiagnosticSeverity::INFORMATION => "info", + DiagnosticSeverity::HINT => "hint", + _ => continue, + }; + + diagnostic_count += 1; + + writeln!( + &mut diagnostic_content, + "{} at line {}: {}", + severity, + start_point.row + 1, + entry.diagnostic.message + )?; + } + + if !diagnostic_content.is_empty() { + file_chunks.push(sweep_ai::FileChunk { + file_path: format!("Diagnostics for {}", full_path.display()), + start_line: 0, + end_line: diagnostic_count, + content: diagnostic_content, + timestamp: None, + }); + } + + let request_body = sweep_ai::AutocompleteRequest { + debug_info, + repo_name, + file_path: full_path.clone(), + file_contents: text.clone(), + original_file_contents: text, + cursor_position: offset, + recent_changes: recent_changes.clone(), + changes_above_cursor: true, + multiple_suggestions: false, + branch: None, + file_chunks, + retrieval_chunks: vec![], + recent_user_actions: vec![], + // TODO + privacy_mode_enabled: false, + }; + + let mut buf: Vec = Vec::new(); + let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22); + serde_json::to_writer(writer, &request_body)?; + let body: AsyncBody = buf.into(); + + let inputs = EditPredictionInputs { + events, + included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile { + path: full_path.clone(), + max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row), + excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt { + start_line: cloud_llm_client::predict_edits_v3::Line(0), + text: request_body.file_contents.into(), + }], + }], + cursor_point: cloud_llm_client::predict_edits_v3::Point { + column: cursor_point.column, + line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row), + }, + cursor_path: full_path.clone(), + }; + + const SWEEP_API_URL: &str = + "https://autocomplete.sweep.dev/backend/next_edit_autocomplete"; + + let request = http_client::Request::builder() + .uri(SWEEP_API_URL) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_token)) + .header("Connection", "keep-alive") + .header("Content-Encoding", "br") + .method(Method::POST) + .body(body)?; + + let mut response = http_client.send(request).await?; + + let mut body: Vec = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + let response_received_at = Instant::now(); + if !response.status().is_success() { + anyhow::bail!( + "Request failed with status: {:?}\nBody: {}", + response.status(), + String::from_utf8_lossy(&body), + ); + }; + + let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?; + + let old_text = snapshot + .text_for_range(response.start_index..response.end_index) + .collect::(); + let edits = language::text_diff(&old_text, &response.completion) + .into_iter() + .map(|(range, text)| { + ( + snapshot.anchor_after(response.start_index + range.start) + ..snapshot.anchor_before(response.start_index + range.end), + text, + ) + }) + .collect::>(); + + anyhow::Ok(( + response.autocomplete_id, + edits, + snapshot, + response_received_at, + inputs, + )) + } + }); + + let buffer = active_buffer.clone(); + let project = project.clone(); + let active_buffer = active_buffer.clone(); + + cx.spawn(async move |this, cx| { + let (id, edits, old_snapshot, response_received_at, inputs) = result.await?; + + if edits.is_empty() { + if has_events + && allow_jump + && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( + active_buffer, + &snapshot, + diagnostic_search_range, + cursor_point, + &project, + cx, + ) + .await? + { + return this + .update(cx, |this, cx| { + this.request_prediction_with_sweep( + &project, + &jump_buffer, + jump_position, + false, + cx, + ) + })? + .await; } + + return anyhow::Ok(None); } - } - true - } - fn is_file_open_source(&self, file: &Arc, cx: &App) -> bool { - if !file.is_local() || file.is_private() { - return false; - } - self.license_detection_watchers - .get(&file.worktree_id(cx)) - .is_some_and(|watcher| watcher.is_project_open_source()) + anyhow::Ok( + EditPrediction::new( + EditPredictionId(id.into()), + &buffer, + &old_snapshot, + edits.into(), + buffer_snapshotted_at, + response_received_at, + inputs, + cx, + ) + .await, + ) + }) } - fn load_data_collection_choice() -> DataCollectionChoice { - let choice = KEY_VALUE_STORE - .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) - .log_err() - .flatten(); + async fn next_diagnostic_location( + active_buffer: Entity, + active_buffer_snapshot: &BufferSnapshot, + active_buffer_diagnostic_search_range: Range, + active_buffer_cursor_point: Point, + project: &Entity, + cx: &mut AsyncApp, + ) -> Result, language::Anchor)>> { + // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request + let mut jump_location = active_buffer_snapshot + .diagnostic_groups(None) + .into_iter() + .filter_map(|(_, group)| { + let range = &group.entries[group.primary_ix] + .range + .to_point(&active_buffer_snapshot); + if range.overlaps(&active_buffer_diagnostic_search_range) { + None + } else { + Some(range.start) + } + }) + .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row)) + .map(|position| { + ( + active_buffer.clone(), + active_buffer_snapshot.anchor_before(position), + ) + }); - match choice.as_deref() { - Some("true") => DataCollectionChoice::Enabled, - Some("false") => DataCollectionChoice::Disabled, - Some(_) => { - log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'"); - DataCollectionChoice::NotAnswered + if jump_location.is_none() { + let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| { + let file = buffer.file()?; + + Some(ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + }) + })?; + + let buffer_task = project.update(cx, |project, cx| { + let (path, _, _) = project + .diagnostic_summaries(false, cx) + .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref()) + .max_by_key(|(path, _, _)| { + // find the buffer with errors that shares most parent directories + path.path + .components() + .zip( + active_buffer_path + .as_ref() + .map(|p| p.path.components()) + .unwrap_or_default(), + ) + .take_while(|(a, b)| a == b) + .count() + })?; + + Some(project.open_buffer(path, cx)) + })?; + + if let Some(buffer_task) = buffer_task { + let closest_buffer = buffer_task.await?; + + jump_location = closest_buffer + .read_with(cx, |buffer, _cx| { + buffer + .buffer_diagnostics(None) + .into_iter() + .min_by_key(|entry| entry.diagnostic.severity) + .map(|entry| entry.range.start) + })? + .map(|position| (closest_buffer, position)); } - None => DataCollectionChoice::NotAnswered, } - } - fn toggle_data_collection_choice(&mut self, cx: &mut Context) { - self.data_collection_choice = self.data_collection_choice.toggle(); - let new_choice = self.data_collection_choice; - db::write_and_log(cx, move || { - KEY_VALUE_STORE.write_kvp( - ZED_PREDICT_DATA_COLLECTION_CHOICE.into(), - new_choice.is_enabled().to_string(), - ) - }); + anyhow::Ok(jump_location) } - fn discard_completion( + fn request_prediction_with_zeta2( &mut self, - completion_id: EditPredictionId, - was_shown: bool, + project: &Entity, + active_buffer: &Entity, + position: language::Anchor, cx: &mut Context, - ) { - self.discarded_completions.push(EditPredictionRejection { - request_id: completion_id.to_string(), - was_shown, - }); + ) -> Task>> { + let project_state = self.projects.get(&project.entity_id()); - let reached_request_limit = - self.discarded_completions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST; - let discard_completions_tx = self.discard_completions_tx.clone(); - self.discard_completions_debounce_task = Some(cx.spawn(async move |_this, cx| { - const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15); - if !reached_request_limit { - cx.background_executor() - .timer(DISCARD_COMPLETIONS_DEBOUNCE) - .await; - } - discard_completions_tx.unbounded_send(()).log_err(); - })); - } -} + let index_state = project_state.and_then(|state| { + state + .syntax_index + .as_ref() + .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone())) + }); + let options = self.options.clone(); + let active_snapshot = active_buffer.read(cx).snapshot(); + let buffer_snapshotted_at = Instant::now(); + let Some(excerpt_path) = active_snapshot + .file() + .map(|path| -> Arc { path.full_path(cx).into() }) + else { + return Task::ready(Err(anyhow!("No file path for excerpt"))); + }; + let client = self.client.clone(); + let llm_token = self.llm_token.clone(); + let app_version = AppVersion::global(cx); + let worktree_snapshots = project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect::>(); + let debug_tx = self.debug_tx.clone(); -pub struct PerformPredictEditsParams { - pub client: Arc, - pub llm_token: LlmApiToken, - pub app_version: Version, - pub body: PredictEditsBody, -} + let events = project_state + .map(|state| state.events(cx)) + .unwrap_or_default(); -#[derive(Error, Debug)] -#[error( - "You must update to Zed version {minimum_version} or higher to continue using edit predictions." -)] -pub struct ZedUpdateRequiredError { - minimum_version: Version, -} + let diagnostics = active_snapshot.diagnostic_sets().clone(); -fn common_prefix, T2: Iterator>(a: T1, b: T2) -> usize { - a.zip(b) - .take_while(|(a, b)| a == b) - .map(|(a, _)| a.len_utf8()) - .sum() -} + let file = active_buffer.read(cx).file(); + let parent_abs_path = project::File::from_dyn(file).and_then(|f| { + let mut path = f.worktree.read(cx).absolutize(&f.path); + if path.pop() { Some(path) } else { None } + }); -fn git_info_for_file( - project: &Entity, - project_path: &ProjectPath, - cx: &App, -) -> Option { - let git_store = project.read(cx).git_store().read(cx); - if let Some((repository, _repo_path)) = - git_store.repository_and_path_for_project_path(project_path, cx) - { - let repository = repository.read(cx); - let head_sha = repository - .head_commit + // TODO data collection + let can_collect_data = file .as_ref() - .map(|head_commit| head_commit.sha.to_string()); - let remote_origin_url = repository.remote_origin_url.clone(); - let remote_upstream_url = repository.remote_upstream_url.clone(); - if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { - return None; - } - Some(PredictEditsGitInfo { - head_sha, - remote_origin_url, - remote_upstream_url, - }) - } else { - None - } -} + .map_or(false, |file| self.can_collect_file(project, file, cx)); + + let empty_context_files = HashMap::default(); + let context_files = project_state + .and_then(|project_state| project_state.context.as_ref()) + .unwrap_or(&empty_context_files); + + #[cfg(feature = "eval-support")] + let parsed_fut = futures::future::join_all( + context_files + .keys() + .map(|buffer| buffer.read(cx).parsing_idle()), + ); -pub struct GatherContextOutput { - pub body: PredictEditsBody, - pub editable_range: Range, - pub included_events_count: usize, -} + let mut included_files = context_files + .iter() + .filter_map(|(buffer_entity, ranges)| { + let buffer = buffer_entity.read(cx); + Some(( + buffer_entity.clone(), + buffer.snapshot(), + buffer.file()?.full_path(cx).into(), + ranges.clone(), + )) + }) + .collect::>(); -pub fn gather_context( - full_path_str: String, - snapshot: &BufferSnapshot, - cursor_point: language::Point, - prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static, - cx: &App, -) -> Task> { - cx.background_spawn({ - let snapshot = snapshot.clone(); - async move { - let input_excerpt = excerpt_for_cursor_position( - cursor_point, - &full_path_str, - &snapshot, - MAX_REWRITE_TOKENS, - MAX_CONTEXT_TOKENS, - ); - let (input_events, included_events_count) = prompt_for_events(); - let editable_range = input_excerpt.editable_range.to_offset(&snapshot); - - let body = PredictEditsBody { - input_events, - input_excerpt: input_excerpt.prompt, - can_collect_data: false, - diagnostic_groups: None, - git_info: None, - outline: None, - speculated_output: None, - }; + included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| { + (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len())) + }); - Ok(GatherContextOutput { - body, - editable_range, - included_events_count, - }) - } - }) -} + #[cfg(feature = "eval-support")] + let eval_cache = self.eval_cache.clone(); -fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) { - let mut result = String::new(); - for (ix, event) in events.iter().rev().enumerate() { - let event_string = event.to_prompt(); - let event_tokens = guess_token_count(event_string.len()); - if event_tokens > remaining_tokens { - return (result, ix); - } + let request_task = cx.background_spawn({ + let active_buffer = active_buffer.clone(); + async move { + #[cfg(feature = "eval-support")] + parsed_fut.await; - if !result.is_empty() { - result.insert_str(0, "\n\n"); - } - result.insert_str(0, &event_string); - remaining_tokens -= event_tokens; - } - return (result, events.len()); -} + let index_state = if let Some(index_state) = index_state { + Some(index_state.lock_owned().await) + } else { + None + }; -struct RegisteredBuffer { - snapshot: BufferSnapshot, - _subscriptions: [gpui::Subscription; 2], -} + let cursor_offset = position.to_offset(&active_snapshot); + let cursor_point = cursor_offset.to_point(&active_snapshot); -#[derive(Clone)] -pub enum Event { - BufferChange { - old_snapshot: BufferSnapshot, - new_snapshot: BufferSnapshot, - timestamp: Instant, - }, -} + let before_retrieval = Instant::now(); -impl Event { - fn to_prompt(&self) -> String { - match self { - Event::BufferChange { - old_snapshot, - new_snapshot, - .. - } => { - let mut prompt = String::new(); - - let old_path = old_snapshot - .file() - .map(|f| f.path().as_ref()) - .unwrap_or(RelPath::unix("untitled").unwrap()); - let new_path = new_snapshot - .file() - .map(|f| f.path().as_ref()) - .unwrap_or(RelPath::unix("untitled").unwrap()); - if old_path != new_path { - writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap(); + let (diagnostic_groups, diagnostic_groups_truncated) = + Self::gather_nearby_diagnostics( + cursor_offset, + &diagnostics, + &active_snapshot, + options.max_diagnostic_bytes, + ); + + let cloud_request = match options.context { + ContextMode::Agentic(context_options) => { + let Some(excerpt) = EditPredictionExcerpt::select_from_buffer( + cursor_point, + &active_snapshot, + &context_options.excerpt, + index_state.as_deref(), + ) else { + return Ok((None, None)); + }; + + let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start) + ..active_snapshot.anchor_before(excerpt.range.end); + + if let Some(buffer_ix) = + included_files.iter().position(|(_, snapshot, _, _)| { + snapshot.remote_id() == active_snapshot.remote_id() + }) + { + let (_, buffer, _, ranges) = &mut included_files[buffer_ix]; + ranges.push(excerpt_anchor_range); + retrieval_search::merge_anchor_ranges(ranges, buffer); + let last_ix = included_files.len() - 1; + included_files.swap(buffer_ix, last_ix); + } else { + included_files.push(( + active_buffer.clone(), + active_snapshot.clone(), + excerpt_path.clone(), + vec![excerpt_anchor_range], + )); + } + + let included_files = included_files + .iter() + .map(|(_, snapshot, path, ranges)| { + let ranges = ranges + .iter() + .map(|range| { + let point_range = range.to_point(&snapshot); + Line(point_range.start.row)..Line(point_range.end.row) + }) + .collect::>(); + let excerpts = assemble_excerpts(&snapshot, ranges); + predict_edits_v3::IncludedFile { + path: path.clone(), + max_row: Line(snapshot.max_point().row), + excerpts, + } + }) + .collect::>(); + + predict_edits_v3::PredictEditsRequest { + excerpt_path, + excerpt: String::new(), + excerpt_line_range: Line(0)..Line(0), + excerpt_range: 0..0, + cursor_point: predict_edits_v3::Point { + line: predict_edits_v3::Line(cursor_point.row), + column: cursor_point.column, + }, + included_files, + referenced_declarations: vec![], + events, + can_collect_data, + diagnostic_groups, + diagnostic_groups_truncated, + debug_info: debug_tx.is_some(), + prompt_max_bytes: Some(options.max_prompt_bytes), + prompt_format: options.prompt_format, + // TODO [zeta2] + signatures: vec![], + excerpt_parent: None, + git_info: None, + } + } + ContextMode::Syntax(context_options) => { + let Some(context) = EditPredictionContext::gather_context( + cursor_point, + &active_snapshot, + parent_abs_path.as_deref(), + &context_options, + index_state.as_deref(), + ) else { + return Ok((None, None)); + }; + + make_syntax_context_cloud_request( + excerpt_path, + context, + events, + can_collect_data, + diagnostic_groups, + diagnostic_groups_truncated, + None, + debug_tx.is_some(), + &worktree_snapshots, + index_state.as_deref(), + Some(options.max_prompt_bytes), + options.prompt_format, + ) + } + }; + + let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request); + + let inputs = EditPredictionInputs { + included_files: cloud_request.included_files, + events: cloud_request.events, + cursor_point: cloud_request.cursor_point, + cursor_path: cloud_request.excerpt_path, + }; + + let retrieval_time = Instant::now() - before_retrieval; + + let debug_response_tx = if let Some(debug_tx) = &debug_tx { + let (response_tx, response_rx) = oneshot::channel(); + + debug_tx + .unbounded_send(ZetaDebugInfo::EditPredictionRequested( + ZetaEditPredictionDebugInfo { + inputs: inputs.clone(), + retrieval_time, + buffer: active_buffer.downgrade(), + local_prompt: match prompt_result.as_ref() { + Ok((prompt, _)) => Ok(prompt.clone()), + Err(err) => Err(err.to_string()), + }, + position, + response_rx, + }, + )) + .ok(); + Some(response_tx) + } else { + None + }; + + if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() { + if let Some(debug_response_tx) = debug_response_tx { + debug_response_tx + .send((Err("Request skipped".to_string()), Duration::ZERO)) + .ok(); + } + anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set") } - let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text()); - if !diff.is_empty() { - write!( - prompt, - "User edited {:?}:\n```diff\n{}\n```", - new_path, diff - ) - .unwrap(); + let (prompt, _) = prompt_result?; + let generation_params = + cloud_zeta2_prompt::generation_params(cloud_request.prompt_format); + let request = open_ai::Request { + model: EDIT_PREDICTIONS_MODEL_ID.clone(), + messages: vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt), + }], + stream: false, + max_completion_tokens: None, + stop: generation_params.stop.unwrap_or_default(), + temperature: generation_params.temperature.unwrap_or(0.7), + tool_choice: None, + parallel_tool_calls: None, + tools: vec![], + prompt_cache_key: None, + reasoning_effort: None, + }; + + log::trace!("Sending edit prediction request"); + + let before_request = Instant::now(); + let response = Self::send_raw_llm_request( + request, + client, + llm_token, + app_version, + #[cfg(feature = "eval-support")] + eval_cache, + #[cfg(feature = "eval-support")] + EvalCacheEntryKind::Prediction, + ) + .await; + let received_response_at = Instant::now(); + let request_time = received_response_at - before_request; + + log::trace!("Got edit prediction response"); + + if let Some(debug_response_tx) = debug_response_tx { + debug_response_tx + .send(( + response + .as_ref() + .map_err(|err| err.to_string()) + .map(|response| response.0.clone()), + request_time, + )) + .ok(); } - prompt + let (res, usage) = response?; + let request_id = EditPredictionId(res.id.clone().into()); + let Some(mut output_text) = text_from_response(res) else { + return Ok((None, usage)); + }; + + if output_text.contains(CURSOR_MARKER) { + log::trace!("Stripping out {CURSOR_MARKER} from response"); + output_text = output_text.replace(CURSOR_MARKER, ""); + } + + let get_buffer_from_context = |path: &Path| { + included_files + .iter() + .find_map(|(_, buffer, probe_path, ranges)| { + if probe_path.as_ref() == path { + Some((buffer, ranges.as_slice())) + } else { + None + } + }) + }; + + let (edited_buffer_snapshot, edits) = match options.prompt_format { + PromptFormat::NumLinesUniDiff => { + // TODO: Implement parsing of multi-file diffs + crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? + } + PromptFormat::Minimal + | PromptFormat::MinimalQwen + | PromptFormat::SeedCoder1120 => { + if output_text.contains("--- a/\n+++ b/\nNo edits") { + let edits = vec![]; + (&active_snapshot, edits) + } else { + crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? + } + } + PromptFormat::OldTextNewText => { + crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context) + .await? + } + _ => { + bail!("unsupported prompt format {}", options.prompt_format) + } + }; + + let edited_buffer = included_files + .iter() + .find_map(|(buffer, snapshot, _, _)| { + if snapshot.remote_id() == edited_buffer_snapshot.remote_id() { + Some(buffer.clone()) + } else { + None + } + }) + .context("Failed to find buffer in included_buffers")?; + + anyhow::Ok(( + Some(( + request_id, + inputs, + edited_buffer, + edited_buffer_snapshot.clone(), + edits, + received_response_at, + )), + usage, + )) } - } - } -} + }); -#[derive(Debug, Clone)] -struct CurrentEditPrediction { - buffer_id: EntityId, - completion: EditPrediction, - was_shown: bool, - was_accepted: bool, -} + cx.spawn({ + async move |this, cx| { + let Some(( + id, + inputs, + edited_buffer, + edited_buffer_snapshot, + edits, + received_response_at, + )) = Self::handle_api_response(&this, request_task.await, cx)? + else { + return Ok(None); + }; -impl CurrentEditPrediction { - fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool { - if self.buffer_id != old_completion.buffer_id { - return true; - } + // TODO telemetry: duration, etc + Ok(EditPrediction::new( + id, + &edited_buffer, + &edited_buffer_snapshot, + edits.into(), + buffer_snapshotted_at, + received_response_at, + inputs, + cx, + ) + .await) + } + }) + } - let Some(old_edits) = old_completion.completion.interpolate(snapshot) else { - return true; - }; - let Some(new_edits) = self.completion.interpolate(snapshot) else { - return false; + async fn send_raw_llm_request( + request: open_ai::Request, + client: Arc, + llm_token: LlmApiToken, + app_version: Version, + #[cfg(feature = "eval-support")] eval_cache: Option>, + #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind, + ) -> Result<(open_ai::Response, Option)> { + let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() { + http_client::Url::parse(&predict_edits_url)? + } else { + client + .http_client() + .build_zed_llm_url("/predict_edits/raw", &[])? }; - if old_edits.len() == 1 && new_edits.len() == 1 { - let (old_range, old_text) = &old_edits[0]; - let (new_range, new_text) = &new_edits[0]; - new_range == old_range && new_text.starts_with(old_text.as_ref()) - } else { - true - } - } -} + #[cfg(feature = "eval-support")] + let cache_key = if let Some(cache) = eval_cache { + use collections::FxHasher; + use std::hash::{Hash, Hasher}; -struct PendingCompletion { - id: usize, - task: Task<()>, -} + let mut hasher = FxHasher::default(); + url.hash(&mut hasher); + let request_str = serde_json::to_string_pretty(&request)?; + request_str.hash(&mut hasher); + let hash = hasher.finish(); -#[derive(Debug, Clone, Copy)] -pub enum DataCollectionChoice { - NotAnswered, - Enabled, - Disabled, -} + let key = (eval_cache_kind, hash); + if let Some(response_str) = cache.read(key) { + return Ok((serde_json::from_str(&response_str)?, None)); + } -impl DataCollectionChoice { - pub fn is_enabled(self) -> bool { - match self { - Self::Enabled => true, - Self::NotAnswered | Self::Disabled => false, - } - } + Some((cache, request_str, key)) + } else { + None + }; - pub fn is_answered(self) -> bool { - match self { - Self::Enabled | Self::Disabled => true, - Self::NotAnswered => false, + let (response, usage) = Self::send_api_request( + |builder| { + let req = builder + .uri(url.as_ref()) + .body(serde_json::to_string(&request)?.into()); + Ok(req?) + }, + client, + llm_token, + app_version, + ) + .await?; + + #[cfg(feature = "eval-support")] + if let Some((cache, request, key)) = cache_key { + cache.write(key, &request, &serde_json::to_string_pretty(&response)?); } + + Ok((response, usage)) } - #[must_use] - pub fn toggle(&self) -> DataCollectionChoice { - match self { - Self::Enabled => Self::Disabled, - Self::Disabled => Self::Enabled, - Self::NotAnswered => Self::Enabled, + fn handle_api_response( + this: &WeakEntity, + response: Result<(T, Option)>, + cx: &mut gpui::AsyncApp, + ) -> Result { + match response { + Ok((data, usage)) => { + if let Some(usage) = usage { + this.update(cx, |this, cx| { + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); + }) + .ok(); + } + Ok(data) + } + Err(err) => { + if err.is::() { + cx.update(|cx| { + this.update(cx, |this, _cx| { + this.update_required = true; + }) + .ok(); + + let error_message: SharedString = err.to_string().into(); + show_app_notification( + NotificationId::unique::(), + cx, + move |cx| { + cx.new(|cx| { + ErrorMessagePrompt::new(error_message.clone(), cx) + .with_link_button("Update Zed", "https://zed.dev/releases") + }) + }, + ); + }) + .ok(); + } + Err(err) + } } } -} -impl From for DataCollectionChoice { - fn from(value: bool) -> Self { - match value { - true => DataCollectionChoice::Enabled, - false => DataCollectionChoice::Disabled, + async fn send_api_request( + build: impl Fn(http_client::http::request::Builder) -> Result>, + client: Arc, + llm_token: LlmApiToken, + app_version: Version, + ) -> Result<(Res, Option)> + where + Res: DeserializeOwned, + { + let http_client = client.http_client(); + let mut token = llm_token.acquire(&client).await?; + let mut did_retry = false; + + loop { + let request_builder = http_client::Request::builder().method(Method::POST); + + let request = build( + request_builder + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", token)) + .header(ZED_VERSION_HEADER_NAME, app_version.to_string()), + )?; + + let mut response = http_client.send(request).await?; + + if let Some(minimum_required_version) = response + .headers() + .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) + .and_then(|version| Version::from_str(version.to_str().ok()?).ok()) + { + anyhow::ensure!( + app_version >= minimum_required_version, + ZedUpdateRequiredError { + minimum_version: minimum_required_version + } + ); + } + + if response.status().is_success() { + let usage = EditPredictionUsage::from_headers(response.headers()).ok(); + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + return Ok((serde_json::from_slice(&body)?, usage)); + } else if !did_retry + && response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() + { + did_retry = true; + token = llm_token.refresh(&client).await?; + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + anyhow::bail!( + "Request failed with status: {:?}\nBody: {}", + response.status(), + body + ); + } } } -} -async fn llm_token_retry( - llm_token: &LlmApiToken, - client: &Arc, - build_request: impl Fn(String) -> Result>, -) -> Result> { - let mut did_retry = false; - let http_client = client.http_client(); - let mut token = llm_token.acquire(client).await?; - loop { - let request = build_request(token.clone())?; - let response = http_client.send(request).await?; - - if !did_retry - && !response.status().is_success() - && response - .headers() - .get(EXPIRED_LLM_TOKEN_HEADER_NAME) - .is_some() - { - did_retry = true; - token = llm_token.refresh(client).await?; - continue; + pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10); + pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3); + + // Refresh the related excerpts when the user just beguns editing after + // an idle period, and after they pause editing. + fn refresh_context_if_needed( + &mut self, + project: &Entity, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut Context, + ) { + if !matches!(&self.options().context, ContextMode::Agentic { .. }) { + return; } - return Ok(response); - } -} + let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { + return; + }; -pub struct ZetaEditPredictionProvider { - zeta: Entity, - singleton_buffer: Option>, - pending_completions: ArrayVec, - canceled_completions: HashMap>, - next_pending_completion_id: usize, - current_completion: Option, - last_request_timestamp: Instant, - project: Entity, -} + let now = Instant::now(); + let was_idle = zeta_project + .refresh_context_timestamp + .map_or(true, |timestamp| { + now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION + }); + zeta_project.refresh_context_timestamp = Some(now); + zeta_project.refresh_context_debounce_task = Some(cx.spawn({ + let buffer = buffer.clone(); + let project = project.clone(); + async move |this, cx| { + if was_idle { + log::debug!("refetching edit prediction context after idle"); + } else { + cx.background_executor() + .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION) + .await; + log::debug!("refetching edit prediction context after pause"); + } + this.update(cx, |this, cx| { + let task = this.refresh_context(project.clone(), buffer, cursor_position, cx); -impl ZetaEditPredictionProvider { - pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); + if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { + zeta_project.refresh_context_task = Some(task.log_err()); + }; + }) + .ok() + } + })); + } - pub fn new( - zeta: Entity, + // Refresh the related excerpts asynchronously. Ensure the task runs to completion, + // and avoid spawning more than one concurrent task. + pub fn refresh_context( + &mut self, project: Entity, - singleton_buffer: Option>, + buffer: Entity, + cursor_position: language::Anchor, cx: &mut Context, - ) -> Self { - cx.on_release(|this, cx| { - this.take_current_edit_prediction(cx); - }) - .detach(); + ) -> Task> { + let Some(zeta_project) = self.projects.get(&project.entity_id()) else { + return Task::ready(anyhow::Ok(())); + }; - Self { - zeta, - singleton_buffer, - pending_completions: ArrayVec::new(), - canceled_completions: HashMap::default(), - next_pending_completion_id: 0, - current_completion: None, - last_request_timestamp: Instant::now(), - project, + let ContextMode::Agentic(options) = &self.options().context else { + return Task::ready(anyhow::Ok(())); + }; + + let snapshot = buffer.read(cx).snapshot(); + let cursor_point = cursor_position.to_point(&snapshot); + let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer( + cursor_point, + &snapshot, + &options.excerpt, + None, + ) else { + return Task::ready(Ok(())); + }; + + let app_version = AppVersion::global(cx); + let client = self.client.clone(); + let llm_token = self.llm_token.clone(); + let debug_tx = self.debug_tx.clone(); + let current_file_path: Arc = snapshot + .file() + .map(|f| f.full_path(cx).into()) + .unwrap_or_else(|| Path::new("untitled").into()); + + let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt( + predict_edits_v3::PlanContextRetrievalRequest { + excerpt: cursor_excerpt.text(&snapshot).body, + excerpt_path: current_file_path, + excerpt_line_range: cursor_excerpt.line_range, + cursor_file_max_row: Line(snapshot.max_point().row), + events: zeta_project.events(cx), + }, + ) { + Ok(prompt) => prompt, + Err(err) => { + return Task::ready(Err(err)); + } + }; + + if let Some(debug_tx) = &debug_tx { + debug_tx + .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted( + ZetaContextRetrievalStartedDebugInfo { + project: project.clone(), + timestamp: Instant::now(), + search_prompt: prompt.clone(), + }, + )) + .ok(); } - } - fn take_current_edit_prediction(&mut self, cx: &mut App) { - if let Some(completion) = self.current_completion.take() { - if !completion.was_accepted { - self.zeta.update(cx, |zeta, cx| { - zeta.discard_completion(completion.completion.id, completion.was_shown, cx); - }); + pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| { + let schema = language_model::tool_schema::root_schema_for::( + language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset, + ); + + let description = schema + .get("description") + .and_then(|description| description.as_str()) + .unwrap() + .to_string(); + + (schema.into(), description) + }); + + let (tool_schema, tool_description) = TOOL_SCHEMA.clone(); + + let request = open_ai::Request { + model: CONTEXT_RETRIEVAL_MODEL_ID.clone(), + messages: vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt), + }], + stream: false, + max_completion_tokens: None, + stop: Default::default(), + temperature: 0.7, + tool_choice: None, + parallel_tool_calls: None, + tools: vec![open_ai::ToolDefinition::Function { + function: FunctionDefinition { + name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(), + description: Some(tool_description), + parameters: Some(tool_schema), + }, + }], + prompt_cache_key: None, + reasoning_effort: None, + }; + + #[cfg(feature = "eval-support")] + let eval_cache = self.eval_cache.clone(); + + cx.spawn(async move |this, cx| { + log::trace!("Sending search planning request"); + let response = Self::send_raw_llm_request( + request, + client, + llm_token, + app_version, + #[cfg(feature = "eval-support")] + eval_cache.clone(), + #[cfg(feature = "eval-support")] + EvalCacheEntryKind::Context, + ) + .await; + let mut response = Self::handle_api_response(&this, response, cx)?; + log::trace!("Got search planning response"); + + let choice = response + .choices + .pop() + .context("No choices in retrieval response")?; + let open_ai::RequestMessage::Assistant { + content: _, + tool_calls, + } = choice.message + else { + anyhow::bail!("Retrieval response didn't include an assistant message"); + }; + + let mut queries: Vec = Vec::new(); + for tool_call in tool_calls { + let open_ai::ToolCallContent::Function { function } = tool_call.content; + if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME { + log::warn!( + "Context retrieval response tried to call an unknown tool: {}", + function.name + ); + + continue; + } + + let input: SearchToolInput = serde_json::from_str(&function.arguments) + .with_context(|| format!("invalid search json {}", &function.arguments))?; + queries.extend(input.queries); } - } - } -} -impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { - fn name() -> &'static str { - "zed-predict" - } + if let Some(debug_tx) = &debug_tx { + debug_tx + .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated( + ZetaSearchQueryDebugInfo { + project: project.clone(), + timestamp: Instant::now(), + search_queries: queries.clone(), + }, + )) + .ok(); + } - fn display_name() -> &'static str { - "Zed's Edit Predictions" - } + log::trace!("Running retrieval search: {queries:#?}"); - fn show_completions_in_menu() -> bool { - true - } + let related_excerpts_result = retrieval_search::run_retrieval_searches( + queries, + project.clone(), + #[cfg(feature = "eval-support")] + eval_cache, + cx, + ) + .await; - fn show_tab_accept_marker() -> bool { - true - } + log::trace!("Search queries executed"); + + if let Some(debug_tx) = &debug_tx { + debug_tx + .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted( + ZetaContextRetrievalDebugInfo { + project: project.clone(), + timestamp: Instant::now(), + }, + )) + .ok(); + } - fn data_collection_state(&self, cx: &App) -> DataCollectionState { - if let Some(buffer) = &self.singleton_buffer - && let Some(file) = buffer.read(cx).file() - { - let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx); - if self.zeta.read(cx).data_collection_choice.is_enabled() { - DataCollectionState::Enabled { - is_project_open_source, + this.update(cx, |this, _cx| { + let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else { + return Ok(()); + }; + zeta_project.refresh_context_task.take(); + if let Some(debug_tx) = &this.debug_tx { + debug_tx + .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished( + ZetaContextRetrievalDebugInfo { + project, + timestamp: Instant::now(), + }, + )) + .ok(); } - } else { - DataCollectionState::Disabled { - is_project_open_source, + match related_excerpts_result { + Ok(excerpts) => { + zeta_project.context = Some(excerpts); + Ok(()) + } + Err(error) => Err(error), } - } - } else { - return DataCollectionState::Disabled { - is_project_open_source: false, - }; - } - } - - fn toggle_data_collection(&mut self, cx: &mut App) { - self.zeta - .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx)); - } - - fn usage(&self, cx: &App) -> Option { - self.zeta.read(cx).usage(cx) - } - - fn is_enabled( - &self, - _buffer: &Entity, - _cursor_position: language::Anchor, - _cx: &App, - ) -> bool { - true - } - fn is_refreshing(&self, _cx: &App) -> bool { - !self.pending_completions.is_empty() + })? + }) } - fn refresh( + pub fn set_context( &mut self, - buffer: Entity, - position: language::Anchor, - _debounce: bool, - cx: &mut Context, + project: Entity, + context: HashMap, Vec>>, ) { - if self.zeta.read(cx).update_required { - return; + if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) { + zeta_project.context = Some(context); } + } - if self - .zeta - .read(cx) - .user_store - .read_with(cx, |user_store, _cx| { - user_store.account_too_young() || user_store.has_overdue_invoices() - }) - { - return; + fn gather_nearby_diagnostics( + cursor_offset: usize, + diagnostic_sets: &[(LanguageServerId, DiagnosticSet)], + snapshot: &BufferSnapshot, + max_diagnostics_bytes: usize, + ) -> (Vec, bool) { + // TODO: Could make this more efficient + let mut diagnostic_groups = Vec::new(); + for (language_server_id, diagnostics) in diagnostic_sets { + let mut groups = Vec::new(); + diagnostics.groups(*language_server_id, &mut groups, &snapshot); + diagnostic_groups.extend( + groups + .into_iter() + .map(|(_, group)| group.resolve::(&snapshot)), + ); } - if let Some(current_completion) = self.current_completion.as_ref() { - let snapshot = buffer.read(cx).snapshot(); - if current_completion - .completion - .interpolate(&snapshot) - .is_some() - { - return; + // sort by proximity to cursor + diagnostic_groups.sort_by_key(|group| { + let range = &group.entries[group.primary_ix].range; + if range.start >= cursor_offset { + range.start - cursor_offset + } else if cursor_offset >= range.end { + cursor_offset - range.end + } else { + (cursor_offset - range.start).min(range.end - cursor_offset) + } + }); + + let mut results = Vec::new(); + let mut diagnostic_groups_truncated = false; + let mut diagnostics_byte_count = 0; + for group in diagnostic_groups { + let raw_value = serde_json::value::to_raw_value(&group).unwrap(); + diagnostics_byte_count += raw_value.get().len(); + if diagnostics_byte_count > max_diagnostics_bytes { + diagnostic_groups_truncated = true; + break; } + results.push(predict_edits_v3::DiagnosticGroup(raw_value)); } - let pending_completion_id = self.next_pending_completion_id; - self.next_pending_completion_id += 1; - let last_request_timestamp = self.last_request_timestamp; + (results, diagnostic_groups_truncated) + } - let project = self.project.clone(); - let task = cx.spawn(async move |this, cx| { - if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT) - .checked_duration_since(Instant::now()) - { - cx.background_executor().timer(timeout).await; - } + // TODO: Dedupe with similar code in request_prediction? + pub fn cloud_request_for_zeta_cli( + &mut self, + project: &Entity, + buffer: &Entity, + position: language::Anchor, + cx: &mut Context, + ) -> Task> { + let project_state = self.projects.get(&project.entity_id()); + + let index_state = project_state.and_then(|state| { + state + .syntax_index + .as_ref() + .map(|index| index.read_with(cx, |index, _cx| index.state().clone())) + }); + let options = self.options.clone(); + let snapshot = buffer.read(cx).snapshot(); + let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else { + return Task::ready(Err(anyhow!("No file path for excerpt"))); + }; + let worktree_snapshots = project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect::>(); - let completion_request = this.update(cx, |this, cx| { - this.last_request_timestamp = Instant::now(); - this.zeta.update(cx, |zeta, cx| { - zeta.request_completion(&project, &buffer, position, cx) - }) - }); + let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| { + let mut path = f.worktree.read(cx).absolutize(&f.path); + if path.pop() { Some(path) } else { None } + }); - let completion = match completion_request { - Ok(completion_request) => { - let completion_request = completion_request.await; - completion_request.map(|c| { - c.map(|completion| CurrentEditPrediction { - buffer_id: buffer.entity_id(), - completion, - was_shown: false, - was_accepted: false, - }) - }) - } - Err(error) => Err(error), + cx.background_spawn(async move { + let index_state = if let Some(index_state) = index_state { + Some(index_state.lock_owned().await) + } else { + None }; - let discarded = this - .update(cx, |this, cx| { - if this - .pending_completions - .first() - .is_some_and(|completion| completion.id == pending_completion_id) - { - this.pending_completions.remove(0); - } else { - if let Some(discarded) = this.pending_completions.drain(..).next() { - this.canceled_completions - .insert(discarded.id, discarded.task); - } - } - - let canceled = this.canceled_completions.remove(&pending_completion_id); + let cursor_point = position.to_point(&snapshot); - if canceled.is_some() - && let Ok(Some(new_completion)) = &completion - { - this.zeta.update(cx, |zeta, cx| { - zeta.discard_completion(new_completion.completion.id, false, cx); - }); - return true; + let debug_info = true; + EditPredictionContext::gather_context( + cursor_point, + &snapshot, + parent_abs_path.as_deref(), + match &options.context { + ContextMode::Agentic(_) => { + // TODO + panic!("Llm mode not supported in zeta cli yet"); } - - cx.notify(); - false - }) - .ok() - .unwrap_or(true); - - if discarded { - return; - } - - let Some(new_completion) = completion - .context("edit prediction failed") - .log_err() - .flatten() - else { - return; - }; - - this.update(cx, |this, cx| { - if let Some(old_completion) = this.current_completion.as_ref() { - let snapshot = buffer.read(cx).snapshot(); - if new_completion.should_replace_completion(old_completion, &snapshot) { - this.zeta.update(cx, |zeta, cx| { - zeta.completion_shown(&new_completion.completion, cx); - }); - this.take_current_edit_prediction(cx); - this.current_completion = Some(new_completion); + ContextMode::Syntax(edit_prediction_context_options) => { + edit_prediction_context_options } - } else { - this.zeta.update(cx, |zeta, cx| { - zeta.completion_shown(&new_completion.completion, cx); - }); - this.current_completion = Some(new_completion); - } - - cx.notify(); + }, + index_state.as_deref(), + ) + .context("Failed to select excerpt") + .map(|context| { + make_syntax_context_cloud_request( + excerpt_path.into(), + context, + // TODO pass everything + Vec::new(), + false, + Vec::new(), + false, + None, + debug_info, + &worktree_snapshots, + index_state.as_deref(), + Some(options.max_prompt_bytes), + options.prompt_format, + ) }) - .ok(); - }); - - // We always maintain at most two pending completions. When we already - // have two, we replace the newest one. - if self.pending_completions.len() <= 1 { - self.pending_completions.push(PendingCompletion { - id: pending_completion_id, - task, - }); - } else if self.pending_completions.len() == 2 { - if let Some(discarded) = self.pending_completions.pop() { - self.canceled_completions - .insert(discarded.id, discarded.task); - } - self.pending_completions.push(PendingCompletion { - id: pending_completion_id, - task, - }); - } + }) } - fn cycle( + pub fn wait_for_initial_indexing( &mut self, - _buffer: Entity, - _cursor_position: language::Anchor, - _direction: edit_prediction::Direction, - _cx: &mut Context, - ) { - // Right now we don't support cycling. + project: &Entity, + cx: &mut Context, + ) -> Task> { + let zeta_project = self.get_or_init_zeta_project(project, cx); + if let Some(syntax_index) = &zeta_project.syntax_index { + syntax_index.read(cx).wait_for_initial_file_indexing(cx) + } else { + Task::ready(Ok(())) + } } - fn accept(&mut self, cx: &mut Context) { - let completion = self.current_completion.as_mut(); - if let Some(completion) = completion { - completion.was_accepted = true; - self.zeta - .update(cx, |zeta, cx| { - zeta.accept_edit_prediction(completion.completion.id, cx) - }) - .detach(); + fn is_file_open_source( + &self, + project: &Entity, + file: &Arc, + cx: &App, + ) -> bool { + if !file.is_local() || file.is_private() { + return false; } - self.pending_completions.clear(); + let Some(zeta_project) = self.projects.get(&project.entity_id()) else { + return false; + }; + zeta_project + .license_detection_watchers + .get(&file.worktree_id(cx)) + .as_ref() + .is_some_and(|watcher| watcher.is_project_open_source()) } - fn discard(&mut self, cx: &mut Context) { - self.pending_completions.clear(); - self.take_current_edit_prediction(cx); + fn can_collect_file(&self, project: &Entity, file: &Arc, cx: &App) -> bool { + self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx) } - fn did_show(&mut self, _cx: &mut Context) { - if let Some(current_completion) = self.current_completion.as_mut() { - current_completion.was_shown = true; + fn can_collect_events(&self, events: &[Arc]) -> bool { + if !self.data_collection_choice.is_enabled() { + return false; } + events.iter().all(|event| { + matches!( + event.as_ref(), + Event::BufferChange { + in_open_source_repo: true, + .. + } + ) + }) } - fn suggest( - &mut self, - buffer: &Entity, - cursor_position: language::Anchor, - cx: &mut Context, - ) -> Option { - let CurrentEditPrediction { - buffer_id, - completion, - .. - } = self.current_completion.as_mut()?; - - // Invalidate previous completion if it was generated for a different buffer. - if *buffer_id != buffer.entity_id() { - self.take_current_edit_prediction(cx); - return None; - } - - let buffer = buffer.read(cx); - let Some(edits) = completion.interpolate(&buffer.snapshot()) else { - self.take_current_edit_prediction(cx); - return None; - }; - - let cursor_row = cursor_position.to_point(buffer).row; - let (closest_edit_ix, (closest_edit_range, _)) = - edits.iter().enumerate().min_by_key(|(_, (range, _))| { - let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row); - let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row); - cmp::min(distance_from_start, distance_from_end) - })?; + fn load_data_collection_choice() -> DataCollectionChoice { + let choice = KEY_VALUE_STORE + .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) + .log_err() + .flatten(); - let mut edit_start_ix = closest_edit_ix; - for (range, _) in edits[..edit_start_ix].iter().rev() { - let distance_from_closest_edit = - closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row; - if distance_from_closest_edit <= 1 { - edit_start_ix -= 1; - } else { - break; + match choice.as_deref() { + Some("true") => DataCollectionChoice::Enabled, + Some("false") => DataCollectionChoice::Disabled, + Some(_) => { + log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'"); + DataCollectionChoice::NotAnswered } + None => DataCollectionChoice::NotAnswered, } + } - let mut edit_end_ix = closest_edit_ix + 1; - for (range, _) in &edits[edit_end_ix..] { - let distance_from_closest_edit = - range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row; - if distance_from_closest_edit <= 1 { - edit_end_ix += 1; - } else { - break; - } - } + pub fn shown_predictions(&self) -> impl DoubleEndedIterator { + self.shown_predictions.iter() + } - Some(edit_prediction::EditPrediction::Local { - id: Some(completion.id.to_string().into()), - edits: edits[edit_start_ix..edit_end_ix].to_vec(), - edit_preview: Some(completion.edit_preview.clone()), - }) + pub fn shown_completions_len(&self) -> usize { + self.shown_predictions.len() } -} -/// Typical number of string bytes per token for the purposes of limiting model input. This is -/// intentionally low to err on the side of underestimating limits. -const BYTES_PER_TOKEN_GUESS: usize = 3; + pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool { + self.rated_predictions.contains(id) + } -fn guess_token_count(bytes: usize) -> usize { - bytes / BYTES_PER_TOKEN_GUESS + pub fn rate_prediction( + &mut self, + prediction: &EditPrediction, + rating: EditPredictionRating, + feedback: String, + cx: &mut Context, + ) { + self.rated_predictions.insert(prediction.id.clone()); + telemetry::event!( + "Edit Prediction Rated", + rating, + inputs = prediction.inputs, + output = prediction.edit_preview.as_unified_diff(&prediction.edits), + feedback + ); + self.client.telemetry().flush_events().detach(); + cx.notify(); + } } -#[cfg(test)] -mod tests { - use client::test::FakeServer; - use clock::{FakeSystemClock, ReplicaId}; - use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; - use gpui::TestAppContext; - use http_client::FakeHttpClient; - use indoc::indoc; - use language::Point; - use parking_lot::Mutex; - use serde_json::json; - use settings::SettingsStore; - use util::{path, rel_path::rel_path}; - - use super::*; - - const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt"); +pub fn text_from_response(mut res: open_ai::Response) -> Option { + let choice = res.choices.pop()?; + let output_text = match choice.message { + open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Plain(content)), + .. + } => content, + open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Multipart(mut content)), + .. + } => { + if content.is_empty() { + log::error!("No output from Baseten completion response"); + return None; + } - #[gpui::test] - async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { - let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); - let edits: Arc<[(Range, Arc)]> = cx.update(|cx| { - to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into() - }); + match content.remove(0) { + open_ai::MessagePart::Text { text } => text, + open_ai::MessagePart::Image { .. } => { + log::error!("Expected text, got an image"); + return None; + } + } + } + _ => { + log::error!("Invalid response message: {:?}", choice.message); + return None; + } + }; + Some(output_text) +} - let edit_preview = cx - .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx)) - .await; +#[derive(Error, Debug)] +#[error( + "You must update to Zed version {minimum_version} or higher to continue using edit predictions." +)] +pub struct ZedUpdateRequiredError { + minimum_version: Version, +} - let completion = EditPrediction { - edits, - edit_preview, - path: Path::new("").into(), - snapshot: cx.read(|cx| buffer.read(cx).snapshot()), - id: EditPredictionId(Uuid::new_v4()), - excerpt_range: 0..0, - cursor_offset: 0, - input_outline: "".into(), - input_events: "".into(), - input_excerpt: "".into(), - output_excerpt: "".into(), - buffer_snapshotted_at: Instant::now(), - response_received_at: Instant::now(), +fn make_syntax_context_cloud_request( + excerpt_path: Arc, + context: EditPredictionContext, + events: Vec>, + can_collect_data: bool, + diagnostic_groups: Vec, + diagnostic_groups_truncated: bool, + git_info: Option, + debug_info: bool, + worktrees: &Vec, + index_state: Option<&SyntaxIndexState>, + prompt_max_bytes: Option, + prompt_format: PromptFormat, +) -> predict_edits_v3::PredictEditsRequest { + let mut signatures = Vec::new(); + let mut declaration_to_signature_index = HashMap::default(); + let mut referenced_declarations = Vec::new(); + + for snippet in context.declarations { + let project_entry_id = snippet.declaration.project_entry_id(); + let Some(path) = worktrees.iter().find_map(|worktree| { + worktree.entry_for_id(project_entry_id).map(|entry| { + let mut full_path = RelPathBuf::new(); + full_path.push(worktree.root_name()); + full_path.push(&entry.path); + full_path + }) + }) else { + continue; }; - cx.update(|cx| { - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(2..5, "REM".into()), (9..11, "".into())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(2..2, "REM".into()), (6..8, "".into())] - ); - - buffer.update(cx, |buffer, cx| buffer.undo(cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(2..5, "REM".into()), (9..11, "".into())] - ); + let parent_index = index_state.and_then(|index_state| { + snippet.declaration.parent().and_then(|parent| { + add_signature( + parent, + &mut declaration_to_signature_index, + &mut signatures, + index_state, + ) + }) + }); - buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(3..3, "EM".into()), (7..9, "".into())] - ); + let (text, text_is_truncated) = snippet.declaration.item_text(); + referenced_declarations.push(predict_edits_v3::ReferencedDeclaration { + path: path.as_std_path().into(), + text: text.into(), + range: snippet.declaration.item_line_range(), + text_is_truncated, + signature_range: snippet.declaration.signature_range_in_item_text(), + parent_index, + signature_score: snippet.score(DeclarationStyle::Signature), + declaration_score: snippet.score(DeclarationStyle::Declaration), + score_components: snippet.components, + }); + } - buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(4..4, "M".into()), (8..10, "".into())] - ); + let excerpt_parent = index_state.and_then(|index_state| { + context + .excerpt + .parent_declarations + .last() + .and_then(|(parent, _)| { + add_signature( + *parent, + &mut declaration_to_signature_index, + &mut signatures, + index_state, + ) + }) + }); + + predict_edits_v3::PredictEditsRequest { + excerpt_path, + excerpt: context.excerpt_text.body, + excerpt_line_range: context.excerpt.line_range, + excerpt_range: context.excerpt.range, + cursor_point: predict_edits_v3::Point { + line: predict_edits_v3::Line(context.cursor_point.row), + column: context.cursor_point.column, + }, + referenced_declarations, + included_files: vec![], + signatures, + excerpt_parent, + events, + can_collect_data, + diagnostic_groups, + diagnostic_groups_truncated, + git_info, + debug_info, + prompt_max_bytes, + prompt_format, + } +} - buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(9..11, "".into())] - ); +fn add_signature( + declaration_id: DeclarationId, + declaration_to_signature_index: &mut HashMap, + signatures: &mut Vec, + index: &SyntaxIndexState, +) -> Option { + if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) { + return Some(*signature_index); + } + let Some(parent_declaration) = index.declaration(declaration_id) else { + log::error!("bug: missing parent declaration"); + return None; + }; + let parent_index = parent_declaration.parent().and_then(|parent| { + add_signature(parent, declaration_to_signature_index, signatures, index) + }); + let (text, text_is_truncated) = parent_declaration.signature_text(); + let signature_index = signatures.len(); + signatures.push(Signature { + text: text.into(), + text_is_truncated, + parent_index, + range: parent_declaration.signature_line_range(), + }); + declaration_to_signature_index.insert(declaration_id, signature_index); + Some(signature_index) +} - buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(4..4, "M".into()), (8..10, "".into())] - ); +#[cfg(feature = "eval-support")] +pub type EvalCacheKey = (EvalCacheEntryKind, u64); - buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx)); - assert_eq!( - from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(4..4, "M".into())] - ); +#[cfg(feature = "eval-support")] +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum EvalCacheEntryKind { + Context, + Search, + Prediction, +} - buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx)); - assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None); - }) +#[cfg(feature = "eval-support")] +impl std::fmt::Display for EvalCacheEntryKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EvalCacheEntryKind::Search => write!(f, "search"), + EvalCacheEntryKind::Context => write!(f, "context"), + EvalCacheEntryKind::Prediction => write!(f, "prediction"), + } } +} - #[gpui::test] - async fn test_clean_up_diff(cx: &mut TestAppContext) { - init_test(cx); - - assert_eq!( - apply_edit_prediction( - indoc! {" - fn main() { - let word_1 = \"lorem\"; - let range = word.len()..word.len(); - } - "}, - indoc! {" - <|editable_region_start|> - fn main() { - let word_1 = \"lorem\"; - let range = word_1.len()..word_1.len(); - } +#[cfg(feature = "eval-support")] +pub trait EvalCache: Send + Sync { + fn read(&self, key: EvalCacheKey) -> Option; + fn write(&self, key: EvalCacheKey, input: &str, value: &str); +} - <|editable_region_end|> - "}, - cx, - ) - .await, - indoc! {" - fn main() { - let word_1 = \"lorem\"; - let range = word_1.len()..word_1.len(); - } - "}, - ); +#[derive(Debug, Clone, Copy)] +pub enum DataCollectionChoice { + NotAnswered, + Enabled, + Disabled, +} - assert_eq!( - apply_edit_prediction( - indoc! {" - fn main() { - let story = \"the quick\" - } - "}, - indoc! {" - <|editable_region_start|> - fn main() { - let story = \"the quick brown fox jumps over the lazy dog\"; - } +impl DataCollectionChoice { + pub fn is_enabled(self) -> bool { + match self { + Self::Enabled => true, + Self::NotAnswered | Self::Disabled => false, + } + } - <|editable_region_end|> - "}, - cx, - ) - .await, - indoc! {" - fn main() { - let story = \"the quick brown fox jumps over the lazy dog\"; - } - "}, - ); + pub fn is_answered(self) -> bool { + match self { + Self::Enabled | Self::Disabled => true, + Self::NotAnswered => false, + } } - #[gpui::test] - async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { - init_test(cx); - - let buffer_content = "lorem\n"; - let completion_response = indoc! {" - ```animals.js - <|start_of_file|> - <|editable_region_start|> - lorem - ipsum - <|editable_region_end|> - ```"}; + #[must_use] + pub fn toggle(&self) -> DataCollectionChoice { + match self { + Self::Enabled => Self::Disabled, + Self::Disabled => Self::Enabled, + Self::NotAnswered => Self::Enabled, + } + } +} - assert_eq!( - apply_edit_prediction(buffer_content, completion_response, cx).await, - "lorem\nipsum" - ); +impl From for DataCollectionChoice { + fn from(value: bool) -> Self { + match value { + true => DataCollectionChoice::Enabled, + false => DataCollectionChoice::Disabled, + } } +} - #[gpui::test] - async fn test_can_collect_data(cx: &mut TestAppContext) { - init_test(cx); +struct ZedPredictUpsell; - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT })) - .await; +impl Dismissable for ZedPredictUpsell { + const KEY: &'static str = "dismissed-edit-predict-upsell"; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/project/src/main.rs"), cx) - }) - .await - .unwrap(); + fn dismissed() -> bool { + // To make this backwards compatible with older versions of Zed, we + // check if the user has seen the previous Edit Prediction Onboarding + // before, by checking the data collection choice which was written to + // the database once the user clicked on "Accept and Enable" + if KEY_VALUE_STORE + .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) + .log_err() + .is_some_and(|s| s.is_some()) + { + return true; + } - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled - }); + KEY_VALUE_STORE + .read_kvp(Self::KEY) + .log_err() + .is_some_and(|s| s.is_some()) + } +} - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); +pub fn should_show_upsell_modal() -> bool { + !ZedPredictUpsell::dismissed() +} + +pub fn init(cx: &mut App) { + feature_gate_predict_edits_actions(cx); - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Disabled + cx.observe_new(move |workspace: &mut Workspace, _, _cx| { + workspace.register_action(|workspace, _: &RateCompletions, window, cx| { + if cx.has_flag::() { + RatePredictionsModal::toggle(workspace, window, cx); + } }); - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false + workspace.register_action( + move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| { + ZedPredictModal::toggle( + workspace, + workspace.user_store().clone(), + workspace.client().clone(), + window, + cx, + ) + }, ); - } - #[gpui::test] - async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [], cx).await; - - let buffer = cx.new(|_cx| { - Buffer::remote( - language::BufferId::new(1).unwrap(), - ReplicaId::new(1), - language::Capability::ReadWrite, - "fn main() {\n println!(\"Hello\");\n}", - ) + workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| { + update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| { + settings + .project + .all_languages + .features + .get_or_insert_default() + .edit_prediction_provider = Some(EditPredictionProvider::None) + }); }); + }) + .detach(); +} - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled +fn feature_gate_predict_edits_actions(cx: &mut App) { + let rate_completion_action_types = [TypeId::of::()]; + let reset_onboarding_action_types = [TypeId::of::()]; + let zeta_all_action_types = [ + TypeId::of::(), + TypeId::of::(), + zed_actions::OpenZedPredictOnboarding.type_id(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + ]; + + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.hide_action_types(&rate_completion_action_types); + filter.hide_action_types(&reset_onboarding_action_types); + filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]); + }); + + cx.observe_global::(move |cx| { + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + let has_feature_flag = cx.has_flag::(); + + CommandPaletteFilter::update_global(cx, |filter, _cx| { + if is_ai_disabled { + filter.hide_action_types(&zeta_all_action_types); + } else if has_feature_flag { + filter.show_action_types(&rate_completion_action_types); + } else { + filter.hide_action_types(&rate_completion_action_types); + } }); + }) + .detach(); - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); - } + cx.observe_flag::(move |is_enabled, cx| { + if !DisableAiSettings::get_global(cx).disable_ai { + if is_enabled { + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.show_action_types(&rate_completion_action_types); + }); + } else { + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.hide_action_types(&rate_completion_action_types); + }); + } + } + }) + .detach(); +} - #[gpui::test] - async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) { - init_test(cx); +#[cfg(test)] +mod tests { + use std::{path::Path, sync::Arc}; + + use client::UserStore; + use clock::FakeSystemClock; + use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; + use futures::{ + AsyncReadExt, StreamExt, + channel::{mpsc, oneshot}, + }; + use gpui::{ + Entity, TestAppContext, + http_client::{FakeHttpClient, Response}, + prelude::*, + }; + use indoc::indoc; + use language::OffsetRangeExt as _; + use open_ai::Usage; + use pretty_assertions::{assert_eq, assert_matches}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + use uuid::Uuid; - let fs = project::FakeFs::new(cx.executor()); + use crate::{BufferEditPrediction, Zeta}; + + #[gpui::test] + async fn test_current_state(cx: &mut TestAppContext) { + let (zeta, mut req_rx) = init_test(cx); + let fs = FakeFs::new(cx.executor()); fs.insert_tree( - path!("/project"), + "/root", json!({ - "LICENSE": BSD_0_TXT, - ".env": "SECRET_KEY=secret" + "1.txt": "Hello!\nHow\nBye\n", + "2.txt": "Hola!\nComo\nAdios\n" }), ) .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = project + zeta.update(cx, |zeta, cx| { + zeta.register_project(&project, cx); + }); + + let buffer1 = project .update(cx, |project, cx| { - project.open_local_buffer("/project/.env", cx) + let path = project.find_project_path(path!("root/1.txt"), cx).unwrap(); + project.open_buffer(path, cx) }) .await .unwrap(); + let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot1.anchor_before(language::Point::new(1, 3)); - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled - }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); - } + // Prediction for current file - #[gpui::test] - async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) { - init_test(cx); + zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) + }); + let (_request, respond_tx) = req_rx.next().await.unwrap(); + + respond_tx + .send(model_response(indoc! {r" + --- a/root/1.txt + +++ b/root/1.txt + @@ ... @@ + Hello! + -How + +How are you? + Bye + "})) + .unwrap(); - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [], cx).await; - let buffer = cx.new(|cx| Buffer::local("", cx)); + cx.run_until_parked(); - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled + zeta.read_with(cx, |zeta, cx| { + let prediction = zeta + .current_prediction_for_buffer(&buffer1, &project, cx) + .unwrap(); + assert_matches!(prediction, BufferEditPrediction::Local { .. }); }); - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); - } + // Context refresh + let refresh_task = zeta.update(cx, |zeta, cx| { + zeta.refresh_context(project.clone(), buffer1.clone(), position, cx) + }); + let (_request, respond_tx) = req_rx.next().await.unwrap(); + respond_tx + .send(open_ai::Response { + id: Uuid::new_v4().to_string(), + object: "response".into(), + created: 0, + model: "model".into(), + choices: vec![open_ai::Choice { + index: 0, + message: open_ai::RequestMessage::Assistant { + content: None, + tool_calls: vec![open_ai::ToolCall { + id: "search".into(), + content: open_ai::ToolCallContent::Function { + function: open_ai::FunctionContent { + name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME + .to_string(), + arguments: serde_json::to_string(&SearchToolInput { + queries: Box::new([SearchToolQuery { + glob: "root/2.txt".to_string(), + syntax_node: vec![], + content: Some(".".into()), + }]), + }) + .unwrap(), + }, + }, + }], + }, + finish_reason: None, + }], + usage: Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + }) + .unwrap(); + refresh_task.await.unwrap(); - #[gpui::test] - async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) { - init_test(cx); + zeta.update(cx, |zeta, cx| { + zeta.discard_current_prediction(&project, cx); + }); - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" })) - .await; + // Prediction for another file + zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) + }); + let (_request, respond_tx) = req_rx.next().await.unwrap(); + respond_tx + .send(model_response(indoc! {r#" + --- a/root/2.txt + +++ b/root/2.txt + Hola! + -Como + +Como estas? + Adios + "#})) + .unwrap(); + cx.run_until_parked(); + + zeta.read_with(cx, |zeta, cx| { + let prediction = zeta + .current_prediction_for_buffer(&buffer1, &project, cx) + .unwrap(); + assert_matches!( + prediction, + BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt")) + ); + }); - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = project + let buffer2 = project .update(cx, |project, cx| { - project.open_local_buffer("/project/main.rs", cx) + let path = project.find_project_path(path!("root/2.txt"), cx).unwrap(); + project.open_buffer(path, cx) }) .await .unwrap(); - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled + zeta.read_with(cx, |zeta, cx| { + let prediction = zeta + .current_prediction_for_buffer(&buffer2, &project, cx) + .unwrap(); + assert_matches!(prediction, BufferEditPrediction::Local { .. }); }); - - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); } #[gpui::test] - async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); + async fn test_simple_request(cx: &mut TestAppContext) { + let (zeta, mut req_rx) = init_test(cx); + let fs = FakeFs::new(cx.executor()); fs.insert_tree( - path!("/open_source_worktree"), - json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }), + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), ) .await; - fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" })) - .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - let project = Project::test( - fs.clone(), - [ - path!("/open_source_worktree").as_ref(), - path!("/closed_source_worktree").as_ref(), - ], - cx, - ) - .await; let buffer = project .update(cx, |project, cx| { - project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx) + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) }) .await .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled + let prediction_task = zeta.update(cx, |zeta, cx| { + zeta.request_prediction(&project, &buffer, position, cx) }); - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); - - let closed_source_file = project - .update(cx, |project, cx| { - let worktree2 = project - .worktree_for_root_name("closed_source_worktree", cx) - .unwrap(); - worktree2.update(cx, |worktree2, cx| { - worktree2.load_file(rel_path("main.rs"), cx) - }) - }) - .await - .unwrap() - .file; + let (_, respond_tx) = req_rx.next().await.unwrap(); + + // TODO Put back when we have a structured request again + // assert_eq!( + // request.excerpt_path.as_ref(), + // Path::new(path!("root/foo.md")) + // ); + // assert_eq!( + // request.cursor_point, + // Point { + // line: Line(1), + // column: 3 + // } + // ); + + respond_tx + .send(model_response(indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye + "})) + .unwrap(); - buffer.update(cx, |buffer, cx| { - buffer.file_updated(closed_source_file, cx); - }); + let prediction = prediction_task.await.unwrap().unwrap(); - run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!(prediction.edits.len(), 1); assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false + prediction.edits[0].0.to_point(&snapshot).start, + language::Point::new(1, 3) ); + assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); } #[gpui::test] - async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); + async fn test_request_events(cx: &mut TestAppContext) { + let (zeta, mut req_rx) = init_test(cx); + let fs = FakeFs::new(cx.executor()); fs.insert_tree( - path!("/worktree1"), - json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }), + "/root", + json!({ + "foo.md": "Hello!\n\nBye\n" + }), ) .await; - fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" })) - .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - let project = Project::test( - fs.clone(), - [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], - cx, - ) - .await; let buffer = project .update(cx, |project, cx| { - project.open_local_buffer(path!("/worktree1/main.rs"), cx) - }) - .await - .unwrap(); - let private_buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/worktree2/file.rs"), cx) + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) }) .await .unwrap(); - let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; - zeta.update(cx, |zeta, _cx| { - zeta.data_collection_choice = DataCollectionChoice::Enabled + zeta.update(cx, |zeta, cx| { + zeta.register_buffer(&buffer, &project, cx); }); - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true - ); + buffer.update(cx, |buffer, cx| { + buffer.edit(vec![(7..7, "How")], None, cx); + }); - // this has a side effect of registering the buffer to watch for edits - run_edit_prediction(&private_buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false - ); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); - private_buffer.update(cx, |private_buffer, cx| { - private_buffer.edit([(0..0, "An edit for the history!")], None, cx); + let prediction_task = zeta.update(cx, |zeta, cx| { + zeta.request_prediction(&project, &buffer, position, cx) }); - run_edit_prediction(&buffer, &project, &zeta, cx).await; - assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - false + let (request, respond_tx) = req_rx.next().await.unwrap(); + + let prompt = prompt_from_request(&request); + assert!( + prompt.contains(indoc! {" + --- a/root/foo.md + +++ b/root/foo.md + @@ -1,3 +1,3 @@ + Hello! + - + +How + Bye + "}), + "{prompt}" ); - // make an edit that uses too many bytes, causing private_buffer edit to not be able to be - // included - buffer.update(cx, |buffer, cx| { - buffer.edit( - [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))], - None, - cx, - ); - }); + respond_tx + .send(model_response(indoc! {r#" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye + "#})) + .unwrap(); + + let prediction = prediction_task.await.unwrap().unwrap(); - run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!(prediction.edits.len(), 1); assert_eq!( - captured_request.lock().clone().unwrap().can_collect_data, - true + prediction.edits[0].0.to_point(&snapshot).start, + language::Point::new(1, 3) ); + assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); + } + + // Skipped until we start including diagnostics in prompt + // #[gpui::test] + // async fn test_request_diagnostics(cx: &mut TestAppContext) { + // let (zeta, mut req_rx) = init_test(cx); + // let fs = FakeFs::new(cx.executor()); + // fs.insert_tree( + // "/root", + // json!({ + // "foo.md": "Hello!\nBye" + // }), + // ) + // .await; + // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap(); + // let diagnostic = lsp::Diagnostic { + // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), + // severity: Some(lsp::DiagnosticSeverity::ERROR), + // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(), + // ..Default::default() + // }; + + // project.update(cx, |project, cx| { + // project.lsp_store().update(cx, |lsp_store, cx| { + // // Create some diagnostics + // lsp_store + // .update_diagnostics( + // LanguageServerId(0), + // lsp::PublishDiagnosticsParams { + // uri: path_to_buffer_uri.clone(), + // diagnostics: vec![diagnostic], + // version: None, + // }, + // None, + // language::DiagnosticSourceKind::Pushed, + // &[], + // cx, + // ) + // .unwrap(); + // }); + // }); + + // let buffer = project + // .update(cx, |project, cx| { + // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + // project.open_buffer(path, cx) + // }) + // .await + // .unwrap(); + + // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + // let position = snapshot.anchor_before(language::Point::new(0, 0)); + + // let _prediction_task = zeta.update(cx, |zeta, cx| { + // zeta.request_prediction(&project, &buffer, position, cx) + // }); + + // let (request, _respond_tx) = req_rx.next().await.unwrap(); + + // assert_eq!(request.diagnostic_groups.len(), 1); + // let value = serde_json::from_str::(request.diagnostic_groups[0].0.get()) + // .unwrap(); + // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3 + // assert_eq!( + // value, + // json!({ + // "entries": [{ + // "range": { + // "start": 8, + // "end": 10 + // }, + // "diagnostic": { + // "source": null, + // "code": null, + // "code_description": null, + // "severity": 1, + // "message": "\"Hello\" deprecated. Use \"Hi\" instead", + // "markdown": null, + // "group_id": 0, + // "is_primary": true, + // "is_disk_based": false, + // "is_unnecessary": false, + // "source_kind": "Pushed", + // "data": null, + // "underline": true + // } + // }], + // "primary_ix": 0 + // }) + // ); + // } + + fn model_response(text: &str) -> open_ai::Response { + open_ai::Response { + id: Uuid::new_v4().to_string(), + object: "response".into(), + created: 0, + model: "model".into(), + choices: vec![open_ai::Choice { + index: 0, + message: open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Plain(text.to_string())), + tool_calls: vec![], + }, + finish_reason: None, + }], + usage: Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + } } - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - }); - } - - async fn apply_edit_prediction( - buffer_content: &str, - completion_response: &str, - cx: &mut TestAppContext, - ) -> String { - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); - let (zeta, _, response) = make_test_zeta(&project, cx).await; - *response.lock() = completion_response.to_string(); - let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await; - buffer.update(cx, |buffer, cx| { - buffer.edit(edit_prediction.edits.iter().cloned(), None, cx) - }); - buffer.read_with(cx, |buffer, _| buffer.text()) - } - - async fn run_edit_prediction( - buffer: &Entity, - project: &Entity, - zeta: &Entity, - cx: &mut TestAppContext, - ) -> EditPrediction { - let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); - zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx)); - cx.background_executor.run_until_parked(); - let completion_task = zeta.update(cx, |zeta, cx| { - zeta.request_completion(&project, buffer, cursor, cx) - }); - completion_task.await.unwrap().unwrap() + fn prompt_from_request(request: &open_ai::Request) -> &str { + assert_eq!(request.messages.len(), 1); + let open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(content), + .. + } = &request.messages[0] + else { + panic!( + "Request does not have single user message of type Plain. {:#?}", + request + ); + }; + content } - async fn make_test_zeta( - project: &Entity, + fn init_test( cx: &mut TestAppContext, ) -> ( Entity, - Arc>>, - Arc>, + mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender)>, ) { - let default_response = indoc! {" - ```main.rs - <|start_of_file|> - <|editable_region_start|> - hello world - <|editable_region_end|> - ```" - }; - let captured_request: Arc>> = Arc::new(Mutex::new(None)); - let completion_response: Arc> = - Arc::new(Mutex::new(default_response.to_string())); - let http_client = FakeHttpClient::create({ - let captured_request = captured_request.clone(); - let completion_response = completion_response.clone(); - move |req| { - let captured_request = captured_request.clone(); - let completion_response = completion_response.clone(); - async move { - match (req.method(), req.uri().path()) { - (&Method::POST, "/client/llm_tokens") => { - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&CreateLlmTokenResponse { - token: LlmToken("the-llm-token".to_string()), - }) - .unwrap() - .into(), - ) - .unwrap()) - } - (&Method::POST, "/predict_edits/v2") => { - let mut request_body = String::new(); - req.into_body().read_to_string(&mut request_body).await?; - *captured_request.lock() = - Some(serde_json::from_str(&request_body).unwrap()); - Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::new_v4().to_string(), - output_excerpt: completion_response.lock().clone(), - }) - .unwrap() - .into(), - ) - .unwrap()) - } - _ => Ok(http_client::Response::builder() - .status(404) - .body("Not Found".into()) - .unwrap()), + cx.update(move |cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + zlog::init_test(); + + let (req_tx, req_rx) = mpsc::unbounded(); + + let http_client = FakeHttpClient::create({ + move |req| { + let uri = req.uri().path().to_string(); + let mut body = req.into_body(); + let req_tx = req_tx.clone(); + async move { + let resp = match uri.as_str() { + "/client/llm_tokens" => serde_json::to_string(&json!({ + "token": "test" + })) + .unwrap(), + "/predict_edits/raw" => { + let mut buf = Vec::new(); + body.read_to_end(&mut buf).await.ok(); + let req = serde_json::from_slice(&buf).unwrap(); + + let (res_tx, res_rx) = oneshot::channel(); + req_tx.unbounded_send((req, res_tx)).unwrap(); + serde_json::to_string(&res_rx.await?).unwrap() + } + _ => { + panic!("Unexpected path: {}", uri) + } + }; + + Ok(Response::builder().body(resp.into()).unwrap()) } } - } - }); - - let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); - cx.update(|cx| { - RefreshLlmTokenListener::register(client.clone(), cx); - }); - let _server = FakeServer::for_client(42, &client, cx).await; - - let zeta = cx.new(|cx| { - let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx); - - let worktrees = project.read(cx).worktrees(cx).collect::>(); - for worktree in worktrees { - let worktree_id = worktree.read(cx).id(); - zeta.license_detection_watchers - .entry(worktree_id) - .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx))); - } - - zeta - }); + }); - (zeta, captured_request, completion_response) - } + let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx); + client.cloud_client().set_credentials(1, "test".into()); - fn to_completion_edits( - iterator: impl IntoIterator, Arc)>, - buffer: &Entity, - cx: &App, - ) -> Vec<(Range, Arc)> { - let buffer = buffer.read(cx); - iterator - .into_iter() - .map(|(range, text)| { - ( - buffer.anchor_after(range.start)..buffer.anchor_before(range.end), - text, - ) - }) - .collect() - } + language_model::init(client.clone(), cx); - fn from_completion_edits( - editor_edits: &[(Range, Arc)], - buffer: &Entity, - cx: &App, - ) -> Vec<(Range, Arc)> { - let buffer = buffer.read(cx); - editor_edits - .iter() - .map(|(range, text)| { - ( - range.start.to_offset(buffer)..range.end.to_offset(buffer), - text.clone(), - ) - }) - .collect() - } + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + let zeta = Zeta::global(&client, &user_store, cx); - #[ctor::ctor] - fn init_logger() { - zlog::init_test(); + (zeta, req_rx) + }) } } diff --git a/crates/zeta/src/zeta1.rs b/crates/zeta/src/zeta1.rs new file mode 100644 index 0000000000000000000000000000000000000000..5a779cabeceac0bcb58340f7bbb98175409916e8 --- /dev/null +++ b/crates/zeta/src/zeta1.rs @@ -0,0 +1,500 @@ +mod input_excerpt; + +use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant}; + +use crate::{ + EditPredictionId, ZedUpdateRequiredError, Zeta, + prediction::{EditPrediction, EditPredictionInputs}, +}; +use anyhow::{Context as _, Result}; +use cloud_llm_client::{ + PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, predict_edits_v3::Event, +}; +use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task}; +use input_excerpt::excerpt_for_cursor_position; +use language::{ + Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff, +}; +use project::{Project, ProjectPath}; +use release_channel::AppVersion; +use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; + +const CURSOR_MARKER: &str = "<|user_cursor_is_here|>"; +const START_OF_FILE_MARKER: &str = "<|start_of_file|>"; +const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>"; +const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>"; + +pub(crate) const MAX_CONTEXT_TOKENS: usize = 150; +pub(crate) const MAX_REWRITE_TOKENS: usize = 350; +pub(crate) const MAX_EVENT_TOKENS: usize = 500; + +pub(crate) fn request_prediction_with_zeta1( + zeta: &mut Zeta, + project: &Entity, + buffer: &Entity, + position: language::Anchor, + cx: &mut Context, +) -> Task>> { + let buffer = buffer.clone(); + let buffer_snapshotted_at = Instant::now(); + let snapshot = buffer.read(cx).snapshot(); + let client = zeta.client.clone(); + let llm_token = zeta.llm_token.clone(); + let app_version = AppVersion::global(cx); + + let zeta_project = zeta.get_or_init_zeta_project(project, cx); + let events = Arc::new(zeta_project.events(cx)); + + let (git_info, can_collect_file) = if let Some(file) = snapshot.file() { + let can_collect_file = zeta.can_collect_file(project, file, cx); + let git_info = if can_collect_file { + git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) + } else { + None + }; + (git_info, can_collect_file) + } else { + (None, false) + }; + + let full_path: Arc = snapshot + .file() + .map(|f| Arc::from(f.full_path(cx).as_path())) + .unwrap_or_else(|| Arc::from(Path::new("untitled"))); + let full_path_str = full_path.to_string_lossy().into_owned(); + let cursor_point = position.to_point(&snapshot); + let prompt_for_events = { + let events = events.clone(); + move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS) + }; + let gather_task = gather_context( + full_path_str, + &snapshot, + cursor_point, + prompt_for_events, + cx, + ); + + cx.spawn(async move |this, cx| { + let GatherContextOutput { + mut body, + context_range, + editable_range, + included_events_count, + } = gather_task.await?; + let done_gathering_context_at = Instant::now(); + + let included_events = &events[events.len() - included_events_count..events.len()]; + body.can_collect_data = can_collect_file + && this + .read_with(cx, |this, _| this.can_collect_events(included_events)) + .unwrap_or(false); + if body.can_collect_data { + body.git_info = git_info; + } + + log::debug!( + "Events:\n{}\nExcerpt:\n{:?}", + body.input_events, + body.input_excerpt + ); + + let http_client = client.http_client(); + + let response = Zeta::send_api_request::( + |request| { + let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") { + predict_edits_url + } else { + http_client + .build_zed_llm_url("/predict_edits/v2", &[])? + .as_str() + .into() + }; + Ok(request + .uri(uri) + .body(serde_json::to_string(&body)?.into())?) + }, + client, + llm_token, + app_version, + ) + .await; + + let inputs = EditPredictionInputs { + events: included_events.into(), + included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile { + path: full_path.clone(), + max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row), + excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt { + start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row), + text: snapshot + .text_for_range(context_range) + .collect::() + .into(), + }], + }], + cursor_point: cloud_llm_client::predict_edits_v3::Point { + column: cursor_point.column, + line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row), + }, + cursor_path: full_path, + }; + + // let response = perform_predict_edits(PerformPredictEditsParams { + // client, + // llm_token, + // app_version, + // body, + // }) + // .await; + + let (response, usage) = match response { + Ok(response) => response, + Err(err) => { + if err.is::() { + cx.update(|cx| { + this.update(cx, |zeta, _cx| { + zeta.update_required = true; + }) + .ok(); + + let error_message: SharedString = err.to_string().into(); + show_app_notification( + NotificationId::unique::(), + cx, + move |cx| { + cx.new(|cx| { + ErrorMessagePrompt::new(error_message.clone(), cx) + .with_link_button("Update Zed", "https://zed.dev/releases") + }) + }, + ); + }) + .ok(); + } + + return Err(err); + } + }; + + let received_response_at = Instant::now(); + log::debug!("completion response: {}", &response.output_excerpt); + + if let Some(usage) = usage { + this.update(cx, |this, cx| { + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); + }) + .ok(); + } + + let edit_prediction = process_completion_response( + response, + buffer, + &snapshot, + editable_range, + inputs, + buffer_snapshotted_at, + received_response_at, + cx, + ) + .await; + + let finished_at = Instant::now(); + + // record latency for ~1% of requests + if rand::random::() <= 2 { + telemetry::event!( + "Edit Prediction Request", + context_latency = done_gathering_context_at + .duration_since(buffer_snapshotted_at) + .as_millis(), + request_latency = received_response_at + .duration_since(done_gathering_context_at) + .as_millis(), + process_latency = finished_at.duration_since(received_response_at).as_millis() + ); + } + + edit_prediction + }) +} + +fn process_completion_response( + prediction_response: PredictEditsResponse, + buffer: Entity, + snapshot: &BufferSnapshot, + editable_range: Range, + inputs: EditPredictionInputs, + buffer_snapshotted_at: Instant, + received_response_at: Instant, + cx: &AsyncApp, +) -> Task>> { + let snapshot = snapshot.clone(); + let request_id = prediction_response.request_id; + let output_excerpt = prediction_response.output_excerpt; + cx.spawn(async move |cx| { + let output_excerpt: Arc = output_excerpt.into(); + + let edits: Arc<[(Range, Arc)]> = cx + .background_spawn({ + let output_excerpt = output_excerpt.clone(); + let editable_range = editable_range.clone(); + let snapshot = snapshot.clone(); + async move { parse_edits(output_excerpt, editable_range, &snapshot) } + }) + .await? + .into(); + + Ok(EditPrediction::new( + EditPredictionId(request_id.into()), + &buffer, + &snapshot, + edits, + buffer_snapshotted_at, + received_response_at, + inputs, + cx, + ) + .await) + }) +} + +fn parse_edits( + output_excerpt: Arc, + editable_range: Range, + snapshot: &BufferSnapshot, +) -> Result, Arc)>> { + let content = output_excerpt.replace(CURSOR_MARKER, ""); + + let start_markers = content + .match_indices(EDITABLE_REGION_START_MARKER) + .collect::>(); + anyhow::ensure!( + start_markers.len() == 1, + "expected exactly one start marker, found {}", + start_markers.len() + ); + + let end_markers = content + .match_indices(EDITABLE_REGION_END_MARKER) + .collect::>(); + anyhow::ensure!( + end_markers.len() == 1, + "expected exactly one end marker, found {}", + end_markers.len() + ); + + let sof_markers = content + .match_indices(START_OF_FILE_MARKER) + .collect::>(); + anyhow::ensure!( + sof_markers.len() <= 1, + "expected at most one start-of-file marker, found {}", + sof_markers.len() + ); + + let codefence_start = start_markers[0].0; + let content = &content[codefence_start..]; + + let newline_ix = content.find('\n').context("could not find newline")?; + let content = &content[newline_ix + 1..]; + + let codefence_end = content + .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}")) + .context("could not find end marker")?; + let new_text = &content[..codefence_end]; + + let old_text = snapshot + .text_for_range(editable_range.clone()) + .collect::(); + + Ok(compute_edits( + old_text, + new_text, + editable_range.start, + snapshot, + )) +} + +pub fn compute_edits( + old_text: String, + new_text: &str, + offset: usize, + snapshot: &BufferSnapshot, +) -> Vec<(Range, Arc)> { + text_diff(&old_text, new_text) + .into_iter() + .map(|(mut old_range, new_text)| { + old_range.start += offset; + old_range.end += offset; + + let prefix_len = common_prefix( + snapshot.chars_for_range(old_range.clone()), + new_text.chars(), + ); + old_range.start += prefix_len; + + let suffix_len = common_prefix( + snapshot.reversed_chars_for_range(old_range.clone()), + new_text[prefix_len..].chars().rev(), + ); + old_range.end = old_range.end.saturating_sub(suffix_len); + + let new_text = new_text[prefix_len..new_text.len() - suffix_len].into(); + let range = if old_range.is_empty() { + let anchor = snapshot.anchor_after(old_range.start); + anchor..anchor + } else { + snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end) + }; + (range, new_text) + }) + .collect() +} + +fn common_prefix, T2: Iterator>(a: T1, b: T2) -> usize { + a.zip(b) + .take_while(|(a, b)| a == b) + .map(|(a, _)| a.len_utf8()) + .sum() +} + +fn git_info_for_file( + project: &Entity, + project_path: &ProjectPath, + cx: &App, +) -> Option { + let git_store = project.read(cx).git_store().read(cx); + if let Some((repository, _repo_path)) = + git_store.repository_and_path_for_project_path(project_path, cx) + { + let repository = repository.read(cx); + let head_sha = repository + .head_commit + .as_ref() + .map(|head_commit| head_commit.sha.to_string()); + let remote_origin_url = repository.remote_origin_url.clone(); + let remote_upstream_url = repository.remote_upstream_url.clone(); + if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { + return None; + } + Some(PredictEditsGitInfo { + head_sha, + remote_origin_url, + remote_upstream_url, + }) + } else { + None + } +} + +pub struct GatherContextOutput { + pub body: PredictEditsBody, + pub context_range: Range, + pub editable_range: Range, + pub included_events_count: usize, +} + +pub fn gather_context( + full_path_str: String, + snapshot: &BufferSnapshot, + cursor_point: language::Point, + prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static, + cx: &App, +) -> Task> { + cx.background_spawn({ + let snapshot = snapshot.clone(); + async move { + let input_excerpt = excerpt_for_cursor_position( + cursor_point, + &full_path_str, + &snapshot, + MAX_REWRITE_TOKENS, + MAX_CONTEXT_TOKENS, + ); + let (input_events, included_events_count) = prompt_for_events(); + let editable_range = input_excerpt.editable_range.to_offset(&snapshot); + + let body = PredictEditsBody { + input_events, + input_excerpt: input_excerpt.prompt, + can_collect_data: false, + diagnostic_groups: None, + git_info: None, + outline: None, + speculated_output: None, + }; + + Ok(GatherContextOutput { + body, + context_range: input_excerpt.context_range, + editable_range, + included_events_count, + }) + } + }) +} + +fn prompt_for_events_impl(events: &[Arc], mut remaining_tokens: usize) -> (String, usize) { + let mut result = String::new(); + for (ix, event) in events.iter().rev().enumerate() { + let event_string = format_event(event.as_ref()); + let event_tokens = guess_token_count(event_string.len()); + if event_tokens > remaining_tokens { + return (result, ix); + } + + if !result.is_empty() { + result.insert_str(0, "\n\n"); + } + result.insert_str(0, &event_string); + remaining_tokens -= event_tokens; + } + return (result, events.len()); +} + +pub fn format_event(event: &Event) -> String { + match event { + Event::BufferChange { + path, + old_path, + diff, + .. + } => { + let mut prompt = String::new(); + + if old_path != path { + writeln!( + prompt, + "User renamed {} to {}\n", + old_path.display(), + path.display() + ) + .unwrap(); + } + + if !diff.is_empty() { + write!( + prompt, + "User edited {}:\n```diff\n{}\n```", + path.display(), + diff + ) + .unwrap(); + } + + prompt + } + } +} + +/// Typical number of string bytes per token for the purposes of limiting model input. This is +/// intentionally low to err on the side of underestimating limits. +pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3; + +fn guess_token_count(bytes: usize) -> usize { + bytes / BYTES_PER_TOKEN_GUESS +} diff --git a/crates/zeta/src/input_excerpt.rs b/crates/zeta/src/zeta1/input_excerpt.rs similarity index 98% rename from crates/zeta/src/input_excerpt.rs rename to crates/zeta/src/zeta1/input_excerpt.rs index 06bff5b1bea0f099b2ccd98605ac5de5bb5e6360..853d74da463c19de4f1d3915cb703a53b6c43c61 100644 --- a/crates/zeta/src/input_excerpt.rs +++ b/crates/zeta/src/zeta1/input_excerpt.rs @@ -1,4 +1,4 @@ -use crate::{ +use super::{ CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER, guess_token_count, }; @@ -7,6 +7,7 @@ use std::{fmt::Write, ops::Range}; #[derive(Debug)] pub struct InputExcerpt { + pub context_range: Range, pub editable_range: Range, pub prompt: String, } @@ -63,6 +64,7 @@ pub fn excerpt_for_cursor_position( write!(prompt, "\n```").unwrap(); InputExcerpt { + context_range, editable_range, prompt, } @@ -124,7 +126,7 @@ mod tests { use super::*; use gpui::{App, AppContext}; use indoc::indoc; - use language::{Buffer, Language, LanguageConfig, LanguageMatcher}; + use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; use std::sync::Arc; #[gpui::test] diff --git a/crates/zeta/src/zeta_tests.rs b/crates/zeta/src/zeta_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..eb12f81af25d72b5e7003187ab0a9536622c9a74 --- /dev/null +++ b/crates/zeta/src/zeta_tests.rs @@ -0,0 +1,671 @@ +use client::test::FakeServer; +use clock::{FakeSystemClock, ReplicaId}; +use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; +use cloud_llm_client::{PredictEditsBody, PredictEditsResponse}; +use gpui::TestAppContext; +use http_client::FakeHttpClient; +use indoc::indoc; +use language::Point; +use parking_lot::Mutex; +use serde_json::json; +use settings::SettingsStore; +use util::{path, rel_path::rel_path}; + +use crate::zeta1::MAX_EVENT_TOKENS; + +use super::*; + +const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt"); + +#[gpui::test] +async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { + let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); + let edits: Arc<[(Range, Arc)]> = cx.update(|cx| { + to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into() + }); + + let edit_preview = cx + .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx)) + .await; + + let completion = EditPrediction { + edits, + edit_preview, + buffer: buffer.clone(), + snapshot: cx.read(|cx| buffer.read(cx).snapshot()), + id: EditPredictionId("the-id".into()), + inputs: EditPredictionInputs { + events: Default::default(), + included_files: Default::default(), + cursor_point: cloud_llm_client::predict_edits_v3::Point { + line: Line(0), + column: 0, + }, + cursor_path: Path::new("").into(), + }, + buffer_snapshotted_at: Instant::now(), + response_received_at: Instant::now(), + }; + + cx.update(|cx| { + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..5, "REM".into()), (9..11, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..2, "REM".into()), (6..8, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.undo(cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..5, "REM".into()), (9..11, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(3..3, "EM".into()), (7..9, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".into()), (8..10, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(9..11, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".into()), (8..10, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx)); + assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None); + }) +} + +#[gpui::test] +async fn test_clean_up_diff(cx: &mut TestAppContext) { + init_test(cx); + + assert_eq!( + apply_edit_prediction( + indoc! {" + fn main() { + let word_1 = \"lorem\"; + let range = word.len()..word.len(); + } + "}, + indoc! {" + <|editable_region_start|> + fn main() { + let word_1 = \"lorem\"; + let range = word_1.len()..word_1.len(); + } + + <|editable_region_end|> + "}, + cx, + ) + .await, + indoc! {" + fn main() { + let word_1 = \"lorem\"; + let range = word_1.len()..word_1.len(); + } + "}, + ); + + assert_eq!( + apply_edit_prediction( + indoc! {" + fn main() { + let story = \"the quick\" + } + "}, + indoc! {" + <|editable_region_start|> + fn main() { + let story = \"the quick brown fox jumps over the lazy dog\"; + } + + <|editable_region_end|> + "}, + cx, + ) + .await, + indoc! {" + fn main() { + let story = \"the quick brown fox jumps over the lazy dog\"; + } + "}, + ); +} + +#[gpui::test] +async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { + init_test(cx); + + let buffer_content = "lorem\n"; + let completion_response = indoc! {" + ```animals.js + <|start_of_file|> + <|editable_region_start|> + lorem + ipsum + <|editable_region_end|> + ```"}; + + assert_eq!( + apply_edit_prediction(buffer_content, completion_response, cx).await, + "lorem\nipsum" + ); +} + +#[gpui::test] +async fn test_can_collect_data(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT })) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/project/src/main.rs"), cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Disabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + + let buffer = cx.new(|_cx| { + Buffer::remote( + language::BufferId::new(1).unwrap(), + ReplicaId::new(1), + language::Capability::ReadWrite, + "fn main() {\n println!(\"Hello\");\n}", + ) + }); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "LICENSE": BSD_0_TXT, + ".env": "SECRET_KEY=secret" + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer("/project/.env", cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + let buffer = cx.new(|cx| Buffer::local("", cx)); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" })) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer("/project/main.rs", cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/open_source_worktree"), + json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }), + ) + .await; + fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" })) + .await; + + let project = Project::test( + fs.clone(), + [ + path!("/open_source_worktree").as_ref(), + path!("/closed_source_worktree").as_ref(), + ], + cx, + ) + .await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + let closed_source_file = project + .update(cx, |project, cx| { + let worktree2 = project + .worktree_for_root_name("closed_source_worktree", cx) + .unwrap(); + worktree2.update(cx, |worktree2, cx| { + worktree2.load_file(rel_path("main.rs"), cx) + }) + }) + .await + .unwrap() + .file; + + buffer.update(cx, |buffer, cx| { + buffer.file_updated(closed_source_file, cx); + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/worktree1"), + json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }), + ) + .await; + fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" })) + .await; + + let project = Project::test( + fs.clone(), + [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], + cx, + ) + .await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/worktree1/main.rs"), cx) + }) + .await + .unwrap(); + let private_buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/worktree2/file.rs"), cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + // this has a side effect of registering the buffer to watch for edits + run_edit_prediction(&private_buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + + private_buffer.update(cx, |private_buffer, cx| { + private_buffer.edit([(0..0, "An edit for the history!")], None, cx); + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + + // make an edit that uses too many bytes, causing private_buffer edit to not be able to be + // included + buffer.update(cx, |buffer, cx| { + buffer.edit( + [( + 0..0, + " ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS), + )], + None, + cx, + ); + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); +} + +fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + }); +} + +async fn apply_edit_prediction( + buffer_content: &str, + completion_response: &str, + cx: &mut TestAppContext, +) -> String { + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); + let (zeta, _, response) = make_test_zeta(&project, cx).await; + *response.lock() = completion_response.to_string(); + let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await; + buffer.update(cx, |buffer, cx| { + buffer.edit(edit_prediction.edits.iter().cloned(), None, cx) + }); + buffer.read_with(cx, |buffer, _| buffer.text()) +} + +async fn run_edit_prediction( + buffer: &Entity, + project: &Entity, + zeta: &Entity, + cx: &mut TestAppContext, +) -> EditPrediction { + let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); + zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx)); + cx.background_executor.run_until_parked(); + let prediction_task = zeta.update(cx, |zeta, cx| { + zeta.request_prediction(&project, buffer, cursor, cx) + }); + prediction_task.await.unwrap().unwrap() +} + +async fn make_test_zeta( + project: &Entity, + cx: &mut TestAppContext, +) -> ( + Entity, + Arc>>, + Arc>, +) { + let default_response = indoc! {" + ```main.rs + <|start_of_file|> + <|editable_region_start|> + hello world + <|editable_region_end|> + ```" + }; + let captured_request: Arc>> = Arc::new(Mutex::new(None)); + let completion_response: Arc> = + Arc::new(Mutex::new(default_response.to_string())); + let http_client = FakeHttpClient::create({ + let captured_request = captured_request.clone(); + let completion_response = completion_response.clone(); + let mut next_request_id = 0; + move |req| { + let captured_request = captured_request.clone(); + let completion_response = completion_response.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => { + let mut request_body = String::new(); + req.into_body().read_to_string(&mut request_body).await?; + *captured_request.lock() = + Some(serde_json::from_str(&request_body).unwrap()); + next_request_id += 1; + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: format!("request-{next_request_id}"), + output_excerpt: completion_response.lock().clone(), + }) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } + } + } + }); + + let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); + cx.update(|cx| { + RefreshLlmTokenListener::register(client.clone(), cx); + }); + let _server = FakeServer::for_client(42, &client, cx).await; + + let zeta = cx.new(|cx| { + let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx); + zeta.set_edit_prediction_model(ZetaEditPredictionModel::Zeta1); + + let worktrees = project.read(cx).worktrees(cx).collect::>(); + for worktree in worktrees { + let worktree_id = worktree.read(cx).id(); + zeta.get_or_init_zeta_project(project, cx) + .license_detection_watchers + .entry(worktree_id) + .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx))); + } + + zeta + }); + + (zeta, captured_request, completion_response) +} + +fn to_completion_edits( + iterator: impl IntoIterator, Arc)>, + buffer: &Entity, + cx: &App, +) -> Vec<(Range, Arc)> { + let buffer = buffer.read(cx); + iterator + .into_iter() + .map(|(range, text)| { + ( + buffer.anchor_after(range.start)..buffer.anchor_before(range.end), + text, + ) + }) + .collect() +} + +fn from_completion_edits( + editor_edits: &[(Range, Arc)], + buffer: &Entity, + cx: &App, +) -> Vec<(Range, Arc)> { + let buffer = buffer.read(cx); + editor_edits + .iter() + .map(|(range, text)| { + ( + range.start.to_offset(buffer)..range.end.to_offset(buffer), + text.clone(), + ) + }) + .collect() +} + +#[ctor::ctor] +fn init_logger() { + zlog::init_test(); +} diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml deleted file mode 100644 index 0b20f980feaa6c2e86b0d3a6b88150d27d06fab2..0000000000000000000000000000000000000000 --- a/crates/zeta2/Cargo.toml +++ /dev/null @@ -1,61 +0,0 @@ -[package] -name = "zeta2" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/zeta2.rs" - -[features] -eval-support = [] - -[dependencies] -anyhow.workspace = true -arrayvec.workspace = true -brotli.workspace = true -chrono.workspace = true -client.workspace = true -cloud_llm_client.workspace = true -cloud_zeta2_prompt.workspace = true -collections.workspace = true -edit_prediction.workspace = true -edit_prediction_context.workspace = true -feature_flags.workspace = true -futures.workspace = true -gpui.workspace = true -indoc.workspace = true -language.workspace = true -language_model.workspace = true -log.workspace = true -lsp.workspace = true -open_ai.workspace = true -pretty_assertions.workspace = true -project.workspace = true -release_channel.workspace = true -semver.workspace = true -serde.workspace = true -serde_json.workspace = true -smol.workspace = true -strsim.workspace = true -thiserror.workspace = true -util.workspace = true -uuid.workspace = true -workspace.workspace = true -worktree.workspace = true - -[dev-dependencies] -clock = { workspace = true, features = ["test-support"] } -cloud_llm_client = { workspace = true, features = ["test-support"] } -gpui = { workspace = true, features = ["test-support"] } -lsp.workspace = true -indoc.workspace = true -language = { workspace = true, features = ["test-support"] } -language_model = { workspace = true, features = ["test-support"] } -project = { workspace = true, features = ["test-support"] } -settings = { workspace = true, features = ["test-support"] } -zlog.workspace = true diff --git a/crates/zeta2/LICENSE-GPL b/crates/zeta2/LICENSE-GPL deleted file mode 120000 index 89e542f750cd3860a0598eff0dc34b56d7336dc4..0000000000000000000000000000000000000000 --- a/crates/zeta2/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs deleted file mode 100644 index 255b294d7cc25fade197c3a50d39130bc6bb99c5..0000000000000000000000000000000000000000 --- a/crates/zeta2/src/zeta2.rs +++ /dev/null @@ -1,2968 +0,0 @@ -use anyhow::{Context as _, Result, anyhow, bail}; -use arrayvec::ArrayVec; -use chrono::TimeDelta; -use client::{Client, EditPredictionUsage, UserStore}; -use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature}; -use cloud_llm_client::{ - AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, - ZED_VERSION_HEADER_NAME, -}; -use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; -use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES}; -use collections::HashMap; -use edit_prediction_context::{ - DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, - EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line, - SyntaxIndex, SyntaxIndexState, -}; -use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; -use futures::AsyncReadExt as _; -use futures::channel::{mpsc, oneshot}; -use gpui::http_client::{AsyncBody, Method}; -use gpui::{ - App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, - http_client, prelude::*, -}; -use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint}; -use language::{BufferSnapshot, OffsetRangeExt}; -use language_model::{LlmApiToken, RefreshLlmTokenListener}; -use lsp::DiagnosticSeverity; -use open_ai::FunctionDefinition; -use project::{Project, ProjectPath}; -use release_channel::AppVersion; -use semver::Version; -use serde::de::DeserializeOwned; -use std::collections::{VecDeque, hash_map}; - -use std::fmt::Write; -use std::ops::Range; -use std::path::Path; -use std::str::FromStr; -use std::sync::{Arc, LazyLock}; -use std::time::{Duration, Instant}; -use std::{env, mem}; -use thiserror::Error; -use util::rel_path::RelPathBuf; -use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt}; -use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; - -pub mod assemble_excerpts; -mod prediction; -mod provider; -pub mod retrieval_search; -mod sweep_ai; -pub mod udiff; -mod xml_edits; - -use crate::assemble_excerpts::assemble_excerpts; -pub use crate::prediction::EditPrediction; -pub use crate::prediction::EditPredictionId; -pub use provider::ZetaEditPredictionProvider; - -/// Maximum number of events to track. -const EVENT_COUNT_MAX_SWEEP: usize = 6; -const EVENT_COUNT_MAX_ZETA: usize = 16; -const CHANGE_GROUPING_LINE_SPAN: u32 = 8; - -pub struct SweepFeatureFlag; - -impl FeatureFlag for SweepFeatureFlag { - const NAME: &str = "sweep-ai"; -} -pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions { - max_bytes: 512, - min_bytes: 128, - target_before_cursor_over_total_bytes: 0.5, -}; - -pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = - ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS); - -pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions { - excerpt: DEFAULT_EXCERPT_OPTIONS, -}; - -pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions = - EditPredictionContextOptions { - use_imports: true, - max_retrieved_declarations: 0, - excerpt: DEFAULT_EXCERPT_OPTIONS, - score: EditPredictionScoreOptions { - omit_excerpt_overlaps: true, - }, - }; - -pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { - context: DEFAULT_CONTEXT_OPTIONS, - max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES, - max_diagnostic_bytes: 2048, - prompt_format: PromptFormat::DEFAULT, - file_indexing_parallelism: 1, - buffer_change_grouping_interval: Duration::from_secs(1), -}; - -static USE_OLLAMA: LazyLock = - LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty())); -static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock = LazyLock::new(|| { - env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA { - "qwen3-coder:30b".to_string() - } else { - "yqvev8r3".to_string() - }) -}); -static EDIT_PREDICTIONS_MODEL_ID: LazyLock = LazyLock::new(|| { - match env::var("ZED_ZETA2_MODEL").as_deref() { - Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten - Ok(model) => model, - Err(_) if *USE_OLLAMA => "qwen3-coder:30b", - Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten - } - .to_string() -}); -static PREDICT_EDITS_URL: LazyLock> = LazyLock::new(|| { - env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| { - if *USE_OLLAMA { - Some("http://localhost:11434/v1/chat/completions".into()) - } else { - None - } - }) -}); - -pub struct Zeta2FeatureFlag; - -impl FeatureFlag for Zeta2FeatureFlag { - const NAME: &'static str = "zeta2"; - - fn enabled_for_staff() -> bool { - false - } -} - -#[derive(Clone)] -struct ZetaGlobal(Entity); - -impl Global for ZetaGlobal {} - -pub struct Zeta { - client: Arc, - user_store: Entity, - llm_token: LlmApiToken, - _llm_token_subscription: Subscription, - projects: HashMap, - options: ZetaOptions, - update_required: bool, - debug_tx: Option>, - #[cfg(feature = "eval-support")] - eval_cache: Option>, - edit_prediction_model: ZetaEditPredictionModel, - sweep_api_token: Option, - sweep_ai_debug_info: Arc, -} - -#[derive(PartialEq, Eq)] -pub enum ZetaEditPredictionModel { - ZedCloud, - Sweep, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct ZetaOptions { - pub context: ContextMode, - pub max_prompt_bytes: usize, - pub max_diagnostic_bytes: usize, - pub prompt_format: predict_edits_v3::PromptFormat, - pub file_indexing_parallelism: usize, - pub buffer_change_grouping_interval: Duration, -} - -#[derive(Debug, Clone, PartialEq)] -pub enum ContextMode { - Agentic(AgenticContextOptions), - Syntax(EditPredictionContextOptions), -} - -#[derive(Debug, Clone, PartialEq)] -pub struct AgenticContextOptions { - pub excerpt: EditPredictionExcerptOptions, -} - -impl ContextMode { - pub fn excerpt(&self) -> &EditPredictionExcerptOptions { - match self { - ContextMode::Agentic(options) => &options.excerpt, - ContextMode::Syntax(options) => &options.excerpt, - } - } -} - -#[derive(Debug)] -pub enum ZetaDebugInfo { - ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo), - SearchQueriesGenerated(ZetaSearchQueryDebugInfo), - SearchQueriesExecuted(ZetaContextRetrievalDebugInfo), - ContextRetrievalFinished(ZetaContextRetrievalDebugInfo), - EditPredictionRequested(ZetaEditPredictionDebugInfo), -} - -#[derive(Debug)] -pub struct ZetaContextRetrievalStartedDebugInfo { - pub project: Entity, - pub timestamp: Instant, - pub search_prompt: String, -} - -#[derive(Debug)] -pub struct ZetaContextRetrievalDebugInfo { - pub project: Entity, - pub timestamp: Instant, -} - -#[derive(Debug)] -pub struct ZetaEditPredictionDebugInfo { - pub request: predict_edits_v3::PredictEditsRequest, - pub retrieval_time: TimeDelta, - pub buffer: WeakEntity, - pub position: language::Anchor, - pub local_prompt: Result, - pub response_rx: oneshot::Receiver<(Result, TimeDelta)>, -} - -#[derive(Debug)] -pub struct ZetaSearchQueryDebugInfo { - pub project: Entity, - pub timestamp: Instant, - pub search_queries: Vec, -} - -pub type RequestDebugInfo = predict_edits_v3::DebugInfo; - -struct ZetaProject { - syntax_index: Option>, - events: VecDeque, - recent_paths: VecDeque, - registered_buffers: HashMap, - current_prediction: Option, - next_pending_prediction_id: usize, - pending_predictions: ArrayVec, - last_prediction_refresh: Option<(EntityId, Instant)>, - context: Option, Vec>>>, - refresh_context_task: Option>>>, - refresh_context_debounce_task: Option>>, - refresh_context_timestamp: Option, - _subscription: gpui::Subscription, -} - -#[derive(Debug, Clone)] -struct CurrentEditPrediction { - pub requested_by: PredictionRequestedBy, - pub prediction: EditPrediction, -} - -impl CurrentEditPrediction { - fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool { - let Some(new_edits) = self - .prediction - .interpolate(&self.prediction.buffer.read(cx)) - else { - return false; - }; - - if self.prediction.buffer != old_prediction.prediction.buffer { - return true; - } - - let Some(old_edits) = old_prediction - .prediction - .interpolate(&old_prediction.prediction.buffer.read(cx)) - else { - return true; - }; - - let requested_by_buffer_id = self.requested_by.buffer_id(); - - // This reduces the occurrence of UI thrash from replacing edits - // - // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits. - if requested_by_buffer_id == Some(self.prediction.buffer.entity_id()) - && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id()) - && old_edits.len() == 1 - && new_edits.len() == 1 - { - let (old_range, old_text) = &old_edits[0]; - let (new_range, new_text) = &new_edits[0]; - new_range == old_range && new_text.starts_with(old_text.as_ref()) - } else { - true - } - } -} - -#[derive(Debug, Clone)] -enum PredictionRequestedBy { - DiagnosticsUpdate, - Buffer(EntityId), -} - -impl PredictionRequestedBy { - pub fn buffer_id(&self) -> Option { - match self { - PredictionRequestedBy::DiagnosticsUpdate => None, - PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id), - } - } -} - -struct PendingPrediction { - id: usize, - _task: Task<()>, -} - -/// A prediction from the perspective of a buffer. -#[derive(Debug)] -enum BufferEditPrediction<'a> { - Local { prediction: &'a EditPrediction }, - Jump { prediction: &'a EditPrediction }, -} - -struct RegisteredBuffer { - snapshot: BufferSnapshot, - _subscriptions: [gpui::Subscription; 2], -} - -#[derive(Clone)] -pub enum Event { - BufferChange { - old_snapshot: BufferSnapshot, - new_snapshot: BufferSnapshot, - end_edit_anchor: Option, - timestamp: Instant, - }, -} - -impl Event { - pub fn to_request_event(&self, cx: &App) -> Option { - match self { - Event::BufferChange { - old_snapshot, - new_snapshot, - .. - } => { - let path = new_snapshot.file().map(|f| f.full_path(cx)); - - let old_path = old_snapshot.file().and_then(|f| { - let old_path = f.full_path(cx); - if Some(&old_path) != path.as_ref() { - Some(old_path) - } else { - None - } - }); - - // TODO [zeta2] move to bg? - let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text()); - - if path == old_path && diff.is_empty() { - None - } else { - Some(predict_edits_v3::Event::BufferChange { - old_path, - path, - diff, - //todo: Actually detect if this edit was predicted or not - predicted: false, - }) - } - } - } - } - - pub fn project_path(&self, cx: &App) -> Option { - match self { - Event::BufferChange { new_snapshot, .. } => new_snapshot - .file() - .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)), - } - } -} - -impl Zeta { - pub fn try_global(cx: &App) -> Option> { - cx.try_global::().map(|global| global.0.clone()) - } - - pub fn global( - client: &Arc, - user_store: &Entity, - cx: &mut App, - ) -> Entity { - cx.try_global::() - .map(|global| global.0.clone()) - .unwrap_or_else(|| { - let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx)); - cx.set_global(ZetaGlobal(zeta.clone())); - zeta - }) - } - - pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { - let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - - Self { - projects: HashMap::default(), - client, - user_store, - options: DEFAULT_OPTIONS, - llm_token: LlmApiToken::default(), - _llm_token_subscription: cx.subscribe( - &refresh_llm_token_listener, - |this, _listener, _event, cx| { - let client = this.client.clone(); - let llm_token = this.llm_token.clone(); - cx.spawn(async move |_this, _cx| { - llm_token.refresh(&client).await?; - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - }, - ), - update_required: false, - debug_tx: None, - #[cfg(feature = "eval-support")] - eval_cache: None, - edit_prediction_model: ZetaEditPredictionModel::ZedCloud, - sweep_api_token: std::env::var("SWEEP_AI_TOKEN") - .context("No SWEEP_AI_TOKEN environment variable set") - .log_err(), - sweep_ai_debug_info: sweep_ai::debug_info(cx), - } - } - - pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) { - self.edit_prediction_model = model; - } - - pub fn has_sweep_api_token(&self) -> bool { - self.sweep_api_token.is_some() - } - - #[cfg(feature = "eval-support")] - pub fn with_eval_cache(&mut self, cache: Arc) { - self.eval_cache = Some(cache); - } - - pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver { - let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded(); - self.debug_tx = Some(debug_watch_tx); - debug_watch_rx - } - - pub fn options(&self) -> &ZetaOptions { - &self.options - } - - pub fn set_options(&mut self, options: ZetaOptions) { - self.options = options; - } - - pub fn clear_history(&mut self) { - for zeta_project in self.projects.values_mut() { - zeta_project.events.clear(); - } - } - - pub fn history_for_project( - &self, - project: &Entity, - ) -> impl DoubleEndedIterator { - self.projects - .get(&project.entity_id()) - .map(|project| project.events.iter()) - .into_iter() - .flatten() - } - - pub fn context_for_project( - &self, - project: &Entity, - ) -> impl Iterator, &[Range])> { - self.projects - .get(&project.entity_id()) - .and_then(|project| { - Some( - project - .context - .as_ref()? - .iter() - .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())), - ) - }) - .into_iter() - .flatten() - } - - pub fn usage(&self, cx: &App) -> Option { - if self.edit_prediction_model == ZetaEditPredictionModel::ZedCloud { - self.user_store.read(cx).edit_prediction_usage() - } else { - None - } - } - - pub fn register_project(&mut self, project: &Entity, cx: &mut Context) { - self.get_or_init_zeta_project(project, cx); - } - - pub fn register_buffer( - &mut self, - buffer: &Entity, - project: &Entity, - cx: &mut Context, - ) { - let zeta_project = self.get_or_init_zeta_project(project, cx); - Self::register_buffer_impl(zeta_project, buffer, project, cx); - } - - fn get_or_init_zeta_project( - &mut self, - project: &Entity, - cx: &mut Context, - ) -> &mut ZetaProject { - self.projects - .entry(project.entity_id()) - .or_insert_with(|| ZetaProject { - syntax_index: if let ContextMode::Syntax(_) = &self.options.context { - Some(cx.new(|cx| { - SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx) - })) - } else { - None - }, - events: VecDeque::new(), - recent_paths: VecDeque::new(), - registered_buffers: HashMap::default(), - current_prediction: None, - pending_predictions: ArrayVec::new(), - next_pending_prediction_id: 0, - last_prediction_refresh: None, - context: None, - refresh_context_task: None, - refresh_context_debounce_task: None, - refresh_context_timestamp: None, - _subscription: cx.subscribe(&project, Self::handle_project_event), - }) - } - - fn handle_project_event( - &mut self, - project: Entity, - event: &project::Event, - cx: &mut Context, - ) { - // TODO [zeta2] init with recent paths - match event { - project::Event::ActiveEntryChanged(Some(active_entry_id)) => { - let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { - return; - }; - let path = project.read(cx).path_for_entry(*active_entry_id, cx); - if let Some(path) = path { - if let Some(ix) = zeta_project - .recent_paths - .iter() - .position(|probe| probe == &path) - { - zeta_project.recent_paths.remove(ix); - } - zeta_project.recent_paths.push_front(path); - } - } - project::Event::DiagnosticsUpdated { .. } => { - self.refresh_prediction_from_diagnostics(project, cx); - } - _ => (), - } - } - - fn register_buffer_impl<'a>( - zeta_project: &'a mut ZetaProject, - buffer: &Entity, - project: &Entity, - cx: &mut Context, - ) -> &'a mut RegisteredBuffer { - let buffer_id = buffer.entity_id(); - match zeta_project.registered_buffers.entry(buffer_id) { - hash_map::Entry::Occupied(entry) => entry.into_mut(), - hash_map::Entry::Vacant(entry) => { - let snapshot = buffer.read(cx).snapshot(); - let project_entity_id = project.entity_id(); - entry.insert(RegisteredBuffer { - snapshot, - _subscriptions: [ - cx.subscribe(buffer, { - let project = project.downgrade(); - move |this, buffer, event, cx| { - if let language::BufferEvent::Edited = event - && let Some(project) = project.upgrade() - { - this.report_changes_for_buffer(&buffer, &project, cx); - } - } - }), - cx.observe_release(buffer, move |this, _buffer, _cx| { - let Some(zeta_project) = this.projects.get_mut(&project_entity_id) - else { - return; - }; - zeta_project.registered_buffers.remove(&buffer_id); - }), - ], - }) - } - } - } - - fn report_changes_for_buffer( - &mut self, - buffer: &Entity, - project: &Entity, - cx: &mut Context, - ) { - let event_count_max = match self.edit_prediction_model { - ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA, - ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP, - }; - - let sweep_ai_project = self.get_or_init_zeta_project(project, cx); - let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx); - - let new_snapshot = buffer.read(cx).snapshot(); - if new_snapshot.version == registered_buffer.snapshot.version { - return; - } - - let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); - let end_edit_anchor = new_snapshot - .anchored_edits_since::(&old_snapshot.version) - .last() - .map(|(_, range)| range.end); - let events = &mut sweep_ai_project.events; - - if let Some(Event::BufferChange { - new_snapshot: last_new_snapshot, - end_edit_anchor: last_end_edit_anchor, - .. - }) = events.back_mut() - { - let is_next_snapshot_of_same_buffer = old_snapshot.remote_id() - == last_new_snapshot.remote_id() - && old_snapshot.version == last_new_snapshot.version; - - let should_coalesce = is_next_snapshot_of_same_buffer - && end_edit_anchor - .as_ref() - .zip(last_end_edit_anchor.as_ref()) - .is_some_and(|(a, b)| { - let a = a.to_point(&new_snapshot); - let b = b.to_point(&new_snapshot); - a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN - }); - - if should_coalesce { - *last_end_edit_anchor = end_edit_anchor; - *last_new_snapshot = new_snapshot; - return; - } - } - - if events.len() >= event_count_max { - events.pop_front(); - } - - events.push_back(Event::BufferChange { - old_snapshot, - new_snapshot, - end_edit_anchor, - timestamp: Instant::now(), - }); - } - - fn current_prediction_for_buffer( - &self, - buffer: &Entity, - project: &Entity, - cx: &App, - ) -> Option> { - let project_state = self.projects.get(&project.entity_id())?; - - let CurrentEditPrediction { - requested_by, - prediction, - } = project_state.current_prediction.as_ref()?; - - if prediction.targets_buffer(buffer.read(cx)) { - Some(BufferEditPrediction::Local { prediction }) - } else { - let show_jump = match requested_by { - PredictionRequestedBy::Buffer(requested_by_buffer_id) => { - requested_by_buffer_id == &buffer.entity_id() - } - PredictionRequestedBy::DiagnosticsUpdate => true, - }; - - if show_jump { - Some(BufferEditPrediction::Jump { prediction }) - } else { - None - } - } - } - - fn accept_current_prediction(&mut self, project: &Entity, cx: &mut Context) { - if self.edit_prediction_model != ZetaEditPredictionModel::ZedCloud { - return; - } - - let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { - return; - }; - - let Some(prediction) = project_state.current_prediction.take() else { - return; - }; - let request_id = prediction.prediction.id.to_string(); - project_state.pending_predictions.clear(); - - let client = self.client.clone(); - let llm_token = self.llm_token.clone(); - let app_version = AppVersion::global(cx); - cx.spawn(async move |this, cx| { - let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") { - http_client::Url::parse(&predict_edits_url)? - } else { - client - .http_client() - .build_zed_llm_url("/predict_edits/accept", &[])? - }; - - let response = cx - .background_spawn(Self::send_api_request::<()>( - move |builder| { - let req = builder.uri(url.as_ref()).body( - serde_json::to_string(&AcceptEditPredictionBody { - request_id: request_id.clone(), - })? - .into(), - ); - Ok(req?) - }, - client, - llm_token, - app_version, - )) - .await; - - Self::handle_api_response(&this, response, cx)?; - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } - - fn discard_current_prediction(&mut self, project: &Entity) { - if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { - project_state.current_prediction.take(); - project_state.pending_predictions.clear(); - }; - } - - fn is_refreshing(&self, project: &Entity) -> bool { - self.projects - .get(&project.entity_id()) - .is_some_and(|project_state| !project_state.pending_predictions.is_empty()) - } - - pub fn refresh_prediction_from_buffer( - &mut self, - project: Entity, - buffer: Entity, - position: language::Anchor, - cx: &mut Context, - ) { - self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| { - let Some(request_task) = this - .update(cx, |this, cx| { - this.request_prediction(&project, &buffer, position, cx) - }) - .log_err() - else { - return Task::ready(anyhow::Ok(())); - }; - - let project = project.clone(); - cx.spawn(async move |cx| { - if let Some(prediction) = request_task.await? { - this.update(cx, |this, cx| { - let project_state = this - .projects - .get_mut(&project.entity_id()) - .context("Project not found")?; - - let new_prediction = CurrentEditPrediction { - requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()), - prediction: prediction, - }; - - if project_state - .current_prediction - .as_ref() - .is_none_or(|old_prediction| { - new_prediction.should_replace_prediction(&old_prediction, cx) - }) - { - project_state.current_prediction = Some(new_prediction); - cx.notify(); - } - anyhow::Ok(()) - })??; - } - Ok(()) - }) - }) - } - - pub fn refresh_prediction_from_diagnostics( - &mut self, - project: Entity, - cx: &mut Context, - ) { - let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { - return; - }; - - // Prefer predictions from buffer - if zeta_project.current_prediction.is_some() { - return; - }; - - self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| { - let Some(open_buffer_task) = project - .update(cx, |project, cx| { - project - .active_entry() - .and_then(|entry| project.path_for_entry(entry, cx)) - .map(|path| project.open_buffer(path, cx)) - }) - .log_err() - .flatten() - else { - return Task::ready(anyhow::Ok(())); - }; - - cx.spawn(async move |cx| { - let active_buffer = open_buffer_task.await?; - let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - - let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( - active_buffer, - &snapshot, - Default::default(), - Default::default(), - &project, - cx, - ) - .await? - else { - return anyhow::Ok(()); - }; - - let Some(prediction) = this - .update(cx, |this, cx| { - this.request_prediction(&project, &jump_buffer, jump_position, cx) - })? - .await? - else { - return anyhow::Ok(()); - }; - - this.update(cx, |this, cx| { - if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { - zeta_project.current_prediction.get_or_insert_with(|| { - cx.notify(); - CurrentEditPrediction { - requested_by: PredictionRequestedBy::DiagnosticsUpdate, - prediction, - } - }); - } - })?; - - anyhow::Ok(()) - }) - }); - } - - #[cfg(not(test))] - pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); - #[cfg(test)] - pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO; - - fn queue_prediction_refresh( - &mut self, - project: Entity, - throttle_entity: EntityId, - cx: &mut Context, - do_refresh: impl FnOnce(WeakEntity, &mut AsyncApp) -> Task> + 'static, - ) { - let zeta_project = self.get_or_init_zeta_project(&project, cx); - let pending_prediction_id = zeta_project.next_pending_prediction_id; - zeta_project.next_pending_prediction_id += 1; - let last_request = zeta_project.last_prediction_refresh; - - // TODO report cancelled requests like in zeta1 - let task = cx.spawn(async move |this, cx| { - if let Some((last_entity, last_timestamp)) = last_request - && throttle_entity == last_entity - && let Some(timeout) = - (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now()) - { - cx.background_executor().timer(timeout).await; - } - - do_refresh(this.clone(), cx).await.log_err(); - - this.update(cx, |this, cx| { - let zeta_project = this.get_or_init_zeta_project(&project, cx); - - if zeta_project.pending_predictions[0].id == pending_prediction_id { - zeta_project.pending_predictions.remove(0); - } else { - zeta_project.pending_predictions.clear(); - } - - cx.notify(); - }) - .ok(); - }); - - if zeta_project.pending_predictions.len() <= 1 { - zeta_project.pending_predictions.push(PendingPrediction { - id: pending_prediction_id, - _task: task, - }); - } else if zeta_project.pending_predictions.len() == 2 { - zeta_project.pending_predictions.pop(); - zeta_project.pending_predictions.push(PendingPrediction { - id: pending_prediction_id, - _task: task, - }); - } - } - - pub fn request_prediction( - &mut self, - project: &Entity, - active_buffer: &Entity, - position: language::Anchor, - cx: &mut Context, - ) -> Task>> { - match self.edit_prediction_model { - ZetaEditPredictionModel::ZedCloud => { - self.request_prediction_with_zed_cloud(project, active_buffer, position, cx) - } - ZetaEditPredictionModel::Sweep => { - self.request_prediction_with_sweep(project, active_buffer, position, true, cx) - } - } - } - - fn request_prediction_with_sweep( - &mut self, - project: &Entity, - active_buffer: &Entity, - position: language::Anchor, - allow_jump: bool, - cx: &mut Context, - ) -> Task>> { - let snapshot = active_buffer.read(cx).snapshot(); - let debug_info = self.sweep_ai_debug_info.clone(); - let Some(api_token) = self.sweep_api_token.clone() else { - return Task::ready(Ok(None)); - }; - let full_path: Arc = snapshot - .file() - .map(|file| file.full_path(cx)) - .unwrap_or_else(|| "untitled".into()) - .into(); - - let project_file = project::File::from_dyn(snapshot.file()); - let repo_name = project_file - .map(|file| file.worktree.read(cx).root_name_str()) - .unwrap_or("untitled") - .into(); - let offset = position.to_offset(&snapshot); - - let project_state = self.get_or_init_zeta_project(project, cx); - let events = project_state.events.clone(); - let has_events = !events.is_empty(); - let recent_buffers = project_state.recent_paths.iter().cloned(); - let http_client = cx.http_client(); - - let recent_buffer_snapshots = recent_buffers - .filter_map(|project_path| { - let buffer = project.read(cx).get_open_buffer(&project_path, cx)?; - if active_buffer == &buffer { - None - } else { - Some(buffer.read(cx).snapshot()) - } - }) - .take(3) - .collect::>(); - - const DIAGNOSTIC_LINES_RANGE: u32 = 20; - - let cursor_point = position.to_point(&snapshot); - let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE); - let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE; - let diagnostic_search_range = - Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0); - - let result = cx.background_spawn({ - let snapshot = snapshot.clone(); - let diagnostic_search_range = diagnostic_search_range.clone(); - async move { - let text = snapshot.text(); - - let mut recent_changes = String::new(); - for event in events { - sweep_ai::write_event(event, &mut recent_changes).unwrap(); - } - - let mut file_chunks = recent_buffer_snapshots - .into_iter() - .map(|snapshot| { - let end_point = Point::new(30, 0).min(snapshot.max_point()); - sweep_ai::FileChunk { - content: snapshot.text_for_range(Point::zero()..end_point).collect(), - file_path: snapshot - .file() - .map(|f| f.path().as_unix_str()) - .unwrap_or("untitled") - .to_string(), - start_line: 0, - end_line: end_point.row as usize, - timestamp: snapshot.file().and_then(|file| { - Some( - file.disk_state() - .mtime()? - .to_seconds_and_nanos_for_persistence()? - .0, - ) - }), - } - }) - .collect::>(); - - let diagnostic_entries = - snapshot.diagnostics_in_range(diagnostic_search_range, false); - let mut diagnostic_content = String::new(); - let mut diagnostic_count = 0; - - for entry in diagnostic_entries { - let start_point: Point = entry.range.start; - - let severity = match entry.diagnostic.severity { - DiagnosticSeverity::ERROR => "error", - DiagnosticSeverity::WARNING => "warning", - DiagnosticSeverity::INFORMATION => "info", - DiagnosticSeverity::HINT => "hint", - _ => continue, - }; - - diagnostic_count += 1; - - writeln!( - &mut diagnostic_content, - "{} at line {}: {}", - severity, - start_point.row + 1, - entry.diagnostic.message - )?; - } - - if !diagnostic_content.is_empty() { - file_chunks.push(sweep_ai::FileChunk { - file_path: format!("Diagnostics for {}", full_path.display()), - start_line: 0, - end_line: diagnostic_count, - content: diagnostic_content, - timestamp: None, - }); - } - - let request_body = sweep_ai::AutocompleteRequest { - debug_info, - repo_name, - file_path: full_path.clone(), - file_contents: text.clone(), - original_file_contents: text, - cursor_position: offset, - recent_changes: recent_changes.clone(), - changes_above_cursor: true, - multiple_suggestions: false, - branch: None, - file_chunks, - retrieval_chunks: vec![], - recent_user_actions: vec![], - // TODO - privacy_mode_enabled: false, - }; - - let mut buf: Vec = Vec::new(); - let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22); - serde_json::to_writer(writer, &request_body)?; - let body: AsyncBody = buf.into(); - - const SWEEP_API_URL: &str = - "https://autocomplete.sweep.dev/backend/next_edit_autocomplete"; - - let request = http_client::Request::builder() - .uri(SWEEP_API_URL) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_token)) - .header("Connection", "keep-alive") - .header("Content-Encoding", "br") - .method(Method::POST) - .body(body)?; - - let mut response = http_client.send(request).await?; - - let mut body: Vec = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; - - if !response.status().is_success() { - anyhow::bail!( - "Request failed with status: {:?}\nBody: {}", - response.status(), - String::from_utf8_lossy(&body), - ); - }; - - let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?; - - let old_text = snapshot - .text_for_range(response.start_index..response.end_index) - .collect::(); - let edits = language::text_diff(&old_text, &response.completion) - .into_iter() - .map(|(range, text)| { - ( - snapshot.anchor_after(response.start_index + range.start) - ..snapshot.anchor_before(response.start_index + range.end), - text, - ) - }) - .collect::>(); - - anyhow::Ok((response.autocomplete_id, edits, snapshot)) - } - }); - - let buffer = active_buffer.clone(); - let project = project.clone(); - let active_buffer = active_buffer.clone(); - - cx.spawn(async move |this, cx| { - let (id, edits, old_snapshot) = result.await?; - - if edits.is_empty() { - if has_events - && allow_jump - && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( - active_buffer, - &snapshot, - diagnostic_search_range, - cursor_point, - &project, - cx, - ) - .await? - { - return this - .update(cx, |this, cx| { - this.request_prediction_with_sweep( - &project, - &jump_buffer, - jump_position, - false, - cx, - ) - })? - .await; - } - - return anyhow::Ok(None); - } - - let Some((edits, new_snapshot, preview_task)) = - buffer.read_with(cx, |buffer, cx| { - let new_snapshot = buffer.snapshot(); - - let edits: Arc<[(Range, Arc)]> = - edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)? - .into(); - let preview_task = buffer.preview_edits(edits.clone(), cx); - - Some((edits, new_snapshot, preview_task)) - })? - else { - return anyhow::Ok(None); - }; - - let prediction = EditPrediction { - id: EditPredictionId(id.into()), - edits, - snapshot: new_snapshot, - edit_preview: preview_task.await, - buffer, - }; - - anyhow::Ok(Some(prediction)) - }) - } - - async fn next_diagnostic_location( - active_buffer: Entity, - active_buffer_snapshot: &BufferSnapshot, - active_buffer_diagnostic_search_range: Range, - active_buffer_cursor_point: Point, - project: &Entity, - cx: &mut AsyncApp, - ) -> Result, language::Anchor)>> { - // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request - let mut jump_location = active_buffer_snapshot - .diagnostic_groups(None) - .into_iter() - .filter_map(|(_, group)| { - let range = &group.entries[group.primary_ix] - .range - .to_point(&active_buffer_snapshot); - if range.overlaps(&active_buffer_diagnostic_search_range) { - None - } else { - Some(range.start) - } - }) - .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row)) - .map(|position| { - ( - active_buffer.clone(), - active_buffer_snapshot.anchor_before(position), - ) - }); - - if jump_location.is_none() { - let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| { - let file = buffer.file()?; - - Some(ProjectPath { - worktree_id: file.worktree_id(cx), - path: file.path().clone(), - }) - })?; - - let buffer_task = project.update(cx, |project, cx| { - let (path, _, _) = project - .diagnostic_summaries(false, cx) - .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref()) - .max_by_key(|(path, _, _)| { - // find the buffer with errors that shares most parent directories - path.path - .components() - .zip( - active_buffer_path - .as_ref() - .map(|p| p.path.components()) - .unwrap_or_default(), - ) - .take_while(|(a, b)| a == b) - .count() - })?; - - Some(project.open_buffer(path, cx)) - })?; - - if let Some(buffer_task) = buffer_task { - let closest_buffer = buffer_task.await?; - - jump_location = closest_buffer - .read_with(cx, |buffer, _cx| { - buffer - .buffer_diagnostics(None) - .into_iter() - .min_by_key(|entry| entry.diagnostic.severity) - .map(|entry| entry.range.start) - })? - .map(|position| (closest_buffer, position)); - } - } - - anyhow::Ok(jump_location) - } - - fn request_prediction_with_zed_cloud( - &mut self, - project: &Entity, - active_buffer: &Entity, - position: language::Anchor, - cx: &mut Context, - ) -> Task>> { - let project_state = self.projects.get(&project.entity_id()); - - let index_state = project_state.and_then(|state| { - state - .syntax_index - .as_ref() - .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone())) - }); - let options = self.options.clone(); - let active_snapshot = active_buffer.read(cx).snapshot(); - let Some(excerpt_path) = active_snapshot - .file() - .map(|path| -> Arc { path.full_path(cx).into() }) - else { - return Task::ready(Err(anyhow!("No file path for excerpt"))); - }; - let client = self.client.clone(); - let llm_token = self.llm_token.clone(); - let app_version = AppVersion::global(cx); - let worktree_snapshots = project - .read(cx) - .worktrees(cx) - .map(|worktree| worktree.read(cx).snapshot()) - .collect::>(); - let debug_tx = self.debug_tx.clone(); - - let events = project_state - .map(|state| { - state - .events - .iter() - .filter_map(|event| event.to_request_event(cx)) - .collect::>() - }) - .unwrap_or_default(); - - let diagnostics = active_snapshot.diagnostic_sets().clone(); - - let parent_abs_path = - project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| { - let mut path = f.worktree.read(cx).absolutize(&f.path); - if path.pop() { Some(path) } else { None } - }); - - // TODO data collection - let can_collect_data = cx.is_staff(); - - let empty_context_files = HashMap::default(); - let context_files = project_state - .and_then(|project_state| project_state.context.as_ref()) - .unwrap_or(&empty_context_files); - - #[cfg(feature = "eval-support")] - let parsed_fut = futures::future::join_all( - context_files - .keys() - .map(|buffer| buffer.read(cx).parsing_idle()), - ); - - let mut included_files = context_files - .iter() - .filter_map(|(buffer_entity, ranges)| { - let buffer = buffer_entity.read(cx); - Some(( - buffer_entity.clone(), - buffer.snapshot(), - buffer.file()?.full_path(cx).into(), - ranges.clone(), - )) - }) - .collect::>(); - - included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| { - (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len())) - }); - - #[cfg(feature = "eval-support")] - let eval_cache = self.eval_cache.clone(); - - let request_task = cx.background_spawn({ - let active_buffer = active_buffer.clone(); - async move { - #[cfg(feature = "eval-support")] - parsed_fut.await; - - let index_state = if let Some(index_state) = index_state { - Some(index_state.lock_owned().await) - } else { - None - }; - - let cursor_offset = position.to_offset(&active_snapshot); - let cursor_point = cursor_offset.to_point(&active_snapshot); - - let before_retrieval = chrono::Utc::now(); - - let (diagnostic_groups, diagnostic_groups_truncated) = - Self::gather_nearby_diagnostics( - cursor_offset, - &diagnostics, - &active_snapshot, - options.max_diagnostic_bytes, - ); - - let cloud_request = match options.context { - ContextMode::Agentic(context_options) => { - let Some(excerpt) = EditPredictionExcerpt::select_from_buffer( - cursor_point, - &active_snapshot, - &context_options.excerpt, - index_state.as_deref(), - ) else { - return Ok((None, None)); - }; - - let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start) - ..active_snapshot.anchor_before(excerpt.range.end); - - if let Some(buffer_ix) = - included_files.iter().position(|(_, snapshot, _, _)| { - snapshot.remote_id() == active_snapshot.remote_id() - }) - { - let (_, buffer, _, ranges) = &mut included_files[buffer_ix]; - ranges.push(excerpt_anchor_range); - retrieval_search::merge_anchor_ranges(ranges, buffer); - let last_ix = included_files.len() - 1; - included_files.swap(buffer_ix, last_ix); - } else { - included_files.push(( - active_buffer.clone(), - active_snapshot.clone(), - excerpt_path.clone(), - vec![excerpt_anchor_range], - )); - } - - let included_files = included_files - .iter() - .map(|(_, snapshot, path, ranges)| { - let ranges = ranges - .iter() - .map(|range| { - let point_range = range.to_point(&snapshot); - Line(point_range.start.row)..Line(point_range.end.row) - }) - .collect::>(); - let excerpts = assemble_excerpts(&snapshot, ranges); - predict_edits_v3::IncludedFile { - path: path.clone(), - max_row: Line(snapshot.max_point().row), - excerpts, - } - }) - .collect::>(); - - predict_edits_v3::PredictEditsRequest { - excerpt_path, - excerpt: String::new(), - excerpt_line_range: Line(0)..Line(0), - excerpt_range: 0..0, - cursor_point: predict_edits_v3::Point { - line: predict_edits_v3::Line(cursor_point.row), - column: cursor_point.column, - }, - included_files, - referenced_declarations: vec![], - events, - can_collect_data, - diagnostic_groups, - diagnostic_groups_truncated, - debug_info: debug_tx.is_some(), - prompt_max_bytes: Some(options.max_prompt_bytes), - prompt_format: options.prompt_format, - // TODO [zeta2] - signatures: vec![], - excerpt_parent: None, - git_info: None, - } - } - ContextMode::Syntax(context_options) => { - let Some(context) = EditPredictionContext::gather_context( - cursor_point, - &active_snapshot, - parent_abs_path.as_deref(), - &context_options, - index_state.as_deref(), - ) else { - return Ok((None, None)); - }; - - make_syntax_context_cloud_request( - excerpt_path, - context, - events, - can_collect_data, - diagnostic_groups, - diagnostic_groups_truncated, - None, - debug_tx.is_some(), - &worktree_snapshots, - index_state.as_deref(), - Some(options.max_prompt_bytes), - options.prompt_format, - ) - } - }; - - let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request); - - let retrieval_time = chrono::Utc::now() - before_retrieval; - - let debug_response_tx = if let Some(debug_tx) = &debug_tx { - let (response_tx, response_rx) = oneshot::channel(); - - debug_tx - .unbounded_send(ZetaDebugInfo::EditPredictionRequested( - ZetaEditPredictionDebugInfo { - request: cloud_request.clone(), - retrieval_time, - buffer: active_buffer.downgrade(), - local_prompt: match prompt_result.as_ref() { - Ok((prompt, _)) => Ok(prompt.clone()), - Err(err) => Err(err.to_string()), - }, - position, - response_rx, - }, - )) - .ok(); - Some(response_tx) - } else { - None - }; - - if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() { - if let Some(debug_response_tx) = debug_response_tx { - debug_response_tx - .send((Err("Request skipped".to_string()), TimeDelta::zero())) - .ok(); - } - anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set") - } - - let (prompt, _) = prompt_result?; - let generation_params = - cloud_zeta2_prompt::generation_params(cloud_request.prompt_format); - let request = open_ai::Request { - model: EDIT_PREDICTIONS_MODEL_ID.clone(), - messages: vec![open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(prompt), - }], - stream: false, - max_completion_tokens: None, - stop: generation_params.stop.unwrap_or_default(), - temperature: generation_params.temperature.unwrap_or(0.7), - tool_choice: None, - parallel_tool_calls: None, - tools: vec![], - prompt_cache_key: None, - reasoning_effort: None, - }; - - log::trace!("Sending edit prediction request"); - - let before_request = chrono::Utc::now(); - let response = Self::send_raw_llm_request( - request, - client, - llm_token, - app_version, - #[cfg(feature = "eval-support")] - eval_cache, - #[cfg(feature = "eval-support")] - EvalCacheEntryKind::Prediction, - ) - .await; - let request_time = chrono::Utc::now() - before_request; - - log::trace!("Got edit prediction response"); - - if let Some(debug_response_tx) = debug_response_tx { - debug_response_tx - .send(( - response - .as_ref() - .map_err(|err| err.to_string()) - .map(|response| response.0.clone()), - request_time, - )) - .ok(); - } - - let (res, usage) = response?; - let request_id = EditPredictionId(res.id.clone().into()); - let Some(mut output_text) = text_from_response(res) else { - return Ok((None, usage)); - }; - - if output_text.contains(CURSOR_MARKER) { - log::trace!("Stripping out {CURSOR_MARKER} from response"); - output_text = output_text.replace(CURSOR_MARKER, ""); - } - - let get_buffer_from_context = |path: &Path| { - included_files - .iter() - .find_map(|(_, buffer, probe_path, ranges)| { - if probe_path.as_ref() == path { - Some((buffer, ranges.as_slice())) - } else { - None - } - }) - }; - - let (edited_buffer_snapshot, edits) = match options.prompt_format { - PromptFormat::NumLinesUniDiff => { - // TODO: Implement parsing of multi-file diffs - crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? - } - PromptFormat::Minimal - | PromptFormat::MinimalQwen - | PromptFormat::SeedCoder1120 => { - if output_text.contains("--- a/\n+++ b/\nNo edits") { - let edits = vec![]; - (&active_snapshot, edits) - } else { - crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? - } - } - PromptFormat::OldTextNewText => { - crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context) - .await? - } - _ => { - bail!("unsupported prompt format {}", options.prompt_format) - } - }; - - let edited_buffer = included_files - .iter() - .find_map(|(buffer, snapshot, _, _)| { - if snapshot.remote_id() == edited_buffer_snapshot.remote_id() { - Some(buffer.clone()) - } else { - None - } - }) - .context("Failed to find buffer in included_buffers")?; - - anyhow::Ok(( - Some(( - request_id, - edited_buffer, - edited_buffer_snapshot.clone(), - edits, - )), - usage, - )) - } - }); - - cx.spawn({ - async move |this, cx| { - let Some((id, edited_buffer, edited_buffer_snapshot, edits)) = - Self::handle_api_response(&this, request_task.await, cx)? - else { - return Ok(None); - }; - - // TODO telemetry: duration, etc - Ok( - EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx) - .await, - ) - } - }) - } - - async fn send_raw_llm_request( - request: open_ai::Request, - client: Arc, - llm_token: LlmApiToken, - app_version: Version, - #[cfg(feature = "eval-support")] eval_cache: Option>, - #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind, - ) -> Result<(open_ai::Response, Option)> { - let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() { - http_client::Url::parse(&predict_edits_url)? - } else { - client - .http_client() - .build_zed_llm_url("/predict_edits/raw", &[])? - }; - - #[cfg(feature = "eval-support")] - let cache_key = if let Some(cache) = eval_cache { - use collections::FxHasher; - use std::hash::{Hash, Hasher}; - - let mut hasher = FxHasher::default(); - url.hash(&mut hasher); - let request_str = serde_json::to_string_pretty(&request)?; - request_str.hash(&mut hasher); - let hash = hasher.finish(); - - let key = (eval_cache_kind, hash); - if let Some(response_str) = cache.read(key) { - return Ok((serde_json::from_str(&response_str)?, None)); - } - - Some((cache, request_str, key)) - } else { - None - }; - - let (response, usage) = Self::send_api_request( - |builder| { - let req = builder - .uri(url.as_ref()) - .body(serde_json::to_string(&request)?.into()); - Ok(req?) - }, - client, - llm_token, - app_version, - ) - .await?; - - #[cfg(feature = "eval-support")] - if let Some((cache, request, key)) = cache_key { - cache.write(key, &request, &serde_json::to_string_pretty(&response)?); - } - - Ok((response, usage)) - } - - fn handle_api_response( - this: &WeakEntity, - response: Result<(T, Option)>, - cx: &mut gpui::AsyncApp, - ) -> Result { - match response { - Ok((data, usage)) => { - if let Some(usage) = usage { - this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); - }); - }) - .ok(); - } - Ok(data) - } - Err(err) => { - if err.is::() { - cx.update(|cx| { - this.update(cx, |this, _cx| { - this.update_required = true; - }) - .ok(); - - let error_message: SharedString = err.to_string().into(); - show_app_notification( - NotificationId::unique::(), - cx, - move |cx| { - cx.new(|cx| { - ErrorMessagePrompt::new(error_message.clone(), cx) - .with_link_button("Update Zed", "https://zed.dev/releases") - }) - }, - ); - }) - .ok(); - } - Err(err) - } - } - } - - async fn send_api_request( - build: impl Fn(http_client::http::request::Builder) -> Result>, - client: Arc, - llm_token: LlmApiToken, - app_version: Version, - ) -> Result<(Res, Option)> - where - Res: DeserializeOwned, - { - let http_client = client.http_client(); - let mut token = llm_token.acquire(&client).await?; - let mut did_retry = false; - - loop { - let request_builder = http_client::Request::builder().method(Method::POST); - - let request = build( - request_builder - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", token)) - .header(ZED_VERSION_HEADER_NAME, app_version.to_string()), - )?; - - let mut response = http_client.send(request).await?; - - if let Some(minimum_required_version) = response - .headers() - .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) - .and_then(|version| Version::from_str(version.to_str().ok()?).ok()) - { - anyhow::ensure!( - app_version >= minimum_required_version, - ZedUpdateRequiredError { - minimum_version: minimum_required_version - } - ); - } - - if response.status().is_success() { - let usage = EditPredictionUsage::from_headers(response.headers()).ok(); - - let mut body = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; - return Ok((serde_json::from_slice(&body)?, usage)); - } else if !did_retry - && response - .headers() - .get(EXPIRED_LLM_TOKEN_HEADER_NAME) - .is_some() - { - did_retry = true; - token = llm_token.refresh(&client).await?; - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - anyhow::bail!( - "Request failed with status: {:?}\nBody: {}", - response.status(), - body - ); - } - } - } - - pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10); - pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3); - - // Refresh the related excerpts when the user just beguns editing after - // an idle period, and after they pause editing. - fn refresh_context_if_needed( - &mut self, - project: &Entity, - buffer: &Entity, - cursor_position: language::Anchor, - cx: &mut Context, - ) { - if !matches!(&self.options().context, ContextMode::Agentic { .. }) { - return; - } - - let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { - return; - }; - - let now = Instant::now(); - let was_idle = zeta_project - .refresh_context_timestamp - .map_or(true, |timestamp| { - now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION - }); - zeta_project.refresh_context_timestamp = Some(now); - zeta_project.refresh_context_debounce_task = Some(cx.spawn({ - let buffer = buffer.clone(); - let project = project.clone(); - async move |this, cx| { - if was_idle { - log::debug!("refetching edit prediction context after idle"); - } else { - cx.background_executor() - .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION) - .await; - log::debug!("refetching edit prediction context after pause"); - } - this.update(cx, |this, cx| { - let task = this.refresh_context(project.clone(), buffer, cursor_position, cx); - - if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { - zeta_project.refresh_context_task = Some(task.log_err()); - }; - }) - .ok() - } - })); - } - - // Refresh the related excerpts asynchronously. Ensure the task runs to completion, - // and avoid spawning more than one concurrent task. - pub fn refresh_context( - &mut self, - project: Entity, - buffer: Entity, - cursor_position: language::Anchor, - cx: &mut Context, - ) -> Task> { - let Some(zeta_project) = self.projects.get(&project.entity_id()) else { - return Task::ready(anyhow::Ok(())); - }; - - let ContextMode::Agentic(options) = &self.options().context else { - return Task::ready(anyhow::Ok(())); - }; - - let snapshot = buffer.read(cx).snapshot(); - let cursor_point = cursor_position.to_point(&snapshot); - let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer( - cursor_point, - &snapshot, - &options.excerpt, - None, - ) else { - return Task::ready(Ok(())); - }; - - let app_version = AppVersion::global(cx); - let client = self.client.clone(); - let llm_token = self.llm_token.clone(); - let debug_tx = self.debug_tx.clone(); - let current_file_path: Arc = snapshot - .file() - .map(|f| f.full_path(cx).into()) - .unwrap_or_else(|| Path::new("untitled").into()); - - let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt( - predict_edits_v3::PlanContextRetrievalRequest { - excerpt: cursor_excerpt.text(&snapshot).body, - excerpt_path: current_file_path, - excerpt_line_range: cursor_excerpt.line_range, - cursor_file_max_row: Line(snapshot.max_point().row), - events: zeta_project - .events - .iter() - .filter_map(|ev| ev.to_request_event(cx)) - .collect(), - }, - ) { - Ok(prompt) => prompt, - Err(err) => { - return Task::ready(Err(err)); - } - }; - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted( - ZetaContextRetrievalStartedDebugInfo { - project: project.clone(), - timestamp: Instant::now(), - search_prompt: prompt.clone(), - }, - )) - .ok(); - } - - pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| { - let schema = language_model::tool_schema::root_schema_for::( - language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset, - ); - - let description = schema - .get("description") - .and_then(|description| description.as_str()) - .unwrap() - .to_string(); - - (schema.into(), description) - }); - - let (tool_schema, tool_description) = TOOL_SCHEMA.clone(); - - let request = open_ai::Request { - model: CONTEXT_RETRIEVAL_MODEL_ID.clone(), - messages: vec![open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(prompt), - }], - stream: false, - max_completion_tokens: None, - stop: Default::default(), - temperature: 0.7, - tool_choice: None, - parallel_tool_calls: None, - tools: vec![open_ai::ToolDefinition::Function { - function: FunctionDefinition { - name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(), - description: Some(tool_description), - parameters: Some(tool_schema), - }, - }], - prompt_cache_key: None, - reasoning_effort: None, - }; - - #[cfg(feature = "eval-support")] - let eval_cache = self.eval_cache.clone(); - - cx.spawn(async move |this, cx| { - log::trace!("Sending search planning request"); - let response = Self::send_raw_llm_request( - request, - client, - llm_token, - app_version, - #[cfg(feature = "eval-support")] - eval_cache.clone(), - #[cfg(feature = "eval-support")] - EvalCacheEntryKind::Context, - ) - .await; - let mut response = Self::handle_api_response(&this, response, cx)?; - log::trace!("Got search planning response"); - - let choice = response - .choices - .pop() - .context("No choices in retrieval response")?; - let open_ai::RequestMessage::Assistant { - content: _, - tool_calls, - } = choice.message - else { - anyhow::bail!("Retrieval response didn't include an assistant message"); - }; - - let mut queries: Vec = Vec::new(); - for tool_call in tool_calls { - let open_ai::ToolCallContent::Function { function } = tool_call.content; - if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME { - log::warn!( - "Context retrieval response tried to call an unknown tool: {}", - function.name - ); - - continue; - } - - let input: SearchToolInput = serde_json::from_str(&function.arguments) - .with_context(|| format!("invalid search json {}", &function.arguments))?; - queries.extend(input.queries); - } - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated( - ZetaSearchQueryDebugInfo { - project: project.clone(), - timestamp: Instant::now(), - search_queries: queries.clone(), - }, - )) - .ok(); - } - - log::trace!("Running retrieval search: {queries:#?}"); - - let related_excerpts_result = retrieval_search::run_retrieval_searches( - queries, - project.clone(), - #[cfg(feature = "eval-support")] - eval_cache, - cx, - ) - .await; - - log::trace!("Search queries executed"); - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted( - ZetaContextRetrievalDebugInfo { - project: project.clone(), - timestamp: Instant::now(), - }, - )) - .ok(); - } - - this.update(cx, |this, _cx| { - let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else { - return Ok(()); - }; - zeta_project.refresh_context_task.take(); - if let Some(debug_tx) = &this.debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished( - ZetaContextRetrievalDebugInfo { - project, - timestamp: Instant::now(), - }, - )) - .ok(); - } - match related_excerpts_result { - Ok(excerpts) => { - zeta_project.context = Some(excerpts); - Ok(()) - } - Err(error) => Err(error), - } - })? - }) - } - - pub fn set_context( - &mut self, - project: Entity, - context: HashMap, Vec>>, - ) { - if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) { - zeta_project.context = Some(context); - } - } - - fn gather_nearby_diagnostics( - cursor_offset: usize, - diagnostic_sets: &[(LanguageServerId, DiagnosticSet)], - snapshot: &BufferSnapshot, - max_diagnostics_bytes: usize, - ) -> (Vec, bool) { - // TODO: Could make this more efficient - let mut diagnostic_groups = Vec::new(); - for (language_server_id, diagnostics) in diagnostic_sets { - let mut groups = Vec::new(); - diagnostics.groups(*language_server_id, &mut groups, &snapshot); - diagnostic_groups.extend( - groups - .into_iter() - .map(|(_, group)| group.resolve::(&snapshot)), - ); - } - - // sort by proximity to cursor - diagnostic_groups.sort_by_key(|group| { - let range = &group.entries[group.primary_ix].range; - if range.start >= cursor_offset { - range.start - cursor_offset - } else if cursor_offset >= range.end { - cursor_offset - range.end - } else { - (cursor_offset - range.start).min(range.end - cursor_offset) - } - }); - - let mut results = Vec::new(); - let mut diagnostic_groups_truncated = false; - let mut diagnostics_byte_count = 0; - for group in diagnostic_groups { - let raw_value = serde_json::value::to_raw_value(&group).unwrap(); - diagnostics_byte_count += raw_value.get().len(); - if diagnostics_byte_count > max_diagnostics_bytes { - diagnostic_groups_truncated = true; - break; - } - results.push(predict_edits_v3::DiagnosticGroup(raw_value)); - } - - (results, diagnostic_groups_truncated) - } - - // TODO: Dedupe with similar code in request_prediction? - pub fn cloud_request_for_zeta_cli( - &mut self, - project: &Entity, - buffer: &Entity, - position: language::Anchor, - cx: &mut Context, - ) -> Task> { - let project_state = self.projects.get(&project.entity_id()); - - let index_state = project_state.and_then(|state| { - state - .syntax_index - .as_ref() - .map(|index| index.read_with(cx, |index, _cx| index.state().clone())) - }); - let options = self.options.clone(); - let snapshot = buffer.read(cx).snapshot(); - let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else { - return Task::ready(Err(anyhow!("No file path for excerpt"))); - }; - let worktree_snapshots = project - .read(cx) - .worktrees(cx) - .map(|worktree| worktree.read(cx).snapshot()) - .collect::>(); - - let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| { - let mut path = f.worktree.read(cx).absolutize(&f.path); - if path.pop() { Some(path) } else { None } - }); - - cx.background_spawn(async move { - let index_state = if let Some(index_state) = index_state { - Some(index_state.lock_owned().await) - } else { - None - }; - - let cursor_point = position.to_point(&snapshot); - - let debug_info = true; - EditPredictionContext::gather_context( - cursor_point, - &snapshot, - parent_abs_path.as_deref(), - match &options.context { - ContextMode::Agentic(_) => { - // TODO - panic!("Llm mode not supported in zeta cli yet"); - } - ContextMode::Syntax(edit_prediction_context_options) => { - edit_prediction_context_options - } - }, - index_state.as_deref(), - ) - .context("Failed to select excerpt") - .map(|context| { - make_syntax_context_cloud_request( - excerpt_path.into(), - context, - // TODO pass everything - Vec::new(), - false, - Vec::new(), - false, - None, - debug_info, - &worktree_snapshots, - index_state.as_deref(), - Some(options.max_prompt_bytes), - options.prompt_format, - ) - }) - }) - } - - pub fn wait_for_initial_indexing( - &mut self, - project: &Entity, - cx: &mut Context, - ) -> Task> { - let zeta_project = self.get_or_init_zeta_project(project, cx); - if let Some(syntax_index) = &zeta_project.syntax_index { - syntax_index.read(cx).wait_for_initial_file_indexing(cx) - } else { - Task::ready(Ok(())) - } - } -} - -pub fn text_from_response(mut res: open_ai::Response) -> Option { - let choice = res.choices.pop()?; - let output_text = match choice.message { - open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Plain(content)), - .. - } => content, - open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Multipart(mut content)), - .. - } => { - if content.is_empty() { - log::error!("No output from Baseten completion response"); - return None; - } - - match content.remove(0) { - open_ai::MessagePart::Text { text } => text, - open_ai::MessagePart::Image { .. } => { - log::error!("Expected text, got an image"); - return None; - } - } - } - _ => { - log::error!("Invalid response message: {:?}", choice.message); - return None; - } - }; - Some(output_text) -} - -#[derive(Error, Debug)] -#[error( - "You must update to Zed version {minimum_version} or higher to continue using edit predictions." -)] -pub struct ZedUpdateRequiredError { - minimum_version: Version, -} - -fn make_syntax_context_cloud_request( - excerpt_path: Arc, - context: EditPredictionContext, - events: Vec, - can_collect_data: bool, - diagnostic_groups: Vec, - diagnostic_groups_truncated: bool, - git_info: Option, - debug_info: bool, - worktrees: &Vec, - index_state: Option<&SyntaxIndexState>, - prompt_max_bytes: Option, - prompt_format: PromptFormat, -) -> predict_edits_v3::PredictEditsRequest { - let mut signatures = Vec::new(); - let mut declaration_to_signature_index = HashMap::default(); - let mut referenced_declarations = Vec::new(); - - for snippet in context.declarations { - let project_entry_id = snippet.declaration.project_entry_id(); - let Some(path) = worktrees.iter().find_map(|worktree| { - worktree.entry_for_id(project_entry_id).map(|entry| { - let mut full_path = RelPathBuf::new(); - full_path.push(worktree.root_name()); - full_path.push(&entry.path); - full_path - }) - }) else { - continue; - }; - - let parent_index = index_state.and_then(|index_state| { - snippet.declaration.parent().and_then(|parent| { - add_signature( - parent, - &mut declaration_to_signature_index, - &mut signatures, - index_state, - ) - }) - }); - - let (text, text_is_truncated) = snippet.declaration.item_text(); - referenced_declarations.push(predict_edits_v3::ReferencedDeclaration { - path: path.as_std_path().into(), - text: text.into(), - range: snippet.declaration.item_line_range(), - text_is_truncated, - signature_range: snippet.declaration.signature_range_in_item_text(), - parent_index, - signature_score: snippet.score(DeclarationStyle::Signature), - declaration_score: snippet.score(DeclarationStyle::Declaration), - score_components: snippet.components, - }); - } - - let excerpt_parent = index_state.and_then(|index_state| { - context - .excerpt - .parent_declarations - .last() - .and_then(|(parent, _)| { - add_signature( - *parent, - &mut declaration_to_signature_index, - &mut signatures, - index_state, - ) - }) - }); - - predict_edits_v3::PredictEditsRequest { - excerpt_path, - excerpt: context.excerpt_text.body, - excerpt_line_range: context.excerpt.line_range, - excerpt_range: context.excerpt.range, - cursor_point: predict_edits_v3::Point { - line: predict_edits_v3::Line(context.cursor_point.row), - column: context.cursor_point.column, - }, - referenced_declarations, - included_files: vec![], - signatures, - excerpt_parent, - events, - can_collect_data, - diagnostic_groups, - diagnostic_groups_truncated, - git_info, - debug_info, - prompt_max_bytes, - prompt_format, - } -} - -fn add_signature( - declaration_id: DeclarationId, - declaration_to_signature_index: &mut HashMap, - signatures: &mut Vec, - index: &SyntaxIndexState, -) -> Option { - if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) { - return Some(*signature_index); - } - let Some(parent_declaration) = index.declaration(declaration_id) else { - log::error!("bug: missing parent declaration"); - return None; - }; - let parent_index = parent_declaration.parent().and_then(|parent| { - add_signature(parent, declaration_to_signature_index, signatures, index) - }); - let (text, text_is_truncated) = parent_declaration.signature_text(); - let signature_index = signatures.len(); - signatures.push(Signature { - text: text.into(), - text_is_truncated, - parent_index, - range: parent_declaration.signature_line_range(), - }); - declaration_to_signature_index.insert(declaration_id, signature_index); - Some(signature_index) -} - -#[cfg(feature = "eval-support")] -pub type EvalCacheKey = (EvalCacheEntryKind, u64); - -#[cfg(feature = "eval-support")] -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum EvalCacheEntryKind { - Context, - Search, - Prediction, -} - -#[cfg(feature = "eval-support")] -impl std::fmt::Display for EvalCacheEntryKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - EvalCacheEntryKind::Search => write!(f, "search"), - EvalCacheEntryKind::Context => write!(f, "context"), - EvalCacheEntryKind::Prediction => write!(f, "prediction"), - } - } -} - -#[cfg(feature = "eval-support")] -pub trait EvalCache: Send + Sync { - fn read(&self, key: EvalCacheKey) -> Option; - fn write(&self, key: EvalCacheKey, input: &str, value: &str); -} - -#[cfg(test)] -mod tests { - use std::{path::Path, sync::Arc}; - - use client::UserStore; - use clock::FakeSystemClock; - use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; - use futures::{ - AsyncReadExt, StreamExt, - channel::{mpsc, oneshot}, - }; - use gpui::{ - Entity, TestAppContext, - http_client::{FakeHttpClient, Response}, - prelude::*, - }; - use indoc::indoc; - use language::OffsetRangeExt as _; - use open_ai::Usage; - use pretty_assertions::{assert_eq, assert_matches}; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use util::path; - use uuid::Uuid; - - use crate::{BufferEditPrediction, Zeta}; - - #[gpui::test] - async fn test_current_state(cx: &mut TestAppContext) { - let (zeta, mut req_rx) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "1.txt": "Hello!\nHow\nBye\n", - "2.txt": "Hola!\nComo\nAdios\n" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - zeta.update(cx, |zeta, cx| { - zeta.register_project(&project, cx); - }); - - let buffer1 = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/1.txt"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot1.anchor_before(language::Point::new(1, 3)); - - // Prediction for current file - - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) - }); - let (_request, respond_tx) = req_rx.next().await.unwrap(); - - respond_tx - .send(model_response(indoc! {r" - --- a/root/1.txt - +++ b/root/1.txt - @@ ... @@ - Hello! - -How - +How are you? - Bye - "})) - .unwrap(); - - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - let prediction = zeta - .current_prediction_for_buffer(&buffer1, &project, cx) - .unwrap(); - assert_matches!(prediction, BufferEditPrediction::Local { .. }); - }); - - // Context refresh - let refresh_task = zeta.update(cx, |zeta, cx| { - zeta.refresh_context(project.clone(), buffer1.clone(), position, cx) - }); - let (_request, respond_tx) = req_rx.next().await.unwrap(); - respond_tx - .send(open_ai::Response { - id: Uuid::new_v4().to_string(), - object: "response".into(), - created: 0, - model: "model".into(), - choices: vec![open_ai::Choice { - index: 0, - message: open_ai::RequestMessage::Assistant { - content: None, - tool_calls: vec![open_ai::ToolCall { - id: "search".into(), - content: open_ai::ToolCallContent::Function { - function: open_ai::FunctionContent { - name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME - .to_string(), - arguments: serde_json::to_string(&SearchToolInput { - queries: Box::new([SearchToolQuery { - glob: "root/2.txt".to_string(), - syntax_node: vec![], - content: Some(".".into()), - }]), - }) - .unwrap(), - }, - }, - }], - }, - finish_reason: None, - }], - usage: Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, - }) - .unwrap(); - refresh_task.await.unwrap(); - - zeta.update(cx, |zeta, _cx| { - zeta.discard_current_prediction(&project); - }); - - // Prediction for another file - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) - }); - let (_request, respond_tx) = req_rx.next().await.unwrap(); - respond_tx - .send(model_response(indoc! {r#" - --- a/root/2.txt - +++ b/root/2.txt - Hola! - -Como - +Como estas? - Adios - "#})) - .unwrap(); - cx.run_until_parked(); - - zeta.read_with(cx, |zeta, cx| { - let prediction = zeta - .current_prediction_for_buffer(&buffer1, &project, cx) - .unwrap(); - assert_matches!( - prediction, - BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt")) - ); - }); - - let buffer2 = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/2.txt"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - - zeta.read_with(cx, |zeta, cx| { - let prediction = zeta - .current_prediction_for_buffer(&buffer2, &project, cx) - .unwrap(); - assert_matches!(prediction, BufferEditPrediction::Local { .. }); - }); - } - - #[gpui::test] - async fn test_simple_request(cx: &mut TestAppContext) { - let (zeta, mut req_rx) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "foo.md": "Hello!\nHow\nBye\n" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - let buffer = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot.anchor_before(language::Point::new(1, 3)); - - let prediction_task = zeta.update(cx, |zeta, cx| { - zeta.request_prediction(&project, &buffer, position, cx) - }); - - let (_, respond_tx) = req_rx.next().await.unwrap(); - - // TODO Put back when we have a structured request again - // assert_eq!( - // request.excerpt_path.as_ref(), - // Path::new(path!("root/foo.md")) - // ); - // assert_eq!( - // request.cursor_point, - // Point { - // line: Line(1), - // column: 3 - // } - // ); - - respond_tx - .send(model_response(indoc! { r" - --- a/root/foo.md - +++ b/root/foo.md - @@ ... @@ - Hello! - -How - +How are you? - Bye - "})) - .unwrap(); - - let prediction = prediction_task.await.unwrap().unwrap(); - - assert_eq!(prediction.edits.len(), 1); - assert_eq!( - prediction.edits[0].0.to_point(&snapshot).start, - language::Point::new(1, 3) - ); - assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); - } - - #[gpui::test] - async fn test_request_events(cx: &mut TestAppContext) { - let (zeta, mut req_rx) = init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "foo.md": "Hello!\n\nBye\n" - }), - ) - .await; - let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - let buffer = project - .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - project.open_buffer(path, cx) - }) - .await - .unwrap(); - - zeta.update(cx, |zeta, cx| { - zeta.register_buffer(&buffer, &project, cx); - }); - - buffer.update(cx, |buffer, cx| { - buffer.edit(vec![(7..7, "How")], None, cx); - }); - - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let position = snapshot.anchor_before(language::Point::new(1, 3)); - - let prediction_task = zeta.update(cx, |zeta, cx| { - zeta.request_prediction(&project, &buffer, position, cx) - }); - - let (request, respond_tx) = req_rx.next().await.unwrap(); - - let prompt = prompt_from_request(&request); - assert!( - prompt.contains(indoc! {" - --- a/root/foo.md - +++ b/root/foo.md - @@ -1,3 +1,3 @@ - Hello! - - - +How - Bye - "}), - "{prompt}" - ); - - respond_tx - .send(model_response(indoc! {r#" - --- a/root/foo.md - +++ b/root/foo.md - @@ ... @@ - Hello! - -How - +How are you? - Bye - "#})) - .unwrap(); - - let prediction = prediction_task.await.unwrap().unwrap(); - - assert_eq!(prediction.edits.len(), 1); - assert_eq!( - prediction.edits[0].0.to_point(&snapshot).start, - language::Point::new(1, 3) - ); - assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); - } - - // Skipped until we start including diagnostics in prompt - // #[gpui::test] - // async fn test_request_diagnostics(cx: &mut TestAppContext) { - // let (zeta, mut req_rx) = init_test(cx); - // let fs = FakeFs::new(cx.executor()); - // fs.insert_tree( - // "/root", - // json!({ - // "foo.md": "Hello!\nBye" - // }), - // ) - // .await; - // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; - - // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap(); - // let diagnostic = lsp::Diagnostic { - // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), - // severity: Some(lsp::DiagnosticSeverity::ERROR), - // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(), - // ..Default::default() - // }; - - // project.update(cx, |project, cx| { - // project.lsp_store().update(cx, |lsp_store, cx| { - // // Create some diagnostics - // lsp_store - // .update_diagnostics( - // LanguageServerId(0), - // lsp::PublishDiagnosticsParams { - // uri: path_to_buffer_uri.clone(), - // diagnostics: vec![diagnostic], - // version: None, - // }, - // None, - // language::DiagnosticSourceKind::Pushed, - // &[], - // cx, - // ) - // .unwrap(); - // }); - // }); - - // let buffer = project - // .update(cx, |project, cx| { - // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); - // project.open_buffer(path, cx) - // }) - // .await - // .unwrap(); - - // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - // let position = snapshot.anchor_before(language::Point::new(0, 0)); - - // let _prediction_task = zeta.update(cx, |zeta, cx| { - // zeta.request_prediction(&project, &buffer, position, cx) - // }); - - // let (request, _respond_tx) = req_rx.next().await.unwrap(); - - // assert_eq!(request.diagnostic_groups.len(), 1); - // let value = serde_json::from_str::(request.diagnostic_groups[0].0.get()) - // .unwrap(); - // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3 - // assert_eq!( - // value, - // json!({ - // "entries": [{ - // "range": { - // "start": 8, - // "end": 10 - // }, - // "diagnostic": { - // "source": null, - // "code": null, - // "code_description": null, - // "severity": 1, - // "message": "\"Hello\" deprecated. Use \"Hi\" instead", - // "markdown": null, - // "group_id": 0, - // "is_primary": true, - // "is_disk_based": false, - // "is_unnecessary": false, - // "source_kind": "Pushed", - // "data": null, - // "underline": true - // } - // }], - // "primary_ix": 0 - // }) - // ); - // } - - fn model_response(text: &str) -> open_ai::Response { - open_ai::Response { - id: Uuid::new_v4().to_string(), - object: "response".into(), - created: 0, - model: "model".into(), - choices: vec![open_ai::Choice { - index: 0, - message: open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Plain(text.to_string())), - tool_calls: vec![], - }, - finish_reason: None, - }], - usage: Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, - } - } - - fn prompt_from_request(request: &open_ai::Request) -> &str { - assert_eq!(request.messages.len(), 1); - let open_ai::RequestMessage::User { - content: open_ai::MessageContent::Plain(content), - .. - } = &request.messages[0] - else { - panic!( - "Request does not have single user message of type Plain. {:#?}", - request - ); - }; - content - } - - fn init_test( - cx: &mut TestAppContext, - ) -> ( - Entity, - mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender)>, - ) { - cx.update(move |cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - zlog::init_test(); - - let (req_tx, req_rx) = mpsc::unbounded(); - - let http_client = FakeHttpClient::create({ - move |req| { - let uri = req.uri().path().to_string(); - let mut body = req.into_body(); - let req_tx = req_tx.clone(); - async move { - let resp = match uri.as_str() { - "/client/llm_tokens" => serde_json::to_string(&json!({ - "token": "test" - })) - .unwrap(), - "/predict_edits/raw" => { - let mut buf = Vec::new(); - body.read_to_end(&mut buf).await.ok(); - let req = serde_json::from_slice(&buf).unwrap(); - - let (res_tx, res_rx) = oneshot::channel(); - req_tx.unbounded_send((req, res_tx)).unwrap(); - serde_json::to_string(&res_rx.await?).unwrap() - } - _ => { - panic!("Unexpected path: {}", uri) - } - }; - - Ok(Response::builder().body(resp.into()).unwrap()) - } - } - }); - - let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx); - client.cloud_client().set_credentials(1, "test".into()); - - language_model::init(client.clone(), cx); - - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = Zeta::global(&client, &user_store, cx); - - (zeta, req_rx) - }) - } -} diff --git a/crates/zeta2_tools/Cargo.toml b/crates/zeta2_tools/Cargo.toml index 3a9b1ccbf9340dfdaa06030e59c2112b9cda6307..607e24c895d96de1464ff1bfa2a4dfa01c5d9669 100644 --- a/crates/zeta2_tools/Cargo.toml +++ b/crates/zeta2_tools/Cargo.toml @@ -13,7 +13,6 @@ path = "src/zeta2_tools.rs" [dependencies] anyhow.workspace = true -chrono.workspace = true client.workspace = true cloud_llm_client.workspace = true cloud_zeta2_prompt.workspace = true @@ -24,9 +23,7 @@ feature_flags.workspace = true futures.workspace = true gpui.workspace = true language.workspace = true -log.workspace = true multi_buffer.workspace = true -ordered-float.workspace = true project.workspace = true serde.workspace = true serde_json.workspace = true @@ -36,7 +33,7 @@ ui.workspace = true ui_input.workspace = true util.workspace = true workspace.workspace = true -zeta2.workspace = true +zeta.workspace = true [dev-dependencies] clap.workspace = true diff --git a/crates/zeta2_tools/src/zeta2_context_view.rs b/crates/zeta2_tools/src/zeta2_context_view.rs index 759d0d0a3da1adbd9e61fa05b5d305ca9de1f823..54f1ea2d813f7c00d30b12e341fb3e5ac3f155dc 100644 --- a/crates/zeta2_tools/src/zeta2_context_view.rs +++ b/crates/zeta2_tools/src/zeta2_context_view.rs @@ -25,7 +25,7 @@ use ui::{ v_flex, }; use workspace::Item; -use zeta2::{ +use zeta::{ Zeta, ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo, ZetaSearchQueryDebugInfo, }; diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 8758857e7cf50d6a5f2e5a4ea509293b18a8cb2c..6a6268f68ad0fa10e2379ac21e07d4fa530dddc1 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -1,30 +1,26 @@ mod zeta2_context_view; -use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc}; +use std::{str::FromStr, sync::Arc, time::Duration}; -use chrono::TimeDelta; use client::{Client, UserStore}; -use cloud_llm_client::predict_edits_v3::{ - DeclarationScoreComponents, PredictEditsRequest, PromptFormat, -}; +use cloud_llm_client::predict_edits_v3::PromptFormat; use collections::HashMap; -use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer}; +use editor::{Editor, EditorEvent, EditorMode, MultiBuffer}; use feature_flags::FeatureFlagAppExt as _; use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared}; use gpui::{ - CursorStyle, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, - WeakEntity, actions, prelude::*, + Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions, + prelude::*, }; -use language::{Buffer, DiskState}; -use ordered_float::OrderedFloat; -use project::{Project, WorktreeId, telemetry_snapshot::TelemetrySnapshot}; +use language::Buffer; +use project::{Project, telemetry_snapshot::TelemetrySnapshot}; use ui::{ButtonLike, ContextMenu, ContextMenuEntry, DropdownMenu, KeyBinding, prelude::*}; use ui_input::InputField; -use util::{ResultExt, paths::PathStyle, rel_path::RelPath}; +use util::ResultExt; use workspace::{Item, SplitDirection, Workspace}; -use zeta2::{ - AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, Zeta, Zeta2FeatureFlag, - ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions, +use zeta::{ + AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, EditPredictionInputs, Zeta, + Zeta2FeatureFlag, ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions, }; use edit_prediction_context::{EditPredictionContextOptions, EditPredictionExcerptOptions}; @@ -99,7 +95,6 @@ pub struct Zeta2Inspector { cursor_context_ratio_input: Entity, max_prompt_bytes_input: Entity, context_mode: ContextModeState, - active_view: ActiveView, zeta: Entity, _active_editor_subscription: Option, _update_state_task: Task<()>, @@ -113,21 +108,14 @@ pub enum ContextModeState { }, } -#[derive(PartialEq)] -enum ActiveView { - Context, - Inference, -} - struct LastPrediction { - context_editor: Entity, prompt_editor: Entity, - retrieval_time: TimeDelta, - request_time: Option, + retrieval_time: Duration, + request_time: Option, buffer: WeakEntity, position: language::Anchor, state: LastPredictionState, - request: PredictEditsRequest, + inputs: EditPredictionInputs, project_snapshot: Shared>>, _task: Option>, } @@ -175,7 +163,6 @@ impl Zeta2Inspector { focus_handle: cx.focus_handle(), project: project.clone(), last_prediction: None, - active_view: ActiveView::Inference, max_excerpt_bytes_input: Self::number_input("Max Excerpt Bytes", window, cx), min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx), cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx), @@ -305,7 +292,7 @@ impl Zeta2Inspector { ContextMode::Syntax(context_options) => { let max_retrieved_declarations = match &this.context_mode { ContextModeState::Llm => { - zeta2::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations + zeta::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations } ContextModeState::Syntax { max_retrieved_declarations, @@ -340,22 +327,10 @@ impl Zeta2Inspector { fn update_last_prediction( &mut self, - prediction: zeta2::ZetaDebugInfo, + prediction: zeta::ZetaDebugInfo, window: &mut Window, cx: &mut Context, ) { - let project = self.project.read(cx); - let path_style = project.path_style(cx); - let Some(worktree_id) = project - .worktrees(cx) - .next() - .map(|worktree| worktree.read(cx).id()) - else { - log::error!("Open a worktree to use edit prediction debug view"); - self.last_prediction.take(); - return; - }; - self._update_state_task = cx.spawn_in(window, { let language_registry = self.project.read(cx).languages().clone(); async move |this, cx| { @@ -364,11 +339,10 @@ impl Zeta2Inspector { return; }; for ext in prediction - .request - .referenced_declarations + .inputs + .included_files .iter() - .filter_map(|snippet| snippet.path.extension()) - .chain(prediction.request.excerpt_path.extension()) + .filter_map(|file| file.path.extension()) { if !languages.contains_key(ext) { // Most snippets are gonna be the same language, @@ -391,90 +365,6 @@ impl Zeta2Inspector { let json_language = language_registry.language_for_name("Json").await.log_err(); this.update_in(cx, |this, window, cx| { - let context_editor = cx.new(|cx| { - let mut excerpt_score_components = HashMap::default(); - - let multibuffer = cx.new(|cx| { - let mut multibuffer = MultiBuffer::new(language::Capability::ReadOnly); - let excerpt_file = Arc::new(ExcerptMetadataFile { - title: RelPath::unix("Cursor Excerpt").unwrap().into(), - path_style, - worktree_id, - }); - - let excerpt_buffer = cx.new(|cx| { - let mut buffer = - Buffer::local(prediction.request.excerpt.clone(), cx); - if let Some(language) = prediction - .request - .excerpt_path - .extension() - .and_then(|ext| languages.get(ext)) - { - buffer.set_language(language.clone(), cx); - } - buffer.file_updated(excerpt_file, cx); - buffer - }); - - multibuffer.push_excerpts( - excerpt_buffer, - [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)], - cx, - ); - - let mut declarations = - prediction.request.referenced_declarations.clone(); - declarations.sort_unstable_by_key(|declaration| { - Reverse(OrderedFloat(declaration.declaration_score)) - }); - - for snippet in &declarations { - let snippet_file = Arc::new(ExcerptMetadataFile { - title: RelPath::unix(&format!( - "{} (Score: {})", - snippet.path.display(), - snippet.declaration_score - )) - .unwrap() - .into(), - path_style, - worktree_id, - }); - - let excerpt_buffer = cx.new(|cx| { - let mut buffer = Buffer::local(snippet.text.clone(), cx); - buffer.file_updated(snippet_file, cx); - if let Some(ext) = snippet.path.extension() - && let Some(language) = languages.get(ext) - { - buffer.set_language(language.clone(), cx); - } - buffer - }); - - let excerpt_ids = multibuffer.push_excerpts( - excerpt_buffer, - [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)], - cx, - ); - let excerpt_id = excerpt_ids.first().unwrap(); - - excerpt_score_components - .insert(*excerpt_id, snippet.score_components.clone()); - } - - multibuffer - }); - - let mut editor = - Editor::new(EditorMode::full(), multibuffer, None, window, cx); - editor.register_addon(ZetaContextAddon { - excerpt_score_components, - }); - editor - }); - let ZetaEditPredictionDebugInfo { response_rx, position, @@ -606,7 +496,6 @@ impl Zeta2Inspector { let project_snapshot_task = TelemetrySnapshot::new(&this.project, cx); this.last_prediction = Some(LastPrediction { - context_editor, prompt_editor: cx.new(|cx| { let buffer = cx.new(|cx| { let mut buffer = @@ -632,7 +521,7 @@ impl Zeta2Inspector { .foreground_executor() .spawn(async move { Arc::new(project_snapshot_task.await) }) .shared(), - request: prediction.request, + inputs: prediction.inputs, _task: Some(task), }); cx.notify(); @@ -664,9 +553,6 @@ impl Zeta2Inspector { let Some(last_prediction) = self.last_prediction.as_mut() else { return; }; - if !last_prediction.request.can_collect_data { - return; - } let project_snapshot_task = last_prediction.project_snapshot.clone(); @@ -718,7 +604,7 @@ impl Zeta2Inspector { id = request_id, kind = kind, text = text, - request = last_prediction.request, + request = last_prediction.inputs, project_snapshot = project_snapshot, ); }) @@ -727,17 +613,6 @@ impl Zeta2Inspector { .detach(); } - fn focus_feedback(&mut self, window: &mut Window, cx: &mut Context) { - if let Some(last_prediction) = self.last_prediction.as_mut() { - if let LastPredictionState::Success { - feedback_editor, .. - } = &mut last_prediction.state - { - feedback_editor.focus_handle(cx).focus(window); - } - }; - } - fn render_options(&self, window: &mut Window, cx: &mut Context) -> Div { v_flex() .gap_2() @@ -747,11 +622,11 @@ impl Zeta2Inspector { .justify_between() .child( ui::Button::new("reset-options", "Reset") - .disabled(self.zeta.read(cx).options() == &zeta2::DEFAULT_OPTIONS) + .disabled(self.zeta.read(cx).options() == &zeta::DEFAULT_OPTIONS) .style(ButtonStyle::Outlined) .size(ButtonSize::Large) .on_click(cx.listener(|this, _, window, cx| { - this.set_options_state(&zeta2::DEFAULT_OPTIONS, window, cx); + this.set_options_state(&zeta::DEFAULT_OPTIONS, window, cx); })), ), ) @@ -915,42 +790,6 @@ impl Zeta2Inspector { ) } - fn render_tabs(&self, cx: &mut Context) -> Option { - if self.last_prediction.is_none() { - return None; - }; - - Some( - ui::ToggleButtonGroup::single_row( - "prediction", - [ - ui::ToggleButtonSimple::new( - "Context", - cx.listener(|this, _, _, cx| { - this.active_view = ActiveView::Context; - cx.notify(); - }), - ), - ui::ToggleButtonSimple::new( - "Inference", - cx.listener(|this, _, window, cx| { - this.active_view = ActiveView::Inference; - this.focus_feedback(window, cx); - cx.notify(); - }), - ), - ], - ) - .style(ui::ToggleButtonGroupStyle::Outlined) - .selected_index(if self.active_view == ActiveView::Context { - 0 - } else { - 1 - }) - .into_any_element(), - ) - } - fn render_stats(&self) -> Option
{ let Some(prediction) = self.last_prediction.as_ref() else { return None; @@ -970,15 +809,15 @@ impl Zeta2Inspector { ) } - fn render_duration(name: &'static str, time: Option) -> Div { + fn render_duration(name: &'static str, time: Option) -> Div { h_flex() .gap_1() .child(Label::new(name).color(Color::Muted).size(LabelSize::Small)) .child(match time { - Some(time) => Label::new(if time.num_microseconds().unwrap_or(0) >= 1000 { - format!("{} ms", time.num_milliseconds()) + Some(time) => Label::new(if time.as_micros() >= 1000 { + format!("{} ms", time.as_millis()) } else { - format!("{} µs", time.num_microseconds().unwrap_or(0)) + format!("{} µs", time.as_micros()) }) .size(LabelSize::Small), None => Label::new("...").size(LabelSize::Small), @@ -1006,144 +845,135 @@ impl Zeta2Inspector { } fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context) -> Div { - match &self.active_view { - ActiveView::Context => div().size_full().child(prediction.context_editor.clone()), - ActiveView::Inference => h_flex() - .items_start() - .w_full() - .flex_1() - .border_t_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().editor_background) - .child( - v_flex() - .flex_1() - .gap_2() - .p_4() - .h_full() - .child( - h_flex() - .justify_between() - .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall)) - .child(match prediction.state { - LastPredictionState::Requested - | LastPredictionState::Failed { .. } => ui::Chip::new("Local") - .bg_color(cx.theme().status().warning_background) - .label_color(Color::Success), - LastPredictionState::Success { .. } => ui::Chip::new("Cloud") - .bg_color(cx.theme().status().success_background) - .label_color(Color::Success), - }), - ) - .child(prediction.prompt_editor.clone()), - ) - .child(ui::vertical_divider()) - .child( - v_flex() - .flex_1() - .gap_2() - .h_full() - .child( + h_flex() + .items_start() + .w_full() + .flex_1() + .border_t_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().editor_background) + .child( + v_flex() + .flex_1() + .gap_2() + .p_4() + .h_full() + .child( + h_flex() + .justify_between() + .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall)) + .child(match prediction.state { + LastPredictionState::Requested + | LastPredictionState::Failed { .. } => ui::Chip::new("Local") + .bg_color(cx.theme().status().warning_background) + .label_color(Color::Success), + LastPredictionState::Success { .. } => ui::Chip::new("Cloud") + .bg_color(cx.theme().status().success_background) + .label_color(Color::Success), + }), + ) + .child(prediction.prompt_editor.clone()), + ) + .child(ui::vertical_divider()) + .child( + v_flex() + .flex_1() + .gap_2() + .h_full() + .child( + v_flex() + .flex_1() + .gap_2() + .p_4() + .child( + ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall), + ) + .child(match &prediction.state { + LastPredictionState::Success { + model_response_editor, + .. + } => model_response_editor.clone().into_any_element(), + LastPredictionState::Requested => v_flex() + .gap_2() + .child(Label::new("Loading...").buffer_font(cx)) + .into_any_element(), + LastPredictionState::Failed { message } => v_flex() + .gap_2() + .max_w_96() + .child(Label::new(message.clone()).buffer_font(cx)) + .into_any_element(), + }), + ) + .child(ui::divider()) + .child( + if let LastPredictionState::Success { + feedback_editor, + feedback: feedback_state, + .. + } = &prediction.state + { v_flex() - .flex_1() + .key_context("Zeta2Feedback") + .on_action(cx.listener(Self::handle_rate_positive)) + .on_action(cx.listener(Self::handle_rate_negative)) .gap_2() - .p_4() + .p_2() + .child(feedback_editor.clone()) .child( - ui::Headline::new("Model Response") - .size(ui::HeadlineSize::XSmall), - ) - .child(match &prediction.state { - LastPredictionState::Success { - model_response_editor, - .. - } => model_response_editor.clone().into_any_element(), - LastPredictionState::Requested => v_flex() - .gap_2() - .child(Label::new("Loading...").buffer_font(cx)) - .into_any_element(), - LastPredictionState::Failed { message } => v_flex() - .gap_2() - .max_w_96() - .child(Label::new(message.clone()).buffer_font(cx)) - .into_any_element(), - }), - ) - .child(ui::divider()) - .child( - if prediction.request.can_collect_data - && let LastPredictionState::Success { - feedback_editor, - feedback: feedback_state, - .. - } = &prediction.state - { - v_flex() - .key_context("Zeta2Feedback") - .on_action(cx.listener(Self::handle_rate_positive)) - .on_action(cx.listener(Self::handle_rate_negative)) - .gap_2() - .p_2() - .child(feedback_editor.clone()) - .child( - h_flex() - .justify_end() - .w_full() - .child( - ButtonLike::new("rate-positive") - .when( - *feedback_state == Some(Feedback::Positive), - |this| this.style(ButtonStyle::Filled), - ) - .child( - KeyBinding::for_action( - &Zeta2RatePredictionPositive, - cx, - ) - .size(TextSize::Small.rems(cx)), - ) - .child(ui::Icon::new(ui::IconName::ThumbsUp)) - .on_click(cx.listener( - |this, _, window, cx| { - this.handle_rate_positive( - &Zeta2RatePredictionPositive, - window, - cx, - ); - }, - )), - ) - .child( - ButtonLike::new("rate-negative") - .when( - *feedback_state == Some(Feedback::Negative), - |this| this.style(ButtonStyle::Filled), + h_flex() + .justify_end() + .w_full() + .child( + ButtonLike::new("rate-positive") + .when( + *feedback_state == Some(Feedback::Positive), + |this| this.style(ButtonStyle::Filled), + ) + .child( + KeyBinding::for_action( + &Zeta2RatePredictionPositive, + cx, ) - .child( - KeyBinding::for_action( - &Zeta2RatePredictionNegative, - cx, - ) - .size(TextSize::Small.rems(cx)), + .size(TextSize::Small.rems(cx)), + ) + .child(ui::Icon::new(ui::IconName::ThumbsUp)) + .on_click(cx.listener(|this, _, window, cx| { + this.handle_rate_positive( + &Zeta2RatePredictionPositive, + window, + cx, + ); + })), + ) + .child( + ButtonLike::new("rate-negative") + .when( + *feedback_state == Some(Feedback::Negative), + |this| this.style(ButtonStyle::Filled), + ) + .child( + KeyBinding::for_action( + &Zeta2RatePredictionNegative, + cx, ) - .child(ui::Icon::new(ui::IconName::ThumbsDown)) - .on_click(cx.listener( - |this, _, window, cx| { - this.handle_rate_negative( - &Zeta2RatePredictionNegative, - window, - cx, - ); - }, - )), - ), - ) - .into_any() - } else { - Empty.into_any_element() - }, - ), - ), - } + .size(TextSize::Small.rems(cx)), + ) + .child(ui::Icon::new(ui::IconName::ThumbsDown)) + .on_click(cx.listener(|this, _, window, cx| { + this.handle_rate_negative( + &Zeta2RatePredictionNegative, + window, + cx, + ); + })), + ), + ) + .into_any() + } else { + Empty.into_any_element() + }, + ), + ) } } @@ -1178,8 +1008,7 @@ impl Render for Zeta2Inspector { .h_full() .justify_between() .child(self.render_options(window, cx)) - .gap_4() - .children(self.render_tabs(cx)), + .gap_4(), ) .child(ui::vertical_divider()) .children(self.render_stats()), @@ -1187,104 +1016,3 @@ impl Render for Zeta2Inspector { .child(self.render_content(window, cx)) } } - -// Using same approach as commit view - -struct ExcerptMetadataFile { - title: Arc, - worktree_id: WorktreeId, - path_style: PathStyle, -} - -impl language::File for ExcerptMetadataFile { - fn as_local(&self) -> Option<&dyn language::LocalFile> { - None - } - - fn disk_state(&self) -> DiskState { - DiskState::New - } - - fn path(&self) -> &Arc { - &self.title - } - - fn full_path(&self, _: &App) -> PathBuf { - self.title.as_std_path().to_path_buf() - } - - fn file_name<'a>(&'a self, _: &'a App) -> &'a str { - self.title.file_name().unwrap() - } - - fn path_style(&self, _: &App) -> PathStyle { - self.path_style - } - - fn worktree_id(&self, _: &App) -> WorktreeId { - self.worktree_id - } - - fn to_proto(&self, _: &App) -> language::proto::File { - unimplemented!() - } - - fn is_private(&self) -> bool { - false - } -} - -struct ZetaContextAddon { - excerpt_score_components: HashMap, -} - -impl editor::Addon for ZetaContextAddon { - fn to_any(&self) -> &dyn std::any::Any { - self - } - - fn render_buffer_header_controls( - &self, - excerpt_info: &multi_buffer::ExcerptInfo, - _window: &Window, - _cx: &App, - ) -> Option { - let score_components = self.excerpt_score_components.get(&excerpt_info.id)?.clone(); - - Some( - div() - .id(excerpt_info.id.to_proto() as usize) - .child(ui::Icon::new(IconName::Info)) - .cursor(CursorStyle::PointingHand) - .tooltip(move |_, cx| { - cx.new(|_| ScoreComponentsTooltip::new(&score_components)) - .into() - }) - .into_any(), - ) - } -} - -struct ScoreComponentsTooltip { - text: SharedString, -} - -impl ScoreComponentsTooltip { - fn new(components: &DeclarationScoreComponents) -> Self { - Self { - text: format!("{:#?}", components).into(), - } - } -} - -impl Render for ScoreComponentsTooltip { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - div().pl_2().pt_2p5().child( - div() - .elevation_2(cx) - .py_1() - .px_2() - .child(ui::Label::new(self.text.clone()).buffer_font(cx)), - ) - } -} diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml index e18cf54787ca98e2be60db4977dd2de18e9c09e2..2dbca537f55377e84f306e13649dfb71ccf2f181 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/zeta_cli/Cargo.toml @@ -53,8 +53,7 @@ terminal_view.workspace = true toml.workspace = true util.workspace = true watch.workspace = true -zeta.workspace = true -zeta2 = { workspace = true, features = ["eval-support"] } +zeta = { workspace = true, features = ["eval-support"] } zlog.workspace = true [dev-dependencies] diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index a9d7acaee2287450eac828bd2d770b88a8150940..a0ebdf998595ccacec2dafecf51b6094e5e401b5 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -9,7 +9,7 @@ use collections::HashSet; use gpui::{AsyncApp, Entity}; use project::Project; use util::ResultExt as _; -use zeta2::{Zeta, udiff::DiffLine}; +use zeta::{Zeta, udiff::DiffLine}; use crate::{ EvaluateArguments, PredictionOptions, diff --git a/crates/zeta_cli/src/example.rs b/crates/zeta_cli/src/example.rs index 67eed23f90dc1a5b48a53a2a7de07f500396ba9f..7dbe304a88b9ea024adab793fa782fd2f4bdf1c0 100644 --- a/crates/zeta_cli/src/example.rs +++ b/crates/zeta_cli/src/example.rs @@ -26,7 +26,7 @@ use project::{Project, ProjectPath}; use pulldown_cmark::CowStr; use serde::{Deserialize, Serialize}; use util::{paths::PathStyle, rel_path::RelPath}; -use zeta2::udiff::OpenedBuffers; +use zeta::udiff::OpenedBuffers; use crate::paths::{REPOS_DIR, WORKTREES_DIR}; @@ -557,7 +557,7 @@ impl NamedExample { project: &Entity, cx: &mut AsyncApp, ) -> Result> { - zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await + zeta::udiff::apply_diff(&self.example.edit_history, project, cx).await } } diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 914b141915cd3a89cd35a02bc6c9463094f0de96..f87563cc34ca7631baf8195e42e4e3473f522659 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -31,7 +31,7 @@ use serde_json::json; use std::io::{self}; use std::time::Duration; use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc}; -use zeta2::ContextMode; +use zeta::ContextMode; #[derive(Parser, Debug)] #[command(name = "zeta")] @@ -193,13 +193,14 @@ pub struct EvaluateArguments { #[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)] enum PredictionProvider { + Zeta1, #[default] Zeta2, Sweep, } -fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions { - zeta2::ZetaOptions { +fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions { + zeta::ZetaOptions { context: ContextMode::Syntax(EditPredictionContextOptions { max_retrieved_declarations: args.max_retrieved_definitions, use_imports: !args.disable_imports_gathering, @@ -397,7 +398,7 @@ async fn zeta2_syntax_context( let output = cx .update(|cx| { let zeta = cx.new(|cx| { - zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx) + zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx) }); let indexing_done_task = zeta.update(cx, |zeta, cx| { zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true)); @@ -435,7 +436,7 @@ async fn zeta1_context( args: ContextArgs, app_state: &Arc, cx: &mut AsyncApp, -) -> Result { +) -> Result { let LoadedContext { full_path_str, snapshot, @@ -450,7 +451,7 @@ async fn zeta1_context( let prompt_for_events = move || (events, 0); cx.update(|cx| { - zeta::gather_context( + zeta::zeta1::gather_context( full_path_str, &snapshot, clipped_cursor, diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index c792b318cec6de42e518793ed5400df0010ae5ea..a757a5faa0dbae95c4dcab58c76d50450b1d2e9f 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -21,7 +21,7 @@ use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; use std::time::{Duration, Instant}; -use zeta2::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta}; +use zeta::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta}; pub async fn run_predict( args: PredictArguments, @@ -47,12 +47,13 @@ pub fn setup_zeta( cx: &mut AsyncApp, ) -> Result> { let zeta = - cx.new(|cx| zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?; + cx.new(|cx| zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?; zeta.update(cx, |zeta, _cx| { let model = match provider { - PredictionProvider::Zeta2 => zeta2::ZetaEditPredictionModel::ZedCloud, - PredictionProvider::Sweep => zeta2::ZetaEditPredictionModel::Sweep, + PredictionProvider::Zeta1 => zeta::ZetaEditPredictionModel::Zeta1, + PredictionProvider::Zeta2 => zeta::ZetaEditPredictionModel::Zeta2, + PredictionProvider::Sweep => zeta::ZetaEditPredictionModel::Sweep, }; zeta.set_edit_prediction_model(model); })?; @@ -142,25 +143,25 @@ pub async fn perform_predict( let mut search_queries_executed_at = None; while let Some(event) = debug_rx.next().await { match event { - zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { + zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => { start_time = Some(info.timestamp); fs::write( example_run_dir.join("search_prompt.md"), &info.search_prompt, )?; } - zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { + zeta::ZetaDebugInfo::SearchQueriesGenerated(info) => { search_queries_generated_at = Some(info.timestamp); fs::write( example_run_dir.join("search_queries.json"), serde_json::to_string_pretty(&info.search_queries).unwrap(), )?; } - zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => { + zeta::ZetaDebugInfo::SearchQueriesExecuted(info) => { search_queries_executed_at = Some(info.timestamp); } - zeta2::ZetaDebugInfo::ContextRetrievalFinished(_info) => {} - zeta2::ZetaDebugInfo::EditPredictionRequested(request) => { + zeta::ZetaDebugInfo::ContextRetrievalFinished(_info) => {} + zeta::ZetaDebugInfo::EditPredictionRequested(request) => { let prediction_started_at = Instant::now(); start_time.get_or_insert(prediction_started_at); let prompt = request.local_prompt.unwrap_or_default(); @@ -170,9 +171,9 @@ pub async fn perform_predict( let mut result = result.lock().unwrap(); result.prompt_len = prompt.chars().count(); - for included_file in request.request.included_files { + for included_file in request.inputs.included_files { let insertions = - vec![(request.request.cursor_point, CURSOR_MARKER)]; + vec![(request.inputs.cursor_point, CURSOR_MARKER)]; result.excerpts.extend(included_file.excerpts.iter().map( |excerpt| ActualExcerpt { path: included_file.path.components().skip(1).collect(), @@ -182,7 +183,7 @@ pub async fn perform_predict( write_codeblock( &included_file.path, included_file.excerpts.iter(), - if included_file.path == request.request.excerpt_path { + if included_file.path == request.inputs.cursor_path { &insertions } else { &[] @@ -196,7 +197,7 @@ pub async fn perform_predict( let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?; - let response = zeta2::text_from_response(response).unwrap_or_default(); + let response = zeta::text_from_response(response).unwrap_or_default(); let prediction_finished_at = Instant::now(); fs::write(example_run_dir.join("prediction_response.md"), &response)?; @@ -267,20 +268,7 @@ pub async fn perform_predict( let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap(); result.diff = prediction - .map(|prediction| { - let old_text = prediction.snapshot.text(); - let new_text = prediction - .buffer - .update(cx, |buffer, cx| { - let branch = buffer.branch(cx); - branch.update(cx, |branch, cx| { - branch.edit(prediction.edits.iter().cloned(), None, cx); - branch.text() - }) - }) - .unwrap(); - language::unified_diff(&old_text, &new_text) - }) + .and_then(|prediction| prediction.edit_preview.as_unified_diff(&prediction.edits)) .unwrap_or_default(); anyhow::Ok(result) diff --git a/crates/zeta_cli/src/syntax_retrieval_stats.rs b/crates/zeta_cli/src/syntax_retrieval_stats.rs index f2634b1323d92b7136c591627226161b2905a955..4c7506ff78952da79acfeae751959bfe8182b9d4 100644 --- a/crates/zeta_cli/src/syntax_retrieval_stats.rs +++ b/crates/zeta_cli/src/syntax_retrieval_stats.rs @@ -32,7 +32,7 @@ use std::{ time::Duration, }; use util::paths::PathStyle; -use zeta2::ContextMode; +use zeta::ContextMode; use crate::headless::ZetaCliAppState; use crate::source_location::SourceLocation; @@ -44,7 +44,7 @@ pub async fn retrieval_stats( only_extension: Option, file_limit: Option, skip_files: Option, - options: zeta2::ZetaOptions, + options: zeta::ZetaOptions, cx: &mut AsyncApp, ) -> Result { let ContextMode::Syntax(context_options) = options.context.clone() else {