zeta2: Predict at next diagnostic location (#43257)

Agus Zubiaga created

When no predictions are available for the current buffer, we will now
attempt to predict at the closest diagnostic from the cursor location
that wasn't included in the last prediction request. This enables a
commonly desired kind of far-away jump without requiring explicit model
support.

Release Notes:

- N/A

Change summary

crates/codestral/src/codestral.rs                       |   2 
crates/copilot/src/copilot_completion_provider.rs       |   2 
crates/edit_prediction/src/edit_prediction.rs           |   4 
crates/editor/src/edit_prediction_tests.rs              |   4 
crates/supermaven/src/supermaven_completion_provider.rs |   2 
crates/util/src/rel_path.rs                             |   1 
crates/zeta/src/zeta.rs                                 |   2 
crates/zeta2/Cargo.toml                                 |   3 
crates/zeta2/src/provider.rs                            |  89 -
crates/zeta2/src/zeta2.rs                               | 668 ++++++++--
crates/zeta2_tools/src/zeta2_tools.rs                   |  19 
11 files changed, 539 insertions(+), 257 deletions(-)

Detailed changes

crates/codestral/src/codestral.rs 🔗

@@ -182,7 +182,7 @@ impl EditPredictionProvider for CodestralCompletionProvider {
         Self::api_key(cx).is_some()
     }
 
-    fn is_refreshing(&self) -> bool {
+    fn is_refreshing(&self, _cx: &App) -> bool {
         self.pending_request.is_some()
     }
 

crates/copilot/src/copilot_completion_provider.rs 🔗

@@ -68,7 +68,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
         false
     }
 
-    fn is_refreshing(&self) -> bool {
+    fn is_refreshing(&self, _cx: &App) -> bool {
         self.pending_refresh.is_some() && self.completions.is_empty()
     }
 

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -87,7 +87,7 @@ pub trait EditPredictionProvider: 'static + Sized {
         cursor_position: language::Anchor,
         cx: &App,
     ) -> bool;
-    fn is_refreshing(&self) -> bool;
+    fn is_refreshing(&self, cx: &App) -> bool;
     fn refresh(
         &mut self,
         buffer: Entity<Buffer>,
@@ -200,7 +200,7 @@ where
     }
 
     fn is_refreshing(&self, cx: &App) -> bool {
-        self.read(cx).is_refreshing()
+        self.read(cx).is_refreshing(cx)
     }
 
     fn refresh(

crates/editor/src/edit_prediction_tests.rs 🔗

@@ -469,7 +469,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
         true
     }
 
-    fn is_refreshing(&self) -> bool {
+    fn is_refreshing(&self, _cx: &gpui::App) -> bool {
         false
     }
 
@@ -542,7 +542,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
         true
     }
 
-    fn is_refreshing(&self) -> bool {
+    fn is_refreshing(&self, _cx: &gpui::App) -> bool {
         false
     }
 

crates/supermaven/src/supermaven_completion_provider.rs 🔗

@@ -129,7 +129,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
         self.supermaven.read(cx).is_enabled()
     }
 
-    fn is_refreshing(&self) -> bool {
+    fn is_refreshing(&self, _cx: &App) -> bool {
         self.pending_refresh.is_some() && self.completion_id.is_none()
     }
 

crates/util/src/rel_path.rs 🔗

@@ -374,6 +374,7 @@ impl PartialEq<str> for RelPath {
     }
 }
 
+#[derive(Default)]
 pub struct RelPathComponents<'a>(&'a str);
 
 pub struct RelPathAncestors<'a>(Option<&'a str>);

crates/zeta/src/zeta.rs 🔗

@@ -1486,7 +1486,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
     ) -> bool {
         true
     }
