From 4fb671f4eb95b5190c4e9540322a61f137c556f5 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Fri, 21 Nov 2025 13:39:08 -0300 Subject: [PATCH] zeta2: Predict at next diagnostic location (#43257) When no predictions are available for the current buffer, we will now attempt to predict at the closest diagnostic from the cursor location that wasn't included in the last prediction request. This enables a commonly desired kind of far-away jump without requiring explicit model support. Release Notes: - N/A --- crates/codestral/src/codestral.rs | 2 +- .../src/copilot_completion_provider.rs | 2 +- crates/edit_prediction/src/edit_prediction.rs | 4 +- crates/editor/src/edit_prediction_tests.rs | 4 +- .../src/supermaven_completion_provider.rs | 2 +- crates/util/src/rel_path.rs | 1 + crates/zeta/src/zeta.rs | 2 +- crates/zeta2/Cargo.toml | 3 +- crates/zeta2/src/provider.rs | 89 +-- crates/zeta2/src/zeta2.rs | 668 ++++++++++++++---- crates/zeta2_tools/src/zeta2_tools.rs | 19 +- 11 files changed, 539 insertions(+), 257 deletions(-) diff --git a/crates/codestral/src/codestral.rs b/crates/codestral/src/codestral.rs index 9fbd207a809fb2cb3ac685ea6629a36c8631d1fe..6a500acbf6ec5eea63c35a8deb83a8545cee497e 100644 --- a/crates/codestral/src/codestral.rs +++ b/crates/codestral/src/codestral.rs @@ -182,7 +182,7 @@ impl EditPredictionProvider for CodestralCompletionProvider { Self::api_key(cx).is_some() } - fn is_refreshing(&self) -> bool { + fn is_refreshing(&self, _cx: &App) -> bool { self.pending_request.is_some() } diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs index 30ef6de07ec92f66cc888b52a540cf9c7e673bb4..e92f0c7d7dd7e51c4a8fdc19f34bd6eb4189c097 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -68,7 +68,7 @@ impl EditPredictionProvider for CopilotCompletionProvider { false } - fn is_refreshing(&self) -> bool { + fn is_refreshing(&self, _cx: &App) -> bool { self.pending_refresh.is_some() && self.completions.is_empty() } diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index aebfa5e5229ef1fec50f2d9cf74e354878ddc1c5..1984383a9691ae9373973a3eb9f00db4e7e795f2 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -87,7 +87,7 @@ pub trait EditPredictionProvider: 'static + Sized { cursor_position: language::Anchor, cx: &App, ) -> bool; - fn is_refreshing(&self) -> bool; + fn is_refreshing(&self, cx: &App) -> bool; fn refresh( &mut self, buffer: Entity, @@ -200,7 +200,7 @@ where } fn is_refreshing(&self, cx: &App) -> bool { - self.read(cx).is_refreshing() + self.read(cx).is_refreshing(cx) } fn refresh( diff --git a/crates/editor/src/edit_prediction_tests.rs b/crates/editor/src/edit_prediction_tests.rs index 74f13a404c6a52db448d68eba9e5c255e7276923..a1839144a47a81f668ba2743cd5e362f6711d0e9 100644 --- a/crates/editor/src/edit_prediction_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -469,7 +469,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider { true } - fn is_refreshing(&self) -> bool { + fn is_refreshing(&self, _cx: &gpui::App) -> bool { false } @@ -542,7 +542,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider { true } - fn is_refreshing(&self) -> bool { + fn is_refreshing(&self, _cx: &gpui::App) -> bool { false } diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs index 0c9fe85da6130f5ea2040434a0dcd3727754d3c0..9d5e256aca1b66644145cb688851d0ec5c1b81b9 100644 --- a/crates/supermaven/src/supermaven_completion_provider.rs +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -129,7 +129,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider { self.supermaven.read(cx).is_enabled() } - fn is_refreshing(&self) -> bool { + fn is_refreshing(&self, _cx: &App) -> bool { self.pending_refresh.is_some() && self.completion_id.is_none() } diff --git a/crates/util/src/rel_path.rs b/crates/util/src/rel_path.rs index b360297f209c54c6a33b174a738ed1876fbc16a0..60a0e2ef9ef51ee579a30fac57486b0040c42227 100644 --- a/crates/util/src/rel_path.rs +++ b/crates/util/src/rel_path.rs @@ -374,6 +374,7 @@ impl PartialEq for RelPath { } } +#[derive(Default)] pub struct RelPathComponents<'a>(&'a str); pub struct RelPathAncestors<'a>(Option<&'a str>); diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index c2ef5cb826db0947c18e1e91a6163cccc12deb11..cb31488d17668531ee11a67d1e4be19a1674d3d2 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -1486,7 +1486,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { ) -> bool { true } - fn is_refreshing(&self) -> bool { + fn is_refreshing(&self, _cx: &App) -> bool { !self.pending_completions.is_empty() } diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index 0f156f68fac881d65d76f178315f40df1dba9d7f..834762447707b88d6b009f0d6700c639306c9bbd 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -32,7 +32,9 @@ indoc.workspace = true language.workspace = true language_model.workspace = true log.workspace = true +lsp.workspace = true open_ai.workspace = true +pretty_assertions.workspace = true project.workspace = true release_channel.workspace = true serde.workspace = true @@ -44,7 +46,6 @@ util.workspace = true uuid.workspace = true workspace.workspace = true worktree.workspace = true -pretty_assertions.workspace = true [dev-dependencies] clock = { workspace = true, features = ["test-support"] } diff --git a/crates/zeta2/src/provider.rs b/crates/zeta2/src/provider.rs index 1b82826f663b092b5763935d9a7a2d4bb9607ebf..768af6253fe1a2aa60ef9cb0a10fcee0035dc3e2 100644 --- a/crates/zeta2/src/provider.rs +++ b/crates/zeta2/src/provider.rs @@ -1,24 +1,15 @@ -use std::{ - cmp, - sync::Arc, - time::{Duration, Instant}, -}; +use std::{cmp, sync::Arc, time::Duration}; -use arrayvec::ArrayVec; use client::{Client, UserStore}; use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider}; -use gpui::{App, Entity, Task, prelude::*}; +use gpui::{App, Entity, prelude::*}; use language::ToPoint as _; use project::Project; -use util::ResultExt as _; use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel}; pub struct ZetaEditPredictionProvider { zeta: Entity, - next_pending_prediction_id: usize, - pending_predictions: ArrayVec, - last_request_timestamp: Instant, project: Entity, } @@ -29,28 +20,25 @@ impl ZetaEditPredictionProvider { project: Entity, client: &Arc, user_store: &Entity, - cx: &mut App, + cx: &mut Context, ) -> Self { let zeta = Zeta::global(client, user_store, cx); zeta.update(cx, |zeta, cx| { zeta.register_project(&project, cx); }); + cx.observe(&zeta, |_this, _zeta, cx| { + cx.notify(); + }) + .detach(); + Self { - zeta, - next_pending_prediction_id: 0, - pending_predictions: ArrayVec::new(), - last_request_timestamp: Instant::now(), project: project, + zeta, } } } -struct PendingPrediction { - id: usize, - _task: Task<()>, -} - impl EditPredictionProvider for ZetaEditPredictionProvider { fn name() -> &'static str { "zed-predict2" @@ -95,8 +83,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { } } - fn is_refreshing(&self) -> bool { - !self.pending_predictions.is_empty() + fn is_refreshing(&self, cx: &App) -> bool { + self.zeta.read(cx).is_refreshing(&self.project) } fn refresh( @@ -123,59 +111,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { self.zeta.update(cx, |zeta, cx| { zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx); + zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx) }); - - let pending_prediction_id = self.next_pending_prediction_id; - self.next_pending_prediction_id += 1; - let last_request_timestamp = self.last_request_timestamp; - - let project = self.project.clone(); - let task = cx.spawn(async move |this, cx| { - if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT) - .checked_duration_since(Instant::now()) - { - cx.background_executor().timer(timeout).await; - } - - let refresh_task = this.update(cx, |this, cx| { - this.last_request_timestamp = Instant::now(); - this.zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction(&project, &buffer, cursor_position, cx) - }) - }); - - if let Some(refresh_task) = refresh_task.ok() { - refresh_task.await.log_err(); - } - - this.update(cx, |this, cx| { - if this.pending_predictions[0].id == pending_prediction_id { - this.pending_predictions.remove(0); - } else { - this.pending_predictions.clear(); - } - - cx.notify(); - }) - .ok(); - }); - - // We always maintain at most two pending predictions. When we already - // have two, we replace the newest one. - if self.pending_predictions.len() <= 1 { - self.pending_predictions.push(PendingPrediction { - id: pending_prediction_id, - _task: task, - }); - } else if self.pending_predictions.len() == 2 { - self.pending_predictions.pop(); - self.pending_predictions.push(PendingPrediction { - id: pending_prediction_id, - _task: task, - }); - } - - cx.notify(); } fn cycle( @@ -191,14 +128,12 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { self.zeta.update(cx, |zeta, cx| { zeta.accept_current_prediction(&self.project, cx); }); - self.pending_predictions.clear(); } fn discard(&mut self, cx: &mut Context) { self.zeta.update(cx, |zeta, _cx| { zeta.discard_current_prediction(&self.project); }); - self.pending_predictions.clear(); } fn suggest( diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 0d0f4f3d39e9c997282695828ba16e7eccd7d8e2..a06d7043cf565dccf0d8a4e8830cbb41c2e9981b 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -1,4 +1,5 @@ use anyhow::{Context as _, Result, anyhow, bail}; +use arrayvec::ArrayVec; use chrono::TimeDelta; use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature}; @@ -19,18 +20,20 @@ use futures::AsyncReadExt as _; use futures::channel::{mpsc, oneshot}; use gpui::http_client::{AsyncBody, Method}; use gpui::{ - App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity, - http_client, prelude::*, + App, AsyncApp, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, + WeakEntity, http_client, prelude::*, }; use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint}; use language::{BufferSnapshot, OffsetRangeExt}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; +use lsp::DiagnosticSeverity; use open_ai::FunctionDefinition; use project::{Project, ProjectPath}; use release_channel::AppVersion; use serde::de::DeserializeOwned; use std::collections::{VecDeque, hash_map}; +use std::fmt::Write; use std::ops::Range; use std::path::Path; use std::str::FromStr as _; @@ -39,7 +42,7 @@ use std::time::{Duration, Instant}; use std::{env, mem}; use thiserror::Error; use util::rel_path::RelPathBuf; -use util::{LogErrorFuture, ResultExt as _, TryFutureExt}; +use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; pub mod assemble_excerpts; @@ -239,6 +242,9 @@ struct ZetaProject { recent_paths: VecDeque, registered_buffers: HashMap, current_prediction: Option, + next_pending_prediction_id: usize, + pending_predictions: ArrayVec, + last_prediction_refresh: Option<(EntityId, Instant)>, context: Option, Vec>>>, refresh_context_task: Option>>>, refresh_context_debounce_task: Option>>, @@ -248,7 +254,7 @@ struct ZetaProject { #[derive(Debug, Clone)] struct CurrentEditPrediction { - pub requested_by_buffer_id: EntityId, + pub requested_by: PredictionRequestedBy, pub prediction: EditPrediction, } @@ -272,11 +278,13 @@ impl CurrentEditPrediction { return true; }; + let requested_by_buffer_id = self.requested_by.buffer_id(); + // This reduces the occurrence of UI thrash from replacing edits // // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits. - if self.requested_by_buffer_id == self.prediction.buffer.entity_id() - && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id() + if requested_by_buffer_id == Some(self.prediction.buffer.entity_id()) + && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id()) && old_edits.len() == 1 && new_edits.len() == 1 { @@ -289,6 +297,26 @@ impl CurrentEditPrediction { } } +#[derive(Debug, Clone)] +enum PredictionRequestedBy { + DiagnosticsUpdate, + Buffer(EntityId), +} + +impl PredictionRequestedBy { + pub fn buffer_id(&self) -> Option { + match self { + PredictionRequestedBy::DiagnosticsUpdate => None, + PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id), + } + } +} + +struct PendingPrediction { + id: usize, + _task: Task<()>, +} + /// A prediction from the perspective of a buffer. #[derive(Debug)] enum BufferEditPrediction<'a> { @@ -513,31 +541,48 @@ impl Zeta { recent_paths: VecDeque::new(), registered_buffers: HashMap::default(), current_prediction: None, + pending_predictions: ArrayVec::new(), + next_pending_prediction_id: 0, + last_prediction_refresh: None, context: None, refresh_context_task: None, refresh_context_debounce_task: None, refresh_context_timestamp: None, - _subscription: cx.subscribe(&project, |this, project, event, cx| { - // TODO [zeta2] init with recent paths - if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { - if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event { - let path = project.read(cx).path_for_entry(*active_entry_id, cx); - if let Some(path) = path { - if let Some(ix) = zeta_project - .recent_paths - .iter() - .position(|probe| probe == &path) - { - zeta_project.recent_paths.remove(ix); - } - zeta_project.recent_paths.push_front(path); - } - } - } - }), + _subscription: cx.subscribe(&project, Self::handle_project_event), }) } + fn handle_project_event( + &mut self, + project: Entity, + event: &project::Event, + cx: &mut Context, + ) { + // TODO [zeta2] init with recent paths + match event { + project::Event::ActiveEntryChanged(Some(active_entry_id)) => { + let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + let path = project.read(cx).path_for_entry(*active_entry_id, cx); + if let Some(path) = path { + if let Some(ix) = zeta_project + .recent_paths + .iter() + .position(|probe| probe == &path) + { + zeta_project.recent_paths.remove(ix); + } + zeta_project.recent_paths.push_front(path); + } + } + project::Event::DiagnosticsUpdated { .. } => { + self.refresh_prediction_from_diagnostics(project, cx); + } + _ => (), + } + } + fn register_buffer_impl<'a>( zeta_project: &'a mut ZetaProject, buffer: &Entity, @@ -650,16 +695,25 @@ impl Zeta { let project_state = self.projects.get(&project.entity_id())?; let CurrentEditPrediction { - requested_by_buffer_id, + requested_by, prediction, } = project_state.current_prediction.as_ref()?; if prediction.targets_buffer(buffer.read(cx)) { Some(BufferEditPrediction::Local { prediction }) - } else if *requested_by_buffer_id == buffer.entity_id() { - Some(BufferEditPrediction::Jump { prediction }) } else { - None + let show_jump = match requested_by { + PredictionRequestedBy::Buffer(requested_by_buffer_id) => { + requested_by_buffer_id == &buffer.entity_id() + } + PredictionRequestedBy::DiagnosticsUpdate => true, + }; + + if show_jump { + Some(BufferEditPrediction::Jump { prediction }) + } else { + None + } } } @@ -676,6 +730,7 @@ impl Zeta { return; }; let request_id = prediction.prediction.id.to_string(); + project_state.pending_predictions.clear(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); @@ -715,47 +770,191 @@ impl Zeta { fn discard_current_prediction(&mut self, project: &Entity) { if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { project_state.current_prediction.take(); + project_state.pending_predictions.clear(); }; } - pub fn refresh_prediction( + fn is_refreshing(&self, project: &Entity) -> bool { + self.projects + .get(&project.entity_id()) + .is_some_and(|project_state| !project_state.pending_predictions.is_empty()) + } + + pub fn refresh_prediction_from_buffer( &mut self, - project: &Entity, - buffer: &Entity, + project: Entity, + buffer: Entity, position: language::Anchor, cx: &mut Context, - ) -> Task> { - let request_task = self.request_prediction(project, buffer, position, cx); - let buffer = buffer.clone(); - let project = project.clone(); + ) { + self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| { + let Some(request_task) = this + .update(cx, |this, cx| { + this.request_prediction(&project, &buffer, position, cx) + }) + .log_err() + else { + return Task::ready(anyhow::Ok(())); + }; - cx.spawn(async move |this, cx| { - if let Some(prediction) = request_task.await? { - this.update(cx, |this, cx| { - let project_state = this - .projects - .get_mut(&project.entity_id()) - .context("Project not found")?; - - let new_prediction = CurrentEditPrediction { - requested_by_buffer_id: buffer.entity_id(), - prediction: prediction, - }; + let project = project.clone(); + cx.spawn(async move |cx| { + if let Some(prediction) = request_task.await? { + this.update(cx, |this, cx| { + let project_state = this + .projects + .get_mut(&project.entity_id()) + .context("Project not found")?; + + let new_prediction = CurrentEditPrediction { + requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()), + prediction: prediction, + }; - if project_state - .current_prediction - .as_ref() - .is_none_or(|old_prediction| { - new_prediction.should_replace_prediction(&old_prediction, cx) - }) - { - project_state.current_prediction = Some(new_prediction); + if project_state + .current_prediction + .as_ref() + .is_none_or(|old_prediction| { + new_prediction.should_replace_prediction(&old_prediction, cx) + }) + { + project_state.current_prediction = Some(new_prediction); + cx.notify(); + } + anyhow::Ok(()) + })??; + } + Ok(()) + }) + }) + } + + pub fn refresh_prediction_from_diagnostics( + &mut self, + project: Entity, + cx: &mut Context, + ) { + let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + + // Prefer predictions from buffer + if zeta_project.current_prediction.is_some() { + return; + }; + + self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| { + let Some(open_buffer_task) = project + .update(cx, |project, cx| { + project + .active_entry() + .and_then(|entry| project.path_for_entry(entry, cx)) + .map(|path| project.open_buffer(path, cx)) + }) + .log_err() + .flatten() + else { + return Task::ready(anyhow::Ok(())); + }; + + cx.spawn(async move |cx| { + let active_buffer = open_buffer_task.await?; + let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( + active_buffer, + &snapshot, + Default::default(), + Default::default(), + &project, + cx, + ) + .await? + else { + return anyhow::Ok(()); + }; + + let Some(prediction) = this + .update(cx, |this, cx| { + this.request_prediction(&project, &jump_buffer, jump_position, cx) + })? + .await? + else { + return anyhow::Ok(()); + }; + + this.update(cx, |this, cx| { + if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { + zeta_project.current_prediction.get_or_insert_with(|| { + cx.notify(); + CurrentEditPrediction { + requested_by: PredictionRequestedBy::DiagnosticsUpdate, + prediction, + } + }); } - anyhow::Ok(()) - })??; + })?; + + anyhow::Ok(()) + }) + }); + } + + #[cfg(not(test))] + pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); + #[cfg(test)] + pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO; + + fn queue_prediction_refresh( + &mut self, + project: Entity, + throttle_entity: EntityId, + cx: &mut Context, + do_refresh: impl FnOnce(WeakEntity, &mut AsyncApp) -> Task> + 'static, + ) { + let zeta_project = self.get_or_init_zeta_project(&project, cx); + let pending_prediction_id = zeta_project.next_pending_prediction_id; + zeta_project.next_pending_prediction_id += 1; + let last_request = zeta_project.last_prediction_refresh; + + // TODO report cancelled requests like in zeta1 + let task = cx.spawn(async move |this, cx| { + if let Some((last_entity, last_timestamp)) = last_request + && throttle_entity == last_entity + && let Some(timeout) = + (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now()) + { + cx.background_executor().timer(timeout).await; } - Ok(()) - }) + + do_refresh(this.clone(), cx).await.log_err(); + + this.update(cx, |this, cx| { + let zeta_project = this.get_or_init_zeta_project(&project, cx); + + if zeta_project.pending_predictions[0].id == pending_prediction_id { + zeta_project.pending_predictions.remove(0); + } else { + zeta_project.pending_predictions.clear(); + } + + cx.notify(); + }) + .ok(); + }); + + if zeta_project.pending_predictions.len() <= 1 { + zeta_project.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + _task: task, + }); + } else if zeta_project.pending_predictions.len() == 2 { + zeta_project.pending_predictions.pop(); + zeta_project.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + _task: task, + }); + } } pub fn request_prediction( @@ -770,7 +969,7 @@ impl Zeta { self.request_prediction_with_zed_cloud(project, active_buffer, position, cx) } ZetaEditPredictionModel::Sweep => { - self.request_prediction_with_sweep(project, active_buffer, position, cx) + self.request_prediction_with_sweep(project, active_buffer, position, true, cx) } } } @@ -780,6 +979,7 @@ impl Zeta { project: &Entity, active_buffer: &Entity, position: language::Anchor, + allow_jump: bool, cx: &mut Context, ) -> Task>> { let snapshot = active_buffer.read(cx).snapshot(); @@ -802,6 +1002,7 @@ impl Zeta { let project_state = self.get_or_init_zeta_project(project, cx); let events = project_state.events.clone(); + let has_events = !events.is_empty(); let recent_buffers = project_state.recent_paths.iter().cloned(); let http_client = cx.http_client(); @@ -817,114 +1018,188 @@ impl Zeta { .take(3) .collect::>(); - let result = cx.background_spawn(async move { - let text = snapshot.text(); + const DIAGNOSTIC_LINES_RANGE: u32 = 20; - let mut recent_changes = String::new(); - for event in events { - sweep_ai::write_event(event, &mut recent_changes).unwrap(); - } + let cursor_point = position.to_point(&snapshot); + let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE); + let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE; + let diagnostic_search_range = + Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0); + + let result = cx.background_spawn({ + let snapshot = snapshot.clone(); + let diagnostic_search_range = diagnostic_search_range.clone(); + async move { + let text = snapshot.text(); + + let mut recent_changes = String::new(); + for event in events { + sweep_ai::write_event(event, &mut recent_changes).unwrap(); + } + + let mut file_chunks = recent_buffer_snapshots + .into_iter() + .map(|snapshot| { + let end_point = Point::new(30, 0).min(snapshot.max_point()); + sweep_ai::FileChunk { + content: snapshot.text_for_range(Point::zero()..end_point).collect(), + file_path: snapshot + .file() + .map(|f| f.path().as_unix_str()) + .unwrap_or("untitled") + .to_string(), + start_line: 0, + end_line: end_point.row as usize, + timestamp: snapshot.file().and_then(|file| { + Some( + file.disk_state() + .mtime()? + .to_seconds_and_nanos_for_persistence()? + .0, + ) + }), + } + }) + .collect::>(); + + let diagnostic_entries = + snapshot.diagnostics_in_range(diagnostic_search_range, false); + let mut diagnostic_content = String::new(); + let mut diagnostic_count = 0; + + for entry in diagnostic_entries { + let start_point: Point = entry.range.start; + + let severity = match entry.diagnostic.severity { + DiagnosticSeverity::ERROR => "error", + DiagnosticSeverity::WARNING => "warning", + DiagnosticSeverity::INFORMATION => "info", + DiagnosticSeverity::HINT => "hint", + _ => continue, + }; + + diagnostic_count += 1; - let file_chunks = recent_buffer_snapshots - .into_iter() - .map(|snapshot| { - let end_point = language::Point::new(30, 0).min(snapshot.max_point()); - sweep_ai::FileChunk { - content: snapshot - .text_for_range(language::Point::zero()..end_point) - .collect(), - file_path: snapshot - .file() - .map(|f| f.path().as_unix_str()) - .unwrap_or("untitled") - .to_string(), + writeln!( + &mut diagnostic_content, + "{} at line {}: {}", + severity, + start_point.row + 1, + entry.diagnostic.message + )?; + } + + if !diagnostic_content.is_empty() { + file_chunks.push(sweep_ai::FileChunk { + file_path: format!("Diagnostics for {}", full_path.display()), start_line: 0, - end_line: end_point.row as usize, - timestamp: snapshot.file().and_then(|file| { - Some( - file.disk_state() - .mtime()? - .to_seconds_and_nanos_for_persistence()? - .0, - ) - }), - } - }) - .collect(); - - let request_body = sweep_ai::AutocompleteRequest { - debug_info, - repo_name, - file_path: full_path.clone(), - file_contents: text.clone(), - original_file_contents: text, - cursor_position: offset, - recent_changes: recent_changes.clone(), - changes_above_cursor: true, - multiple_suggestions: false, - branch: None, - file_chunks, - retrieval_chunks: vec![], - recent_user_actions: vec![], - // TODO - privacy_mode_enabled: false, - }; + end_line: diagnostic_count, + content: diagnostic_content, + timestamp: None, + }); + } - let mut buf: Vec = Vec::new(); - let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22); - serde_json::to_writer(writer, &request_body)?; - let body: AsyncBody = buf.into(); + let request_body = sweep_ai::AutocompleteRequest { + debug_info, + repo_name, + file_path: full_path.clone(), + file_contents: text.clone(), + original_file_contents: text, + cursor_position: offset, + recent_changes: recent_changes.clone(), + changes_above_cursor: true, + multiple_suggestions: false, + branch: None, + file_chunks, + retrieval_chunks: vec![], + recent_user_actions: vec![], + // TODO + privacy_mode_enabled: false, + }; - const SWEEP_API_URL: &str = - "https://autocomplete.sweep.dev/backend/next_edit_autocomplete"; + let mut buf: Vec = Vec::new(); + let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22); + serde_json::to_writer(writer, &request_body)?; + let body: AsyncBody = buf.into(); - let request = http_client::Request::builder() - .uri(SWEEP_API_URL) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_token)) - .header("Connection", "keep-alive") - .header("Content-Encoding", "br") - .method(Method::POST) - .body(body)?; + const SWEEP_API_URL: &str = + "https://autocomplete.sweep.dev/backend/next_edit_autocomplete"; - let mut response = http_client.send(request).await?; + let request = http_client::Request::builder() + .uri(SWEEP_API_URL) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_token)) + .header("Connection", "keep-alive") + .header("Content-Encoding", "br") + .method(Method::POST) + .body(body)?; - let mut body: Vec = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; + let mut response = http_client.send(request).await?; - if !response.status().is_success() { - anyhow::bail!( - "Request failed with status: {:?}\nBody: {}", - response.status(), - String::from_utf8_lossy(&body), - ); - }; + let mut body: Vec = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; - let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?; - - let old_text = snapshot - .text_for_range(response.start_index..response.end_index) - .collect::(); - let edits = language::text_diff(&old_text, &response.completion) - .into_iter() - .map(|(range, text)| { - ( - snapshot.anchor_after(response.start_index + range.start) - ..snapshot.anchor_before(response.start_index + range.end), - text, - ) - }) - .collect::>(); + if !response.status().is_success() { + anyhow::bail!( + "Request failed with status: {:?}\nBody: {}", + response.status(), + String::from_utf8_lossy(&body), + ); + }; + + let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?; - anyhow::Ok((response.autocomplete_id, edits, snapshot)) + let old_text = snapshot + .text_for_range(response.start_index..response.end_index) + .collect::(); + let edits = language::text_diff(&old_text, &response.completion) + .into_iter() + .map(|(range, text)| { + ( + snapshot.anchor_after(response.start_index + range.start) + ..snapshot.anchor_before(response.start_index + range.end), + text, + ) + }) + .collect::>(); + + anyhow::Ok((response.autocomplete_id, edits, snapshot)) + } }); let buffer = active_buffer.clone(); + let project = project.clone(); + let active_buffer = active_buffer.clone(); - cx.spawn(async move |_, cx| { + cx.spawn(async move |this, cx| { let (id, edits, old_snapshot) = result.await?; if edits.is_empty() { + if has_events + && allow_jump + && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( + active_buffer, + &snapshot, + diagnostic_search_range, + cursor_point, + &project, + cx, + ) + .await? + { + return this + .update(cx, |this, cx| { + this.request_prediction_with_sweep( + &project, + &jump_buffer, + jump_position, + false, + cx, + ) + })? + .await; + } + return anyhow::Ok(None); } @@ -955,6 +1230,85 @@ impl Zeta { }) } + async fn next_diagnostic_location( + active_buffer: Entity, + active_buffer_snapshot: &BufferSnapshot, + active_buffer_diagnostic_search_range: Range, + active_buffer_cursor_point: Point, + project: &Entity, + cx: &mut AsyncApp, + ) -> Result, language::Anchor)>> { + // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request + let mut jump_location = active_buffer_snapshot + .diagnostic_groups(None) + .into_iter() + .filter_map(|(_, group)| { + let range = &group.entries[group.primary_ix] + .range + .to_point(&active_buffer_snapshot); + if range.overlaps(&active_buffer_diagnostic_search_range) { + None + } else { + Some(range.start) + } + }) + .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row)) + .map(|position| { + ( + active_buffer.clone(), + active_buffer_snapshot.anchor_before(position), + ) + }); + + if jump_location.is_none() { + let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| { + let file = buffer.file()?; + + Some(ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + }) + })?; + + let buffer_task = project.update(cx, |project, cx| { + let (path, _, _) = project + .diagnostic_summaries(false, cx) + .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref()) + .max_by_key(|(path, _, _)| { + // find the buffer with errors that shares most parent directories + path.path + .components() + .zip( + active_buffer_path + .as_ref() + .map(|p| p.path.components()) + .unwrap_or_default(), + ) + .take_while(|(a, b)| a == b) + .count() + })?; + + Some(project.open_buffer(path, cx)) + })?; + + if let Some(buffer_task) = buffer_task { + let closest_buffer = buffer_task.await?; + + jump_location = closest_buffer + .read_with(cx, |buffer, _cx| { + buffer + .buffer_diagnostics(None) + .into_iter() + .min_by_key(|entry| entry.diagnostic.severity) + .map(|entry| entry.range.start) + })? + .map(|position| (closest_buffer, position)); + } + } + + anyhow::Ok(jump_location) + } + fn request_prediction_with_zed_cloud( &mut self, project: &Entity, @@ -2168,8 +2522,8 @@ mod tests { // Prediction for current file - let prediction_task = zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction(&project, &buffer1, position, cx) + zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) }); let (_request, respond_tx) = req_rx.next().await.unwrap(); @@ -2184,7 +2538,8 @@ mod tests { Bye "})) .unwrap(); - prediction_task.await.unwrap(); + + cx.run_until_parked(); zeta.read_with(cx, |zeta, cx| { let prediction = zeta @@ -2242,8 +2597,8 @@ mod tests { }); // Prediction for another file - let prediction_task = zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction(&project, &buffer1, position, cx) + zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) }); let (_request, respond_tx) = req_rx.next().await.unwrap(); respond_tx @@ -2256,7 +2611,8 @@ mod tests { Adios "#})) .unwrap(); - prediction_task.await.unwrap(); + cx.run_until_parked(); + zeta.read_with(cx, |zeta, cx| { let prediction = zeta .current_prediction_for_buffer(&buffer1, &project, cx) diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 756fff5d621a85f7936a980d71f68c87098c4539..8758857e7cf50d6a5f2e5a4ea509293b18a8cb2c 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -1,6 +1,6 @@ mod zeta2_context_view; -use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc, time::Duration}; +use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc}; use chrono::TimeDelta; use client::{Client, UserStore}; @@ -237,24 +237,13 @@ impl Zeta2Inspector { fn set_zeta_options(&mut self, options: ZetaOptions, cx: &mut Context) { self.zeta.update(cx, |this, _cx| this.set_options(options)); - const DEBOUNCE_TIME: Duration = Duration::from_millis(100); - if let Some(prediction) = self.last_prediction.as_mut() { if let Some(buffer) = prediction.buffer.upgrade() { let position = prediction.position; - let zeta = self.zeta.clone(); let project = self.project.clone(); - prediction._task = Some(cx.spawn(async move |_this, cx| { - cx.background_executor().timer(DEBOUNCE_TIME).await; - if let Some(task) = zeta - .update(cx, |zeta, cx| { - zeta.refresh_prediction(&project, &buffer, position, cx) - }) - .ok() - { - task.await.log_err(); - } - })); + self.zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project, buffer, position, cx) + }); prediction.state = LastPredictionState::Requested; } else { self.last_prediction.take();