zeta2: Allow provider to suggest edits in different files (#39110)

Bennet Bo Fenner and Agus Zubiaga created

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>

Change summary

Cargo.lock                                              |   1 
crates/cloud_llm_client/src/predict_edits_v3.rs         |  13 
crates/copilot/src/copilot_completion_provider.rs       |   4 
crates/edit_prediction/Cargo.toml                       |   1 
crates/edit_prediction/src/edit_prediction.rs           |  24 
crates/editor/src/edit_prediction_tests.rs              |  11 
crates/editor/src/editor.rs                             | 191 +++++-
crates/editor/src/editor_tests.rs                       |   2 
crates/supermaven/Cargo.toml                            |   1 
crates/supermaven/src/supermaven_completion_provider.rs |   4 
crates/zed/src/zed/edit_prediction_registry.rs          |  70 +-
crates/zeta/src/zeta.rs                                 |  15 
crates/zeta2/src/prediction.rs                          | 143 +++
crates/zeta2/src/provider.rs                            | 179 +---
crates/zeta2/src/zeta2.rs                               | 332 ++++++++--
crates/zeta2_tools/src/zeta2_tools.rs                   |   2 
16 files changed, 682 insertions(+), 311 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5126,7 +5126,6 @@ dependencies = [
  "client",
  "gpui",
  "language",
- "project",
  "workspace-hack",
 ]
 

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -43,15 +43,24 @@ pub struct PredictEditsRequest {
     pub prompt_format: PromptFormat,
 }
 
-#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
+#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
 pub enum PromptFormat {
-    #[default]
     MarkedExcerpt,
     LabeledSections,
     /// Prompt format intended for use via zeta_cli
     OnlySnippets,
 }
 
+impl PromptFormat {
+    pub const DEFAULT: PromptFormat = PromptFormat::LabeledSections;
+}
+
+impl Default for PromptFormat {
+    fn default() -> Self {
+        Self::DEFAULT
+    }
+}
+
 impl PromptFormat {
     pub fn iter() -> impl Iterator<Item = Self> {
         <Self as strum::IntoEnumIterator>::iter()

crates/copilot/src/copilot_completion_provider.rs 🔗

@@ -3,7 +3,6 @@ use anyhow::Result;
 use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
 use gpui::{App, Context, Entity, EntityId, Task};
 use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings};
-use project::Project;
 use settings::Settings;
 use std::{path::Path, time::Duration};
 
@@ -84,7 +83,6 @@ impl EditPredictionProvider for CopilotCompletionProvider {
 
     fn refresh(
         &mut self,
-        _project: Option<Entity<Project>>,
         buffer: Entity<Buffer>,
         cursor_position: language::Anchor,
         debounce: bool,
@@ -249,7 +247,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
                 None
             } else {
                 let position = cursor_position.bias_right(buffer);
-                Some(EditPrediction {
+                Some(EditPrediction::Local {
                     id: None,
                     edits: vec![(position..position, completion_text.into())],
                     edit_preview: None,

crates/edit_prediction/Cargo.toml 🔗

@@ -15,5 +15,4 @@ path = "src/edit_prediction.rs"
 client.workspace = true
 gpui.workspace = true
 language.workspace = true
-project.workspace = true
 workspace-hack.workspace = true

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -3,7 +3,6 @@ use std::ops::Range;
 use client::EditPredictionUsage;
 use gpui::{App, Context, Entity, SharedString};
 use language::Buffer;
-use project::Project;
 
 // TODO: Find a better home for `Direction`.
 //
@@ -16,11 +15,19 @@ pub enum Direction {
 }
 
 #[derive(Clone)]
-pub struct EditPrediction {
-    /// The ID of the completion, if it has one.
-    pub id: Option<SharedString>,
-    pub edits: Vec<(Range<language::Anchor>, String)>,
-    pub edit_preview: Option<language::EditPreview>,
+pub enum EditPrediction {
+    /// Edits within the buffer that requested the prediction
+    Local {
+        id: Option<SharedString>,
+        edits: Vec<(Range<language::Anchor>, String)>,
+        edit_preview: Option<language::EditPreview>,
+    },
+    /// Jump to a different file from the one that requested the prediction
+    Jump {
+        id: Option<SharedString>,
+        snapshot: language::BufferSnapshot,
+        target: language::Anchor,
+    },
 }
 
 pub enum DataCollectionState {
@@ -83,7 +90,6 @@ pub trait EditPredictionProvider: 'static + Sized {
     fn is_refreshing(&self) -> bool;
     fn refresh(
         &mut self,
-        project: Option<Entity<Project>>,
         buffer: Entity<Buffer>,
         cursor_position: language::Anchor,
         debounce: bool,
@@ -124,7 +130,6 @@ pub trait EditPredictionProviderHandle {
     fn is_refreshing(&self, cx: &App) -> bool;
     fn refresh(
         &self,
-        project: Option<Entity<Project>>,
         buffer: Entity<Buffer>,
         cursor_position: language::Anchor,
         debounce: bool,
@@ -198,14 +203,13 @@ where
 
     fn refresh(
         &self,
-        project: Option<Entity<Project>>,
         buffer: Entity<Buffer>,
         cursor_position: language::Anchor,
         debounce: bool,
         cx: &mut App,
     ) {
         self.update(cx, |this, cx| {
-            this.refresh(project, buffer, cursor_position, debounce, cx)
+            this.refresh(buffer, cursor_position, debounce, cx)
         })
     }
 

crates/editor/src/edit_prediction_tests.rs 🔗

@@ -2,7 +2,6 @@ use edit_prediction::EditPredictionProvider;
 use gpui::{Entity, prelude::*};
 use indoc::indoc;
 use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
-use project::Project;
 use std::ops::Range;
 use text::{Point, ToOffset};
 
@@ -261,7 +260,7 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui:
                 EditPrediction::Edit { .. } => {
                     // This is expected for non-Zed providers
                 }
-                EditPrediction::Move { .. } => {
+                EditPrediction::MoveWithin { .. } | EditPrediction::MoveOutside { .. } => {
                     panic!(
                         "Non-Zed providers should not show Move predictions (jump functionality)"
                     );
@@ -299,7 +298,7 @@ fn assert_editor_active_move_completion(
             .as_ref()
             .expect("editor has no active completion");
 
-        if let EditPrediction::Move { target, .. } = &completion_state.completion {
+        if let EditPrediction::MoveWithin { target, .. } = &completion_state.completion {
             assert(editor.buffer().read(cx).snapshot(cx), *target);
         } else {
             panic!("expected move completion");
@@ -326,7 +325,7 @@ fn propose_edits<T: ToOffset>(
 
     cx.update(|_, cx| {
         provider.update(cx, |provider, _| {
-            provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
+            provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
                 id: None,
                 edits: edits.collect(),
                 edit_preview: None,
@@ -357,7 +356,7 @@ fn propose_edits_non_zed<T: ToOffset>(
 
     cx.update(|_, cx| {
         provider.update(cx, |provider, _| {
-            provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
+            provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
                 id: None,
                 edits: edits.collect(),
                 edit_preview: None,
@@ -418,7 +417,6 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
 
     fn refresh(
         &mut self,
-        _project: Option<Entity<Project>>,
         _buffer: gpui::Entity<language::Buffer>,
         _cursor_position: language::Anchor,
         _debounce: bool,
@@ -492,7 +490,6 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
 
     fn refresh(
         &mut self,
-        _project: Option<Entity<Project>>,
         _buffer: gpui::Entity<language::Buffer>,
         _cursor_position: language::Anchor,
         _debounce: bool,

crates/editor/src/editor.rs 🔗

@@ -638,17 +638,23 @@ enum EditPrediction {
         display_mode: EditDisplayMode,
         snapshot: BufferSnapshot,
     },
-    Move {
+    /// Move to a specific location in the active editor
+    MoveWithin {
         target: Anchor,
         snapshot: BufferSnapshot,
     },
+    /// Move to a specific location in a different editor (not the active one)
+    MoveOutside {
+        target: language::Anchor,
+        snapshot: BufferSnapshot,
+    },
 }
 
 struct EditPredictionState {
     inlay_ids: Vec<InlayId>,
     completion: EditPrediction,
     completion_id: Option<SharedString>,
-    invalidation_range: Range<Anchor>,
+    invalidation_range: Option<Range<Anchor>>,
 }
 
 enum EditPredictionSettings {
@@ -7175,13 +7181,7 @@ impl Editor {
             return None;
         }
 
-        provider.refresh(
-            self.project.clone(),
-            buffer,
-            cursor_buffer_position,
-            debounce,
-            cx,
-        );
+        provider.refresh(buffer, cursor_buffer_position, debounce, cx);
         Some(())
     }
 
@@ -7424,10 +7424,8 @@ impl Editor {
             return;
         };
 
-        self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx);
-
         match &active_edit_prediction.completion {
-            EditPrediction::Move { target, .. } => {
+            EditPrediction::MoveWithin { target, .. } => {
                 let target = *target;
 
                 if let Some(position_map) = &self.last_position_map {
@@ -7469,7 +7467,19 @@ impl Editor {
                     }
                 }
             }
+            EditPrediction::MoveOutside { snapshot, target } => {
+                if let Some(workspace) = self.workspace() {
+                    Self::open_editor_at_anchor(snapshot, *target, &workspace, window, cx)
+                        .detach_and_log_err(cx);
+                }
+            }
             EditPrediction::Edit { edits, .. } => {
+                self.report_edit_prediction_event(
+                    active_edit_prediction.completion_id.clone(),
+                    true,
+                    cx,
+                );
+
                 if let Some(provider) = self.edit_prediction_provider() {
                     provider.accept(cx);
                 }
@@ -7522,10 +7532,8 @@ impl Editor {
             return;
         }
 
-        self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx);
-
         match &active_edit_prediction.completion {
-            EditPrediction::Move { target, .. } => {
+            EditPrediction::MoveWithin { target, .. } => {
                 let target = *target;
                 self.change_selections(
                     SelectionEffects::scroll(Autoscroll::newest()),
@@ -7536,7 +7544,19 @@ impl Editor {
                     },
                 );
             }
+            EditPrediction::MoveOutside { snapshot, target } => {
+                if let Some(workspace) = self.workspace() {
+                    Self::open_editor_at_anchor(snapshot, *target, &workspace, window, cx)
+                        .detach_and_log_err(cx);
+                }
+            }
             EditPrediction::Edit { edits, .. } => {
+                self.report_edit_prediction_event(
+                    active_edit_prediction.completion_id.clone(),
+                    true,
+                    cx,
+                );
+
                 // Find an insertion that starts at the cursor position.
                 let snapshot = self.buffer.read(cx).snapshot(cx);
                 let cursor_offset = self.selections.newest::<usize>(cx).head();
@@ -7631,6 +7651,36 @@ impl Editor {
         );
     }
 
+    fn open_editor_at_anchor(
+        snapshot: &language::BufferSnapshot,
+        target: language::Anchor,
+        workspace: &Entity<Workspace>,
+        window: &mut Window,
+        cx: &mut App,
+    ) -> Task<Result<()>> {
+        workspace.update(cx, |workspace, cx| {
+            let path = snapshot.file().map(|file| file.full_path(cx));
+            let Some(path) =
+                path.and_then(|path| workspace.project().read(cx).find_project_path(path, cx))
+            else {
+                return Task::ready(Err(anyhow::anyhow!("Project path not found")));
+            };
+            let target = text::ToPoint::to_point(&target, snapshot);
+            let item = workspace.open_path(path, None, true, window, cx);
+            window.spawn(cx, async move |cx| {
+                let Some(editor) = item.await?.downcast::<Editor>() else {
+                    return Ok(());
+                };
+                editor
+                    .update_in(cx, |editor, window, cx| {
+                        editor.go_to_singleton_buffer_point(target, window, cx);
+                    })
+                    .ok();
+                anyhow::Ok(())
+            })
+        })
+    }
+
     pub fn has_active_edit_prediction(&self) -> bool {
         self.active_edit_prediction.is_some()
     }
@@ -7846,7 +7896,10 @@ impl Editor {
                 .active_edit_prediction
                 .as_ref()
                 .is_some_and(|completion| {
-                    let invalidation_range = completion.invalidation_range.to_offset(&multibuffer);
+                    let Some(invalidation_range) = completion.invalidation_range.as_ref() else {
+                        return false;
+                    };
+                    let invalidation_range = invalidation_range.to_offset(&multibuffer);
                     let invalidation_range = invalidation_range.start..=invalidation_range.end;
                     !invalidation_range.contains(&offset_selection.head())
                 })
@@ -7882,8 +7935,31 @@ impl Editor {
         }
 
         let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?;
-        let edits = edit_prediction
-            .edits
+
+        let (completion_id, edits, edit_preview) = match edit_prediction {
+            edit_prediction::EditPrediction::Local {
+                id,
+                edits,
+                edit_preview,
+            } => (id, edits, edit_preview),
+            edit_prediction::EditPrediction::Jump {
+                id,
+                snapshot,
+                target,
+            } => {
+                self.stale_edit_prediction_in_menu = None;
+                self.active_edit_prediction = Some(EditPredictionState {
+                    inlay_ids: vec![],
+                    completion: EditPrediction::MoveOutside { snapshot, target },
+                    completion_id: id,
+                    invalidation_range: None,
+                });
+                cx.notify();
+                return Some(());
+            }
+        };
+
+        let edits = edits
             .into_iter()
             .flat_map(|(range, new_text)| {
                 let start = multibuffer.anchor_in_excerpt(excerpt_id, range.start)?;
@@ -7928,7 +8004,7 @@ impl Editor {
             invalidation_row_range =
                 move_invalidation_row_range.unwrap_or(edit_start_row..edit_end_row);
             let target = first_edit_start;
-            EditPrediction::Move { target, snapshot }
+            EditPrediction::MoveWithin { target, snapshot }
         } else {
             let show_completions_in_buffer = !self.edit_prediction_visible_in_cursor_popover(true)
                 && !self.edit_predictions_hidden_for_vim_mode;
@@ -7977,7 +8053,7 @@ impl Editor {
 
             EditPrediction::Edit {
                 edits,
-                edit_preview: edit_prediction.edit_preview,
+                edit_preview,
                 display_mode,
                 snapshot,
             }
@@ -7994,8 +8070,8 @@ impl Editor {
         self.active_edit_prediction = Some(EditPredictionState {
             inlay_ids,
             completion,
-            completion_id: edit_prediction.id,
-            invalidation_range,
+            completion_id,
+            invalidation_range: Some(invalidation_range),
         });
 
         cx.notify();
@@ -8581,7 +8657,7 @@ impl Editor {
         }
 
         match &active_edit_prediction.completion {
-            EditPrediction::Move { target, .. } => {
+            EditPrediction::MoveWithin { target, .. } => {
                 let target_display_point = target.to_display_point(editor_snapshot);
 
                 if self.edit_prediction_requires_modifier() {
@@ -8666,6 +8742,28 @@ impl Editor {
                 window,
                 cx,
             ),
+            EditPrediction::MoveOutside { snapshot, .. } => {
+                let file_name = snapshot
+                    .file()
+                    .map(|file| file.file_name(cx))
+                    .unwrap_or("untitled");
+                let mut element = self
+                    .render_edit_prediction_line_popover(
+                        format!("Jump to {file_name}"),
+                        Some(IconName::ZedPredict),
+                        window,
+                        cx,
+                    )
+                    .into_any();
+
+                let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
+                let origin_x = text_bounds.size.width / 2. - size.width / 2.;
+                let origin_y = text_bounds.size.height - size.height - px(30.);
+                let origin = text_bounds.origin + gpui::Point::new(origin_x, origin_y);
+                element.prepaint_at(origin, window, cx);
+
+                Some((element, origin))
+            }
         }
     }
 
@@ -8730,13 +8828,13 @@ impl Editor {
             .items_end()
             .when(flag_on_right, |el| el.items_start())
             .child(if flag_on_right {
-                self.render_edit_prediction_line_popover("Jump", None, window, cx)?
+                self.render_edit_prediction_line_popover("Jump", None, window, cx)
                     .rounded_bl(px(0.))
                     .rounded_tl(px(0.))
                     .border_l_2()
                     .border_color(border_color)
             } else {
-                self.render_edit_prediction_line_popover("Jump", None, window, cx)?
+                self.render_edit_prediction_line_popover("Jump", None, window, cx)
                     .rounded_br(px(0.))
                     .rounded_tr(px(0.))
                     .border_r_2()
@@ -8776,7 +8874,7 @@ impl Editor {
         cx: &mut App,
     ) -> Option<(AnyElement, gpui::Point<Pixels>)> {
         let mut element = self
-            .render_edit_prediction_line_popover("Scroll", Some(scroll_icon), window, cx)?
+            .render_edit_prediction_line_popover("Scroll", Some(scroll_icon), window, cx)
             .into_any();
 
         let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@@ -8816,7 +8914,7 @@ impl Editor {
                     Some(IconName::ArrowUp),
                     window,
                     cx,
-                )?
+                )
                 .into_any();
 
             let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@@ -8835,7 +8933,7 @@ impl Editor {
                     Some(IconName::ArrowDown),
                     window,
                     cx,
-                )?
+                )
                 .into_any();
 
             let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@@ -8882,7 +8980,7 @@ impl Editor {
         );
 
         let mut element = self
-            .render_edit_prediction_line_popover(label, None, window, cx)?
+            .render_edit_prediction_line_popover(label, None, window, cx)
             .into_any();
 
         let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@@ -8909,7 +9007,7 @@ impl Editor {
             };
 
             element = self
-                .render_edit_prediction_line_popover(label, Some(icon), window, cx)?
+                .render_edit_prediction_line_popover(label, Some(icon), window, cx)
                 .into_any();
 
             let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@@ -9163,13 +9261,13 @@ impl Editor {
         icon: Option<IconName>,
         window: &mut Window,
         cx: &App,
-    ) -> Option<Stateful<Div>> {
+    ) -> Stateful<Div> {
         let padding_right = if icon.is_some() { px(4.) } else { px(8.) };
 
         let keybind = self.render_edit_prediction_accept_keybind(window, cx);
         let has_keybind = keybind.is_some();
 
-        let result = h_flex()
+        h_flex()
             .id("ep-line-popover")
             .py_0p5()
             .pl_1()
@@ -9215,9 +9313,7 @@ impl Editor {
                         .mt(px(1.5))
                         .child(Icon::new(icon).size(IconSize::Small)),
                 )
-            });
-
-        Some(result)
+            })
     }
 
     fn edit_prediction_line_popover_bg_color(cx: &App) -> Hsla {
@@ -9281,7 +9377,7 @@ impl Editor {
                             .rounded_tl(px(0.))
                             .overflow_hidden()
                             .child(div().px_1p5().child(match &prediction.completion {
-                                EditPrediction::Move { target, snapshot } => {
+                                EditPrediction::MoveWithin { target, snapshot } => {
                                     use text::ToPoint as _;
                                     if target.text_anchor.to_point(snapshot).row > cursor_point.row
                                     {
@@ -9290,6 +9386,10 @@ impl Editor {
                                         Icon::new(IconName::ZedPredictUp)
                                     }
                                 }
+                                EditPrediction::MoveOutside { .. } => {
+                                    // TODO [zeta2] custom icon for external jump?
+                                    Icon::new(provider_icon)
+                                }
                                 EditPrediction::Edit { .. } => Icon::new(provider_icon),
                             }))
                             .child(
@@ -9472,7 +9572,7 @@ impl Editor {
             .unwrap_or(true);
 
         match &completion.completion {
-            EditPrediction::Move {
+            EditPrediction::MoveWithin {
                 target, snapshot, ..
             } => {
                 if !supports_jump {
@@ -9494,7 +9594,20 @@ impl Editor {
                         .child(Label::new("Jump to Edit")),
                 )
             }
-
+            EditPrediction::MoveOutside { snapshot, .. } => {
+                let file_name = snapshot
+                    .file()
+                    .map(|file| file.file_name(cx))
+                    .unwrap_or("untitled");
+                Some(
+                    h_flex()
+                        .px_2()
+                        .gap_2()
+                        .flex_1()
+                        .child(Icon::new(IconName::ZedPredict))
+                        .child(Label::new(format!("Jump to {file_name}"))),
+                )
+            }
             EditPrediction::Edit {
                 edits,
                 edit_preview,
@@ -21418,7 +21531,7 @@ impl Editor {
         {
             self.hide_context_menu(window, cx);
         }
-        self.discard_edit_prediction(false, cx);
+        self.take_active_edit_prediction(cx);
         cx.emit(EditorEvent::Blurred);
         cx.notify();
     }

crates/editor/src/editor_tests.rs 🔗

@@ -8272,7 +8272,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext)
 
     cx.update(|_, cx| {
         provider.update(cx, |provider, _| {
-            provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
+            provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
                 id: None,
                 edits: vec![(edit_position..edit_position, "X".into())],
                 edit_preview: None,

crates/supermaven/Cargo.toml 🔗

@@ -22,7 +22,6 @@ gpui.workspace = true
 language.workspace = true
 log.workspace = true
 postage.workspace = true
-project.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true

crates/supermaven/src/supermaven_completion_provider.rs 🔗

@@ -4,7 +4,6 @@ use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
 use futures::StreamExt as _;
 use gpui::{App, Context, Entity, EntityId, Task};
 use language::{Anchor, Buffer, BufferSnapshot};
-use project::Project;
 use std::{
     ops::{AddAssign, Range},
     path::Path,
@@ -94,7 +93,7 @@ fn completion_from_diff(
         edits.push((edit_range, edit_text));
     }
 
-    EditPrediction {
+    EditPrediction::Local {
         id: None,
         edits,
         edit_preview: None,
@@ -132,7 +131,6 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
 
     fn refresh(
         &mut self,
-        _project: Option<Entity<Project>>,
         buffer_handle: Entity<Buffer>,
         cursor_position: Anchor,
         debounce: bool,

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -205,42 +205,48 @@ fn assign_edit_prediction_provider(
                     }
                 }
 
-                if std::env::var("ZED_ZETA2").is_ok() {
-                    let zeta = zeta2::Zeta::global(client, &user_store, cx);
-                    let provider = cx.new(|cx| {
-                        zeta2::ZetaEditPredictionProvider::new(
-                            editor.project(),
-                            &client,
-                            &user_store,
-                            cx,
-                        )
-                    });
-
-                    if let Some(buffer) = &singleton_buffer
-                        && buffer.read(cx).file().is_some()
-                        && let Some(project) = editor.project()
-                    {
-                        zeta.update(cx, |zeta, cx| {
-                            zeta.register_buffer(buffer, project, cx);
+                if let Some(project) = editor.project() {
+                    if std::env::var("ZED_ZETA2").is_ok() {
+                        let zeta = zeta2::Zeta::global(client, &user_store, cx);
+                        let provider = cx.new(|cx| {
+                            zeta2::ZetaEditPredictionProvider::new(
+                                project.clone(),
+                                &client,
+                                &user_store,
+                                cx,
+                            )
                         });
-                    }
 
-                    editor.set_edit_prediction_provider(Some(provider), window, cx);
-                } else {
-                    let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
-
-                    if let Some(buffer) = &singleton_buffer
-                        && buffer.read(cx).file().is_some()
-                        && let Some(project) = editor.project()
-                    {
-                        zeta.update(cx, |zeta, cx| {
-                            zeta.register_buffer(buffer, project, cx);
+                        // TODO [zeta2] handle multibuffers
+                        if let Some(buffer) = &singleton_buffer
+                            && buffer.read(cx).file().is_some()
+                        {
+                            zeta.update(cx, |zeta, cx| {
+                                zeta.register_buffer(buffer, project, cx);
+                            });
+                        }
+
+                        editor.set_edit_prediction_provider(Some(provider), window, cx);
+                    } else {
+                        let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
+
+                        if let Some(buffer) = &singleton_buffer
+                            && buffer.read(cx).file().is_some()
+                        {
+                            zeta.update(cx, |zeta, cx| {
+                                zeta.register_buffer(buffer, project, cx);
+                            });
+                        }
+
+                        let provider = cx.new(|_| {
+                            zeta::ZetaEditPredictionProvider::new(
+                                zeta,
+                                project.clone(),
+                                singleton_buffer,
+                            )
                         });
+                        editor.set_edit_prediction_provider(Some(provider), window, cx);
                     }
-
-                    let provider =
-                        cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer));
-                    editor.set_edit_prediction_provider(Some(provider), window, cx);
                 }
             }
         }

crates/zeta/src/zeta.rs 🔗

@@ -1316,12 +1316,17 @@ pub struct ZetaEditPredictionProvider {
     next_pending_completion_id: usize,
     current_completion: Option<CurrentEditPrediction>,
     last_request_timestamp: Instant,
+    project: Entity<Project>,
 }
 
 impl ZetaEditPredictionProvider {
     pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 
-    pub fn new(zeta: Entity<Zeta>, singleton_buffer: Option<Entity<Buffer>>) -> Self {
+    pub fn new(
+        zeta: Entity<Zeta>,
+        project: Entity<Project>,
+        singleton_buffer: Option<Entity<Buffer>>,
+    ) -> Self {
         Self {
             zeta,
             singleton_buffer,
@@ -1329,6 +1334,7 @@ impl ZetaEditPredictionProvider {
             next_pending_completion_id: 0,
             current_completion: None,
             last_request_timestamp: Instant::now(),
+            project,
         }
     }
 }
@@ -1394,7 +1400,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
 
     fn refresh(
         &mut self,
-        project: Option<Entity<Project>>,
         buffer: Entity<Buffer>,
         position: language::Anchor,
         _debounce: bool,
@@ -1403,9 +1408,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
         if self.zeta.read(cx).update_required {
             return;
         }
-        let Some(project) = project else {
-            return;
-        };
 
         if self
             .zeta
@@ -1433,6 +1435,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
         self.next_pending_completion_id += 1;
         let last_request_timestamp = self.last_request_timestamp;
 
+        let project = self.project.clone();
         let task = cx.spawn(async move |this, cx| {
             if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
                 .checked_duration_since(Instant::now())
@@ -1604,7 +1607,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
             }
         }
 
-        Some(edit_prediction::EditPrediction {
+        Some(edit_prediction::EditPrediction::Local {
             id: Some(completion.id.to_string().into()),
             edits: edits[edit_start_ix..edit_end_ix].to_vec(),
             edit_preview: Some(completion.edit_preview.clone()),

crates/zeta2/src/prediction.rs 🔗

@@ -1,50 +1,146 @@
-use std::{borrow::Cow, ops::Range, sync::Arc};
+use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
 
+use anyhow::Context as _;
 use cloud_llm_client::predict_edits_v3;
-use language::{Anchor, BufferSnapshot, EditPreview, OffsetRangeExt, text_diff};
+use gpui::{App, AsyncApp, Entity};
+use language::{
+    Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
+};
+use project::Project;
+use util::ResultExt;
 use uuid::Uuid;
 
+#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
+pub struct EditPredictionId(Uuid);
+
+impl From<EditPredictionId> for gpui::ElementId {
+    fn from(value: EditPredictionId) -> Self {
+        gpui::ElementId::Uuid(value.0)
+    }
+}
+
+impl std::fmt::Display for EditPredictionId {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "{}", self.0)
+    }
+}
+
 #[derive(Clone)]
 pub struct EditPrediction {
     pub id: EditPredictionId,
+    pub path: Arc<Path>,
     pub edits: Arc<[(Range<Anchor>, String)]>,
     pub snapshot: BufferSnapshot,
     pub edit_preview: EditPreview,
+    // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
+    _buffer: Entity<Buffer>,
 }
 
 impl EditPrediction {
+    pub async fn from_response(
+        response: predict_edits_v3::PredictEditsResponse,
+        active_buffer_old_snapshot: &TextBufferSnapshot,
+        active_buffer: &Entity<Buffer>,
+        project: &Entity<Project>,
+        cx: &mut AsyncApp,
+    ) -> Option<Self> {
+        // TODO only allow cloud to return one path
+        let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
+            return None;
+        };
+
+        let is_same_path = active_buffer
+            .read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
+            .ok()?;
+
+        let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
+            active_buffer
+                .read_with(cx, |buffer, cx| {
+                    let new_snapshot = buffer.snapshot();
+                    let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
+                    let edits: Arc<[_]> =
+                        interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
+
+                    Some((
+                        active_buffer.clone(),
+                        edits.clone(),
+                        new_snapshot,
+                        buffer.preview_edits(edits, cx),
+                    ))
+                })
+                .ok()??
+        } else {
+            let buffer_handle = project
+                .update(cx, |project, cx| {
+                    let project_path = project
+                        .find_project_path(&path, cx)
+                        .context("Failed to find project path for zeta edit")?;
+                    anyhow::Ok(project.open_buffer(project_path, cx))
+                })
+                .ok()?
+                .log_err()?
+                .await
+                .context("Failed to open buffer for zeta edit")
+                .log_err()?;
+
+            buffer_handle
+                .read_with(cx, |buffer, cx| {
+                    let snapshot = buffer.snapshot();
+                    let edits = edits_from_response(&response.edits, &snapshot);
+                    if edits.is_empty() {
+                        return None;
+                    }
+                    Some((
+                        buffer_handle.clone(),
+                        edits.clone(),
+                        snapshot,
+                        buffer.preview_edits(edits, cx),
+                    ))
+                })
+                .ok()??
+        };
+
+        let edit_preview = edit_preview_task.await;
+
+        Some(EditPrediction {
+            id: EditPredictionId(response.request_id),
+            path,
+            edits,
+            snapshot,
+            edit_preview,
+            _buffer: buffer,
+        })
+    }
+
     pub fn interpolate(
         &self,
-        new_snapshot: &BufferSnapshot,
+        new_snapshot: &TextBufferSnapshot,
     ) -> Option<Vec<(Range<Anchor>, String)>> {
         interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
     }
-}
 
-#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
-pub struct EditPredictionId(Uuid);
-
-impl From<Uuid> for EditPredictionId {
-    fn from(value: Uuid) -> Self {
-        EditPredictionId(value)
+    pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
+        buffer_path_eq(buffer, &self.path, cx)
     }
 }
 
-impl From<EditPredictionId> for gpui::ElementId {
-    fn from(value: EditPredictionId) -> Self {
-        gpui::ElementId::Uuid(value.0)
+impl std::fmt::Debug for EditPrediction {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("EditPrediction")
+            .field("id", &self.id)
+            .field("path", &self.path)
+            .field("edits", &self.edits)
+            .finish()
     }
 }
 
-impl std::fmt::Display for EditPredictionId {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        write!(f, "{}", self.0)
-    }
+pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
+    buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
 }
 
 pub fn interpolate_edits(
-    old_snapshot: &BufferSnapshot,
-    new_snapshot: &BufferSnapshot,
+    old_snapshot: &TextBufferSnapshot,
+    new_snapshot: &TextBufferSnapshot,
     current_edits: Arc<[(Range<Anchor>, String)]>,
 ) -> Option<Vec<(Range<Anchor>, String)>> {
     let mut edits = Vec::new();
@@ -88,14 +184,13 @@ pub fn interpolate_edits(
     if edits.is_empty() { None } else { Some(edits) }
 }
 
-pub fn edits_from_response(
+fn edits_from_response(
     edits: &[predict_edits_v3::Edit],
-    snapshot: &BufferSnapshot,
+    snapshot: &TextBufferSnapshot,
 ) -> Arc<[(Range<Anchor>, String)]> {
     edits
         .iter()
         .flat_map(|edit| {
-            // TODO multi-file edits
             let old_text = snapshot.text_for_range(edit.range.clone());
 
             excerpt_edits_from_response(
@@ -113,7 +208,7 @@ fn excerpt_edits_from_response(
     old_text: Cow<str>,
     new_text: &str,
     offset: usize,
-    snapshot: &BufferSnapshot,
+    snapshot: &TextBufferSnapshot,
 ) -> impl Iterator<Item = (Range<Anchor>, String)> {
     text_diff(&old_text, new_text)
         .into_iter()
@@ -221,6 +316,8 @@ mod tests {
             id: EditPredictionId(Uuid::new_v4()),
             edits,
             snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
+            path: Path::new("test.txt").into(),
+            _buffer: buffer.clone(),
             edit_preview,
         };
 

crates/zeta2/src/provider.rs 🔗

@@ -4,76 +4,44 @@ use std::{
     time::{Duration, Instant},
 };
 
-use anyhow::Context as _;
 use arrayvec::ArrayVec;
 use client::{Client, UserStore};
 use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
-use gpui::{App, Entity, EntityId, Task, prelude::*};
-use language::{BufferSnapshot, ToPoint as _};
+use gpui::{App, Entity, Task, prelude::*};
+use language::ToPoint as _;
 use project::Project;
 use util::ResultExt as _;
 
-use crate::{Zeta, prediction::EditPrediction};
+use crate::{BufferEditPrediction, Zeta};
 
 pub struct ZetaEditPredictionProvider {
     zeta: Entity<Zeta>,
-    current_prediction: Option<CurrentEditPrediction>,
     next_pending_prediction_id: usize,
     pending_predictions: ArrayVec<PendingPrediction, 2>,
     last_request_timestamp: Instant,
+    project: Entity<Project>,
 }
 
 impl ZetaEditPredictionProvider {
     pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 
     pub fn new(
-        project: Option<&Entity<Project>>,
+        project: Entity<Project>,
         client: &Arc<Client>,
         user_store: &Entity<UserStore>,
         cx: &mut App,
     ) -> Self {
         let zeta = Zeta::global(client, user_store, cx);
-        if let Some(project) = project {
-            zeta.update(cx, |zeta, cx| {
-                zeta.register_project(project, cx);
-            });
-        }
+        zeta.update(cx, |zeta, cx| {
+            zeta.register_project(&project, cx);
+        });
 
         Self {
             zeta,
-            current_prediction: None,
             next_pending_prediction_id: 0,
             pending_predictions: ArrayVec::new(),
             last_request_timestamp: Instant::now(),
-        }
-    }
-}
-
-#[derive(Clone)]
-struct CurrentEditPrediction {
-    buffer_id: EntityId,
-    prediction: EditPrediction,
-}
-
-impl CurrentEditPrediction {
-    fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
-        if self.buffer_id != old_prediction.buffer_id {
-            return true;
-        }
-
-        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
-            return true;
-        };
-        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
-            return false;
-        };
-
-        if old_edits.len() == 1 && new_edits.len() == 1 {
-            let (old_range, old_text) = &old_edits[0];
-            let (new_range, new_text) = &new_edits[0];
-            new_range == old_range && new_text.starts_with(old_text)
-        } else {
-            true
+            project: project,
         }
     }
 }
@@ -128,42 +96,31 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
 
     fn refresh(
         &mut self,
-        project: Option<Entity<project::Project>>,
         buffer: Entity<language::Buffer>,
         cursor_position: language::Anchor,
         _debounce: bool,
         cx: &mut Context<Self>,
     ) {
-        let Some(project) = project else {
-            return;
-        };
+        let zeta = self.zeta.read(cx);
 
-        if self
-            .zeta
-            .read(cx)
-            .user_store
-            .read_with(cx, |user_store, _cx| {
-                user_store.account_too_young() || user_store.has_overdue_invoices()
-            })
-        {
+        if zeta.user_store.read_with(cx, |user_store, _cx| {
+            user_store.account_too_young() || user_store.has_overdue_invoices()
+        }) {
             return;
         }
 
-        if let Some(current_prediction) = self.current_prediction.as_ref() {
-            let snapshot = buffer.read(cx).snapshot();
-            if current_prediction
-                .prediction
-                .interpolate(&snapshot)
-                .is_some()
-            {
-                return;
-            }
+        if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
+            && let BufferEditPrediction::Local { prediction } = current
+            && prediction.interpolate(buffer.read(cx)).is_some()
+        {
+            return;
         }
 
         let pending_prediction_id = self.next_pending_prediction_id;
         self.next_pending_prediction_id += 1;
         let last_request_timestamp = self.last_request_timestamp;
 
+        let project = self.project.clone();
         let task = cx.spawn(async move |this, cx| {
             if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
                 .checked_duration_since(Instant::now())
@@ -171,25 +128,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
                 cx.background_executor().timer(timeout).await;
             }
 
-            let prediction_request = this.update(cx, |this, cx| {
+            let refresh_task = this.update(cx, |this, cx| {
                 this.last_request_timestamp = Instant::now();
                 this.zeta.update(cx, |zeta, cx| {
-                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
+                    zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
                 })
             });
 
-            let prediction = match prediction_request {
-                Ok(prediction_request) => {
-                    let prediction_request = prediction_request.await;
-                    prediction_request.map(|c| {
-                        c.map(|prediction| CurrentEditPrediction {
-                            buffer_id: buffer.entity_id(),
-                            prediction,
-                        })
-                    })
-                }
-                Err(error) => Err(error),
-            };
+            if let Some(refresh_task) = refresh_task.ok() {
+                refresh_task.await.log_err();
+            }
 
             this.update(cx, |this, cx| {
                 if this.pending_predictions[0].id == pending_prediction_id {
@@ -198,24 +146,6 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
                     this.pending_predictions.clear();
                 }
 
-                let Some(new_prediction) = prediction
-                    .context("edit prediction failed")
-                    .log_err()
-                    .flatten()
-                else {
-                    cx.notify();
-                    return;
-                };
-
-                if let Some(old_prediction) = this.current_prediction.as_ref() {
-                    let snapshot = buffer.read(cx).snapshot();
-                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
-                        this.current_prediction = Some(new_prediction);
-                    }
-                } else {
-                    this.current_prediction = Some(new_prediction);
-                }
-
                 cx.notify();
             })
             .ok();
@@ -248,15 +178,18 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
     ) {
     }
 
-    fn accept(&mut self, _cx: &mut Context<Self>) {
-        // TODO [zeta2] report accept
-        self.current_prediction.take();
+    fn accept(&mut self, cx: &mut Context<Self>) {
+        self.zeta.update(cx, |zeta, _cx| {
+            zeta.accept_current_prediction(&self.project);
+        });
         self.pending_predictions.clear();
     }
 
-    fn discard(&mut self, _cx: &mut Context<Self>) {
+    fn discard(&mut self, cx: &mut Context<Self>) {
+        self.zeta.update(cx, |zeta, _cx| {
+            zeta.discard_current_prediction(&self.project);
+        });
         self.pending_predictions.clear();
-        self.current_prediction.take();
     }
 
     fn suggest(
@@ -265,36 +198,44 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
         cursor_position: language::Anchor,
         cx: &mut Context<Self>,
     ) -> Option<edit_prediction::EditPrediction> {
-        let CurrentEditPrediction {
-            buffer_id,
-            prediction,
-            ..
-        } = self.current_prediction.as_mut()?;
-
-        // Invalidate previous prediction if it was generated for a different buffer.
-        if *buffer_id != buffer.entity_id() {
-            self.current_prediction.take();
-            return None;
-        }
+        let prediction =
+            self.zeta
+                .read(cx)
+                .current_prediction_for_buffer(buffer, &self.project, cx)?;
+
+        let prediction = match prediction {
+            BufferEditPrediction::Local { prediction } => prediction,
+            BufferEditPrediction::Jump { prediction } => {
+                return Some(edit_prediction::EditPrediction::Jump {
+                    id: Some(prediction.id.to_string().into()),
+                    snapshot: prediction.snapshot.clone(),
+                    target: prediction.edits.first().unwrap().0.start,
+                });
+            }
+        };
 
         let buffer = buffer.read(cx);
-        let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
-            self.current_prediction.take();
+        let snapshot = buffer.snapshot();
+
+        let Some(edits) = prediction.interpolate(&snapshot) else {
+            self.zeta.update(cx, |zeta, _cx| {
+                zeta.discard_current_prediction(&self.project);
+            });
             return None;
         };
 
-        let cursor_row = cursor_position.to_point(buffer).row;
+        let cursor_row = cursor_position.to_point(&snapshot).row;
         let (closest_edit_ix, (closest_edit_range, _)) =
             edits.iter().enumerate().min_by_key(|(_, (range, _))| {
-                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
-                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
+                let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
+                let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
                 cmp::min(distance_from_start, distance_from_end)
             })?;
 
         let mut edit_start_ix = closest_edit_ix;
         for (range, _) in edits[..edit_start_ix].iter().rev() {
-            let distance_from_closest_edit =
-                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
+            let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
+                - range.end.to_point(&snapshot).row;
             if distance_from_closest_edit <= 1 {
                 edit_start_ix -= 1;
             } else {
@@ -305,7 +246,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
         let mut edit_end_ix = closest_edit_ix + 1;
         for (range, _) in &edits[edit_end_ix..] {
             let distance_from_closest_edit =
-                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
+                range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
             if distance_from_closest_edit <= 1 {
                 edit_end_ix += 1;
             } else {
@@ -313,7 +254,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
             }
         }
 
-        Some(edit_prediction::EditPrediction {
+        Some(edit_prediction::EditPrediction::Local {
             id: Some(prediction.id.to_string().into()),
             edits: edits[edit_start_ix..edit_end_ix].to_vec(),
             edit_preview: Some(prediction.edit_preview.clone()),

crates/zeta2/src/zeta2.rs 🔗

@@ -17,8 +17,8 @@ use gpui::{
     App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
     http_client, prelude::*,
 };
-use language::BufferSnapshot;
 use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
+use language::{BufferSnapshot, TextBufferSnapshot};
 use language_model::{LlmApiToken, RefreshLlmTokenListener};
 use project::Project;
 use release_channel::AppVersion;
@@ -35,7 +35,7 @@ use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_noti
 mod prediction;
 mod provider;
 
-use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
+use crate::prediction::EditPrediction;
 pub use provider::ZetaEditPredictionProvider;
 
 const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
@@ -53,7 +53,7 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
     excerpt: DEFAULT_EXCERPT_OPTIONS,
     max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
     max_diagnostic_bytes: 2048,
-    prompt_format: PromptFormat::MarkedExcerpt,
+    prompt_format: PromptFormat::DEFAULT,
 };
 
 #[derive(Clone)]
@@ -94,6 +94,47 @@ struct ZetaProject {
     syntax_index: Entity<SyntaxIndex>,
     events: VecDeque<Event>,
     registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
+    current_prediction: Option<CurrentEditPrediction>,
+}
+
+#[derive(Clone)]
+struct CurrentEditPrediction {
+    pub requested_by_buffer_id: EntityId,
+    pub prediction: EditPrediction,
+}
+
+impl CurrentEditPrediction {
+    fn should_replace_prediction(
+        &self,
+        old_prediction: &Self,
+        snapshot: &TextBufferSnapshot,
+    ) -> bool {
+        if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
+            return true;
+        }
+
+        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
+            return true;
+        };
+
+        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
+            return false;
+        };
+        if old_edits.len() == 1 && new_edits.len() == 1 {
+            let (old_range, old_text) = &old_edits[0];
+            let (new_range, new_text) = &new_edits[0];
+            new_range == old_range && new_text.starts_with(old_text)
+        } else {
+            true
+        }
+    }
+}
+
+/// A prediction from the perspective of a buffer.
+#[derive(Debug)]
+enum BufferEditPrediction<'a> {
+    Local { prediction: &'a EditPrediction },
+    Jump { prediction: &'a EditPrediction },
 }
 
 struct RegisteredBuffer {
@@ -204,6 +245,7 @@ impl Zeta {
                 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
                 events: VecDeque::new(),
                 registered_buffers: HashMap::new(),
+                current_prediction: None,
             })
     }
 
@@ -305,7 +347,83 @@ impl Zeta {
         events.push_back(event);
     }
 
-    pub fn request_prediction(
+    fn current_prediction_for_buffer(
+        &self,
+        buffer: &Entity<Buffer>,
+        project: &Entity<Project>,
+        cx: &App,
+    ) -> Option<BufferEditPrediction<'_>> {
+        let project_state = self.projects.get(&project.entity_id())?;
+
+        let CurrentEditPrediction {
+            requested_by_buffer_id,
+            prediction,
+        } = project_state.current_prediction.as_ref()?;
+
+        if prediction.targets_buffer(buffer.read(cx), cx) {
+            Some(BufferEditPrediction::Local { prediction })
+        } else if *requested_by_buffer_id == buffer.entity_id() {
+            Some(BufferEditPrediction::Jump { prediction })
+        } else {
+            None
+        }
+    }
+
+    fn accept_current_prediction(&mut self, project: &Entity<Project>) {
+        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
+            project_state.current_prediction.take();
+        };
+        // TODO report accepted
+    }
+
+    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
+        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
+            project_state.current_prediction.take();
+        };
+    }
+
+    pub fn refresh_prediction(
+        &mut self,
+        project: &Entity<Project>,
+        buffer: &Entity<Buffer>,
+        position: language::Anchor,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<()>> {
+        let request_task = self.request_prediction(project, buffer, position, cx);
+        let buffer = buffer.clone();
+        let project = project.clone();
+
+        cx.spawn(async move |this, cx| {
+            if let Some(prediction) = request_task.await? {
+                this.update(cx, |this, cx| {
+                    let project_state = this
+                        .projects
+                        .get_mut(&project.entity_id())
+                        .context("Project not found")?;
+
+                    let new_prediction = CurrentEditPrediction {
+                        requested_by_buffer_id: buffer.entity_id(),
+                        prediction: prediction,
+                    };
+
+                    if project_state
+                        .current_prediction
+                        .as_ref()
+                        .is_none_or(|old_prediction| {
+                            new_prediction
+                                .should_replace_prediction(&old_prediction, buffer.read(cx))
+                        })
+                    {
+                        project_state.current_prediction = Some(new_prediction);
+                    }
+                    anyhow::Ok(())
+                })??;
+            }
+            Ok(())
+        })
+    }
+
+    fn request_prediction(
         &mut self,
         project: &Entity<Project>,
         buffer: &Entity<Buffer>,
@@ -457,74 +575,63 @@ impl Zeta {
                         .ok();
                 }
 
-                let (response, usage) = response?;
-                let edits = edits_from_response(&response.edits, &snapshot);
-
-                anyhow::Ok(Some((response.request_id, edits, usage)))
+                anyhow::Ok(Some(response?))
             }
         });
 
         let buffer = buffer.clone();
 
-        cx.spawn(async move |this, cx| {
-            match request_task.await {
-                Ok(Some((id, edits, usage))) => {
-                    if let Some(usage) = usage {
-                        this.update(cx, |this, cx| {
-                            this.user_store.update(cx, |user_store, cx| {
-                                user_store.update_edit_prediction_usage(usage, cx);
-                            });
-                        })
-                        .ok();
-                    }
+        cx.spawn({
+            let project = project.clone();
+            async move |this, cx| {
+                match request_task.await {
+                    Ok(Some((response, usage))) => {
+                        if let Some(usage) = usage {
+                            this.update(cx, |this, cx| {
+                                this.user_store.update(cx, |user_store, cx| {
+                                    user_store.update_edit_prediction_usage(usage, cx);
+                                });
+                            })
+                            .ok();
+                        }
 
-                    // TODO telemetry: duration, etc
-                    let Some((edits, snapshot, edit_preview_task)) =
-                        buffer.read_with(cx, |buffer, cx| {
-                            let new_snapshot = buffer.snapshot();
-                            let edits: Arc<[_]> =
-                                interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
-                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
-                        })?
-                    else {
-                        return Ok(None);
-                    };
+                        let prediction = EditPrediction::from_response(
+                            response, &snapshot, &buffer, &project, cx,
+                        )
+                        .await;
 
-                    Ok(Some(EditPrediction {
-                        id: id.into(),
-                        edits,
-                        snapshot,
-                        edit_preview: edit_preview_task.await,
-                    }))
-                }
-                Ok(None) => Ok(None),
-                Err(err) => {
-                    if err.is::<ZedUpdateRequiredError>() {
-                        cx.update(|cx| {
-                            this.update(cx, |this, _cx| {
-                                this.update_required = true;
+                        // TODO telemetry: duration, etc
+                        Ok(prediction)
+                    }
+                    Ok(None) => Ok(None),
+                    Err(err) => {
+                        if err.is::<ZedUpdateRequiredError>() {
+                            cx.update(|cx| {
+                                this.update(cx, |this, _cx| {
+                                    this.update_required = true;
+                                })
+                                .ok();
+
+                                let error_message: SharedString = err.to_string().into();
+                                show_app_notification(
+                                    NotificationId::unique::<ZedUpdateRequiredError>(),
+                                    cx,
+                                    move |cx| {
+                                        cx.new(|cx| {
+                                            ErrorMessagePrompt::new(error_message.clone(), cx)
+                                                .with_link_button(
+                                                    "Update Zed",
+                                                    "https://zed.dev/releases",
+                                                )
+                                        })
+                                    },
+                                );
                             })
                             .ok();
+                        }
 
-                            let error_message: SharedString = err.to_string().into();
-                            show_app_notification(
-                                NotificationId::unique::<ZedUpdateRequiredError>(),
-                                cx,
-                                move |cx| {
-                                    cx.new(|cx| {
-                                        ErrorMessagePrompt::new(error_message.clone(), cx)
-                                            .with_link_button(
-                                                "Update Zed",
-                                                "https://zed.dev/releases",
-                                            )
-                                    })
-                                },
-                            );
-                        })
-                        .ok();
+                        Err(err)
                     }
-
-                    Err(err)
                 }
             }
         })
@@ -859,13 +966,113 @@ mod tests {
     };
     use indoc::indoc;
     use language::{LanguageServerId, OffsetRangeExt as _};
+    use pretty_assertions::{assert_eq, assert_matches};
     use project::{FakeFs, Project};
     use serde_json::json;
     use settings::SettingsStore;
     use util::path;
     use uuid::Uuid;
 
-    use crate::Zeta;
+    use crate::{BufferEditPrediction, Zeta};
+
+    #[gpui::test]
+    async fn test_current_state(cx: &mut TestAppContext) {
+        let (zeta, mut req_rx) = init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/root",
+            json!({
+                "1.txt": "Hello!\nHow\nBye",
+                "2.txt": "Hola!\nComo\nAdios"
+            }),
+        )
+        .await;
+        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+        zeta.update(cx, |zeta, cx| {
+            zeta.register_project(&project, cx);
+        });
+
+        let buffer1 = project
+            .update(cx, |project, cx| {
+                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
+                project.open_buffer(path, cx)
+            })
+            .await
+            .unwrap();
+        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
+        let position = snapshot1.anchor_before(language::Point::new(1, 3));
+
+        // Prediction for current file
+
+        let prediction_task = zeta.update(cx, |zeta, cx| {
+            zeta.refresh_prediction(&project, &buffer1, position, cx)
+        });
+        let (_request, respond_tx) = req_rx.next().await.unwrap();
+        respond_tx
+            .send(predict_edits_v3::PredictEditsResponse {
+                request_id: Uuid::new_v4(),
+                edits: vec![predict_edits_v3::Edit {
+                    path: Path::new(path!("root/1.txt")).into(),
+                    range: 0..snapshot1.len(),
+                    content: "Hello!\nHow are you?\nBye".into(),
+                }],
+                debug_info: None,
+            })
+            .unwrap();
+        prediction_task.await.unwrap();
+
+        zeta.read_with(cx, |zeta, cx| {
+            let prediction = zeta
+                .current_prediction_for_buffer(&buffer1, &project, cx)
+                .unwrap();
+            assert_matches!(prediction, BufferEditPrediction::Local { .. });
+        });
+
+        // Prediction for another file
+
+        let prediction_task = zeta.update(cx, |zeta, cx| {
+            zeta.refresh_prediction(&project, &buffer1, position, cx)
+        });
+        let (_request, respond_tx) = req_rx.next().await.unwrap();
+        respond_tx
+            .send(predict_edits_v3::PredictEditsResponse {
+                request_id: Uuid::new_v4(),
+                edits: vec![predict_edits_v3::Edit {
+                    path: Path::new(path!("root/2.txt")).into(),
+                    range: 0..snapshot1.len(),
+                    content: "Hola!\nComo estas?\nAdios".into(),
+                }],
+                debug_info: None,
+            })
+            .unwrap();
+        prediction_task.await.unwrap();
+
+        zeta.read_with(cx, |zeta, cx| {
+            let prediction = zeta
+                .current_prediction_for_buffer(&buffer1, &project, cx)
+                .unwrap();
+            assert_matches!(
+                prediction,
+                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
+            );
+        });
+
+        let buffer2 = project
+            .update(cx, |project, cx| {
+                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
+                project.open_buffer(path, cx)
+            })
+            .await
+            .unwrap();
+
+        zeta.read_with(cx, |zeta, cx| {
+            let prediction = zeta
+                .current_prediction_for_buffer(&buffer2, &project, cx)
+                .unwrap();
+            assert_matches!(prediction, BufferEditPrediction::Local { .. });
+        });
+    }
 
     #[gpui::test]
     async fn test_simple_request(cx: &mut TestAppContext) {
@@ -1146,6 +1353,7 @@ mod tests {
 
             let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
             let zeta = Zeta::global(&client, &user_store, cx);
+
             (zeta, req_rx)
         })
     }

crates/zeta2_tools/src/zeta2_tools.rs 🔗

@@ -185,7 +185,7 @@ impl Zeta2Inspector {
                     cx.background_executor().timer(THROTTLE_TIME).await;
                     if let Some(task) = zeta
                         .update(cx, |zeta, cx| {
-                            zeta.request_prediction(&project, &buffer, position, cx)
+                            zeta.refresh_prediction(&project, &buffer, position, cx)
                         })
                         .ok()
                     {