diff --git a/crates/codestral/src/codestral.rs b/crates/codestral/src/codestral.rs index 9fbd207a809fb2cb3ac685ea6629a36c8631d1fe..6a500acbf6ec5eea63c35a8deb83a8545cee497e 100644 --- a/crates/codestral/src/codestral.rs +++ b/crates/codestral/src/codestral.rs @@ -182,7 +182,7 @@ impl EditPredictionProvider for CodestralCompletionProvider { Self::api_key(cx).is_some() } - fn is_refreshing(&self) -> bool { + fn is_refreshing(&self, _cx: &App) -> bool { self.pending_request.is_some() } diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs index 30ef6de07ec92f66cc888b52a540cf9c7e673bb4..e92f0c7d7dd7e51c4a8fdc19f34bd6eb4189c097 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -68,7 +68,7 @@ impl EditPredictionProvider for CopilotCompletionProvider { false } - fn is_refreshing(&self) -> bool { + fn is_refreshing(&self, _cx: &App) -> bool { self.pending_refresh.is_some() && self.completions.is_empty() } diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index aebfa5e5229ef1fec50f2d9cf74e354878ddc1c5..1984383a9691ae9373973a3eb9f00db4e7e795f2 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -87,7 +87,7 @@ pub trait EditPredictionProvider: 'static + Sized { cursor_position: language::Anchor, cx: &App, ) -> bool; - fn is_refreshing(&self) -> bool; + fn is_refreshing(&self, cx: &App) -> bool; fn refresh( &mut self, buffer: Entity, @@ -200,7 +200,7 @@ where } fn is_refreshing(&self, cx: &App) -> bool { - self.read(cx).is_refreshing() + self.read(cx).is_refreshing(cx) } fn refresh( diff --git a/crates/editor/src/edit_prediction_tests.rs b/crates/editor/src/edit_prediction_tests.rs index 74f13a404c6a52db448d68eba9e5c255e7276923..a1839144a47a81f668ba2743cd5e362f6711d0e9 100644 --- a/crates/editor/src/edit_prediction_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -469,7 +469,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider { true } - fn is_refreshing(&self) -> bool { + fn is_refreshing(&self, _cx: &gpui::App) -> bool { false } @@ -542,7 +542,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider { true } - fn is_refreshing(&self) -> bool { + fn is_refreshing(&self, _cx: &gpui::App) -> bool { false } diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs index 0c9fe85da6130f5ea2040434a0dcd3727754d3c0..9d5e256aca1b66644145cb688851d0ec5c1b81b9 100644 --- a/crates/supermaven/src/supermaven_completion_provider.rs +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -129,7 +129,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider { self.supermaven.read(cx).is_enabled() } - fn is_refreshing(&self) -> bool { + fn is_refreshing(&self, _cx: &App) -> bool { self.pending_refresh.is_some() && self.completion_id.is_none() } diff --git a/crates/util/src/rel_path.rs b/crates/util/src/rel_path.rs index b360297f209c54c6a33b174a738ed1876fbc16a0..60a0e2ef9ef51ee579a30fac57486b0040c42227 100644 --- a/crates/util/src/rel_path.rs +++ b/crates/util/src/rel_path.rs @@ -374,6 +374,7 @@ impl PartialEq for RelPath { } } +#[derive(Default)] pub struct RelPathComponents<'a>(&'a str); pub struct RelPathAncestors<'a>(Option<&'a str>); diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index c2ef5cb826db0947c18e1e91a6163cccc12deb11..cb31488d17668531ee11a67d1e4be19a1674d3d2 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -1486,7 +1486,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { ) -> bool { true } - fn is_refreshing(&self) -> bool { + fn is_refreshing(&self, _cx: &App) -> bool { !self.pending_completions.is_empty() } diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index 0f156f68fac881d65d76f178315f40df1dba9d7f..834762447707b88d6b009f0d6700c639306c9bbd 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -32,7 +32,9 @@ indoc.workspace = true language.workspace = true language_model.workspace = true log.workspace = true +lsp.workspace = true open_ai.workspace = true +pretty_assertions.workspace = true project.workspace = true release_channel.workspace = true serde.workspace = true @@ -44,7 +46,6 @@ util.workspace = true uuid.workspace = true workspace.workspace = true worktree.workspace = true -pretty_assertions.workspace = true [dev-dependencies] clock = { workspace = true, features = ["test-support"] } diff --git a/crates/zeta2/src/provider.rs b/crates/zeta2/src/provider.rs index 1b82826f663b092b5763935d9a7a2d4bb9607ebf..768af6253fe1a2aa60ef9cb0a10fcee0035dc3e2 100644 --- a/crates/zeta2/src/provider.rs +++ b/crates/zeta2/src/provider.rs @@ -1,24 +1,15 @@ -use std::{ - cmp, - sync::Arc, - time::{Duration, Instant}, -}; +use std::{cmp, sync::Arc, time::Duration}; -use arrayvec::ArrayVec; use client::{Client, UserStore}; use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider}; -use gpui::{App, Entity, Task, prelude::*}; +use gpui::{App, Entity, prelude::*}; use language::ToPoint as _; use project::Project; -use util::ResultExt as _; use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel}; pub struct ZetaEditPredictionProvider { zeta: Entity, - next_pending_prediction_id: usize, - pending_predictions: ArrayVec, - last_request_timestamp: Instant, project: Entity, } @@ -29,28 +20,25 @@ impl ZetaEditPredictionProvider { project: Entity, client: &Arc, user_store: &Entity, - cx: &mut App, + cx: &mut Context, ) -> Self { let zeta = Zeta::global(client, user_store, cx); zeta.update(cx, |zeta, cx| { zeta.register_project(&project, cx); }); + cx.observe(&zeta, |_this, _zeta, cx| { + cx.notify(); + }) + .detach(); + Self { - zeta, - next_pending_prediction_id: 0, - pending_predictions: ArrayVec::new(), - last_request_timestamp: Instant::now(), project: project, + zeta, } } } -struct PendingPrediction { - id: usize, - _task: Task<()>, -} - impl EditPredictionProvider for ZetaEditPredictionProvider { fn name() -> &'static str { "zed-predict2" @@ -95,8 +83,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { } } - fn is_refreshing(&self) -> bool { - !self.pending_predictions.is_empty() + fn is_refreshing(&self, cx: &App) -> bool { + self.zeta.read(cx).is_refreshing(&self.project) } fn refresh( @@ -123,59 +111,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { self.zeta.update(cx, |zeta, cx| { zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx); + zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx) }); - - let pending_prediction_id = self.next_pending_prediction_id; - self.next_pending_prediction_id += 1; - let last_request_timestamp = self.last_request_timestamp; - - let project = self.project.clone(); - let task = cx.spawn(async move |this, cx| { - if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT) - .checked_duration_since(Instant::now()) - { - cx.background_executor().timer(timeout).await; - } - - let refresh_task = this.update(cx, |this, cx| { - this.last_request_timestamp = Instant::now(); - this.zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction(&project, &buffer, cursor_position, cx) - }) - }); - - if let Some(refresh_task) = refresh_task.ok() { - refresh_task.await.log_err(); - } - - this.update(cx, |this, cx| { - if this.pending_predictions[0].id == pending_prediction_id { - this.pending_predictions.remove(0); - } else { - this.pending_predictions.clear(); - } - - cx.notify(); - }) - .ok(); - }); - - // We always maintain at most two pending predictions. When we already - // have two, we replace the newest one. - if self.pending_predictions.len() <= 1 { - self.pending_predictions.push(PendingPrediction { - id: pending_prediction_id, - _task: task, - }); - } else if self.pending_predictions.len() == 2 { - self.pending_predictions.pop(); - self.pending_predictions.push(PendingPrediction { - id: pending_prediction_id, - _task: task, - }); - } - - cx.notify(); } fn cycle( @@ -191,14 +128,12 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { self.zeta.update(cx, |zeta, cx| { zeta.accept_current_prediction(&self.project, cx); }); - self.pending_predictions.clear(); } fn discard(&mut self, cx: &mut Context) { self.zeta.update(cx, |zeta, _cx| { zeta.discard_current_prediction(&self.project); }); - self.pending_predictions.clear(); } fn suggest( diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 0d0f4f3d39e9c997282695828ba16e7eccd7d8e2..a06d7043cf565dccf0d8a4e8830cbb41c2e9981b 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -1,4 +1,5 @@ use anyhow::{Context as _, Result, anyhow, bail}; +use arrayvec::ArrayVec; use chrono::TimeDelta; use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature}; @@ -19,18 +20,20 @@ use futures::AsyncReadExt as _; use futures::channel::{mpsc, oneshot}; use gpui::http_client::{AsyncBody, Method}; use gpui::{ - App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity, - http_client, prelude::*, + App, AsyncApp, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, + WeakEntity, http_client, prelude::*, }; use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint}; use language::{BufferSnapshot, OffsetRangeExt}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; +use lsp::DiagnosticSeverity; use open_ai::FunctionDefinition; use project::{Project, ProjectPath}; use release_channel::AppVersion; use serde::de::DeserializeOwned; use std::collections::{VecDeque, hash_map}; +use std::fmt::Write; use std::ops::Range; use std::path::Path; use std::str::FromStr as _; @@ -39,7 +42,7 @@ use std::time::{Duration, Instant}; use std::{env, mem}; use thiserror::Error; use util::rel_path::RelPathBuf; -use util::{LogErrorFuture, ResultExt as _, TryFutureExt}; +use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; pub mod assemble_excerpts; @@ -239,6 +242,9 @@ struct ZetaProject { recent_paths: VecDeque, registered_buffers: HashMap, current_prediction: Option, + next_pending_prediction_id: usize, + pending_predictions: ArrayVec, + last_prediction_refresh: Option<(EntityId, Instant)>, context: Option, Vec>>>, refresh_context_task: Option>>>, refresh_context_debounce_task: Option>>, @@ -248,7 +254,7 @@ struct ZetaProject { #[derive(Debug, Clone)] struct CurrentEditPrediction { - pub requested_by_buffer_id: EntityId, + pub requested_by: PredictionRequestedBy, pub prediction: EditPrediction, } @@ -272,11 +278,13 @@ impl CurrentEditPrediction { return true; }; + let requested_by_buffer_id = self.requested_by.buffer_id(); + // This reduces the occurrence of UI thrash from replacing edits // // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits. - if self.requested_by_buffer_id == self.prediction.buffer.entity_id() - && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id() + if requested_by_buffer_id == Some(self.prediction.buffer.entity_id()) + && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id()) && old_edits.len() == 1 && new_edits.len() == 1 { @@ -289,6 +297,26 @@ impl CurrentEditPrediction { } } +#[derive(Debug, Clone)] +enum PredictionRequestedBy { + DiagnosticsUpdate, + Buffer(EntityId), +} + +impl PredictionRequestedBy { + pub fn buffer_id(&self) -> Option { + match self { + PredictionRequestedBy::DiagnosticsUpdate => None, + PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id), + } + } +} + +struct PendingPrediction { + id: usize, + _task: Task<()>, +} + /// A prediction from the perspective of a buffer. #[derive(Debug)] enum BufferEditPrediction<'a> { @@ -513,31 +541,48 @@ impl Zeta { recent_paths: VecDeque::new(), registered_buffers: HashMap::default(), current_prediction: None, + pending_predictions: ArrayVec::new(), + next_pending_prediction_id: 0, + last_prediction_refresh: None, context: None, refresh_context_task: None, refresh_context_debounce_task: None, refresh_context_timestamp: None, - _subscription: cx.subscribe(&project, |this, project, event, cx| { - // TODO [zeta2] init with recent paths - if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { - if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event { - let path = project.read(cx).path_for_entry(*active_entry_id, cx); - if let Some(path) = path { - if let Some(ix) = zeta_project - .recent_paths - .iter() - .position(|probe| probe == &path) - { - zeta_project.recent_paths.remove(ix); - } - zeta_project.recent_paths.push_front(path); - } - } - } - }), + _subscription: cx.subscribe(&project, Self::handle_project_event), }) } + fn handle_project_event( + &mut self, + project: Entity, + event: &project::Event, + cx: &mut Context, + ) { + // TODO [zeta2] init with recent paths + match event { + project::Event::ActiveEntryChanged(Some(active_entry_id)) => { + let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + let path = project.read(cx).path_for_entry(*active_entry_id, cx); + if let Some(path) = path { + if let Some(ix) = zeta_project + .recent_paths + .iter() + .position(|probe| probe == &path) + { + zeta_project.recent_paths.remove(ix); + } + zeta_project.recent_paths.push_front(path); + } + } + project::Event::DiagnosticsUpdated { .. } => { + self.refresh_prediction_from_diagnostics(project, cx); + } + _ => (), + } + } + fn register_buffer_impl<'a>( zeta_project: &'a mut ZetaProject, buffer: &Entity, @@ -650,16 +695,25 @@ impl Zeta { let project_state = self.projects.get(&project.entity_id())?; let CurrentEditPrediction { - requested_by_buffer_id, + requested_by, prediction, } = project_state.current_prediction.as_ref()?; if prediction.targets_buffer(buffer.read(cx)) { Some(BufferEditPrediction::Local { prediction }) - } else if *requested_by_buffer_id == buffer.entity_id() { - Some(BufferEditPrediction::Jump { prediction }) } else { - None + let show_jump = match requested_by { + PredictionRequestedBy::Buffer(requested_by_buffer_id) => { + requested_by_buffer_id == &buffer.entity_id() + } + PredictionRequestedBy::DiagnosticsUpdate => true, + }; + + if show_jump { + Some(BufferEditPrediction::Jump { prediction }) + } else { + None + } } } @@ -676,6 +730,7 @@ impl Zeta { return; }; let request_id = prediction.prediction.id.to_string(); + project_state.pending_predictions.clear(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); @@ -715,47 +770,191 @@ impl Zeta { fn discard_current_prediction(&mut self, project: &Entity) { if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { project_state.current_prediction.take(); + project_state.pending_predictions.clear(); }; } - pub fn refresh_prediction( + fn is_refreshing(&self, project: &Entity) -> bool { + self.projects + .get(&project.entity_id()) + .is_some_and(|project_state| !project_state.pending_predictions.is_empty()) + } + + pub fn refresh_prediction_from_buffer( &mut self, - project: &Entity, - buffer: &Entity, + project: Entity, + buffer: Entity, position: language::Anchor, cx: &mut Context, - ) -> Task> { - let request_task = self.request_prediction(project, buffer, position, cx); - let buffer = buffer.clone(); - let project = project.clone(); + ) { + self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| { + let Some(request_task) = this + .update(cx, |this, cx| { + this.request_prediction(&project, &buffer, position, cx) + }) + .log_err() + else { + return Task::ready(anyhow::Ok(())); + }; - cx.spawn(async move |this, cx| { - if let Some(prediction) = request_task.await? { - this.update(cx, |this, cx| { - let project_state = this - .projects - .get_mut(&project.entity_id()) - .context("Project not found")?; - - let new_prediction = CurrentEditPrediction { - requested_by_buffer_id: buffer.entity_id(), - prediction: prediction, - }; + let project = project.clone(); + cx.spawn(async move |cx| { + if let Some(prediction) = request_task.await? { + this.update(cx, |this, cx| { + let project_state = this + .projects + .get_mut(&project.entity_id()) + .context("Project not found")?; + + let new_prediction = CurrentEditPrediction { + requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()), + prediction: prediction, + }; - if project_state - .current_prediction - .as_ref() - .is_none_or(|old_prediction| { - new_prediction.should_replace_prediction(&old_prediction, cx) - }) - { - project_state.current_prediction = Some(new_prediction); + if project_state + .current_prediction + .as_ref() + .is_none_or(|old_prediction| { + new_prediction.should_replace_prediction(&old_prediction, cx) + }) + { + project_state.current_prediction = Some(new_prediction); + cx.notify(); + } + anyhow::Ok(()) + })??; + } + Ok(()) + }) + }) + } + + pub fn refresh_prediction_from_diagnostics( + &mut self, + project: Entity, + cx: &mut Context, + ) { + let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + + // Prefer predictions from buffer + if zeta_project.current_prediction.is_some() { + return; + }; + + self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| { + let Some(open_buffer_task) = project + .update(cx, |project, cx| { + project + .active_entry() + .and_then(|entry| project.path_for_entry(entry, cx)) + .map(|path| project.open_buffer(path, cx)) + }) + .log_err() + .flatten() + else { + return Task::ready(anyhow::Ok(())); + }; + + cx.spawn(async move |cx| { + let active_buffer = open_buffer_task.await?; + let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( + active_buffer, + &snapshot, + Default::default(), + Default::default(), + &project, + cx, + ) + .await? + else { + return anyhow::Ok(()); + }; + + let Some(prediction) = this + .update(cx, |this, cx| { + this.request_prediction(&project, &jump_buffer, jump_position, cx) + })? + .await? + else { + return anyhow::Ok(()); + }; + + this.update(cx, |this, cx| { + if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { + zeta_project.current_prediction.get_or_insert_with(|| { + cx.notify(); + CurrentEditPrediction { + requested_by: PredictionRequestedBy::DiagnosticsUpdate, + prediction, + } + }); } - anyhow::Ok(()) - })??; + })?; + + anyhow::Ok(()) + }) + }); + } + + #[cfg(not(test))] + pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); + #[cfg(test)] + pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO; + + fn queue_prediction_refresh( + &mut self, + project: Entity, + throttle_entity: EntityId, + cx: &mut Context, + do_refresh: impl FnOnce(WeakEntity, &mut AsyncApp) -> Task> + 'static, + ) { + let zeta_project = self.get_or_init_zeta_project(&project, cx); + let pending_prediction_id = zeta_project.next_pending_prediction_id; + zeta_project.next_pending_prediction_id += 1; + let last_request = zeta_project.last_prediction_refresh; + + // TODO report cancelled requests like in zeta1 + let task = cx.spawn(async move |this, cx| { + if let Some((last_entity, last_timestamp)) = last_request + && throttle_entity == last_entity + && let Some(timeout) = + (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now()) + { + cx.background_executor().timer(timeout).await; } - Ok(()) - }) + + do_refresh(this.clone(), cx).await.log_err(); + + this.update(cx, |this, cx| { + let zeta_project = this.get_or_init_zeta_project(&project, cx); + + if zeta_project.pending_predictions[0].id == pending_prediction_id { + zeta_project.pending_predictions.remove(0); + } else { + zeta_project.pending_predictions.clear(); + } + + cx.notify(); + }) + .ok(); + }); + + if zeta_project.pending_predictions.len() <= 1 { + zeta_project.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + _task: task, + }); + } else if zeta_project.pending_predictions.len() == 2 { + zeta_project.pending_predictions.pop(); + zeta_project.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + _task: task, + }); + } } pub fn request_prediction( @@ -770,7 +969,7 @@ impl Zeta { self.request_prediction_with_zed_cloud(project, active_buffer, position, cx) } ZetaEditPredictionModel::Sweep => { - self.request_prediction_with_sweep(project, active_buffer, position, cx) + self.request_prediction_with_sweep(project, active_buffer, position, true, cx) } } } @@ -780,6 +979,7 @@ impl Zeta { project: &Entity, active_buffer: &Entity, position: language::Anchor, + allow_jump: bool, cx: &mut Context, ) -> Task>> { let snapshot = active_buffer.read(cx).snapshot(); @@ -802,6 +1002,7 @@ impl Zeta { let project_state = self.get_or_init_zeta_project(project, cx); let events = project_state.events.clone(); + let has_events = !events.is_empty(); let recent_buffers = project_state.recent_paths.iter().cloned(); let http_client = cx.http_client(); @@ -817,114 +1018,188 @@ impl Zeta { .take(3) .collect::>(); - let result = cx.background_spawn(async move { - let text = snapshot.text(); + const DIAGNOSTIC_LINES_RANGE: u32 = 20; - let mut recent_changes = String::new(); - for event in events { - sweep_ai::write_event(event, &mut recent_changes).unwrap(); - } + let cursor_point = position.to_point(&snapshot); + let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE); + let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE; + let diagnostic_search_range = + Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0); + + let result = cx.background_spawn({ + let snapshot = snapshot.clone(); + let diagnostic_search_range = diagnostic_search_range.clone(); + async move { + let text = snapshot.text(); + + let mut recent_changes = String::new(); + for event in events { + sweep_ai::write_event(event, &mut recent_changes).unwrap(); + } + + let mut file_chunks = recent_buffer_snapshots + .into_iter() + .map(|snapshot| { + let end_point = Point::new(30, 0).min(snapshot.max_point()); + sweep_ai::FileChunk { + content: snapshot.text_for_range(Point::zero()..end_point).collect(), + file_path: snapshot + .file() + .map(|f| f.path().as_unix_str()) + .unwrap_or("untitled") + .to_string(), + start_line: 0, + end_line: end_point.row as usize, + timestamp: snapshot.file().and_then(|file| { + Some( + file.disk_state() + .mtime()? + .to_seconds_and_nanos_for_persistence()? + .0, + ) + }), + } + }) + .collect::>(); + + let diagnostic_entries = + snapshot.diagnostics_in_range(diagnostic_search_range, false); + let mut diagnostic_content = String::new(); + let mut diagnostic_count = 0; + + for entry in diagnostic_entries { + let start_point: Point = entry.range.start; + + let severity = match entry.diagnostic.severity { + DiagnosticSeverity::ERROR => "error", + DiagnosticSeverity::WARNING => "warning", + DiagnosticSeverity::INFORMATION => "info", + DiagnosticSeverity::HINT => "hint", + _ => continue, + }; + + diagnostic_count += 1; - let file_chunks = recent_buffer_snapshots - .into_iter() - .map(|snapshot| { - let end_point = language::Point::new(30, 0).min(snapshot.max_point()); - sweep_ai::FileChunk { - content: snapshot - .text_for_range(language::Point::zero()..end_point) - .collect(), - file_path: snapshot - .file() - .map(|f| f.path().as_unix_str()) - .unwrap_or("untitled") - .to_string(), + writeln!( + &mut diagnostic_content, + "{} at line {}: {}", + severity, + start_point.row + 1, + entry.diagnostic.message + )?; + } + + if !diagnostic_content.is_empty() { + file_chunks.push(sweep_ai::FileChunk { + file_path: format!("Diagnostics for {}", full_path.display()), start_line: 0, - end_line: end_point.row as usize, - timestamp: snapshot.file().and_then(|file| { - Some( - file.disk_state() - .mtime()? - .to_seconds_and_nanos_for_persistence()? - .0, - ) - }), - } - }) - .collect(); - - let request_body = sweep_ai::AutocompleteRequest { - debug_info, - repo_name, - file_path: full_path.clone(), - file_contents: text.clone(), - original_file_contents: text, - cursor_position: offset, - recent_changes: recent_changes.clone(), - changes_above_cursor: true, - multiple_suggestions: false, - branch: None, - file_chunks, - retrieval_chunks: vec![], - recent_user_actions: vec![], - // TODO - privacy_mode_enabled: false, - }; + end_line: diagnostic_count, + content: diagnostic_content, + timestamp: None, + }); + } - let mut buf: Vec = Vec::new(); - let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22); - serde_json::to_writer(writer, &request_body)?; - let body: AsyncBody = buf.into(); + let request_body = sweep_ai::AutocompleteRequest { + debug_info, + repo_name, + file_path: full_path.clone(), + file_contents: text.clone(), + original_file_contents: text, + cursor_position: offset, + recent_changes: recent_changes.clone(), + changes_above_cursor: true, + multiple_suggestions: false, + branch: None, + file_chunks, + retrieval_chunks: vec![], + recent_user_actions: vec![], + // TODO + privacy_mode_enabled: false, + }; - const SWEEP_API_URL: &str = - "https://autocomplete.sweep.dev/backend/next_edit_autocomplete"; + let mut buf: Vec = Vec::new(); + let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22); + serde_json::to_writer(writer, &request_body)?; + let body: AsyncBody = buf.into(); - let request = http_client::Request::builder() - .uri(SWEEP_API_URL) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_token)) - .header("Connection", "keep-alive") - .header("Content-Encoding", "br") - .method(Method::POST) - .body(body)?; + const SWEEP_API_URL: &str = + "https://autocomplete.sweep.dev/backend/next_edit_autocomplete"; - let mut response = http_client.send(request).await?; + let request = http_client::Request::builder() + .uri(SWEEP_API_URL) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_token)) + .header("Connection", "keep-alive") + .header("Content-Encoding", "br") + .method(Method::POST) + .body(body)?; - let mut body: Vec = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; + let mut response = http_client.send(request).await?; - if !response.status().is_success() { - anyhow::bail!( - "Request failed with status: {:?}\nBody: {}", - response.status(), - String::from_utf8_lossy(&body), - ); - }; + let mut body: Vec = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; - let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?; - - let old_text = snapshot - .text_for_range(response.start_index..response.end_index) - .collect::(); - let edits = language::text_diff(&old_text, &response.completion) - .into_iter() - .map(|(range, text)| { - ( - snapshot.anchor_after(response.start_index + range.start) - ..snapshot.anchor_before(response.start_index + range.end), - text, - ) - }) - .collect::>(); + if !response.status().is_success() { + anyhow::bail!( + "Request failed with status: {:?}\nBody: {}", + response.status(), + String::from_utf8_lossy(&body), + ); + }; + + let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?; - anyhow::Ok((response.autocomplete_id, edits, snapshot)) + let old_text = snapshot + .text_for_range(response.start_index..response.end_index) + .collect::(); + let edits = language::text_diff(&old_text, &response.completion) + .into_iter() + .map(|(range, text)| { + ( + snapshot.anchor_after(response.start_index + range.start) + ..snapshot.anchor_before(response.start_index + range.end), + text, + ) + }) + .collect::>(); + + anyhow::Ok((response.autocomplete_id, edits, snapshot)) + } }); let buffer = active_buffer.clone(); + let project = project.clone(); + let active_buffer = active_buffer.clone(); - cx.spawn(async move |_, cx| { + cx.spawn(async move |this, cx| { let (id, edits, old_snapshot) = result.await?; if edits.is_empty() { + if has_events + && allow_jump + && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( + active_buffer, + &snapshot, + diagnostic_search_range, + cursor_point, + &project, + cx, + ) + .await? + { + return this + .update(cx, |this, cx| { + this.request_prediction_with_sweep( + &project, + &jump_buffer, + jump_position, + false, + cx, + ) + })? + .await; + } + return anyhow::Ok(None); } @@ -955,6 +1230,85 @@ impl Zeta { }) } + async fn next_diagnostic_location( + active_buffer: Entity, + active_buffer_snapshot: &BufferSnapshot, + active_buffer_diagnostic_search_range: Range, + active_buffer_cursor_point: Point, + project: &Entity, + cx: &mut AsyncApp, + ) -> Result, language::Anchor)>> { + // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request + let mut jump_location = active_buffer_snapshot + .diagnostic_groups(None) + .into_iter() + .filter_map(|(_, group)| { + let range = &group.entries[group.primary_ix] + .range + .to_point(&active_buffer_snapshot); + if range.overlaps(&active_buffer_diagnostic_search_range) { + None + } else { + Some(range.start) + } + }) + .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row)) + .map(|position| { + ( + active_buffer.clone(), + active_buffer_snapshot.anchor_before(position), + ) + }); + + if jump_location.is_none() { + let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| { + let file = buffer.file()?; + + Some(ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + }) + })?; + + let buffer_task = project.update(cx, |project, cx| { + let (path, _, _) = project + .diagnostic_summaries(false, cx) + .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref()) + .max_by_key(|(path, _, _)| { + // find the buffer with errors that shares most parent directories + path.path + .components() + .zip( + active_buffer_path + .as_ref() + .map(|p| p.path.components()) + .unwrap_or_default(), + ) + .take_while(|(a, b)| a == b) + .count() + })?; + + Some(project.open_buffer(path, cx)) + })?; + + if let Some(buffer_task) = buffer_task { + let closest_buffer = buffer_task.await?; + + jump_location = closest_buffer + .read_with(cx, |buffer, _cx| { + buffer + .buffer_diagnostics(None) + .into_iter() + .min_by_key(|entry| entry.diagnostic.severity) + .map(|entry| entry.range.start) + })? + .map(|position| (closest_buffer, position)); + } + } + + anyhow::Ok(jump_location) + } + fn request_prediction_with_zed_cloud( &mut self, project: &Entity, @@ -2168,8 +2522,8 @@ mod tests { // Prediction for current file - let prediction_task = zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction(&project, &buffer1, position, cx) + zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) }); let (_request, respond_tx) = req_rx.next().await.unwrap(); @@ -2184,7 +2538,8 @@ mod tests { Bye "})) .unwrap(); - prediction_task.await.unwrap(); + + cx.run_until_parked(); zeta.read_with(cx, |zeta, cx| { let prediction = zeta @@ -2242,8 +2597,8 @@ mod tests { }); // Prediction for another file - let prediction_task = zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction(&project, &buffer1, position, cx) + zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) }); let (_request, respond_tx) = req_rx.next().await.unwrap(); respond_tx @@ -2256,7 +2611,8 @@ mod tests { Adios "#})) .unwrap(); - prediction_task.await.unwrap(); + cx.run_until_parked(); + zeta.read_with(cx, |zeta, cx| { let prediction = zeta .current_prediction_for_buffer(&buffer1, &project, cx) diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 756fff5d621a85f7936a980d71f68c87098c4539..8758857e7cf50d6a5f2e5a4ea509293b18a8cb2c 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -1,6 +1,6 @@ mod zeta2_context_view; -use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc, time::Duration}; +use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc}; use chrono::TimeDelta; use client::{Client, UserStore}; @@ -237,24 +237,13 @@ impl Zeta2Inspector { fn set_zeta_options(&mut self, options: ZetaOptions, cx: &mut Context) { self.zeta.update(cx, |this, _cx| this.set_options(options)); - const DEBOUNCE_TIME: Duration = Duration::from_millis(100); - if let Some(prediction) = self.last_prediction.as_mut() { if let Some(buffer) = prediction.buffer.upgrade() { let position = prediction.position; - let zeta = self.zeta.clone(); let project = self.project.clone(); - prediction._task = Some(cx.spawn(async move |_this, cx| { - cx.background_executor().timer(DEBOUNCE_TIME).await; - if let Some(task) = zeta - .update(cx, |zeta, cx| { - zeta.refresh_prediction(&project, &buffer, position, cx) - }) - .ok() - { - task.await.log_err(); - } - })); + self.zeta.update(cx, |zeta, cx| { + zeta.refresh_prediction_from_buffer(project, buffer, position, cx) + }); prediction.state = LastPredictionState::Requested; } else { self.last_prediction.take();