diff --git a/Cargo.lock b/Cargo.lock index bb3b3d31e5d13e2f59d1adbda927efc6e6884b92..d108636963b0234345d3c3837e203d371ee9c014 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5126,7 +5126,6 @@ dependencies = [ "client", "gpui", "language", - "project", "workspace-hack", ] diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index ec475598245b111e2647c63c3edcddd0d15ee5b8..90df92f54216c9040c3a36b737bcf9415901ee87 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -43,15 +43,24 @@ pub struct PredictEditsRequest { pub prompt_format: PromptFormat, } -#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)] pub enum PromptFormat { - #[default] MarkedExcerpt, LabeledSections, /// Prompt format intended for use via zeta_cli OnlySnippets, } +impl PromptFormat { + pub const DEFAULT: PromptFormat = PromptFormat::LabeledSections; +} + +impl Default for PromptFormat { + fn default() -> Self { + Self::DEFAULT + } +} + impl PromptFormat { pub fn iter() -> impl Iterator { ::iter() diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs index c122dccec069ff636d39c64f24ef0aca41145012..6027c081ccef31bfdeb83cb944dcba861bc95da8 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -3,7 +3,6 @@ use anyhow::Result; use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; use gpui::{App, Context, Entity, EntityId, Task}; use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings}; -use project::Project; use settings::Settings; use std::{path::Path, time::Duration}; @@ -84,7 +83,6 @@ impl EditPredictionProvider for CopilotCompletionProvider { fn refresh( &mut self, - _project: Option>, buffer: Entity, cursor_position: language::Anchor, debounce: bool, @@ -249,7 +247,7 @@ impl EditPredictionProvider for CopilotCompletionProvider { None } else { let position = cursor_position.bias_right(buffer); - Some(EditPrediction { + Some(EditPrediction::Local { id: None, edits: vec![(position..position, completion_text.into())], edit_preview: None, diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index 81c1e5dec20ce9032c4e1422f330b11da56fabe7..0195bdb06d67297569ef14175148fcab71effd6a 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -15,5 +15,4 @@ path = "src/edit_prediction.rs" client.workspace = true gpui.workspace = true language.workspace = true -project.workspace = true workspace-hack.workspace = true diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 6b695af1ae0e4807c9aa93af34a5d07de0c15795..90cad9f9227ae8071da6e256c6d9b494e61ac67c 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -3,7 +3,6 @@ use std::ops::Range; use client::EditPredictionUsage; use gpui::{App, Context, Entity, SharedString}; use language::Buffer; -use project::Project; // TODO: Find a better home for `Direction`. // @@ -16,11 +15,19 @@ pub enum Direction { } #[derive(Clone)] -pub struct EditPrediction { - /// The ID of the completion, if it has one. - pub id: Option, - pub edits: Vec<(Range, String)>, - pub edit_preview: Option, +pub enum EditPrediction { + /// Edits within the buffer that requested the prediction + Local { + id: Option, + edits: Vec<(Range, String)>, + edit_preview: Option, + }, + /// Jump to a different file from the one that requested the prediction + Jump { + id: Option, + snapshot: language::BufferSnapshot, + target: language::Anchor, + }, } pub enum DataCollectionState { @@ -83,7 +90,6 @@ pub trait EditPredictionProvider: 'static + Sized { fn is_refreshing(&self) -> bool; fn refresh( &mut self, - project: Option>, buffer: Entity, cursor_position: language::Anchor, debounce: bool, @@ -124,7 +130,6 @@ pub trait EditPredictionProviderHandle { fn is_refreshing(&self, cx: &App) -> bool; fn refresh( &self, - project: Option>, buffer: Entity, cursor_position: language::Anchor, debounce: bool, @@ -198,14 +203,13 @@ where fn refresh( &self, - project: Option>, buffer: Entity, cursor_position: language::Anchor, debounce: bool, cx: &mut App, ) { self.update(cx, |this, cx| { - this.refresh(project, buffer, cursor_position, debounce, cx) + this.refresh(buffer, cursor_position, debounce, cx) }) } diff --git a/crates/editor/src/edit_prediction_tests.rs b/crates/editor/src/edit_prediction_tests.rs index 7bf51e45d72f383b4af34cf6ad493792f8e9d351..7d64dd9749c68cb0e436c1cfcb04e3458d052872 100644 --- a/crates/editor/src/edit_prediction_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -2,7 +2,6 @@ use edit_prediction::EditPredictionProvider; use gpui::{Entity, prelude::*}; use indoc::indoc; use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint}; -use project::Project; use std::ops::Range; use text::{Point, ToOffset}; @@ -261,7 +260,7 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui: EditPrediction::Edit { .. } => { // This is expected for non-Zed providers } - EditPrediction::Move { .. } => { + EditPrediction::MoveWithin { .. } | EditPrediction::MoveOutside { .. } => { panic!( "Non-Zed providers should not show Move predictions (jump functionality)" ); @@ -299,7 +298,7 @@ fn assert_editor_active_move_completion( .as_ref() .expect("editor has no active completion"); - if let EditPrediction::Move { target, .. } = &completion_state.completion { + if let EditPrediction::MoveWithin { target, .. } = &completion_state.completion { assert(editor.buffer().read(cx).snapshot(cx), *target); } else { panic!("expected move completion"); @@ -326,7 +325,7 @@ fn propose_edits( cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_edit_prediction(Some(edit_prediction::EditPrediction { + provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local { id: None, edits: edits.collect(), edit_preview: None, @@ -357,7 +356,7 @@ fn propose_edits_non_zed( cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_edit_prediction(Some(edit_prediction::EditPrediction { + provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local { id: None, edits: edits.collect(), edit_preview: None, @@ -418,7 +417,6 @@ impl EditPredictionProvider for FakeEditPredictionProvider { fn refresh( &mut self, - _project: Option>, _buffer: gpui::Entity, _cursor_position: language::Anchor, _debounce: bool, @@ -492,7 +490,6 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider { fn refresh( &mut self, - _project: Option>, _buffer: gpui::Entity, _cursor_position: language::Anchor, _debounce: bool, diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index e261df4867779a5ecebd8ca4f1dc52394bac62e8..406dd395ba2009a53e5d5c2ec8bd1956571bb11a 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -638,17 +638,23 @@ enum EditPrediction { display_mode: EditDisplayMode, snapshot: BufferSnapshot, }, - Move { + /// Move to a specific location in the active editor + MoveWithin { target: Anchor, snapshot: BufferSnapshot, }, + /// Move to a specific location in a different editor (not the active one) + MoveOutside { + target: language::Anchor, + snapshot: BufferSnapshot, + }, } struct EditPredictionState { inlay_ids: Vec, completion: EditPrediction, completion_id: Option, - invalidation_range: Range, + invalidation_range: Option>, } enum EditPredictionSettings { @@ -7175,13 +7181,7 @@ impl Editor { return None; } - provider.refresh( - self.project.clone(), - buffer, - cursor_buffer_position, - debounce, - cx, - ); + provider.refresh(buffer, cursor_buffer_position, debounce, cx); Some(()) } @@ -7424,10 +7424,8 @@ impl Editor { return; }; - self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx); - match &active_edit_prediction.completion { - EditPrediction::Move { target, .. } => { + EditPrediction::MoveWithin { target, .. } => { let target = *target; if let Some(position_map) = &self.last_position_map { @@ -7469,7 +7467,19 @@ impl Editor { } } } + EditPrediction::MoveOutside { snapshot, target } => { + if let Some(workspace) = self.workspace() { + Self::open_editor_at_anchor(snapshot, *target, &workspace, window, cx) + .detach_and_log_err(cx); + } + } EditPrediction::Edit { edits, .. } => { + self.report_edit_prediction_event( + active_edit_prediction.completion_id.clone(), + true, + cx, + ); + if let Some(provider) = self.edit_prediction_provider() { provider.accept(cx); } @@ -7522,10 +7532,8 @@ impl Editor { return; } - self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx); - match &active_edit_prediction.completion { - EditPrediction::Move { target, .. } => { + EditPrediction::MoveWithin { target, .. } => { let target = *target; self.change_selections( SelectionEffects::scroll(Autoscroll::newest()), @@ -7536,7 +7544,19 @@ impl Editor { }, ); } + EditPrediction::MoveOutside { snapshot, target } => { + if let Some(workspace) = self.workspace() { + Self::open_editor_at_anchor(snapshot, *target, &workspace, window, cx) + .detach_and_log_err(cx); + } + } EditPrediction::Edit { edits, .. } => { + self.report_edit_prediction_event( + active_edit_prediction.completion_id.clone(), + true, + cx, + ); + // Find an insertion that starts at the cursor position. let snapshot = self.buffer.read(cx).snapshot(cx); let cursor_offset = self.selections.newest::(cx).head(); @@ -7631,6 +7651,36 @@ impl Editor { ); } + fn open_editor_at_anchor( + snapshot: &language::BufferSnapshot, + target: language::Anchor, + workspace: &Entity, + window: &mut Window, + cx: &mut App, + ) -> Task> { + workspace.update(cx, |workspace, cx| { + let path = snapshot.file().map(|file| file.full_path(cx)); + let Some(path) = + path.and_then(|path| workspace.project().read(cx).find_project_path(path, cx)) + else { + return Task::ready(Err(anyhow::anyhow!("Project path not found"))); + }; + let target = text::ToPoint::to_point(&target, snapshot); + let item = workspace.open_path(path, None, true, window, cx); + window.spawn(cx, async move |cx| { + let Some(editor) = item.await?.downcast::() else { + return Ok(()); + }; + editor + .update_in(cx, |editor, window, cx| { + editor.go_to_singleton_buffer_point(target, window, cx); + }) + .ok(); + anyhow::Ok(()) + }) + }) + } + pub fn has_active_edit_prediction(&self) -> bool { self.active_edit_prediction.is_some() } @@ -7846,7 +7896,10 @@ impl Editor { .active_edit_prediction .as_ref() .is_some_and(|completion| { - let invalidation_range = completion.invalidation_range.to_offset(&multibuffer); + let Some(invalidation_range) = completion.invalidation_range.as_ref() else { + return false; + }; + let invalidation_range = invalidation_range.to_offset(&multibuffer); let invalidation_range = invalidation_range.start..=invalidation_range.end; !invalidation_range.contains(&offset_selection.head()) }) @@ -7882,8 +7935,31 @@ impl Editor { } let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?; - let edits = edit_prediction - .edits + + let (completion_id, edits, edit_preview) = match edit_prediction { + edit_prediction::EditPrediction::Local { + id, + edits, + edit_preview, + } => (id, edits, edit_preview), + edit_prediction::EditPrediction::Jump { + id, + snapshot, + target, + } => { + self.stale_edit_prediction_in_menu = None; + self.active_edit_prediction = Some(EditPredictionState { + inlay_ids: vec![], + completion: EditPrediction::MoveOutside { snapshot, target }, + completion_id: id, + invalidation_range: None, + }); + cx.notify(); + return Some(()); + } + }; + + let edits = edits .into_iter() .flat_map(|(range, new_text)| { let start = multibuffer.anchor_in_excerpt(excerpt_id, range.start)?; @@ -7928,7 +8004,7 @@ impl Editor { invalidation_row_range = move_invalidation_row_range.unwrap_or(edit_start_row..edit_end_row); let target = first_edit_start; - EditPrediction::Move { target, snapshot } + EditPrediction::MoveWithin { target, snapshot } } else { let show_completions_in_buffer = !self.edit_prediction_visible_in_cursor_popover(true) && !self.edit_predictions_hidden_for_vim_mode; @@ -7977,7 +8053,7 @@ impl Editor { EditPrediction::Edit { edits, - edit_preview: edit_prediction.edit_preview, + edit_preview, display_mode, snapshot, } @@ -7994,8 +8070,8 @@ impl Editor { self.active_edit_prediction = Some(EditPredictionState { inlay_ids, completion, - completion_id: edit_prediction.id, - invalidation_range, + completion_id, + invalidation_range: Some(invalidation_range), }); cx.notify(); @@ -8581,7 +8657,7 @@ impl Editor { } match &active_edit_prediction.completion { - EditPrediction::Move { target, .. } => { + EditPrediction::MoveWithin { target, .. } => { let target_display_point = target.to_display_point(editor_snapshot); if self.edit_prediction_requires_modifier() { @@ -8666,6 +8742,28 @@ impl Editor { window, cx, ), + EditPrediction::MoveOutside { snapshot, .. } => { + let file_name = snapshot + .file() + .map(|file| file.file_name(cx)) + .unwrap_or("untitled"); + let mut element = self + .render_edit_prediction_line_popover( + format!("Jump to {file_name}"), + Some(IconName::ZedPredict), + window, + cx, + ) + .into_any(); + + let size = element.layout_as_root(AvailableSpace::min_size(), window, cx); + let origin_x = text_bounds.size.width / 2. - size.width / 2.; + let origin_y = text_bounds.size.height - size.height - px(30.); + let origin = text_bounds.origin + gpui::Point::new(origin_x, origin_y); + element.prepaint_at(origin, window, cx); + + Some((element, origin)) + } } } @@ -8730,13 +8828,13 @@ impl Editor { .items_end() .when(flag_on_right, |el| el.items_start()) .child(if flag_on_right { - self.render_edit_prediction_line_popover("Jump", None, window, cx)? + self.render_edit_prediction_line_popover("Jump", None, window, cx) .rounded_bl(px(0.)) .rounded_tl(px(0.)) .border_l_2() .border_color(border_color) } else { - self.render_edit_prediction_line_popover("Jump", None, window, cx)? + self.render_edit_prediction_line_popover("Jump", None, window, cx) .rounded_br(px(0.)) .rounded_tr(px(0.)) .border_r_2() @@ -8776,7 +8874,7 @@ impl Editor { cx: &mut App, ) -> Option<(AnyElement, gpui::Point)> { let mut element = self - .render_edit_prediction_line_popover("Scroll", Some(scroll_icon), window, cx)? + .render_edit_prediction_line_popover("Scroll", Some(scroll_icon), window, cx) .into_any(); let size = element.layout_as_root(AvailableSpace::min_size(), window, cx); @@ -8816,7 +8914,7 @@ impl Editor { Some(IconName::ArrowUp), window, cx, - )? + ) .into_any(); let size = element.layout_as_root(AvailableSpace::min_size(), window, cx); @@ -8835,7 +8933,7 @@ impl Editor { Some(IconName::ArrowDown), window, cx, - )? + ) .into_any(); let size = element.layout_as_root(AvailableSpace::min_size(), window, cx); @@ -8882,7 +8980,7 @@ impl Editor { ); let mut element = self - .render_edit_prediction_line_popover(label, None, window, cx)? + .render_edit_prediction_line_popover(label, None, window, cx) .into_any(); let size = element.layout_as_root(AvailableSpace::min_size(), window, cx); @@ -8909,7 +9007,7 @@ impl Editor { }; element = self - .render_edit_prediction_line_popover(label, Some(icon), window, cx)? + .render_edit_prediction_line_popover(label, Some(icon), window, cx) .into_any(); let size = element.layout_as_root(AvailableSpace::min_size(), window, cx); @@ -9163,13 +9261,13 @@ impl Editor { icon: Option, window: &mut Window, cx: &App, - ) -> Option> { + ) -> Stateful
{ let padding_right = if icon.is_some() { px(4.) } else { px(8.) }; let keybind = self.render_edit_prediction_accept_keybind(window, cx); let has_keybind = keybind.is_some(); - let result = h_flex() + h_flex() .id("ep-line-popover") .py_0p5() .pl_1() @@ -9215,9 +9313,7 @@ impl Editor { .mt(px(1.5)) .child(Icon::new(icon).size(IconSize::Small)), ) - }); - - Some(result) + }) } fn edit_prediction_line_popover_bg_color(cx: &App) -> Hsla { @@ -9281,7 +9377,7 @@ impl Editor { .rounded_tl(px(0.)) .overflow_hidden() .child(div().px_1p5().child(match &prediction.completion { - EditPrediction::Move { target, snapshot } => { + EditPrediction::MoveWithin { target, snapshot } => { use text::ToPoint as _; if target.text_anchor.to_point(snapshot).row > cursor_point.row { @@ -9290,6 +9386,10 @@ impl Editor { Icon::new(IconName::ZedPredictUp) } } + EditPrediction::MoveOutside { .. } => { + // TODO [zeta2] custom icon for external jump? + Icon::new(provider_icon) + } EditPrediction::Edit { .. } => Icon::new(provider_icon), })) .child( @@ -9472,7 +9572,7 @@ impl Editor { .unwrap_or(true); match &completion.completion { - EditPrediction::Move { + EditPrediction::MoveWithin { target, snapshot, .. } => { if !supports_jump { @@ -9494,7 +9594,20 @@ impl Editor { .child(Label::new("Jump to Edit")), ) } - + EditPrediction::MoveOutside { snapshot, .. } => { + let file_name = snapshot + .file() + .map(|file| file.file_name(cx)) + .unwrap_or("untitled"); + Some( + h_flex() + .px_2() + .gap_2() + .flex_1() + .child(Icon::new(IconName::ZedPredict)) + .child(Label::new(format!("Jump to {file_name}"))), + ) + } EditPrediction::Edit { edits, edit_preview, @@ -21418,7 +21531,7 @@ impl Editor { { self.hide_context_menu(window, cx); } - self.discard_edit_prediction(false, cx); + self.take_active_edit_prediction(cx); cx.emit(EditorEvent::Blurred); cx.notify(); } diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index c6692e00fd66b333c367a193ec59398af4391d62..c3a527899e4c509047fbf2d8273f1fb1c5a6c2a7 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -8272,7 +8272,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext) cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_edit_prediction(Some(edit_prediction::EditPrediction { + provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local { id: None, edits: vec![(edit_position..edit_position, "X".into())], edit_preview: None, diff --git a/crates/supermaven/Cargo.toml b/crates/supermaven/Cargo.toml index 4fc6a618ff1b585d9365357dc3a33c1b148feb99..1ee8ca4ffc094210dd1edf231d9160556829745a 100644 --- a/crates/supermaven/Cargo.toml +++ b/crates/supermaven/Cargo.toml @@ -22,7 +22,6 @@ gpui.workspace = true language.workspace = true log.workspace = true postage.workspace = true -project.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs index 89c5129822d94229cd1644587f15f4a4de2bf86a..32177aaa427e8616a2767410a7a6ec84c05abbee 100644 --- a/crates/supermaven/src/supermaven_completion_provider.rs +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -4,7 +4,6 @@ use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; use futures::StreamExt as _; use gpui::{App, Context, Entity, EntityId, Task}; use language::{Anchor, Buffer, BufferSnapshot}; -use project::Project; use std::{ ops::{AddAssign, Range}, path::Path, @@ -94,7 +93,7 @@ fn completion_from_diff( edits.push((edit_range, edit_text)); } - EditPrediction { + EditPrediction::Local { id: None, edits, edit_preview: None, @@ -132,7 +131,6 @@ impl EditPredictionProvider for SupermavenCompletionProvider { fn refresh( &mut self, - _project: Option>, buffer_handle: Entity, cursor_position: Anchor, debounce: bool, diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 5d099834b924859052b2620d7ea33383c3c530c6..a1ae52fc0650b7eb4eacd37b3670a0d93eed532e 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -205,42 +205,48 @@ fn assign_edit_prediction_provider( } } - if std::env::var("ZED_ZETA2").is_ok() { - let zeta = zeta2::Zeta::global(client, &user_store, cx); - let provider = cx.new(|cx| { - zeta2::ZetaEditPredictionProvider::new( - editor.project(), - &client, - &user_store, - cx, - ) - }); - - if let Some(buffer) = &singleton_buffer - && buffer.read(cx).file().is_some() - && let Some(project) = editor.project() - { - zeta.update(cx, |zeta, cx| { - zeta.register_buffer(buffer, project, cx); + if let Some(project) = editor.project() { + if std::env::var("ZED_ZETA2").is_ok() { + let zeta = zeta2::Zeta::global(client, &user_store, cx); + let provider = cx.new(|cx| { + zeta2::ZetaEditPredictionProvider::new( + project.clone(), + &client, + &user_store, + cx, + ) }); - } - editor.set_edit_prediction_provider(Some(provider), window, cx); - } else { - let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx); - - if let Some(buffer) = &singleton_buffer - && buffer.read(cx).file().is_some() - && let Some(project) = editor.project() - { - zeta.update(cx, |zeta, cx| { - zeta.register_buffer(buffer, project, cx); + // TODO [zeta2] handle multibuffers + if let Some(buffer) = &singleton_buffer + && buffer.read(cx).file().is_some() + { + zeta.update(cx, |zeta, cx| { + zeta.register_buffer(buffer, project, cx); + }); + } + + editor.set_edit_prediction_provider(Some(provider), window, cx); + } else { + let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx); + + if let Some(buffer) = &singleton_buffer + && buffer.read(cx).file().is_some() + { + zeta.update(cx, |zeta, cx| { + zeta.register_buffer(buffer, project, cx); + }); + } + + let provider = cx.new(|_| { + zeta::ZetaEditPredictionProvider::new( + zeta, + project.clone(), + singleton_buffer, + ) }); + editor.set_edit_prediction_provider(Some(provider), window, cx); } - - let provider = - cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer)); - editor.set_edit_prediction_provider(Some(provider), window, cx); } } } diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 6a75b5e2cd369002d928992e61124c18cc47c5e7..3a156f351d8f34e858ce199aa1244729fe07a227 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -1316,12 +1316,17 @@ pub struct ZetaEditPredictionProvider { next_pending_completion_id: usize, current_completion: Option, last_request_timestamp: Instant, + project: Entity, } impl ZetaEditPredictionProvider { pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); - pub fn new(zeta: Entity, singleton_buffer: Option>) -> Self { + pub fn new( + zeta: Entity, + project: Entity, + singleton_buffer: Option>, + ) -> Self { Self { zeta, singleton_buffer, @@ -1329,6 +1334,7 @@ impl ZetaEditPredictionProvider { next_pending_completion_id: 0, current_completion: None, last_request_timestamp: Instant::now(), + project, } } } @@ -1394,7 +1400,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { fn refresh( &mut self, - project: Option>, buffer: Entity, position: language::Anchor, _debounce: bool, @@ -1403,9 +1408,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { if self.zeta.read(cx).update_required { return; } - let Some(project) = project else { - return; - }; if self .zeta @@ -1433,6 +1435,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { self.next_pending_completion_id += 1; let last_request_timestamp = self.last_request_timestamp; + let project = self.project.clone(); let task = cx.spawn(async move |this, cx| { if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT) .checked_duration_since(Instant::now()) @@ -1604,7 +1607,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { } } - Some(edit_prediction::EditPrediction { + Some(edit_prediction::EditPrediction::Local { id: Some(completion.id.to_string().into()), edits: edits[edit_start_ix..edit_end_ix].to_vec(), edit_preview: Some(completion.edit_preview.clone()), diff --git a/crates/zeta2/src/prediction.rs b/crates/zeta2/src/prediction.rs index cca41efb7c62224e1601001a55d5c6c4c50ff47a..d4832993b9ecd7c40f154f2ab696c66872073d5e 100644 --- a/crates/zeta2/src/prediction.rs +++ b/crates/zeta2/src/prediction.rs @@ -1,50 +1,146 @@ -use std::{borrow::Cow, ops::Range, sync::Arc}; +use std::{borrow::Cow, ops::Range, path::Path, sync::Arc}; +use anyhow::Context as _; use cloud_llm_client::predict_edits_v3; -use language::{Anchor, BufferSnapshot, EditPreview, OffsetRangeExt, text_diff}; +use gpui::{App, AsyncApp, Entity}; +use language::{ + Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff, +}; +use project::Project; +use util::ResultExt; use uuid::Uuid; +#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] +pub struct EditPredictionId(Uuid); + +impl From for gpui::ElementId { + fn from(value: EditPredictionId) -> Self { + gpui::ElementId::Uuid(value.0) + } +} + +impl std::fmt::Display for EditPredictionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + #[derive(Clone)] pub struct EditPrediction { pub id: EditPredictionId, + pub path: Arc, pub edits: Arc<[(Range, String)]>, pub snapshot: BufferSnapshot, pub edit_preview: EditPreview, + // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction. + _buffer: Entity, } impl EditPrediction { + pub async fn from_response( + response: predict_edits_v3::PredictEditsResponse, + active_buffer_old_snapshot: &TextBufferSnapshot, + active_buffer: &Entity, + project: &Entity, + cx: &mut AsyncApp, + ) -> Option { + // TODO only allow cloud to return one path + let Some(path) = response.edits.first().map(|e| e.path.clone()) else { + return None; + }; + + let is_same_path = active_buffer + .read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx)) + .ok()?; + + let (buffer, edits, snapshot, edit_preview_task) = if is_same_path { + active_buffer + .read_with(cx, |buffer, cx| { + let new_snapshot = buffer.snapshot(); + let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot); + let edits: Arc<[_]> = + interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into(); + + Some(( + active_buffer.clone(), + edits.clone(), + new_snapshot, + buffer.preview_edits(edits, cx), + )) + }) + .ok()?? + } else { + let buffer_handle = project + .update(cx, |project, cx| { + let project_path = project + .find_project_path(&path, cx) + .context("Failed to find project path for zeta edit")?; + anyhow::Ok(project.open_buffer(project_path, cx)) + }) + .ok()? + .log_err()? + .await + .context("Failed to open buffer for zeta edit") + .log_err()?; + + buffer_handle + .read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(); + let edits = edits_from_response(&response.edits, &snapshot); + if edits.is_empty() { + return None; + } + Some(( + buffer_handle.clone(), + edits.clone(), + snapshot, + buffer.preview_edits(edits, cx), + )) + }) + .ok()?? + }; + + let edit_preview = edit_preview_task.await; + + Some(EditPrediction { + id: EditPredictionId(response.request_id), + path, + edits, + snapshot, + edit_preview, + _buffer: buffer, + }) + } + pub fn interpolate( &self, - new_snapshot: &BufferSnapshot, + new_snapshot: &TextBufferSnapshot, ) -> Option, String)>> { interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone()) } -} -#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] -pub struct EditPredictionId(Uuid); - -impl From for EditPredictionId { - fn from(value: Uuid) -> Self { - EditPredictionId(value) + pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool { + buffer_path_eq(buffer, &self.path, cx) } } -impl From for gpui::ElementId { - fn from(value: EditPredictionId) -> Self { - gpui::ElementId::Uuid(value.0) +impl std::fmt::Debug for EditPrediction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EditPrediction") + .field("id", &self.id) + .field("path", &self.path) + .field("edits", &self.edits) + .finish() } } -impl std::fmt::Display for EditPredictionId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } +pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool { + buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path) } pub fn interpolate_edits( - old_snapshot: &BufferSnapshot, - new_snapshot: &BufferSnapshot, + old_snapshot: &TextBufferSnapshot, + new_snapshot: &TextBufferSnapshot, current_edits: Arc<[(Range, String)]>, ) -> Option, String)>> { let mut edits = Vec::new(); @@ -88,14 +184,13 @@ pub fn interpolate_edits( if edits.is_empty() { None } else { Some(edits) } } -pub fn edits_from_response( +fn edits_from_response( edits: &[predict_edits_v3::Edit], - snapshot: &BufferSnapshot, + snapshot: &TextBufferSnapshot, ) -> Arc<[(Range, String)]> { edits .iter() .flat_map(|edit| { - // TODO multi-file edits let old_text = snapshot.text_for_range(edit.range.clone()); excerpt_edits_from_response( @@ -113,7 +208,7 @@ fn excerpt_edits_from_response( old_text: Cow, new_text: &str, offset: usize, - snapshot: &BufferSnapshot, + snapshot: &TextBufferSnapshot, ) -> impl Iterator, String)> { text_diff(&old_text, new_text) .into_iter() @@ -221,6 +316,8 @@ mod tests { id: EditPredictionId(Uuid::new_v4()), edits, snapshot: cx.read(|cx| buffer.read(cx).snapshot()), + path: Path::new("test.txt").into(), + _buffer: buffer.clone(), edit_preview, }; diff --git a/crates/zeta2/src/provider.rs b/crates/zeta2/src/provider.rs index ae30c0bee0da47d8f6174e76918e6dd751d348d2..db637208aa88e8e3ebe4b30dc3d5639497cd0ac0 100644 --- a/crates/zeta2/src/provider.rs +++ b/crates/zeta2/src/provider.rs @@ -4,76 +4,44 @@ use std::{ time::{Duration, Instant}, }; -use anyhow::Context as _; use arrayvec::ArrayVec; use client::{Client, UserStore}; use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider}; -use gpui::{App, Entity, EntityId, Task, prelude::*}; -use language::{BufferSnapshot, ToPoint as _}; +use gpui::{App, Entity, Task, prelude::*}; +use language::ToPoint as _; use project::Project; use util::ResultExt as _; -use crate::{Zeta, prediction::EditPrediction}; +use crate::{BufferEditPrediction, Zeta}; pub struct ZetaEditPredictionProvider { zeta: Entity, - current_prediction: Option, next_pending_prediction_id: usize, pending_predictions: ArrayVec, last_request_timestamp: Instant, + project: Entity, } impl ZetaEditPredictionProvider { pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); pub fn new( - project: Option<&Entity>, + project: Entity, client: &Arc, user_store: &Entity, cx: &mut App, ) -> Self { let zeta = Zeta::global(client, user_store, cx); - if let Some(project) = project { - zeta.update(cx, |zeta, cx| { - zeta.register_project(project, cx); - }); - } + zeta.update(cx, |zeta, cx| { + zeta.register_project(&project, cx); + }); Self { zeta, - current_prediction: None, next_pending_prediction_id: 0, pending_predictions: ArrayVec::new(), last_request_timestamp: Instant::now(), - } - } -} - -#[derive(Clone)] -struct CurrentEditPrediction { - buffer_id: EntityId, - prediction: EditPrediction, -} - -impl CurrentEditPrediction { - fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool { - if self.buffer_id != old_prediction.buffer_id { - return true; - } - - let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else { - return true; - }; - let Some(new_edits) = self.prediction.interpolate(snapshot) else { - return false; - }; - - if old_edits.len() == 1 && new_edits.len() == 1 { - let (old_range, old_text) = &old_edits[0]; - let (new_range, new_text) = &new_edits[0]; - new_range == old_range && new_text.starts_with(old_text) - } else { - true + project: project, } } } @@ -128,42 +96,31 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { fn refresh( &mut self, - project: Option>, buffer: Entity, cursor_position: language::Anchor, _debounce: bool, cx: &mut Context, ) { - let Some(project) = project else { - return; - }; + let zeta = self.zeta.read(cx); - if self - .zeta - .read(cx) - .user_store - .read_with(cx, |user_store, _cx| { - user_store.account_too_young() || user_store.has_overdue_invoices() - }) - { + if zeta.user_store.read_with(cx, |user_store, _cx| { + user_store.account_too_young() || user_store.has_overdue_invoices() + }) { return; } - if let Some(current_prediction) = self.current_prediction.as_ref() { - let snapshot = buffer.read(cx).snapshot(); - if current_prediction - .prediction - .interpolate(&snapshot) - .is_some() - { - return; - } + if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx) + && let BufferEditPrediction::Local { prediction } = current + && prediction.interpolate(buffer.read(cx)).is_some() + { + return; } let pending_prediction_id = self.next_pending_prediction_id; self.next_pending_prediction_id += 1; let last_request_timestamp = self.last_request_timestamp; + let project = self.project.clone(); let task = cx.spawn(async move |this, cx| { if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT) .checked_duration_since(Instant::now()) @@ -171,25 +128,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { cx.background_executor().timer(timeout).await; } - let prediction_request = this.update(cx, |this, cx| { + let refresh_task = this.update(cx, |this, cx| { this.last_request_timestamp = Instant::now(); this.zeta.update(cx, |zeta, cx| { - zeta.request_prediction(&project, &buffer, cursor_position, cx) + zeta.refresh_prediction(&project, &buffer, cursor_position, cx) }) }); - let prediction = match prediction_request { - Ok(prediction_request) => { - let prediction_request = prediction_request.await; - prediction_request.map(|c| { - c.map(|prediction| CurrentEditPrediction { - buffer_id: buffer.entity_id(), - prediction, - }) - }) - } - Err(error) => Err(error), - }; + if let Some(refresh_task) = refresh_task.ok() { + refresh_task.await.log_err(); + } this.update(cx, |this, cx| { if this.pending_predictions[0].id == pending_prediction_id { @@ -198,24 +146,6 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { this.pending_predictions.clear(); } - let Some(new_prediction) = prediction - .context("edit prediction failed") - .log_err() - .flatten() - else { - cx.notify(); - return; - }; - - if let Some(old_prediction) = this.current_prediction.as_ref() { - let snapshot = buffer.read(cx).snapshot(); - if new_prediction.should_replace_prediction(old_prediction, &snapshot) { - this.current_prediction = Some(new_prediction); - } - } else { - this.current_prediction = Some(new_prediction); - } - cx.notify(); }) .ok(); @@ -248,15 +178,18 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { ) { } - fn accept(&mut self, _cx: &mut Context) { - // TODO [zeta2] report accept - self.current_prediction.take(); + fn accept(&mut self, cx: &mut Context) { + self.zeta.update(cx, |zeta, _cx| { + zeta.accept_current_prediction(&self.project); + }); self.pending_predictions.clear(); } - fn discard(&mut self, _cx: &mut Context) { + fn discard(&mut self, cx: &mut Context) { + self.zeta.update(cx, |zeta, _cx| { + zeta.discard_current_prediction(&self.project); + }); self.pending_predictions.clear(); - self.current_prediction.take(); } fn suggest( @@ -265,36 +198,44 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { cursor_position: language::Anchor, cx: &mut Context, ) -> Option { - let CurrentEditPrediction { - buffer_id, - prediction, - .. - } = self.current_prediction.as_mut()?; - - // Invalidate previous prediction if it was generated for a different buffer. - if *buffer_id != buffer.entity_id() { - self.current_prediction.take(); - return None; - } + let prediction = + self.zeta + .read(cx) + .current_prediction_for_buffer(buffer, &self.project, cx)?; + + let prediction = match prediction { + BufferEditPrediction::Local { prediction } => prediction, + BufferEditPrediction::Jump { prediction } => { + return Some(edit_prediction::EditPrediction::Jump { + id: Some(prediction.id.to_string().into()), + snapshot: prediction.snapshot.clone(), + target: prediction.edits.first().unwrap().0.start, + }); + } + }; let buffer = buffer.read(cx); - let Some(edits) = prediction.interpolate(&buffer.snapshot()) else { - self.current_prediction.take(); + let snapshot = buffer.snapshot(); + + let Some(edits) = prediction.interpolate(&snapshot) else { + self.zeta.update(cx, |zeta, _cx| { + zeta.discard_current_prediction(&self.project); + }); return None; }; - let cursor_row = cursor_position.to_point(buffer).row; + let cursor_row = cursor_position.to_point(&snapshot).row; let (closest_edit_ix, (closest_edit_range, _)) = edits.iter().enumerate().min_by_key(|(_, (range, _))| { - let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row); - let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row); + let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row); + let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row); cmp::min(distance_from_start, distance_from_end) })?; let mut edit_start_ix = closest_edit_ix; for (range, _) in edits[..edit_start_ix].iter().rev() { - let distance_from_closest_edit = - closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row; + let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row + - range.end.to_point(&snapshot).row; if distance_from_closest_edit <= 1 { edit_start_ix -= 1; } else { @@ -305,7 +246,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { let mut edit_end_ix = closest_edit_ix + 1; for (range, _) in &edits[edit_end_ix..] { let distance_from_closest_edit = - range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row; + range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row; if distance_from_closest_edit <= 1 { edit_end_ix += 1; } else { @@ -313,7 +254,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { } } - Some(edit_prediction::EditPrediction { + Some(edit_prediction::EditPrediction::Local { id: Some(prediction.id.to_string().into()), edits: edits[edit_start_ix..edit_end_ix].to_vec(), edit_preview: Some(prediction.edit_preview.clone()), diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index ba9b50bbf3f0ea51232f39a729b3bf0cdb92aaef..0aaf4c9d35e9e00e066a716bc645b6f4ad56480a 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -17,8 +17,8 @@ use gpui::{ App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity, http_client, prelude::*, }; -use language::BufferSnapshot; use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint}; +use language::{BufferSnapshot, TextBufferSnapshot}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use project::Project; use release_channel::AppVersion; @@ -35,7 +35,7 @@ use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_noti mod prediction; mod provider; -use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits}; +use crate::prediction::EditPrediction; pub use provider::ZetaEditPredictionProvider; const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); @@ -53,7 +53,7 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { excerpt: DEFAULT_EXCERPT_OPTIONS, max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES, max_diagnostic_bytes: 2048, - prompt_format: PromptFormat::MarkedExcerpt, + prompt_format: PromptFormat::DEFAULT, }; #[derive(Clone)] @@ -94,6 +94,47 @@ struct ZetaProject { syntax_index: Entity, events: VecDeque, registered_buffers: HashMap, + current_prediction: Option, +} + +#[derive(Clone)] +struct CurrentEditPrediction { + pub requested_by_buffer_id: EntityId, + pub prediction: EditPrediction, +} + +impl CurrentEditPrediction { + fn should_replace_prediction( + &self, + old_prediction: &Self, + snapshot: &TextBufferSnapshot, + ) -> bool { + if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id { + return true; + } + + let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else { + return true; + }; + + let Some(new_edits) = self.prediction.interpolate(snapshot) else { + return false; + }; + if old_edits.len() == 1 && new_edits.len() == 1 { + let (old_range, old_text) = &old_edits[0]; + let (new_range, new_text) = &new_edits[0]; + new_range == old_range && new_text.starts_with(old_text) + } else { + true + } + } +} + +/// A prediction from the perspective of a buffer. +#[derive(Debug)] +enum BufferEditPrediction<'a> { + Local { prediction: &'a EditPrediction }, + Jump { prediction: &'a EditPrediction }, } struct RegisteredBuffer { @@ -204,6 +245,7 @@ impl Zeta { syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)), events: VecDeque::new(), registered_buffers: HashMap::new(), + current_prediction: None, }) } @@ -305,7 +347,83 @@ impl Zeta { events.push_back(event); } - pub fn request_prediction( + fn current_prediction_for_buffer( + &self, + buffer: &Entity, + project: &Entity, + cx: &App, + ) -> Option> { + let project_state = self.projects.get(&project.entity_id())?; + + let CurrentEditPrediction { + requested_by_buffer_id, + prediction, + } = project_state.current_prediction.as_ref()?; + + if prediction.targets_buffer(buffer.read(cx), cx) { + Some(BufferEditPrediction::Local { prediction }) + } else if *requested_by_buffer_id == buffer.entity_id() { + Some(BufferEditPrediction::Jump { prediction }) + } else { + None + } + } + + fn accept_current_prediction(&mut self, project: &Entity) { + if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { + project_state.current_prediction.take(); + }; + // TODO report accepted + } + + fn discard_current_prediction(&mut self, project: &Entity) { + if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { + project_state.current_prediction.take(); + }; + } + + pub fn refresh_prediction( + &mut self, + project: &Entity, + buffer: &Entity, + position: language::Anchor, + cx: &mut Context, + ) -> Task> { + let request_task = self.request_prediction(project, buffer, position, cx); + let buffer = buffer.clone(); + let project = project.clone(); + + cx.spawn(async move |this, cx| { + if let Some(prediction) = request_task.await? { + this.update(cx, |this, cx| { + let project_state = this + .projects + .get_mut(&project.entity_id()) + .context("Project not found")?; + + let new_prediction = CurrentEditPrediction { + requested_by_buffer_id: buffer.entity_id(), + prediction: prediction, + }; + + if project_state + .current_prediction + .as_ref() + .is_none_or(|old_prediction| { + new_prediction + .should_replace_prediction(&old_prediction, buffer.read(cx)) + }) + { + project_state.current_prediction = Some(new_prediction); + } + anyhow::Ok(()) + })??; + } + Ok(()) + }) + } + + fn request_prediction( &mut self, project: &Entity, buffer: &Entity, @@ -457,74 +575,63 @@ impl Zeta { .ok(); } - let (response, usage) = response?; - let edits = edits_from_response(&response.edits, &snapshot); - - anyhow::Ok(Some((response.request_id, edits, usage))) + anyhow::Ok(Some(response?)) } }); let buffer = buffer.clone(); - cx.spawn(async move |this, cx| { - match request_task.await { - Ok(Some((id, edits, usage))) => { - if let Some(usage) = usage { - this.update(cx, |this, cx| { - this.user_store.update(cx, |user_store, cx| { - user_store.update_edit_prediction_usage(usage, cx); - }); - }) - .ok(); - } + cx.spawn({ + let project = project.clone(); + async move |this, cx| { + match request_task.await { + Ok(Some((response, usage))) => { + if let Some(usage) = usage { + this.update(cx, |this, cx| { + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); + }) + .ok(); + } - // TODO telemetry: duration, etc - let Some((edits, snapshot, edit_preview_task)) = - buffer.read_with(cx, |buffer, cx| { - let new_snapshot = buffer.snapshot(); - let edits: Arc<[_]> = - interpolate_edits(&snapshot, &new_snapshot, edits)?.into(); - Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx))) - })? - else { - return Ok(None); - }; + let prediction = EditPrediction::from_response( + response, &snapshot, &buffer, &project, cx, + ) + .await; - Ok(Some(EditPrediction { - id: id.into(), - edits, - snapshot, - edit_preview: edit_preview_task.await, - })) - } - Ok(None) => Ok(None), - Err(err) => { - if err.is::() { - cx.update(|cx| { - this.update(cx, |this, _cx| { - this.update_required = true; + // TODO telemetry: duration, etc + Ok(prediction) + } + Ok(None) => Ok(None), + Err(err) => { + if err.is::() { + cx.update(|cx| { + this.update(cx, |this, _cx| { + this.update_required = true; + }) + .ok(); + + let error_message: SharedString = err.to_string().into(); + show_app_notification( + NotificationId::unique::(), + cx, + move |cx| { + cx.new(|cx| { + ErrorMessagePrompt::new(error_message.clone(), cx) + .with_link_button( + "Update Zed", + "https://zed.dev/releases", + ) + }) + }, + ); }) .ok(); + } - let error_message: SharedString = err.to_string().into(); - show_app_notification( - NotificationId::unique::(), - cx, - move |cx| { - cx.new(|cx| { - ErrorMessagePrompt::new(error_message.clone(), cx) - .with_link_button( - "Update Zed", - "https://zed.dev/releases", - ) - }) - }, - ); - }) - .ok(); + Err(err) } - - Err(err) } } }) @@ -859,13 +966,113 @@ mod tests { }; use indoc::indoc; use language::{LanguageServerId, OffsetRangeExt as _}; + use pretty_assertions::{assert_eq, assert_matches}; use project::{FakeFs, Project}; use serde_json::json; use settings::SettingsStore; use util::path; use uuid::Uuid; - use crate::Zeta; + use crate::{BufferEditPrediction, Zeta}; + + #[gpui::test] + async fn test_current_state(cx: &mut TestAppContext) { + let (zeta, mut req_rx) = init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "1.txt": "Hello!\nHow\nBye", + "2.txt": "Hola!\nComo\nAdios" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + zeta.update(cx, |zeta, cx| { + zeta.register_project(&project, cx); + }); + + let buffer1 = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/1.txt"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot1.anchor_before(language::Point::new(1, 3)); + + // Prediction for current file + + let prediction_task = zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction(&project, &buffer1, position, cx) + }); + let (_request, respond_tx) = req_rx.next().await.unwrap(); + respond_tx + .send(predict_edits_v3::PredictEditsResponse { + request_id: Uuid::new_v4(), + edits: vec![predict_edits_v3::Edit { + path: Path::new(path!("root/1.txt")).into(), + range: 0..snapshot1.len(), + content: "Hello!\nHow are you?\nBye".into(), + }], + debug_info: None, + }) + .unwrap(); + prediction_task.await.unwrap(); + + zeta.read_with(cx, |zeta, cx| { + let prediction = zeta + .current_prediction_for_buffer(&buffer1, &project, cx) + .unwrap(); + assert_matches!(prediction, BufferEditPrediction::Local { .. }); + }); + + // Prediction for another file + + let prediction_task = zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction(&project, &buffer1, position, cx) + }); + let (_request, respond_tx) = req_rx.next().await.unwrap(); + respond_tx + .send(predict_edits_v3::PredictEditsResponse { + request_id: Uuid::new_v4(), + edits: vec![predict_edits_v3::Edit { + path: Path::new(path!("root/2.txt")).into(), + range: 0..snapshot1.len(), + content: "Hola!\nComo estas?\nAdios".into(), + }], + debug_info: None, + }) + .unwrap(); + prediction_task.await.unwrap(); + + zeta.read_with(cx, |zeta, cx| { + let prediction = zeta + .current_prediction_for_buffer(&buffer1, &project, cx) + .unwrap(); + assert_matches!( + prediction, + BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt")) + ); + }); + + let buffer2 = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/2.txt"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + zeta.read_with(cx, |zeta, cx| { + let prediction = zeta + .current_prediction_for_buffer(&buffer2, &project, cx) + .unwrap(); + assert_matches!(prediction, BufferEditPrediction::Local { .. }); + }); + } #[gpui::test] async fn test_simple_request(cx: &mut TestAppContext) { @@ -1146,6 +1353,7 @@ mod tests { let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); let zeta = Zeta::global(&client, &user_store, cx); + (zeta, req_rx) }) } diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index ac4f27be81243c257efa8a5cc498aa95ce6979d7..e553d941325e1bf4cd4dc0db93175cac51514927 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -185,7 +185,7 @@ impl Zeta2Inspector { cx.background_executor().timer(THROTTLE_TIME).await; if let Some(task) = zeta .update(cx, |zeta, cx| { - zeta.request_prediction(&project, &buffer, position, cx) + zeta.refresh_prediction(&project, &buffer, position, cx) }) .ok() {