-    fn is_refreshing(&self) -> bool {
+    fn is_refreshing(&self, _cx: &App) -> bool {
         !self.pending_completions.is_empty()
     }
 

crates/zeta2/Cargo.toml 🔗

@@ -32,7 +32,9 @@ indoc.workspace = true
 language.workspace = true
 language_model.workspace = true
 log.workspace = true
+lsp.workspace = true
 open_ai.workspace = true
+pretty_assertions.workspace = true
 project.workspace = true
 release_channel.workspace = true
 serde.workspace = true
@@ -44,7 +46,6 @@ util.workspace = true
 uuid.workspace = true
 workspace.workspace = true
 worktree.workspace = true
-pretty_assertions.workspace = true
 
 [dev-dependencies]
 clock = { workspace = true, features = ["test-support"] }

crates/zeta2/src/provider.rs 🔗

@@ -1,24 +1,15 @@
-use std::{
-    cmp,
-    sync::Arc,
-    time::{Duration, Instant},
-};
+use std::{cmp, sync::Arc, time::Duration};
 
-use arrayvec::ArrayVec;
 use client::{Client, UserStore};
 use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
-use gpui::{App, Entity, Task, prelude::*};
+use gpui::{App, Entity, prelude::*};
 use language::ToPoint as _;
 use project::Project;
-use util::ResultExt as _;
 
 use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
 
 pub struct ZetaEditPredictionProvider {
     zeta: Entity<Zeta>,
-    next_pending_prediction_id: usize,
-    pending_predictions: ArrayVec<PendingPrediction, 2>,
-    last_request_timestamp: Instant,
     project: Entity<Project>,
 }
 
@@ -29,28 +20,25 @@ impl ZetaEditPredictionProvider {
         project: Entity<Project>,
         client: &Arc<Client>,
         user_store: &Entity<UserStore>,
-        cx: &mut App,
+        cx: &mut Context<Self>,
     ) -> Self {
         let zeta = Zeta::global(client, user_store, cx);
         zeta.update(cx, |zeta, cx| {
             zeta.register_project(&project, cx);
         });
 
+        cx.observe(&zeta, |_this, _zeta, cx| {
+            cx.notify();
+        })
+        .detach();
+
         Self {
-            zeta,
-            next_pending_prediction_id: 0,
-            pending_predictions: ArrayVec::new(),
-            last_request_timestamp: Instant::now(),
             project: project,
+            zeta,
         }
     }
 }
 
-struct PendingPrediction {
-    id: usize,
-    _task: Task<()>,
-}
-
 impl EditPredictionProvider for ZetaEditPredictionProvider {
     fn name() -> &'static str {
         "zed-predict2"
@@ -95,8 +83,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
         }
     }
 
