@@ -1,24 +1,15 @@
-use std::{
- cmp,
- sync::Arc,
- time::{Duration, Instant},
-};
+use std::{cmp, sync::Arc, time::Duration};
-use arrayvec::ArrayVec;
use client::{Client, UserStore};
use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
-use gpui::{App, Entity, Task, prelude::*};
+use gpui::{App, Entity, prelude::*};
use language::ToPoint as _;
use project::Project;
-use util::ResultExt as _;
use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
pub struct ZetaEditPredictionProvider {
zeta: Entity<Zeta>,
- next_pending_prediction_id: usize,
- pending_predictions: ArrayVec<PendingPrediction, 2>,
- last_request_timestamp: Instant,
project: Entity<Project>,
}
@@ -29,28 +20,25 @@ impl ZetaEditPredictionProvider {
project: Entity<Project>,
client: &Arc<Client>,
user_store: &Entity<UserStore>,
- cx: &mut App,
+ cx: &mut Context<Self>,
) -> Self {
let zeta = Zeta::global(client, user_store, cx);
zeta.update(cx, |zeta, cx| {
zeta.register_project(&project, cx);
});
+ cx.observe(&zeta, |_this, _zeta, cx| {
+ cx.notify();
+ })
+ .detach();
+
Self {
- zeta,
- next_pending_prediction_id: 0,
- pending_predictions: ArrayVec::new(),
- last_request_timestamp: Instant::now(),
project: project,
+ zeta,
}
}
}
-struct PendingPrediction {
- id: usize,
- _task: Task<()>,
-}
-
impl EditPredictionProvider for ZetaEditPredictionProvider {
fn name() -> &'static str {
"zed-predict2"
@@ -95,8 +83,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
}
}
- fn is_refreshing(&self) -> bool {
- !self.pending_predictions.is_empty()
+ fn is_refreshing(&self, cx: &App) -> bool {
+ self.zeta.read(cx).is_refreshing(&self.project)
}
fn refresh(
@@ -123,59 +111,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
self.zeta.update(cx, |zeta, cx| {
zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
+ zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
});
-
- let pending_prediction_id = self.next_pending_prediction_id;
- self.next_pending_prediction_id += 1;
- let last_request_timestamp = self.last_request_timestamp;
-
- let project = self.project.clone();
- let task = cx.spawn(async move |this, cx| {
- if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
- .checked_duration_since(Instant::now())
- {
- cx.background_executor().timer(timeout).await;
- }
-
- let refresh_task = this.update(cx, |this, cx| {
- this.last_request_timestamp = Instant::now();
- this.zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
- })
- });
-
- if let Some(refresh_task) = refresh_task.ok() {
- refresh_task.await.log_err();
- }
-
- this.update(cx, |this, cx| {
- if this.pending_predictions[0].id == pending_prediction_id {
- this.pending_predictions.remove(0);
- } else {
- this.pending_predictions.clear();
- }
-
- cx.notify();
- })
- .ok();
- });
-
- // We always maintain at most two pending predictions. When we already
- // have two, we replace the newest one.
- if self.pending_predictions.len() <= 1 {
- self.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- _task: task,
- });
- } else if self.pending_predictions.len() == 2 {
- self.pending_predictions.pop();
- self.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- _task: task,
- });
- }
-
- cx.notify();
}
fn cycle(
@@ -191,14 +128,12 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
self.zeta.update(cx, |zeta, cx| {
zeta.accept_current_prediction(&self.project, cx);
});
- self.pending_predictions.clear();
}
fn discard(&mut self, cx: &mut Context<Self>) {
self.zeta.update(cx, |zeta, _cx| {
zeta.discard_current_prediction(&self.project);
});
- self.pending_predictions.clear();
}
fn suggest(
@@ -1,4 +1,5 @@
use anyhow::{Context as _, Result, anyhow, bail};
+use arrayvec::ArrayVec;
use chrono::TimeDelta;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
@@ -19,18 +20,20 @@ use futures::AsyncReadExt as _;
use futures::channel::{mpsc, oneshot};
use gpui::http_client::{AsyncBody, Method};
use gpui::{
- App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
- http_client, prelude::*,
+ App, AsyncApp, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task,
+ WeakEntity, http_client, prelude::*,
};
use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
+use lsp::DiagnosticSeverity;
use open_ai::FunctionDefinition;
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use serde::de::DeserializeOwned;
use std::collections::{VecDeque, hash_map};
+use std::fmt::Write;
use std::ops::Range;
use std::path::Path;
use std::str::FromStr as _;
@@ -39,7 +42,7 @@ use std::time::{Duration, Instant};
use std::{env, mem};
use thiserror::Error;
use util::rel_path::RelPathBuf;
-use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
+use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
pub mod assemble_excerpts;
@@ -239,6 +242,9 @@ struct ZetaProject {
recent_paths: VecDeque<ProjectPath>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
+ next_pending_prediction_id: usize,
+ pending_predictions: ArrayVec<PendingPrediction, 2>,
+ last_prediction_refresh: Option<(EntityId, Instant)>,
context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
refresh_context_debounce_task: Option<Task<Option<()>>>,
@@ -248,7 +254,7 @@ struct ZetaProject {
#[derive(Debug, Clone)]
struct CurrentEditPrediction {
- pub requested_by_buffer_id: EntityId,
+ pub requested_by: PredictionRequestedBy,
pub prediction: EditPrediction,
}
@@ -272,11 +278,13 @@ impl CurrentEditPrediction {
return true;
};
+ let requested_by_buffer_id = self.requested_by.buffer_id();
+
// This reduces the occurrence of UI thrash from replacing edits
//
// TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
- if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
- && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
+ if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
+ && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
&& old_edits.len() == 1
&& new_edits.len() == 1
{
@@ -289,6 +297,26 @@ impl CurrentEditPrediction {
}
}
+#[derive(Debug, Clone)]
+enum PredictionRequestedBy {
+ DiagnosticsUpdate,
+ Buffer(EntityId),
+}
+
+impl PredictionRequestedBy {
+ pub fn buffer_id(&self) -> Option<EntityId> {
+ match self {
+ PredictionRequestedBy::DiagnosticsUpdate => None,
+ PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
+ }
+ }
+}
+
+struct PendingPrediction {
+ id: usize,
+ _task: Task<()>,
+}
+
/// A prediction from the perspective of a buffer.
#[derive(Debug)]
enum BufferEditPrediction<'a> {
@@ -513,31 +541,48 @@ impl Zeta {
recent_paths: VecDeque::new(),
registered_buffers: HashMap::default(),
current_prediction: None,
+ pending_predictions: ArrayVec::new(),
+ next_pending_prediction_id: 0,
+ last_prediction_refresh: None,
context: None,
refresh_context_task: None,
refresh_context_debounce_task: None,
refresh_context_timestamp: None,
- _subscription: cx.subscribe(&project, |this, project, event, cx| {
- // TODO [zeta2] init with recent paths
- if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
- if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event {
- let path = project.read(cx).path_for_entry(*active_entry_id, cx);
- if let Some(path) = path {
- if let Some(ix) = zeta_project
- .recent_paths
- .iter()
- .position(|probe| probe == &path)
- {
- zeta_project.recent_paths.remove(ix);
- }
- zeta_project.recent_paths.push_front(path);
- }
- }
- }
- }),
+ _subscription: cx.subscribe(&project, Self::handle_project_event),
})
}
+ fn handle_project_event(
+ &mut self,
+ project: Entity<Project>,
+ event: &project::Event,
+ cx: &mut Context<Self>,
+ ) {
+ // TODO [zeta2] init with recent paths
+ match event {
+ project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
+ let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+ let path = project.read(cx).path_for_entry(*active_entry_id, cx);
+ if let Some(path) = path {
+ if let Some(ix) = zeta_project
+ .recent_paths
+ .iter()
+ .position(|probe| probe == &path)
+ {
+ zeta_project.recent_paths.remove(ix);
+ }
+ zeta_project.recent_paths.push_front(path);
+ }
+ }
+ project::Event::DiagnosticsUpdated { .. } => {
+ self.refresh_prediction_from_diagnostics(project, cx);
+ }
+ _ => (),
+ }
+ }
+
fn register_buffer_impl<'a>(
zeta_project: &'a mut ZetaProject,
buffer: &Entity<Buffer>,
@@ -650,16 +695,25 @@ impl Zeta {
let project_state = self.projects.get(&project.entity_id())?;
let CurrentEditPrediction {
- requested_by_buffer_id,
+ requested_by,
prediction,
} = project_state.current_prediction.as_ref()?;
if prediction.targets_buffer(buffer.read(cx)) {
Some(BufferEditPrediction::Local { prediction })
- } else if *requested_by_buffer_id == buffer.entity_id() {
- Some(BufferEditPrediction::Jump { prediction })
} else {
- None
+ let show_jump = match requested_by {
+ PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
+ requested_by_buffer_id == &buffer.entity_id()
+ }
+ PredictionRequestedBy::DiagnosticsUpdate => true,
+ };
+
+ if show_jump {
+ Some(BufferEditPrediction::Jump { prediction })
+ } else {
+ None
+ }
}
}
@@ -676,6 +730,7 @@ impl Zeta {
return;
};
let request_id = prediction.prediction.id.to_string();
+ project_state.pending_predictions.clear();
let client = self.client.clone();
let llm_token = self.llm_token.clone();
@@ -715,47 +770,191 @@ impl Zeta {
fn discard_current_prediction(&mut self, project: &Entity<Project>) {
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
project_state.current_prediction.take();
+ project_state.pending_predictions.clear();
};
}
- pub fn refresh_prediction(
+ fn is_refreshing(&self, project: &Entity<Project>) -> bool {
+ self.projects
+ .get(&project.entity_id())
+ .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
+ }
+
+ pub fn refresh_prediction_from_buffer(
&mut self,
- project: &Entity<Project>,
- buffer: &Entity<Buffer>,
+ project: Entity<Project>,
+ buffer: Entity<Buffer>,
position: language::Anchor,
cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- let request_task = self.request_prediction(project, buffer, position, cx);
- let buffer = buffer.clone();
- let project = project.clone();
+ ) {
+ self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
+ let Some(request_task) = this
+ .update(cx, |this, cx| {
+ this.request_prediction(&project, &buffer, position, cx)
+ })
+ .log_err()
+ else {
+ return Task::ready(anyhow::Ok(()));
+ };
- cx.spawn(async move |this, cx| {
- if let Some(prediction) = request_task.await? {
- this.update(cx, |this, cx| {
- let project_state = this
- .projects
- .get_mut(&project.entity_id())
- .context("Project not found")?;
-
- let new_prediction = CurrentEditPrediction {
- requested_by_buffer_id: buffer.entity_id(),
- prediction: prediction,
- };
+ let project = project.clone();
+ cx.spawn(async move |cx| {
+ if let Some(prediction) = request_task.await? {
+ this.update(cx, |this, cx| {
+ let project_state = this
+ .projects
+ .get_mut(&project.entity_id())
+ .context("Project not found")?;
+
+ let new_prediction = CurrentEditPrediction {
+ requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
+ prediction: prediction,
+ };
- if project_state
- .current_prediction
- .as_ref()
- .is_none_or(|old_prediction| {
- new_prediction.should_replace_prediction(&old_prediction, cx)
- })
- {
- project_state.current_prediction = Some(new_prediction);
+ if project_state
+ .current_prediction
+ .as_ref()
+ .is_none_or(|old_prediction| {
+ new_prediction.should_replace_prediction(&old_prediction, cx)
+ })
+ {
+ project_state.current_prediction = Some(new_prediction);
+ cx.notify();
+ }
+ anyhow::Ok(())
+ })??;
+ }
+ Ok(())
+ })
+ })
+ }
+
+ pub fn refresh_prediction_from_diagnostics(
+ &mut self,
+ project: Entity<Project>,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+
+ // Prefer predictions from buffer
+ if zeta_project.current_prediction.is_some() {
+ return;
+ };
+
+ self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
+ let Some(open_buffer_task) = project
+ .update(cx, |project, cx| {
+ project
+ .active_entry()
+ .and_then(|entry| project.path_for_entry(entry, cx))
+ .map(|path| project.open_buffer(path, cx))
+ })
+ .log_err()
+ .flatten()
+ else {
+ return Task::ready(anyhow::Ok(()));
+ };
+
+ cx.spawn(async move |cx| {
+ let active_buffer = open_buffer_task.await?;
+ let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+ let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
+ active_buffer,
+ &snapshot,
+ Default::default(),
+ Default::default(),
+ &project,
+ cx,
+ )
+ .await?
+ else {
+ return anyhow::Ok(());
+ };
+
+ let Some(prediction) = this
+ .update(cx, |this, cx| {
+ this.request_prediction(&project, &jump_buffer, jump_position, cx)
+ })?
+ .await?
+ else {
+ return anyhow::Ok(());
+ };
+
+ this.update(cx, |this, cx| {
+ if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
+ zeta_project.current_prediction.get_or_insert_with(|| {
+ cx.notify();
+ CurrentEditPrediction {
+ requested_by: PredictionRequestedBy::DiagnosticsUpdate,
+ prediction,
+ }
+ });
}
- anyhow::Ok(())
- })??;
+ })?;
+
+ anyhow::Ok(())
+ })
+ });
+ }
+
+ #[cfg(not(test))]
+ pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
+ #[cfg(test)]
+ pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
+
+ fn queue_prediction_refresh(
+ &mut self,
+ project: Entity<Project>,
+ throttle_entity: EntityId,
+ cx: &mut Context<Self>,
+ do_refresh: impl FnOnce(WeakEntity<Self>, &mut AsyncApp) -> Task<Result<()>> + 'static,
+ ) {
+ let zeta_project = self.get_or_init_zeta_project(&project, cx);
+ let pending_prediction_id = zeta_project.next_pending_prediction_id;
+ zeta_project.next_pending_prediction_id += 1;
+ let last_request = zeta_project.last_prediction_refresh;
+
+ // TODO report cancelled requests like in zeta1
+ let task = cx.spawn(async move |this, cx| {
+ if let Some((last_entity, last_timestamp)) = last_request
+ && throttle_entity == last_entity
+ && let Some(timeout) =
+ (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
+ {
+ cx.background_executor().timer(timeout).await;
}
- Ok(())
- })
+
+ do_refresh(this.clone(), cx).await.log_err();
+
+ this.update(cx, |this, cx| {
+ let zeta_project = this.get_or_init_zeta_project(&project, cx);
+
+ if zeta_project.pending_predictions[0].id == pending_prediction_id {
+ zeta_project.pending_predictions.remove(0);
+ } else {
+ zeta_project.pending_predictions.clear();
+ }
+
+ cx.notify();
+ })
+ .ok();
+ });
+
+ if zeta_project.pending_predictions.len() <= 1 {
+ zeta_project.pending_predictions.push(PendingPrediction {
+ id: pending_prediction_id,
+ _task: task,
+ });
+ } else if zeta_project.pending_predictions.len() == 2 {
+ zeta_project.pending_predictions.pop();
+ zeta_project.pending_predictions.push(PendingPrediction {
+ id: pending_prediction_id,
+ _task: task,
+ });
+ }
}
pub fn request_prediction(
@@ -770,7 +969,7 @@ impl Zeta {
self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
}
ZetaEditPredictionModel::Sweep => {
- self.request_prediction_with_sweep(project, active_buffer, position, cx)
+ self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
}
}
}
@@ -780,6 +979,7 @@ impl Zeta {
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
position: language::Anchor,
+ allow_jump: bool,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
let snapshot = active_buffer.read(cx).snapshot();
@@ -802,6 +1002,7 @@ impl Zeta {
let project_state = self.get_or_init_zeta_project(project, cx);
let events = project_state.events.clone();
+ let has_events = !events.is_empty();
let recent_buffers = project_state.recent_paths.iter().cloned();
let http_client = cx.http_client();
@@ -817,114 +1018,188 @@ impl Zeta {
.take(3)
.collect::<Vec<_>>();
- let result = cx.background_spawn(async move {
- let text = snapshot.text();
+ const DIAGNOSTIC_LINES_RANGE: u32 = 20;
- let mut recent_changes = String::new();
- for event in events {
- sweep_ai::write_event(event, &mut recent_changes).unwrap();
- }
+ let cursor_point = position.to_point(&snapshot);
+ let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
+ let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
+ let diagnostic_search_range =
+ Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
+
+ let result = cx.background_spawn({
+ let snapshot = snapshot.clone();
+ let diagnostic_search_range = diagnostic_search_range.clone();
+ async move {
+ let text = snapshot.text();
+
+ let mut recent_changes = String::new();
+ for event in events {
+ sweep_ai::write_event(event, &mut recent_changes).unwrap();
+ }
+
+ let mut file_chunks = recent_buffer_snapshots
+ .into_iter()
+ .map(|snapshot| {
+ let end_point = Point::new(30, 0).min(snapshot.max_point());
+ sweep_ai::FileChunk {
+ content: snapshot.text_for_range(Point::zero()..end_point).collect(),
+ file_path: snapshot
+ .file()
+ .map(|f| f.path().as_unix_str())
+ .unwrap_or("untitled")
+ .to_string(),
+ start_line: 0,
+ end_line: end_point.row as usize,
+ timestamp: snapshot.file().and_then(|file| {
+ Some(
+ file.disk_state()
+ .mtime()?
+ .to_seconds_and_nanos_for_persistence()?
+ .0,
+ )
+ }),
+ }
+ })
+ .collect::<Vec<_>>();
+
+ let diagnostic_entries =
+ snapshot.diagnostics_in_range(diagnostic_search_range, false);
+ let mut diagnostic_content = String::new();
+ let mut diagnostic_count = 0;
+
+ for entry in diagnostic_entries {
+ let start_point: Point = entry.range.start;
+
+ let severity = match entry.diagnostic.severity {
+ DiagnosticSeverity::ERROR => "error",
+ DiagnosticSeverity::WARNING => "warning",
+ DiagnosticSeverity::INFORMATION => "info",
+ DiagnosticSeverity::HINT => "hint",
+ _ => continue,
+ };
+
+ diagnostic_count += 1;
- let file_chunks = recent_buffer_snapshots
- .into_iter()
- .map(|snapshot| {
- let end_point = language::Point::new(30, 0).min(snapshot.max_point());
- sweep_ai::FileChunk {
- content: snapshot
- .text_for_range(language::Point::zero()..end_point)
- .collect(),
- file_path: snapshot
- .file()
- .map(|f| f.path().as_unix_str())
- .unwrap_or("untitled")
- .to_string(),
+ writeln!(
+ &mut diagnostic_content,
+ "{} at line {}: {}",
+ severity,
+ start_point.row + 1,
+ entry.diagnostic.message
+ )?;
+ }
+
+ if !diagnostic_content.is_empty() {
+ file_chunks.push(sweep_ai::FileChunk {
+ file_path: format!("Diagnostics for {}", full_path.display()),
start_line: 0,
- end_line: end_point.row as usize,
- timestamp: snapshot.file().and_then(|file| {
- Some(
- file.disk_state()
- .mtime()?
- .to_seconds_and_nanos_for_persistence()?
- .0,
- )
- }),
- }
- })
- .collect();
-
- let request_body = sweep_ai::AutocompleteRequest {
- debug_info,
- repo_name,
- file_path: full_path.clone(),
- file_contents: text.clone(),
- original_file_contents: text,
- cursor_position: offset,
- recent_changes: recent_changes.clone(),
- changes_above_cursor: true,
- multiple_suggestions: false,
- branch: None,
- file_chunks,
- retrieval_chunks: vec![],
- recent_user_actions: vec![],
- // TODO
- privacy_mode_enabled: false,
- };
+ end_line: diagnostic_count,
+ content: diagnostic_content,
+ timestamp: None,
+ });
+ }
- let mut buf: Vec<u8> = Vec::new();
- let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
- serde_json::to_writer(writer, &request_body)?;
- let body: AsyncBody = buf.into();
+ let request_body = sweep_ai::AutocompleteRequest {
+ debug_info,
+ repo_name,
+ file_path: full_path.clone(),
+ file_contents: text.clone(),
+ original_file_contents: text,
+ cursor_position: offset,
+ recent_changes: recent_changes.clone(),
+ changes_above_cursor: true,
+ multiple_suggestions: false,
+ branch: None,
+ file_chunks,
+ retrieval_chunks: vec![],
+ recent_user_actions: vec![],
+ // TODO
+ privacy_mode_enabled: false,
+ };
- const SWEEP_API_URL: &str =
- "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
+ let mut buf: Vec<u8> = Vec::new();
+ let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
+ serde_json::to_writer(writer, &request_body)?;
+ let body: AsyncBody = buf.into();
- let request = http_client::Request::builder()
- .uri(SWEEP_API_URL)
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_token))
- .header("Connection", "keep-alive")
- .header("Content-Encoding", "br")
- .method(Method::POST)
- .body(body)?;
+ const SWEEP_API_URL: &str =
+ "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
- let mut response = http_client.send(request).await?;
+ let request = http_client::Request::builder()
+ .uri(SWEEP_API_URL)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_token))
+ .header("Connection", "keep-alive")
+ .header("Content-Encoding", "br")
+ .method(Method::POST)
+ .body(body)?;
- let mut body: Vec<u8> = Vec::new();
- response.body_mut().read_to_end(&mut body).await?;
+ let mut response = http_client.send(request).await?;
- if !response.status().is_success() {
- anyhow::bail!(
- "Request failed with status: {:?}\nBody: {}",
- response.status(),
- String::from_utf8_lossy(&body),
- );
- };
+ let mut body: Vec<u8> = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
- let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
-
- let old_text = snapshot
- .text_for_range(response.start_index..response.end_index)
- .collect::<String>();
- let edits = language::text_diff(&old_text, &response.completion)
- .into_iter()
- .map(|(range, text)| {
- (
- snapshot.anchor_after(response.start_index + range.start)
- ..snapshot.anchor_before(response.start_index + range.end),
- text,
- )
- })
- .collect::<Vec<_>>();
+ if !response.status().is_success() {
+ anyhow::bail!(
+ "Request failed with status: {:?}\nBody: {}",
+ response.status(),
+ String::from_utf8_lossy(&body),
+ );
+ };
+
+ let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
- anyhow::Ok((response.autocomplete_id, edits, snapshot))
+ let old_text = snapshot
+ .text_for_range(response.start_index..response.end_index)
+ .collect::<String>();
+ let edits = language::text_diff(&old_text, &response.completion)
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ snapshot.anchor_after(response.start_index + range.start)
+ ..snapshot.anchor_before(response.start_index + range.end),
+ text,
+ )
+ })
+ .collect::<Vec<_>>();
+
+ anyhow::Ok((response.autocomplete_id, edits, snapshot))
+ }
});
let buffer = active_buffer.clone();
+ let project = project.clone();
+ let active_buffer = active_buffer.clone();
- cx.spawn(async move |_, cx| {
+ cx.spawn(async move |this, cx| {
let (id, edits, old_snapshot) = result.await?;
if edits.is_empty() {
+ if has_events
+ && allow_jump
+ && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
+ active_buffer,
+ &snapshot,
+ diagnostic_search_range,
+ cursor_point,
+ &project,
+ cx,
+ )
+ .await?
+ {
+ return this
+ .update(cx, |this, cx| {
+ this.request_prediction_with_sweep(
+ &project,
+ &jump_buffer,
+ jump_position,
+ false,
+ cx,
+ )
+ })?
+ .await;
+ }
+
return anyhow::Ok(None);
}
@@ -955,6 +1230,85 @@ impl Zeta {
})
}
+ async fn next_diagnostic_location(
+ active_buffer: Entity<Buffer>,
+ active_buffer_snapshot: &BufferSnapshot,
+ active_buffer_diagnostic_search_range: Range<Point>,
+ active_buffer_cursor_point: Point,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+ ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
+ // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
+ let mut jump_location = active_buffer_snapshot
+ .diagnostic_groups(None)
+ .into_iter()
+ .filter_map(|(_, group)| {
+ let range = &group.entries[group.primary_ix]
+ .range
+ .to_point(&active_buffer_snapshot);
+ if range.overlaps(&active_buffer_diagnostic_search_range) {
+ None
+ } else {
+ Some(range.start)
+ }
+ })
+ .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
+ .map(|position| {
+ (
+ active_buffer.clone(),
+ active_buffer_snapshot.anchor_before(position),
+ )
+ });
+
+ if jump_location.is_none() {
+ let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
+ let file = buffer.file()?;
+
+ Some(ProjectPath {
+ worktree_id: file.worktree_id(cx),
+ path: file.path().clone(),
+ })
+ })?;
+
+ let buffer_task = project.update(cx, |project, cx| {
+ let (path, _, _) = project
+ .diagnostic_summaries(false, cx)
+ .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
+ .max_by_key(|(path, _, _)| {
+ // find the buffer with errors that shares most parent directories
+ path.path
+ .components()
+ .zip(
+ active_buffer_path
+ .as_ref()
+ .map(|p| p.path.components())
+ .unwrap_or_default(),
+ )
+ .take_while(|(a, b)| a == b)
+ .count()
+ })?;
+
+ Some(project.open_buffer(path, cx))
+ })?;
+
+ if let Some(buffer_task) = buffer_task {
+ let closest_buffer = buffer_task.await?;
+
+ jump_location = closest_buffer
+ .read_with(cx, |buffer, _cx| {
+ buffer
+ .buffer_diagnostics(None)
+ .into_iter()
+ .min_by_key(|entry| entry.diagnostic.severity)
+ .map(|entry| entry.range.start)
+ })?
+ .map(|position| (closest_buffer, position));
+ }
+ }
+
+ anyhow::Ok(jump_location)
+ }
+
fn request_prediction_with_zed_cloud(
&mut self,
project: &Entity<Project>,
@@ -2168,8 +2522,8 @@ mod tests {
// Prediction for current file
- let prediction_task = zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction(&project, &buffer1, position, cx)
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
let (_request, respond_tx) = req_rx.next().await.unwrap();
@@ -2184,7 +2538,8 @@ mod tests {
Bye
"}))
.unwrap();
- prediction_task.await.unwrap();
+
+ cx.run_until_parked();
zeta.read_with(cx, |zeta, cx| {
let prediction = zeta
@@ -2242,8 +2597,8 @@ mod tests {
});
// Prediction for another file
- let prediction_task = zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction(&project, &buffer1, position, cx)
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
let (_request, respond_tx) = req_rx.next().await.unwrap();
respond_tx
@@ -2256,7 +2611,8 @@ mod tests {
Adios
"#}))
.unwrap();
- prediction_task.await.unwrap();
+ cx.run_until_parked();
+
zeta.read_with(cx, |zeta, cx| {
let prediction = zeta
.current_prediction_for_buffer(&buffer1, &project, cx)