diff --git a/crates/zeta2/src/prediction.rs b/crates/zeta2/src/prediction.rs new file mode 100644 index 0000000000000000000000000000000000000000..d7b3c584a0324869921bff2868838a7ba09585ac --- /dev/null +++ b/crates/zeta2/src/prediction.rs @@ -0,0 +1,345 @@ +use std::{borrow::Cow, ops::Range, sync::Arc}; + +use cloud_llm_client::predict_edits_v3; +use language::{Anchor, BufferSnapshot, EditPreview, OffsetRangeExt, text_diff}; +use uuid::Uuid; + +#[derive(Clone)] +pub struct EditPrediction { + pub id: EditPredictionId, + pub edits: Arc<[(Range, String)]>, + pub snapshot: BufferSnapshot, + pub edit_preview: EditPreview, +} + +impl EditPrediction { + pub fn interpolate( + &self, + new_snapshot: &BufferSnapshot, + ) -> Option, String)>> { + interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone()) + } +} + +#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] +pub struct EditPredictionId(Uuid); + +impl From for EditPredictionId { + fn from(value: Uuid) -> Self { + EditPredictionId(value) + } +} + +impl From for gpui::ElementId { + fn from(value: EditPredictionId) -> Self { + gpui::ElementId::Uuid(value.0) + } +} + +impl std::fmt::Display for EditPredictionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +pub fn interpolate_edits( + old_snapshot: &BufferSnapshot, + new_snapshot: &BufferSnapshot, + current_edits: Arc<[(Range, String)]>, +) -> Option, String)>> { + let mut edits = Vec::new(); + + let mut model_edits = current_edits.iter().peekable(); + for user_edit in new_snapshot.edits_since::(&old_snapshot.version) { + while let Some((model_old_range, _)) = model_edits.peek() { + let model_old_range = model_old_range.to_offset(old_snapshot); + if model_old_range.end < user_edit.old.start { + let (model_old_range, model_new_text) = model_edits.next().unwrap(); + edits.push((model_old_range.clone(), model_new_text.clone())); + } else { + break; + } + } + + if let Some((model_old_range, model_new_text)) = model_edits.peek() { + let model_old_offset_range = model_old_range.to_offset(old_snapshot); + if user_edit.old == model_old_offset_range { + let user_new_text = new_snapshot + .text_for_range(user_edit.new.clone()) + .collect::(); + + if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) { + if !model_suffix.is_empty() { + let anchor = old_snapshot.anchor_after(user_edit.old.end); + edits.push((anchor..anchor, model_suffix.to_string())); + } + + model_edits.next(); + continue; + } + } + } + + return None; + } + + edits.extend(model_edits.cloned()); + + if edits.is_empty() { None } else { Some(edits) } +} + +pub fn edits_from_response( + edits: &[predict_edits_v3::Edit], + snapshot: &BufferSnapshot, +) -> Arc<[(Range, String)]> { + edits + .iter() + .flat_map(|edit| { + // TODO multi-file edits + let old_text = snapshot.text_for_range(edit.range.clone()); + + excerpt_edits_from_response( + old_text.collect::>(), + &edit.content, + edit.range.start, + &snapshot, + ) + }) + .collect::>() + .into() +} + +fn excerpt_edits_from_response( + old_text: Cow, + new_text: &str, + offset: usize, + snapshot: &BufferSnapshot, +) -> impl Iterator, String)> { + text_diff(&old_text, new_text) + .into_iter() + .map(move |(mut old_range, new_text)| { + old_range.start += offset; + old_range.end += offset; + + let prefix_len = common_prefix( + snapshot.chars_for_range(old_range.clone()), + new_text.chars(), + ); + old_range.start += prefix_len; + + let suffix_len = common_prefix( + snapshot.reversed_chars_for_range(old_range.clone()), + new_text[prefix_len..].chars().rev(), + ); + old_range.end = old_range.end.saturating_sub(suffix_len); + + let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string(); + let range = if old_range.is_empty() { + let anchor = snapshot.anchor_after(old_range.start); + anchor..anchor + } else { + snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end) + }; + (range, new_text) + }) +} + +fn common_prefix, T2: Iterator>(a: T1, b: T2) -> usize { + a.zip(b) + .take_while(|(a, b)| a == b) + .map(|(a, _)| a.len_utf8()) + .sum() +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use super::*; + use cloud_llm_client::predict_edits_v3; + use gpui::{App, Entity, TestAppContext, prelude::*}; + use indoc::indoc; + use language::{Buffer, ToOffset as _}; + + #[gpui::test] + async fn test_compute_edits(cx: &mut TestAppContext) { + let old = indoc! {r#" + fn main() { + let args = + println!("{}", args[1]) + } + "#}; + + let new = indoc! {r#" + fn main() { + let args = std::env::args(); + println!("{}", args[1]); + } + "#}; + + let buffer = cx.new(|cx| Buffer::local(old, cx)); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + + // TODO cover more cases when multi-file is supported + let big_edits = vec![predict_edits_v3::Edit { + path: PathBuf::from("test.txt"), + range: 0..old.len(), + content: new.into(), + }]; + + let edits = edits_from_response(&big_edits, &snapshot); + assert_eq!(edits.len(), 2); + assert_eq!( + edits[0].0.to_point(&snapshot).start, + language::Point::new(1, 14) + ); + assert_eq!(edits[0].1, " std::env::args();"); + assert_eq!( + edits[1].0.to_point(&snapshot).start, + language::Point::new(2, 27) + ); + assert_eq!(edits[1].1, ";"); + } + + #[gpui::test] + async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { + let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); + let edits: Arc<[(Range, String)]> = cx.update(|cx| { + to_prediction_edits( + [(2..5, "REM".to_string()), (9..11, "".to_string())], + &buffer, + cx, + ) + .into() + }); + + let edit_preview = cx + .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx)) + .await; + + let prediction = EditPrediction { + id: EditPredictionId(Uuid::new_v4()), + edits, + snapshot: cx.read(|cx| buffer.read(cx).snapshot()), + edit_preview, + }; + + cx.update(|cx| { + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..5, "REM".to_string()), (9..11, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..2, "REM".to_string()), (6..8, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.undo(cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..5, "REM".to_string()), (9..11, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(3..3, "EM".to_string()), (7..9, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".to_string()), (8..10, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(9..11, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".to_string()), (8..10, "".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx)); + assert_eq!( + from_prediction_edits( + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".to_string())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx)); + assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None); + }) + } + + fn to_prediction_edits( + iterator: impl IntoIterator, String)>, + buffer: &Entity, + cx: &App, + ) -> Vec<(Range, String)> { + let buffer = buffer.read(cx); + iterator + .into_iter() + .map(|(range, text)| { + ( + buffer.anchor_after(range.start)..buffer.anchor_before(range.end), + text, + ) + }) + .collect() + } + + fn from_prediction_edits( + editor_edits: &[(Range, String)], + buffer: &Entity, + cx: &App, + ) -> Vec<(Range, String)> { + let buffer = buffer.read(cx); + editor_edits + .iter() + .map(|(range, text)| { + ( + range.start.to_offset(buffer)..range.end.to_offset(buffer), + text.clone(), + ) + }) + .collect() + } +} diff --git a/crates/zeta2/src/provider.rs b/crates/zeta2/src/provider.rs new file mode 100644 index 0000000000000000000000000000000000000000..ae30c0bee0da47d8f6174e76918e6dd751d348d2 --- /dev/null +++ b/crates/zeta2/src/provider.rs @@ -0,0 +1,322 @@ +use std::{ + cmp, + sync::Arc, + time::{Duration, Instant}, +}; + +use anyhow::Context as _; +use arrayvec::ArrayVec; +use client::{Client, UserStore}; +use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider}; +use gpui::{App, Entity, EntityId, Task, prelude::*}; +use language::{BufferSnapshot, ToPoint as _}; +use project::Project; +use util::ResultExt as _; + +use crate::{Zeta, prediction::EditPrediction}; + +pub struct ZetaEditPredictionProvider { + zeta: Entity, + current_prediction: Option, + next_pending_prediction_id: usize, + pending_predictions: ArrayVec, + last_request_timestamp: Instant, +} + +impl ZetaEditPredictionProvider { + pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); + + pub fn new( + project: Option<&Entity>, + client: &Arc, + user_store: &Entity, + cx: &mut App, + ) -> Self { + let zeta = Zeta::global(client, user_store, cx); + if let Some(project) = project { + zeta.update(cx, |zeta, cx| { + zeta.register_project(project, cx); + }); + } + + Self { + zeta, + current_prediction: None, + next_pending_prediction_id: 0, + pending_predictions: ArrayVec::new(), + last_request_timestamp: Instant::now(), + } + } +} + +#[derive(Clone)] +struct CurrentEditPrediction { + buffer_id: EntityId, + prediction: EditPrediction, +} + +impl CurrentEditPrediction { + fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool { + if self.buffer_id != old_prediction.buffer_id { + return true; + } + + let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else { + return true; + }; + let Some(new_edits) = self.prediction.interpolate(snapshot) else { + return false; + }; + + if old_edits.len() == 1 && new_edits.len() == 1 { + let (old_range, old_text) = &old_edits[0]; + let (new_range, new_text) = &new_edits[0]; + new_range == old_range && new_text.starts_with(old_text) + } else { + true + } + } +} + +struct PendingPrediction { + id: usize, + _task: Task<()>, +} + +impl EditPredictionProvider for ZetaEditPredictionProvider { + fn name() -> &'static str { + "zed-predict2" + } + + fn display_name() -> &'static str { + "Zed's Edit Predictions 2" + } + + fn show_completions_in_menu() -> bool { + true + } + + fn show_tab_accept_marker() -> bool { + true + } + + fn data_collection_state(&self, _cx: &App) -> DataCollectionState { + // TODO [zeta2] + DataCollectionState::Unsupported + } + + fn toggle_data_collection(&mut self, _cx: &mut App) { + // TODO [zeta2] + } + + fn usage(&self, cx: &App) -> Option { + self.zeta.read(cx).usage(cx) + } + + fn is_enabled( + &self, + _buffer: &Entity, + _cursor_position: language::Anchor, + _cx: &App, + ) -> bool { + true + } + + fn is_refreshing(&self) -> bool { + !self.pending_predictions.is_empty() + } + + fn refresh( + &mut self, + project: Option>, + buffer: Entity, + cursor_position: language::Anchor, + _debounce: bool, + cx: &mut Context, + ) { + let Some(project) = project else { + return; + }; + + if self + .zeta + .read(cx) + .user_store + .read_with(cx, |user_store, _cx| { + user_store.account_too_young() || user_store.has_overdue_invoices() + }) + { + return; + } + + if let Some(current_prediction) = self.current_prediction.as_ref() { + let snapshot = buffer.read(cx).snapshot(); + if current_prediction + .prediction + .interpolate(&snapshot) + .is_some() + { + return; + } + } + + let pending_prediction_id = self.next_pending_prediction_id; + self.next_pending_prediction_id += 1; + let last_request_timestamp = self.last_request_timestamp; + + let task = cx.spawn(async move |this, cx| { + if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT) + .checked_duration_since(Instant::now()) + { + cx.background_executor().timer(timeout).await; + } + + let prediction_request = this.update(cx, |this, cx| { + this.last_request_timestamp = Instant::now(); + this.zeta.update(cx, |zeta, cx| { + zeta.request_prediction(&project, &buffer, cursor_position, cx) + }) + }); + + let prediction = match prediction_request { + Ok(prediction_request) => { + let prediction_request = prediction_request.await; + prediction_request.map(|c| { + c.map(|prediction| CurrentEditPrediction { + buffer_id: buffer.entity_id(), + prediction, + }) + }) + } + Err(error) => Err(error), + }; + + this.update(cx, |this, cx| { + if this.pending_predictions[0].id == pending_prediction_id { + this.pending_predictions.remove(0); + } else { + this.pending_predictions.clear(); + } + + let Some(new_prediction) = prediction + .context("edit prediction failed") + .log_err() + .flatten() + else { + cx.notify(); + return; + }; + + if let Some(old_prediction) = this.current_prediction.as_ref() { + let snapshot = buffer.read(cx).snapshot(); + if new_prediction.should_replace_prediction(old_prediction, &snapshot) { + this.current_prediction = Some(new_prediction); + } + } else { + this.current_prediction = Some(new_prediction); + } + + cx.notify(); + }) + .ok(); + }); + + // We always maintain at most two pending predictions. When we already + // have two, we replace the newest one. + if self.pending_predictions.len() <= 1 { + self.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + _task: task, + }); + } else if self.pending_predictions.len() == 2 { + self.pending_predictions.pop(); + self.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + _task: task, + }); + } + + cx.notify(); + } + + fn cycle( + &mut self, + _buffer: Entity, + _cursor_position: language::Anchor, + _direction: Direction, + _cx: &mut Context, + ) { + } + + fn accept(&mut self, _cx: &mut Context) { + // TODO [zeta2] report accept + self.current_prediction.take(); + self.pending_predictions.clear(); + } + + fn discard(&mut self, _cx: &mut Context) { + self.pending_predictions.clear(); + self.current_prediction.take(); + } + + fn suggest( + &mut self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut Context, + ) -> Option { + let CurrentEditPrediction { + buffer_id, + prediction, + .. + } = self.current_prediction.as_mut()?; + + // Invalidate previous prediction if it was generated for a different buffer. + if *buffer_id != buffer.entity_id() { + self.current_prediction.take(); + return None; + } + + let buffer = buffer.read(cx); + let Some(edits) = prediction.interpolate(&buffer.snapshot()) else { + self.current_prediction.take(); + return None; + }; + + let cursor_row = cursor_position.to_point(buffer).row; + let (closest_edit_ix, (closest_edit_range, _)) = + edits.iter().enumerate().min_by_key(|(_, (range, _))| { + let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row); + let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row); + cmp::min(distance_from_start, distance_from_end) + })?; + + let mut edit_start_ix = closest_edit_ix; + for (range, _) in edits[..edit_start_ix].iter().rev() { + let distance_from_closest_edit = + closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row; + if distance_from_closest_edit <= 1 { + edit_start_ix -= 1; + } else { + break; + } + } + + let mut edit_end_ix = closest_edit_ix + 1; + for (range, _) in &edits[edit_end_ix..] { + let distance_from_closest_edit = + range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row; + if distance_from_closest_edit <= 1 { + edit_end_ix += 1; + } else { + break; + } + } + + Some(edit_prediction::EditPrediction { + id: Some(prediction.id.to_string().into()), + edits: edits[edit_start_ix..edit_end_ix].to_vec(), + edit_preview: Some(prediction.edit_preview.clone()), + }) + } +} diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 8986fb20740327f5598611d981b53569edcb559e..8af20a6236e06c78d50b9f3c22609adfbe5571e2 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -1,5 +1,4 @@ use anyhow::{Context as _, Result, anyhow}; -use arrayvec::ArrayVec; use chrono::TimeDelta; use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::predict_edits_v3::{self, Signature}; @@ -7,7 +6,6 @@ use cloud_llm_client::{ EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES; -use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider}; use edit_prediction_context::{ DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex, SyntaxIndexState, @@ -19,26 +17,26 @@ use gpui::{ App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity, http_client, prelude::*, }; -use language::{ - Anchor, Buffer, DiagnosticSet, LanguageServerId, OffsetRangeExt as _, ToOffset as _, ToPoint, - text_diff, -}; -use language::{BufferSnapshot, EditPreview}; +use language::BufferSnapshot; +use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use project::Project; use release_channel::AppVersion; -use std::borrow::Cow; -use std::cmp; use std::collections::{HashMap, VecDeque, hash_map}; use std::path::PathBuf; use std::str::FromStr as _; +use std::sync::Arc; use std::time::{Duration, Instant}; -use std::{ops::Range, sync::Arc}; use thiserror::Error; -use util::{ResultExt as _, some_or_debug_panic}; -use uuid::Uuid; +use util::some_or_debug_panic; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; +mod prediction; +mod provider; + +use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits}; +pub use provider::ZetaEditPredictionProvider; + const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); /// Maximum number of events to track. @@ -441,7 +439,7 @@ impl Zeta { } let (response, usage) = response?; - let edits = Self::compute_edits(&response.edits, &snapshot); + let edits = edits_from_response(&response.edits, &snapshot); anyhow::Ok(Some((response.request_id, edits, usage))) } @@ -466,7 +464,7 @@ impl Zeta { buffer.read_with(cx, |buffer, cx| { let new_snapshot = buffer.snapshot(); let edits: Arc<[_]> = - interpolate(&snapshot, &new_snapshot, edits)?.into(); + interpolate_edits(&snapshot, &new_snapshot, edits)?.into(); Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx))) })? else { @@ -474,7 +472,7 @@ impl Zeta { }; Ok(Some(EditPrediction { - id: EditPredictionId(id), + id: id.into(), edits, snapshot, edit_preview: edit_preview_task.await, @@ -585,62 +583,6 @@ impl Zeta { } } - fn compute_edits( - edits: &[predict_edits_v3::Edit], - snapshot: &BufferSnapshot, - ) -> Arc<[(Range, String)]> { - edits - .iter() - .flat_map(|edit| { - // TODO multi-file edits - let old_text = snapshot.text_for_range(edit.range.clone()); - - Self::compute_excerpt_edits( - old_text.collect::>(), - &edit.content, - edit.range.start, - &snapshot, - ) - }) - .collect::>() - .into() - } - - fn compute_excerpt_edits( - old_text: Cow, - new_text: &str, - offset: usize, - snapshot: &BufferSnapshot, - ) -> impl Iterator, String)> { - text_diff(&old_text, new_text) - .into_iter() - .map(move |(mut old_range, new_text)| { - old_range.start += offset; - old_range.end += offset; - - let prefix_len = common_prefix( - snapshot.chars_for_range(old_range.clone()), - new_text.chars(), - ); - old_range.start += prefix_len; - - let suffix_len = common_prefix( - snapshot.reversed_chars_for_range(old_range.clone()), - new_text[prefix_len..].chars().rev(), - ); - old_range.end = old_range.end.saturating_sub(suffix_len); - - let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string(); - let range = if old_range.is_empty() { - let anchor = snapshot.anchor_after(old_range.start); - anchor..anchor - } else { - snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end) - }; - (range, new_text) - }) - } - fn gather_nearby_diagnostics( cursor_offset: usize, diagnostic_sets: &[(LanguageServerId, DiagnosticSet)], @@ -750,13 +692,6 @@ impl Zeta { } } -fn common_prefix, T2: Iterator>(a: T1, b: T2) -> usize { - a.zip(b) - .take_while(|(a, b)| a == b) - .map(|(a, _)| a.len_utf8()) - .sum() -} - #[derive(Error, Debug)] #[error( "You must update to Zed version {minimum_version} or higher to continue using edit predictions." @@ -765,341 +700,6 @@ pub struct ZedUpdateRequiredError { minimum_version: SemanticVersion, } -pub struct ZetaEditPredictionProvider { - zeta: Entity, - current_prediction: Option, - next_pending_prediction_id: usize, - pending_predictions: ArrayVec, - last_request_timestamp: Instant, -} - -impl ZetaEditPredictionProvider { - pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); - - pub fn new( - project: Option<&Entity>, - client: &Arc, - user_store: &Entity, - cx: &mut App, - ) -> Self { - let zeta = Zeta::global(client, user_store, cx); - if let Some(project) = project { - zeta.update(cx, |zeta, cx| { - zeta.register_project(project, cx); - }); - } - - Self { - zeta, - current_prediction: None, - next_pending_prediction_id: 0, - pending_predictions: ArrayVec::new(), - last_request_timestamp: Instant::now(), - } - } -} - -#[derive(Clone)] -struct CurrentEditPrediction { - buffer_id: EntityId, - prediction: EditPrediction, -} - -impl CurrentEditPrediction { - fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool { - if self.buffer_id != old_prediction.buffer_id { - return true; - } - - let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else { - return true; - }; - let Some(new_edits) = self.prediction.interpolate(snapshot) else { - return false; - }; - - if old_edits.len() == 1 && new_edits.len() == 1 { - let (old_range, old_text) = &old_edits[0]; - let (new_range, new_text) = &new_edits[0]; - new_range == old_range && new_text.starts_with(old_text) - } else { - true - } - } -} - -#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] -pub struct EditPredictionId(Uuid); - -impl From for gpui::ElementId { - fn from(value: EditPredictionId) -> Self { - gpui::ElementId::Uuid(value.0) - } -} - -impl std::fmt::Display for EditPredictionId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -#[derive(Clone)] -pub struct EditPrediction { - id: EditPredictionId, - edits: Arc<[(Range, String)]>, - snapshot: BufferSnapshot, - edit_preview: EditPreview, -} - -impl EditPrediction { - fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, String)>> { - interpolate(&self.snapshot, new_snapshot, self.edits.clone()) - } -} - -struct PendingPrediction { - id: usize, - _task: Task<()>, -} - -impl EditPredictionProvider for ZetaEditPredictionProvider { - fn name() -> &'static str { - "zed-predict2" - } - - fn display_name() -> &'static str { - "Zed's Edit Predictions 2" - } - - fn show_completions_in_menu() -> bool { - true - } - - fn show_tab_accept_marker() -> bool { - true - } - - fn data_collection_state(&self, _cx: &App) -> DataCollectionState { - // TODO [zeta2] - DataCollectionState::Unsupported - } - - fn toggle_data_collection(&mut self, _cx: &mut App) { - // TODO [zeta2] - } - - fn usage(&self, cx: &App) -> Option { - self.zeta.read(cx).usage(cx) - } - - fn is_enabled( - &self, - _buffer: &Entity, - _cursor_position: language::Anchor, - _cx: &App, - ) -> bool { - true - } - - fn is_refreshing(&self) -> bool { - !self.pending_predictions.is_empty() - } - - fn refresh( - &mut self, - project: Option>, - buffer: Entity, - cursor_position: language::Anchor, - _debounce: bool, - cx: &mut Context, - ) { - let Some(project) = project else { - return; - }; - - if self - .zeta - .read(cx) - .user_store - .read_with(cx, |user_store, _cx| { - user_store.account_too_young() || user_store.has_overdue_invoices() - }) - { - return; - } - - if let Some(current_prediction) = self.current_prediction.as_ref() { - let snapshot = buffer.read(cx).snapshot(); - if current_prediction - .prediction - .interpolate(&snapshot) - .is_some() - { - return; - } - } - - let pending_prediction_id = self.next_pending_prediction_id; - self.next_pending_prediction_id += 1; - let last_request_timestamp = self.last_request_timestamp; - - let task = cx.spawn(async move |this, cx| { - if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT) - .checked_duration_since(Instant::now()) - { - cx.background_executor().timer(timeout).await; - } - - let prediction_request = this.update(cx, |this, cx| { - this.last_request_timestamp = Instant::now(); - this.zeta.update(cx, |zeta, cx| { - zeta.request_prediction(&project, &buffer, cursor_position, cx) - }) - }); - - let prediction = match prediction_request { - Ok(prediction_request) => { - let prediction_request = prediction_request.await; - prediction_request.map(|c| { - c.map(|prediction| CurrentEditPrediction { - buffer_id: buffer.entity_id(), - prediction, - }) - }) - } - Err(error) => Err(error), - }; - - this.update(cx, |this, cx| { - if this.pending_predictions[0].id == pending_prediction_id { - this.pending_predictions.remove(0); - } else { - this.pending_predictions.clear(); - } - - let Some(new_prediction) = prediction - .context("edit prediction failed") - .log_err() - .flatten() - else { - cx.notify(); - return; - }; - - if let Some(old_prediction) = this.current_prediction.as_ref() { - let snapshot = buffer.read(cx).snapshot(); - if new_prediction.should_replace_prediction(old_prediction, &snapshot) { - this.current_prediction = Some(new_prediction); - } - } else { - this.current_prediction = Some(new_prediction); - } - - cx.notify(); - }) - .ok(); - }); - - // We always maintain at most two pending predictions. When we already - // have two, we replace the newest one. - if self.pending_predictions.len() <= 1 { - self.pending_predictions.push(PendingPrediction { - id: pending_prediction_id, - _task: task, - }); - } else if self.pending_predictions.len() == 2 { - self.pending_predictions.pop(); - self.pending_predictions.push(PendingPrediction { - id: pending_prediction_id, - _task: task, - }); - } - - cx.notify(); - } - - fn cycle( - &mut self, - _buffer: Entity, - _cursor_position: language::Anchor, - _direction: Direction, - _cx: &mut Context, - ) { - } - - fn accept(&mut self, _cx: &mut Context) { - // TODO [zeta2] report accept - self.current_prediction.take(); - self.pending_predictions.clear(); - } - - fn discard(&mut self, _cx: &mut Context) { - self.pending_predictions.clear(); - self.current_prediction.take(); - } - - fn suggest( - &mut self, - buffer: &Entity, - cursor_position: language::Anchor, - cx: &mut Context, - ) -> Option { - let CurrentEditPrediction { - buffer_id, - prediction, - .. - } = self.current_prediction.as_mut()?; - - // Invalidate previous prediction if it was generated for a different buffer. - if *buffer_id != buffer.entity_id() { - self.current_prediction.take(); - return None; - } - - let buffer = buffer.read(cx); - let Some(edits) = prediction.interpolate(&buffer.snapshot()) else { - self.current_prediction.take(); - return None; - }; - - let cursor_row = cursor_position.to_point(buffer).row; - let (closest_edit_ix, (closest_edit_range, _)) = - edits.iter().enumerate().min_by_key(|(_, (range, _))| { - let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row); - let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row); - cmp::min(distance_from_start, distance_from_end) - })?; - - let mut edit_start_ix = closest_edit_ix; - for (range, _) in edits[..edit_start_ix].iter().rev() { - let distance_from_closest_edit = - closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row; - if distance_from_closest_edit <= 1 { - edit_start_ix -= 1; - } else { - break; - } - } - - let mut edit_end_ix = closest_edit_ix + 1; - for (range, _) in &edits[edit_end_ix..] { - let distance_from_closest_edit = - range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row; - if distance_from_closest_edit <= 1 { - edit_end_ix += 1; - } else { - break; - } - } - - Some(edit_prediction::EditPrediction { - id: Some(prediction.id.to_string().into()), - edits: edits[edit_start_ix..edit_end_ix].to_vec(), - edit_preview: Some(prediction.edit_preview.clone()), - }) - } -} - fn make_cloud_request( excerpt_path: PathBuf, context: EditPredictionContext, @@ -1215,238 +815,3 @@ fn add_signature( declaration_to_signature_index.insert(declaration_id, signature_index); Some(signature_index) } - -fn interpolate( - old_snapshot: &BufferSnapshot, - new_snapshot: &BufferSnapshot, - current_edits: Arc<[(Range, String)]>, -) -> Option, String)>> { - let mut edits = Vec::new(); - - let mut model_edits = current_edits.iter().peekable(); - for user_edit in new_snapshot.edits_since::(&old_snapshot.version) { - while let Some((model_old_range, _)) = model_edits.peek() { - let model_old_range = model_old_range.to_offset(old_snapshot); - if model_old_range.end < user_edit.old.start { - let (model_old_range, model_new_text) = model_edits.next().unwrap(); - edits.push((model_old_range.clone(), model_new_text.clone())); - } else { - break; - } - } - - if let Some((model_old_range, model_new_text)) = model_edits.peek() { - let model_old_offset_range = model_old_range.to_offset(old_snapshot); - if user_edit.old == model_old_offset_range { - let user_new_text = new_snapshot - .text_for_range(user_edit.new.clone()) - .collect::(); - - if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) { - if !model_suffix.is_empty() { - let anchor = old_snapshot.anchor_after(user_edit.old.end); - edits.push((anchor..anchor, model_suffix.to_string())); - } - - model_edits.next(); - continue; - } - } - } - - return None; - } - - edits.extend(model_edits.cloned()); - - if edits.is_empty() { None } else { Some(edits) } -} - -#[cfg(test)] -mod tests { - use super::*; - use gpui::TestAppContext; - use indoc::indoc; - - #[gpui::test] - async fn test_compute_edits(cx: &mut TestAppContext) { - let old = indoc! {r#" - fn main() { - let args = - println!("{}", args[1]) - } - "#}; - - let new = indoc! {r#" - fn main() { - let args = std::env::args(); - println!("{}", args[1]); - } - "#}; - - let buffer = cx.new(|cx| Buffer::local(old, cx)); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - - // TODO cover more cases when multi-file is supported - let big_edits = vec![predict_edits_v3::Edit { - path: PathBuf::from("test.txt"), - range: 0..old.len(), - content: new.into(), - }]; - - let edits = Zeta::compute_edits(&big_edits, &snapshot); - assert_eq!(edits.len(), 2); - assert_eq!( - edits[0].0.to_point(&snapshot).start, - language::Point::new(1, 14) - ); - assert_eq!(edits[0].1, " std::env::args();"); - assert_eq!( - edits[1].0.to_point(&snapshot).start, - language::Point::new(2, 27) - ); - assert_eq!(edits[1].1, ";"); - } - - #[gpui::test] - async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { - let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); - let edits: Arc<[(Range, String)]> = cx.update(|cx| { - to_prediction_edits( - [(2..5, "REM".to_string()), (9..11, "".to_string())], - &buffer, - cx, - ) - .into() - }); - - let edit_preview = cx - .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx)) - .await; - - let prediction = EditPrediction { - id: EditPredictionId(Uuid::new_v4()), - edits, - snapshot: cx.read(|cx| buffer.read(cx).snapshot()), - edit_preview, - }; - - cx.update(|cx| { - assert_eq!( - from_prediction_edits( - &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(2..5, "REM".to_string()), (9..11, "".to_string())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx)); - assert_eq!( - from_prediction_edits( - &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(2..2, "REM".to_string()), (6..8, "".to_string())] - ); - - buffer.update(cx, |buffer, cx| buffer.undo(cx)); - assert_eq!( - from_prediction_edits( - &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(2..5, "REM".to_string()), (9..11, "".to_string())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx)); - assert_eq!( - from_prediction_edits( - &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(3..3, "EM".to_string()), (7..9, "".to_string())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx)); - assert_eq!( - from_prediction_edits( - &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(4..4, "M".to_string()), (8..10, "".to_string())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx)); - assert_eq!( - from_prediction_edits( - &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(9..11, "".to_string())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx)); - assert_eq!( - from_prediction_edits( - &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(4..4, "M".to_string()), (8..10, "".to_string())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx)); - assert_eq!( - from_prediction_edits( - &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), - &buffer, - cx - ), - vec![(4..4, "M".to_string())] - ); - - buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx)); - assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None); - }) - } - - fn to_prediction_edits( - iterator: impl IntoIterator, String)>, - buffer: &Entity, - cx: &App, - ) -> Vec<(Range, String)> { - let buffer = buffer.read(cx); - iterator - .into_iter() - .map(|(range, text)| { - ( - buffer.anchor_after(range.start)..buffer.anchor_before(range.end), - text, - ) - }) - .collect() - } - - fn from_prediction_edits( - editor_edits: &[(Range, String)], - buffer: &Entity, - cx: &App, - ) -> Vec<(Range, String)> { - let buffer = buffer.read(cx); - editor_edits - .iter() - .map(|(range, text)| { - ( - range.start.to_offset(buffer)..range.end.to_offset(buffer), - text.clone(), - ) - }) - .collect() - } -}