-    fn is_refreshing(&self) -> bool {
-        !self.pending_predictions.is_empty()
+    fn is_refreshing(&self, cx: &App) -> bool {
+        self.zeta.read(cx).is_refreshing(&self.project)
     }
 
     fn refresh(
@@ -123,59 +111,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
 
         self.zeta.update(cx, |zeta, cx| {
             zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
+            zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
         });
-
-        let pending_prediction_id = self.next_pending_prediction_id;
-        self.next_pending_prediction_id += 1;
-        let last_request_timestamp = self.last_request_timestamp;
-
-        let project = self.project.clone();
-        let task = cx.spawn(async move |this, cx| {
-            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
-                .checked_duration_since(Instant::now())
-            {
-                cx.background_executor().timer(timeout).await;
-            }
-
-            let refresh_task = this.update(cx, |this, cx| {
-                this.last_request_timestamp = Instant::now();
-                this.zeta.update(cx, |zeta, cx| {
-                    zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
-                })
-            });
-
-            if let Some(refresh_task) = refresh_task.ok() {
-                refresh_task.await.log_err();
-            }
-
-            this.update(cx, |this, cx| {
-                if this.pending_predictions[0].id == pending_prediction_id {
-                    this.pending_predictions.remove(0);
-                } else {
-                    this.pending_predictions.clear();
-                }
-
-                cx.notify();
-            })
-            .ok();
-        });
-
-        // We always maintain at most two pending predictions. When we already
-        // have two, we replace the newest one.
-        if self.pending_predictions.len() <= 1 {
-            self.pending_predictions.push(PendingPrediction {
-                id: pending_prediction_id,
-                _task: task,
-            });
-        } else if self.pending_predictions.len() == 2 {
-            self.pending_predictions.pop();
-            self.pending_predictions.push(PendingPrediction {
-                id: pending_prediction_id,
-                _task: task,
-            });
-        }
-
-        cx.notify();
     }
 
     fn cycle(
@@ -191,14 +128,12 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
         self.zeta.update(cx, |zeta, cx| {
             zeta.accept_current_prediction(&self.project, cx);
         });
-        self.pending_predictions.clear();
     }
 
     fn discard(&mut self, cx: &mut Context<Self>) {
         self.zeta.update(cx, |zeta, _cx| {
             zeta.discard_current_prediction(&self.project);
         });
-        self.pending_predictions.clear();
     }
 
     fn suggest(

crates/zeta2/src/zeta2.rs 🔗

@@ -1,4 +1,5 @@
 use anyhow::{Context as _, Result, anyhow, bail};
+use arrayvec::ArrayVec;
 use chrono::TimeDelta;
 use client::{Client, EditPredictionUsage, UserStore};
 use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
@@ -19,18 +20,20 @@ use futures::AsyncReadExt as _;
 use futures::channel::{mpsc, oneshot};
 use gpui::http_client::{AsyncBody, Method};
 use gpui::{
-    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
-    http_client, prelude::*,
+    App, AsyncApp, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task,
+    WeakEntity, http_client, prelude::*,
 };
 use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
 use language::{BufferSnapshot, OffsetRangeExt};
 use language_model::{LlmApiToken, RefreshLlmTokenListener};
+use lsp::DiagnosticSeverity;
 use open_ai::FunctionDefinition;
 use project::{Project, ProjectPath};
 use release_channel::AppVersion;
 use serde::de::DeserializeOwned;
 use std::collections::{VecDeque, hash_map};
 
+use std::fmt::Write;
 use std::ops::Range;
 use std::path::Path;
 use std::str::FromStr as _;
@@ -39,7 +42,7 @@ use std::time::{Duration, Instant};
 use std::{env, mem};
 use thiserror::Error;
 use util::rel_path::RelPathBuf;
-use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
+use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
 use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 
 pub mod assemble_excerpts;
@@ -239,6 +242,9 @@ struct ZetaProject {
     recent_paths: VecDeque<ProjectPath>,
     registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
     current_prediction: Option<CurrentEditPrediction>,
+    next_pending_prediction_id: usize,
+    pending_predictions: ArrayVec<PendingPrediction, 2>,
+    last_prediction_refresh: Option<(EntityId, Instant)>,
     context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
     refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
     refresh_context_debounce_task: Option<Task<Option<()>>>,
@@ -248,7 +254,7 @@ struct ZetaProject {
 
 #[derive(Debug, Clone)]
 struct CurrentEditPrediction {
-    pub requested_by_buffer_id: EntityId,
+    pub requested_by: PredictionRequestedBy,
     pub prediction: EditPrediction,
 }
 
@@ -272,11 +278,13 @@ impl CurrentEditPrediction {
             return true;
         };
 
+        let requested_by_buffer_id = self.requested_by.buffer_id();
+
         // This reduces the occurrence of UI thrash from replacing edits
         //
         // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
-        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
-            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
+        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
+            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
             && old_edits.len() == 1
             && new_edits.len() == 1
         {
@@ -289,6 +297,26 @@ impl CurrentEditPrediction {
     }
 }
 
+#[derive(Debug, Clone)]
+enum PredictionRequestedBy {
+    DiagnosticsUpdate,
+    Buffer(EntityId),
+}
+
+impl PredictionRequestedBy {
+    pub fn buffer_id(&self) -> Option<EntityId> {
+        match self {
+            PredictionRequestedBy::DiagnosticsUpdate => None,
+            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
+        }
+    }
+}
+
+struct PendingPrediction {
+    id: usize,
+    _task: Task<()>,
+}
+
 /// A prediction from the perspective of a buffer.
 #[derive(Debug)]
 enum BufferEditPrediction<'a> {
@@ -513,31 +541,48 @@ impl Zeta {
                 recent_paths: VecDeque::new(),
                 registered_buffers: HashMap::default(),
                 current_prediction: None,
+                pending_predictions: ArrayVec::new(),
+                next_pending_prediction_id: 0,
+                last_prediction_refresh: None,
                 context: None,
                 refresh_context_task: None,
                 refresh_context_debounce_task: None,
                 refresh_context_timestamp: None,
-                _subscription: cx.subscribe(&project, |this, project, event, cx| {
-                    // TODO [zeta2] init with recent paths
-                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
-                        if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event {
-                            let path = project.read(cx).path_for_entry(*active_entry_id, cx);
-                            if let Some(path) = path {
-                                if let Some(ix) = zeta_project
-                                    .recent_paths
-                                    .iter()
-                                    .position(|probe| probe == &path)
-                                {
-                                    zeta_project.recent_paths.remove(ix);
-                                }
-                                zeta_project.recent_paths.push_front(path);
-                            }
-                        }
-                    }
-                }),
+                _subscription: cx.subscribe(&project, Self::handle_project_event),
             })
     }
 
+    fn handle_project_event(
+        &mut self,
+        project: Entity<Project>,
+        event: &project::Event,
+        cx: &mut Context<Self>,
+    ) {
+        // TODO [zeta2] init with recent paths
+        match event {
+            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
+                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
+                    return;
+                };
+                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
+                if let Some(path) = path {
+                    if let Some(ix) = zeta_project
+                        .recent_paths
+                        .iter()
+                        .position(|probe| probe == &path)
+                    {
+                        zeta_project.recent_paths.remove(ix);
+                    }
+                    zeta_project.recent_paths.push_front(path);
+                }
+            }
+            project::Event::DiagnosticsUpdated { .. } => {
+                self.refresh_prediction_from_diagnostics(project, cx);
+            }
+            _ => (),
+        }
+    }
+
     fn register_buffer_impl<'a>(
         zeta_project: &'a mut ZetaProject,
         buffer: &Entity<Buffer>,
@@ -650,16 +695,25 @@ impl Zeta {
         let project_state = self.projects.get(&project.entity_id())?;
 
         let CurrentEditPrediction {
-            requested_by_buffer_id,
+            requested_by,
             prediction,
         } = project_state.current_prediction.as_ref()?;
 
         if prediction.targets_buffer(buffer.read(cx)) {
             Some(BufferEditPrediction::Local { prediction })
-        } else if *requested_by_buffer_id == buffer.entity_id() {
-            Some(BufferEditPrediction::Jump { prediction })
         } else {
-            None
+            let show_jump = match requested_by {
+                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
+                    requested_by_buffer_id == &buffer.entity_id()
+                }
+                PredictionRequestedBy::DiagnosticsUpdate => true,
+            };
+
+            if show_jump {
+                Some(BufferEditPrediction::Jump { prediction })
+            } else {
+                None
+            }
         }
     }
 
@@ -676,6 +730,7 @@ impl Zeta {
             return;
         };
         let request_id = prediction.prediction.id.to_string();
+        project_state.pending_predictions.clear();
 
         let client = self.client.clone();
         let llm_token = self.llm_token.clone();
@@ -715,47 +770,191 @@ impl Zeta {
     fn discard_current_prediction(&mut self, project: &Entity<Project>) {
         if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
             project_state.current_prediction.take();
+            project_state.pending_predictions.clear();
         };
     }
 
-    pub fn refresh_prediction(
+    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
+        self.projects
+            .get(&project.entity_id())
+            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
+    }
+
+    pub fn refresh_prediction_from_buffer(
         &mut self,
-        project: &Entity<Project>,
-        buffer: &Entity<Buffer>,
+        project: Entity<Project>,
+        buffer: Entity<Buffer>,
         position: language::Anchor,
         cx: &mut Context<Self>,
-    ) -> Task<Result<()>> {
-        let request_task = self.request_prediction(project, buffer, position, cx);
-        let buffer = buffer.clone();
-        let project = project.clone();
+    ) {
+        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
+            let Some(request_task) = this
+                .update(cx, |this, cx| {
+                    this.request_prediction(&project, &buffer, position, cx)
+                })
+                .log_err()
+            else {
+                return Task::ready(anyhow::Ok(()));
+            };
 
-        cx.spawn(async move |this, cx| {
-            if let Some(prediction) = request_task.await? {
-                this.update(cx, |this, cx| {
-                    let project_state = this
-                        .projects
-                        .get_mut(&project.entity_id())
-                        .context("Project not found")?;
-
-                    let new_prediction = CurrentEditPrediction {
-                        requested_by_buffer_id: buffer.entity_id(),
-                        prediction: prediction,
-                    };
+            let project = project.clone();
+            cx.spawn(async move |cx| {
+                if let Some(prediction) = request_task.await? {
+                    this.update(cx, |this, cx| {
+                        let project_state = this
+                            .projects
+                            .get_mut(&project.entity_id())
+                            .context("Project not found")?;
+
+                        let new_prediction = CurrentEditPrediction {
+                            requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
+                            prediction: prediction,
+                        };
 
-                    if project_state
-                        .current_prediction
-                        .as_ref()
-                        .is_none_or(|old_prediction| {
-                            new_prediction.should_replace_prediction(&old_prediction, cx)
-                        })
-                    {
-                        project_state.current_prediction = Some(new_prediction);
+                        if project_state
+                            .current_prediction
+                            .as_ref()
+                            .is_none_or(|old_prediction| {
+                                new_prediction.should_replace_prediction(&old_prediction, cx)
+                            })
+                        {
+                            project_state.current_prediction = Some(new_prediction);
+                            cx.notify();
+                        }
+                        anyhow::Ok(())
+                    })??;
+                }
+                Ok(())
+            })
+        })
+    }
+
+    pub fn refresh_prediction_from_diagnostics(
+        &mut self,
+        project: Entity<Project>,
+        cx: &mut Context<Self>,
+    ) {
+        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
+            return;
+        };
+
+        // Prefer predictions from buffer
+        if zeta_project.current_prediction.is_some() {
+            return;
+        };
+
+        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
+            let Some(open_buffer_task) = project
+                .update(cx, |project, cx| {
+                    project
+                        .active_entry()
+                        .and_then(|entry| project.path_for_entry(entry, cx))
+                        .map(|path| project.open_buffer(path, cx))
+                })
+                .log_err()
+                .flatten()
+            else {
+                return Task::ready(anyhow::Ok(()));
+            };
+
+            cx.spawn(async move |cx| {
+                let active_buffer = open_buffer_task.await?;
+                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
+                    active_buffer,
+                    &snapshot,
+                    Default::default(),
+                    Default::default(),
+                    &project,
+                    cx,
+                )
+                .await?
+                else {
+                    return anyhow::Ok(());
+                };
+
+                let Some(prediction) = this
+                    .update(cx, |this, cx| {
+                        this.request_prediction(&project, &jump_buffer, jump_position, cx)
+                    })?
+                    .await?
+                else {
+                    return anyhow::Ok(());
+                };
+
+                this.update(cx, |this, cx| {
+                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
+                        zeta_project.current_prediction.get_or_insert_with(|| {
+                            cx.notify();
+                            CurrentEditPrediction {
+                                requested_by: PredictionRequestedBy::DiagnosticsUpdate,
+                                prediction,
+                            }
+                        });
                     }
-                    anyhow::Ok(())
-                })??;
+                })?;
+
+                anyhow::Ok(())
+            })
+        });
+    }
+
+    #[cfg(not(test))]
+    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
+    #[cfg(test)]
+    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
+
+    fn queue_prediction_refresh(
+        &mut self,
+        project: Entity<Project>,
+        throttle_entity: EntityId,
+        cx: &mut Context<Self>,
+        do_refresh: impl FnOnce(WeakEntity<Self>, &mut AsyncApp) -> Task<Result<()>> + 'static,
+    ) {
+        let zeta_project = self.get_or_init_zeta_project(&project, cx);
+        let pending_prediction_id = zeta_project.next_pending_prediction_id;
+        zeta_project.next_pending_prediction_id += 1;
+        let last_request = zeta_project.last_prediction_refresh;
+
+        // TODO report cancelled requests like in zeta1
+        let task = cx.spawn(async move |this, cx| {
+            if let Some((last_entity, last_timestamp)) = last_request
+                && throttle_entity == last_entity
+                && let Some(timeout) =
+                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
+            {
+                cx.background_executor().timer(timeout).await;
             }
-            Ok(())
-        })
+
+            do_refresh(this.clone(), cx).await.log_err();
+
+            this.update(cx, |this, cx| {
+                let zeta_project = this.get_or_init_zeta_project(&project, cx);
+
+                if zeta_project.pending_predictions[0].id == pending_prediction_id {
+                    zeta_project.pending_predictions.remove(0);
+                } else {
+                    zeta_project.pending_predictions.clear();
+                }
+
+                cx.notify();
+            })
+            .ok();
+        });
+
+        if zeta_project.pending_predictions.len() <= 1 {
+            zeta_project.pending_predictions.push(PendingPrediction {
+                id: pending_prediction_id,
+                _task: task,
+            });
+        } else if zeta_project.pending_predictions.len() == 2 {
+            zeta_project.pending_predictions.pop();
+            zeta_project.pending_predictions.push(PendingPrediction {
+                id: pending_prediction_id,
+                _task: task,
+            });
+        }
     }
 
     pub fn request_prediction(
@@ -770,7 +969,7 @@ impl Zeta {
                 self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
             }
             ZetaEditPredictionModel::Sweep => {
-                self.request_prediction_with_sweep(project, active_buffer, position, cx)
+                self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
             }
         }
     }
@@ -780,6 +979,7 @@ impl Zeta {
         project: &Entity<Project>,
         active_buffer: &Entity<Buffer>,
         position: language::Anchor,
+        allow_jump: bool,
         cx: &mut Context<Self>,
     ) -> Task<Result<Option<EditPrediction>>> {
         let snapshot = active_buffer.read(cx).snapshot();
@@ -802,6 +1002,7 @@ impl Zeta {
 
         let project_state = self.get_or_init_zeta_project(project, cx);
         let events = project_state.events.clone();
+        let has_events = !events.is_empty();
         let recent_buffers = project_state.recent_paths.iter().cloned();
         let http_client = cx.http_client();
 
@@ -817,114 +1018,188 @@ impl Zeta {
             .take(3)
             .collect::<Vec<_>>();
 
-        let result = cx.background_spawn(async move {
-            let text = snapshot.text();
+        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 
-            let mut recent_changes = String::new();
-            for event in events {
-                sweep_ai::write_event(event, &mut recent_changes).unwrap();
-            }
+        let cursor_point = position.to_point(&snapshot);
+        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
+        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
+        let diagnostic_search_range =
+            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
+
+        let result = cx.background_spawn({
+            let snapshot = snapshot.clone();
+            let diagnostic_search_range = diagnostic_search_range.clone();
+            async move {
+                let text = snapshot.text();
+
+                let mut recent_changes = String::new();
+                for event in events {
+                    sweep_ai::write_event(event, &mut recent_changes).unwrap();
+                }
+
+                let mut file_chunks = recent_buffer_snapshots
+                    .into_iter()
+                    .map(|snapshot| {
+                        let end_point = Point::new(30, 0).min(snapshot.max_point());
+                        sweep_ai::FileChunk {
+                            content: snapshot.text_for_range(Point::zero()..end_point).collect(),
+                            file_path: snapshot
+                                .file()
+                                .map(|f| f.path().as_unix_str())
+                                .unwrap_or("untitled")
+                                .to_string(),
+                            start_line: 0,
+                            end_line: end_point.row as usize,
+                            timestamp: snapshot.file().and_then(|file| {
+                                Some(
+                                    file.disk_state()
+                                        .mtime()?
+                                        .to_seconds_and_nanos_for_persistence()?
+                                        .0,
+                                )
+                            }),
+                        }
+                    })
+                    .collect::<Vec<_>>();
+
+                let diagnostic_entries =
+                    snapshot.diagnostics_in_range(diagnostic_search_range, false);
+                let mut diagnostic_content = String::new();
+                let mut diagnostic_count = 0;
+
+                for entry in diagnostic_entries {
+                    let start_point: Point = entry.range.start;
+
+                    let severity = match entry.diagnostic.severity {
+                        DiagnosticSeverity::ERROR => "error",
+                        DiagnosticSeverity::WARNING => "warning",
+                        DiagnosticSeverity::INFORMATION => "info",
+                        DiagnosticSeverity::HINT => "hint",
+                        _ => continue,
+                    };
+
+                    diagnostic_count += 1;
 
-            let file_chunks = recent_buffer_snapshots
-                .into_iter()
-                .map(|snapshot| {
-                    let end_point = language::Point::new(30, 0).min(snapshot.max_point());
-                    sweep_ai::FileChunk {
-                        content: snapshot
-                            .text_for_range(language::Point::zero()..end_point)
-                            .collect(),
-                        file_path: snapshot
-                            .file()
-                            .map(|f| f.path().as_unix_str())
-                            .unwrap_or("untitled")
-                            .to_string(),
+                    writeln!(
+                        &mut diagnostic_content,
+                        "{} at line {}: {}",
+                        severity,
+                        start_point.row + 1,
+                        entry.diagnostic.message
+                    )?;
+                }
+
+                if !diagnostic_content.is_empty() {
+                    file_chunks.push(sweep_ai::FileChunk {
+                        file_path: format!("Diagnostics for {}", full_path.display()),
                         start_line: 0,
-                        end_line: end_point.row as usize,
-                        timestamp: snapshot.file().and_then(|file| {
-                            Some(
-                                file.disk_state()
-                                    .mtime()?
-                                    .to_seconds_and_nanos_for_persistence()?
-                                    .0,
-                            )
-                        }),
-                    }
-                })
-                .collect();
-
-            let request_body = sweep_ai::AutocompleteRequest {
-                debug_info,
-                repo_name,
-                file_path: full_path.clone(),
-                file_contents: text.clone(),
-                original_file_contents: text,
-                cursor_position: offset,
-                recent_changes: recent_changes.clone(),
-                changes_above_cursor: true,
-                multiple_suggestions: false,
-                branch: None,
-                file_chunks,
-                retrieval_chunks: vec![],
-                recent_user_actions: vec![],
-                // TODO
-                privacy_mode_enabled: false,
-            };
+                        end_line: diagnostic_count,
+                        content: diagnostic_content,
+                        timestamp: None,
+                    });
+                }
 
-            let mut buf: Vec<u8> = Vec::new();
-            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
-            serde_json::to_writer(writer, &request_body)?;
-            let body: AsyncBody = buf.into();
+                let request_body = sweep_ai::AutocompleteRequest {
+                    debug_info,
+                    repo_name,
+                    file_path: full_path.clone(),
+                    file_contents: text.clone(),
+                    original_file_contents: text,
+                    cursor_position: offset,
+                    recent_changes: recent_changes.clone(),
+                    changes_above_cursor: true,
+                    multiple_suggestions: false,
+                    branch: None,
+                    file_chunks,
+                    retrieval_chunks: vec![],
+                    recent_user_actions: vec![],
+                    // TODO
+                    privacy_mode_enabled: false,
+                };
 
-            const SWEEP_API_URL: &str =
-                "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
+                let mut buf: Vec<u8> = Vec::new();
+                let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
+                serde_json::to_writer(writer, &request_body)?;
+                let body: AsyncBody = buf.into();
 
-            let request = http_client::Request::builder()
-                .uri(SWEEP_API_URL)
-                .header("Content-Type", "application/json")
-                .header("Authorization", format!("Bearer {}", api_token))
-                .header("Connection", "keep-alive")
-                .header("Content-Encoding", "br")
-                .method(Method::POST)
-                .body(body)?;
+                const SWEEP_API_URL: &str =
+                    "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 
-            let mut response = http_client.send(request).await?;
+                let request = http_client::Request::builder()
+                    .uri(SWEEP_API_URL)
+                    .header("Content-Type", "application/json")
+                    .header("Authorization", format!("Bearer {}", api_token))
+                    .header("Connection", "keep-alive")
+                    .header("Content-Encoding", "br")
+                    .method(Method::POST)
+                    .body(body)?;
 
-            let mut body: Vec<u8> = Vec::new();
-            response.body_mut().read_to_end(&mut body).await?;
+                let mut response = http_client.send(request).await?;
 
-            if !response.status().is_success() {
-                anyhow::bail!(
-                    "Request failed with status: {:?}\nBody: {}",
-                    response.status(),
-                    String::from_utf8_lossy(&body),
-                );
-            };
+                let mut body: Vec<u8> = Vec::new();
+                response.body_mut().read_to_end(&mut body).await?;
 
-            let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
-
-            let old_text = snapshot
-                .text_for_range(response.start_index..response.end_index)
-                .collect::<String>();
-            let edits = language::text_diff(&old_text, &response.completion)
-                .into_iter()
-                .map(|(range, text)| {
-                    (
-                        snapshot.anchor_after(response.start_index + range.start)
-                            ..snapshot.anchor_before(response.start_index + range.end),
-                        text,
-                    )
-                })
-                .collect::<Vec<_>>();
+                if !response.status().is_success() {
+                    anyhow::bail!(
+                        "Request failed with status: {:?}\nBody: {}",
+                        response.status(),
+                        String::from_utf8_lossy(&body),
+                    );
+                };
+
+                let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
 
-            anyhow::Ok((response.autocomplete_id, edits, snapshot))
+                let old_text = snapshot
+                    .text_for_range(response.start_index..response.end_index)
+                    .collect::<String>();
+                let edits = language::text_diff(&old_text, &response.completion)
+                    .into_iter()
+                    .map(|(range, text)| {
+                        (
+                            snapshot.anchor_after(response.start_index + range.start)
+                                ..snapshot.anchor_before(response.start_index + range.end),
+                            text,
+                        )
+                    })
+                    .collect::<Vec<_>>();
+
+                anyhow::Ok((response.autocomplete_id, edits, snapshot))
+            }
         });
 
         let buffer = active_buffer.clone();
+        let project = project.clone();
+        let active_buffer = active_buffer.clone();
 
-        cx.spawn(async move |_, cx| {
+        cx.spawn(async move |this, cx| {
             let (id, edits, old_snapshot) = result.await?;
 
             if edits.is_empty() {
+                if has_events
+                    && allow_jump
+                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
+                        active_buffer,
+                        &snapshot,
+                        diagnostic_search_range,
+                        cursor_point,
+                        &project,
+                        cx,
+                    )
+                    .await?
+                {
+                    return this
+                        .update(cx, |this, cx| {
+                            this.request_prediction_with_sweep(
+                                &project,
+                                &jump_buffer,
+                                jump_position,
+                                false,
+                                cx,
+                            )
+                        })?
+                        .await;
+                }
+
                 return anyhow::Ok(None);
             }
 
@@ -955,6 +1230,85 @@ impl Zeta {
         })
     }
 
+    async fn next_diagnostic_location(
+        active_buffer: Entity<Buffer>,
+        active_buffer_snapshot: &BufferSnapshot,
+        active_buffer_diagnostic_search_range: Range<Point>,
+        active_buffer_cursor_point: Point,
+        project: &Entity<Project>,
+        cx: &mut AsyncApp,
+    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
+        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
+        let mut jump_location = active_buffer_snapshot
+            .diagnostic_groups(None)
+            .into_iter()
+            .filter_map(|(_, group)| {
+                let range = &group.entries[group.primary_ix]
+                    .range
+                    .to_point(&active_buffer_snapshot);
+                if range.overlaps(&active_buffer_diagnostic_search_range) {
+                    None
+                } else {
+                    Some(range.start)
+                }
+            })
+            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
+            .map(|position| {
+                (
+                    active_buffer.clone(),
+                    active_buffer_snapshot.anchor_before(position),
+                )
+            });
+
+        if jump_location.is_none() {
+            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
+                let file = buffer.file()?;
+
+                Some(ProjectPath {
+                    worktree_id: file.worktree_id(cx),
+                    path: file.path().clone(),
+                })
+            })?;
+
+            let buffer_task = project.update(cx, |project, cx| {
+                let (path, _, _) = project
+                    .diagnostic_summaries(false, cx)
+                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
+                    .max_by_key(|(path, _, _)| {
+                        // find the buffer with errors that shares most parent directories
+                        path.path
+                            .components()
+                            .zip(
+                                active_buffer_path
+                                    .as_ref()
+                                    .map(|p| p.path.components())
+                                    .unwrap_or_default(),
+                            )
+                            .take_while(|(a, b)| a == b)
+                            .count()
+                    })?;
+
+                Some(project.open_buffer(path, cx))
+            })?;
+
+            if let Some(buffer_task) = buffer_task {
+                let closest_buffer = buffer_task.await?;
+
+                jump_location = closest_buffer
+                    .read_with(cx, |buffer, _cx| {
+                        buffer
+                            .buffer_diagnostics(None)
+                            .into_iter()
+                            .min_by_key(|entry| entry.diagnostic.severity)
+                            .map(|entry| entry.range.start)
+                    })?
+                    .map(|position| (closest_buffer, position));
+            }
+        }
+
+        anyhow::Ok(jump_location)
+    }
+
     fn request_prediction_with_zed_cloud(
         &mut self,
         project: &Entity<Project>,
@@ -2168,8 +2522,8 @@ mod tests {
 
         // Prediction for current file
 
-        let prediction_task = zeta.update(cx, |zeta, cx| {
-            zeta.refresh_prediction(&project, &buffer1, position, cx)
+        zeta.update(cx, |zeta, cx| {
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
         });
         let (_request, respond_tx) = req_rx.next().await.unwrap();
 
@@ -2184,7 +2538,8 @@ mod tests {
                  Bye
             "}))
             .unwrap();
-        prediction_task.await.unwrap();
+
+        cx.run_until_parked();
 
         zeta.read_with(cx, |zeta, cx| {
             let prediction = zeta
@@ -2242,8 +2597,8 @@ mod tests {
         });
 
         // Prediction for another file
-        let prediction_task = zeta.update(cx, |zeta, cx| {
-            zeta.refresh_prediction(&project, &buffer1, position, cx)
+        zeta.update(cx, |zeta, cx| {
+            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
         });
         let (_request, respond_tx) = req_rx.next().await.unwrap();
         respond_tx
@@ -2256,7 +2611,8 @@ mod tests {
                  Adios
             "#}))
             .unwrap();
-        prediction_task.await.unwrap();
+        cx.run_until_parked();
+
         zeta.read_with(cx, |zeta, cx| {
             let prediction = zeta
                 .current_prediction_for_buffer(&buffer1, &project, cx)

crates/zeta2_tools/src/zeta2_tools.rs 🔗

@@ -1,6 +1,6 @@
 mod zeta2_context_view;
 
-use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
+use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc};
 
 use chrono::TimeDelta;
 use client::{Client, UserStore};
@@ -237,24 +237,13 @@ impl Zeta2Inspector {
     fn set_zeta_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
         self.zeta.update(cx, |this, _cx| this.set_options(options));
 
-        const DEBOUNCE_TIME: Duration = Duration::from_millis(100);
-
         if let Some(prediction) = self.last_prediction.as_mut() {
             if let Some(buffer) = prediction.buffer.upgrade() {
                 let position = prediction.position;
-                let zeta = self.zeta.clone();
                 let project = self.project.clone();
-                prediction._task = Some(cx.spawn(async move |_this, cx| {
-                    cx.background_executor().timer(DEBOUNCE_TIME).await;
-                    if let Some(task) = zeta
-                        .update(cx, |zeta, cx| {
-                            zeta.refresh_prediction(&project, &buffer, position, cx)
-                        })
-                        .ok()
-                    {
-                        task.await.log_err();
-                    }
-                }));
+                self.zeta.update(cx, |zeta, cx| {
+                    zeta.refresh_prediction_from_buffer(project, buffer, position, cx)
+                });
                 prediction.state = LastPredictionState::Requested;
             } else {
                 self.last_prediction.take();