@@ -0,0 +1,345 @@
+use std::{borrow::Cow, ops::Range, sync::Arc};
+
+use cloud_llm_client::predict_edits_v3;
+use language::{Anchor, BufferSnapshot, EditPreview, OffsetRangeExt, text_diff};
+use uuid::Uuid;
+
+#[derive(Clone)]
+pub struct EditPrediction {
+ pub id: EditPredictionId,
+ pub edits: Arc<[(Range<Anchor>, String)]>,
+ pub snapshot: BufferSnapshot,
+ pub edit_preview: EditPreview,
+}
+
+impl EditPrediction {
+ pub fn interpolate(
+ &self,
+ new_snapshot: &BufferSnapshot,
+ ) -> Option<Vec<(Range<Anchor>, String)>> {
+ interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
+ }
+}
+
+#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
+pub struct EditPredictionId(Uuid);
+
+impl From<Uuid> for EditPredictionId {
+ fn from(value: Uuid) -> Self {
+ EditPredictionId(value)
+ }
+}
+
+impl From<EditPredictionId> for gpui::ElementId {
+ fn from(value: EditPredictionId) -> Self {
+ gpui::ElementId::Uuid(value.0)
+ }
+}
+
+impl std::fmt::Display for EditPredictionId {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
+pub fn interpolate_edits(
+ old_snapshot: &BufferSnapshot,
+ new_snapshot: &BufferSnapshot,
+ current_edits: Arc<[(Range<Anchor>, String)]>,
+) -> Option<Vec<(Range<Anchor>, String)>> {
+ let mut edits = Vec::new();
+
+ let mut model_edits = current_edits.iter().peekable();
+ for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
+ while let Some((model_old_range, _)) = model_edits.peek() {
+ let model_old_range = model_old_range.to_offset(old_snapshot);
+ if model_old_range.end < user_edit.old.start {
+ let (model_old_range, model_new_text) = model_edits.next().unwrap();
+ edits.push((model_old_range.clone(), model_new_text.clone()));
+ } else {
+ break;
+ }
+ }
+
+ if let Some((model_old_range, model_new_text)) = model_edits.peek() {
+ let model_old_offset_range = model_old_range.to_offset(old_snapshot);
+ if user_edit.old == model_old_offset_range {
+ let user_new_text = new_snapshot
+ .text_for_range(user_edit.new.clone())
+ .collect::<String>();
+
+ if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
+ if !model_suffix.is_empty() {
+ let anchor = old_snapshot.anchor_after(user_edit.old.end);
+ edits.push((anchor..anchor, model_suffix.to_string()));
+ }
+
+ model_edits.next();
+ continue;
+ }
+ }
+ }
+
+ return None;
+ }
+
+ edits.extend(model_edits.cloned());
+
+ if edits.is_empty() { None } else { Some(edits) }
+}
+
+pub fn edits_from_response(
+ edits: &[predict_edits_v3::Edit],
+ snapshot: &BufferSnapshot,
+) -> Arc<[(Range<Anchor>, String)]> {
+ edits
+ .iter()
+ .flat_map(|edit| {
+ // TODO multi-file edits
+ let old_text = snapshot.text_for_range(edit.range.clone());
+
+ excerpt_edits_from_response(
+ old_text.collect::<Cow<str>>(),
+ &edit.content,
+ edit.range.start,
+ &snapshot,
+ )
+ })
+ .collect::<Vec<_>>()
+ .into()
+}
+
+fn excerpt_edits_from_response(
+ old_text: Cow<str>,
+ new_text: &str,
+ offset: usize,
+ snapshot: &BufferSnapshot,
+) -> impl Iterator<Item = (Range<Anchor>, String)> {
+ text_diff(&old_text, new_text)
+ .into_iter()
+ .map(move |(mut old_range, new_text)| {
+ old_range.start += offset;
+ old_range.end += offset;
+
+ let prefix_len = common_prefix(
+ snapshot.chars_for_range(old_range.clone()),
+ new_text.chars(),
+ );
+ old_range.start += prefix_len;
+
+ let suffix_len = common_prefix(
+ snapshot.reversed_chars_for_range(old_range.clone()),
+ new_text[prefix_len..].chars().rev(),
+ );
+ old_range.end = old_range.end.saturating_sub(suffix_len);
+
+ let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
+ let range = if old_range.is_empty() {
+ let anchor = snapshot.anchor_after(old_range.start);
+ anchor..anchor
+ } else {
+ snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
+ };
+ (range, new_text)
+ })
+}
+
+fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
+ a.zip(b)
+ .take_while(|(a, b)| a == b)
+ .map(|(a, _)| a.len_utf8())
+ .sum()
+}
+
+#[cfg(test)]
+mod tests {
+ use std::path::PathBuf;
+
+ use super::*;
+ use cloud_llm_client::predict_edits_v3;
+ use gpui::{App, Entity, TestAppContext, prelude::*};
+ use indoc::indoc;
+ use language::{Buffer, ToOffset as _};
+
+ #[gpui::test]
+ async fn test_compute_edits(cx: &mut TestAppContext) {
+ let old = indoc! {r#"
+ fn main() {
+ let args =
+ println!("{}", args[1])
+ }
+ "#};
+
+ let new = indoc! {r#"
+ fn main() {
+ let args = std::env::args();
+ println!("{}", args[1]);
+ }
+ "#};
+
+ let buffer = cx.new(|cx| Buffer::local(old, cx));
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+
+ // TODO cover more cases when multi-file is supported
+ let big_edits = vec![predict_edits_v3::Edit {
+ path: PathBuf::from("test.txt"),
+ range: 0..old.len(),
+ content: new.into(),
+ }];
+
+ let edits = edits_from_response(&big_edits, &snapshot);
+ assert_eq!(edits.len(), 2);
+ assert_eq!(
+ edits[0].0.to_point(&snapshot).start,
+ language::Point::new(1, 14)
+ );
+ assert_eq!(edits[0].1, " std::env::args();");
+ assert_eq!(
+ edits[1].0.to_point(&snapshot).start,
+ language::Point::new(2, 27)
+ );
+ assert_eq!(edits[1].1, ";");
+ }
+
+ #[gpui::test]
+ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
+ let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
+ let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
+ to_prediction_edits(
+ [(2..5, "REM".to_string()), (9..11, "".to_string())],
+ &buffer,
+ cx,
+ )
+ .into()
+ });
+
+ let edit_preview = cx
+ .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
+ .await;
+
+ let prediction = EditPrediction {
+ id: EditPredictionId(Uuid::new_v4()),
+ edits,
+ snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
+ edit_preview,
+ };
+
+ cx.update(|cx| {
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.undo(cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".to_string()), (8..10, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(9..11, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".to_string()), (8..10, "".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
+ assert_eq!(
+ from_prediction_edits(
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".to_string())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
+ assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
+ })
+ }
+
+ fn to_prediction_edits(
+ iterator: impl IntoIterator<Item = (Range<usize>, String)>,
+ buffer: &Entity<Buffer>,
+ cx: &App,
+ ) -> Vec<(Range<Anchor>, String)> {
+ let buffer = buffer.read(cx);
+ iterator
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
+ text,
+ )
+ })
+ .collect()
+ }
+
+ fn from_prediction_edits(
+ editor_edits: &[(Range<Anchor>, String)],
+ buffer: &Entity<Buffer>,
+ cx: &App,
+ ) -> Vec<(Range<usize>, String)> {
+ let buffer = buffer.read(cx);
+ editor_edits
+ .iter()
+ .map(|(range, text)| {
+ (
+ range.start.to_offset(buffer)..range.end.to_offset(buffer),
+ text.clone(),
+ )
+ })
+ .collect()
+ }
+}
@@ -0,0 +1,322 @@
+use std::{
+ cmp,
+ sync::Arc,
+ time::{Duration, Instant},
+};
+
+use anyhow::Context as _;
+use arrayvec::ArrayVec;
+use client::{Client, UserStore};
+use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
+use gpui::{App, Entity, EntityId, Task, prelude::*};
+use language::{BufferSnapshot, ToPoint as _};
+use project::Project;
+use util::ResultExt as _;
+
+use crate::{Zeta, prediction::EditPrediction};
+
+pub struct ZetaEditPredictionProvider {
+ zeta: Entity<Zeta>,
+ current_prediction: Option<CurrentEditPrediction>,
+ next_pending_prediction_id: usize,
+ pending_predictions: ArrayVec<PendingPrediction, 2>,
+ last_request_timestamp: Instant,
+}
+
+impl ZetaEditPredictionProvider {
+ pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
+
+ pub fn new(
+ project: Option<&Entity<Project>>,
+ client: &Arc<Client>,
+ user_store: &Entity<UserStore>,
+ cx: &mut App,
+ ) -> Self {
+ let zeta = Zeta::global(client, user_store, cx);
+ if let Some(project) = project {
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_project(project, cx);
+ });
+ }
+
+ Self {
+ zeta,
+ current_prediction: None,
+ next_pending_prediction_id: 0,
+ pending_predictions: ArrayVec::new(),
+ last_request_timestamp: Instant::now(),
+ }
+ }
+}
+
+#[derive(Clone)]
+struct CurrentEditPrediction {
+ buffer_id: EntityId,
+ prediction: EditPrediction,
+}
+
+impl CurrentEditPrediction {
+ fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
+ if self.buffer_id != old_prediction.buffer_id {
+ return true;
+ }
+
+ let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
+ return true;
+ };
+ let Some(new_edits) = self.prediction.interpolate(snapshot) else {
+ return false;
+ };
+
+ if old_edits.len() == 1 && new_edits.len() == 1 {
+ let (old_range, old_text) = &old_edits[0];
+ let (new_range, new_text) = &new_edits[0];
+ new_range == old_range && new_text.starts_with(old_text)
+ } else {
+ true
+ }
+ }
+}
+
+struct PendingPrediction {
+ id: usize,
+ _task: Task<()>,
+}
+
+impl EditPredictionProvider for ZetaEditPredictionProvider {
+ fn name() -> &'static str {
+ "zed-predict2"
+ }
+
+ fn display_name() -> &'static str {
+ "Zed's Edit Predictions 2"
+ }
+
+ fn show_completions_in_menu() -> bool {
+ true
+ }
+
+ fn show_tab_accept_marker() -> bool {
+ true
+ }
+
+ fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
+ // TODO [zeta2]
+ DataCollectionState::Unsupported
+ }
+
+ fn toggle_data_collection(&mut self, _cx: &mut App) {
+ // TODO [zeta2]
+ }
+
+ fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
+ self.zeta.read(cx).usage(cx)
+ }
+
+ fn is_enabled(
+ &self,
+ _buffer: &Entity<language::Buffer>,
+ _cursor_position: language::Anchor,
+ _cx: &App,
+ ) -> bool {
+ true
+ }
+
+ fn is_refreshing(&self) -> bool {
+ !self.pending_predictions.is_empty()
+ }
+
+ fn refresh(
+ &mut self,
+ project: Option<Entity<project::Project>>,
+ buffer: Entity<language::Buffer>,
+ cursor_position: language::Anchor,
+ _debounce: bool,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(project) = project else {
+ return;
+ };
+
+ if self
+ .zeta
+ .read(cx)
+ .user_store
+ .read_with(cx, |user_store, _cx| {
+ user_store.account_too_young() || user_store.has_overdue_invoices()
+ })
+ {
+ return;
+ }
+
+ if let Some(current_prediction) = self.current_prediction.as_ref() {
+ let snapshot = buffer.read(cx).snapshot();
+ if current_prediction
+ .prediction
+ .interpolate(&snapshot)
+ .is_some()
+ {
+ return;
+ }
+ }
+
+ let pending_prediction_id = self.next_pending_prediction_id;
+ self.next_pending_prediction_id += 1;
+ let last_request_timestamp = self.last_request_timestamp;
+
+ let task = cx.spawn(async move |this, cx| {
+ if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
+ .checked_duration_since(Instant::now())
+ {
+ cx.background_executor().timer(timeout).await;
+ }
+
+ let prediction_request = this.update(cx, |this, cx| {
+ this.last_request_timestamp = Instant::now();
+ this.zeta.update(cx, |zeta, cx| {
+ zeta.request_prediction(&project, &buffer, cursor_position, cx)
+ })
+ });
+
+ let prediction = match prediction_request {
+ Ok(prediction_request) => {
+ let prediction_request = prediction_request.await;
+ prediction_request.map(|c| {
+ c.map(|prediction| CurrentEditPrediction {
+ buffer_id: buffer.entity_id(),
+ prediction,
+ })
+ })
+ }
+ Err(error) => Err(error),
+ };
+
+ this.update(cx, |this, cx| {
+ if this.pending_predictions[0].id == pending_prediction_id {
+ this.pending_predictions.remove(0);
+ } else {
+ this.pending_predictions.clear();
+ }
+
+ let Some(new_prediction) = prediction
+ .context("edit prediction failed")
+ .log_err()
+ .flatten()
+ else {
+ cx.notify();
+ return;
+ };
+
+ if let Some(old_prediction) = this.current_prediction.as_ref() {
+ let snapshot = buffer.read(cx).snapshot();
+ if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
+ this.current_prediction = Some(new_prediction);
+ }
+ } else {
+ this.current_prediction = Some(new_prediction);
+ }
+
+ cx.notify();
+ })
+ .ok();
+ });
+
+ // We always maintain at most two pending predictions. When we already
+ // have two, we replace the newest one.
+ if self.pending_predictions.len() <= 1 {
+ self.pending_predictions.push(PendingPrediction {
+ id: pending_prediction_id,
+ _task: task,
+ });
+ } else if self.pending_predictions.len() == 2 {
+ self.pending_predictions.pop();
+ self.pending_predictions.push(PendingPrediction {
+ id: pending_prediction_id,
+ _task: task,
+ });
+ }
+
+ cx.notify();
+ }
+
+ fn cycle(
+ &mut self,
+ _buffer: Entity<language::Buffer>,
+ _cursor_position: language::Anchor,
+ _direction: Direction,
+ _cx: &mut Context<Self>,
+ ) {
+ }
+
+ fn accept(&mut self, _cx: &mut Context<Self>) {
+ // TODO [zeta2] report accept
+ self.current_prediction.take();
+ self.pending_predictions.clear();
+ }
+
+ fn discard(&mut self, _cx: &mut Context<Self>) {
+ self.pending_predictions.clear();
+ self.current_prediction.take();
+ }
+
+ fn suggest(
+ &mut self,
+ buffer: &Entity<language::Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) -> Option<edit_prediction::EditPrediction> {
+ let CurrentEditPrediction {
+ buffer_id,
+ prediction,
+ ..
+ } = self.current_prediction.as_mut()?;
+
+ // Invalidate previous prediction if it was generated for a different buffer.
+ if *buffer_id != buffer.entity_id() {
+ self.current_prediction.take();
+ return None;
+ }
+
+ let buffer = buffer.read(cx);
+ let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
+ self.current_prediction.take();
+ return None;
+ };
+
+ let cursor_row = cursor_position.to_point(buffer).row;
+ let (closest_edit_ix, (closest_edit_range, _)) =
+ edits.iter().enumerate().min_by_key(|(_, (range, _))| {
+ let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
+ let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
+ cmp::min(distance_from_start, distance_from_end)
+ })?;
+
+ let mut edit_start_ix = closest_edit_ix;
+ for (range, _) in edits[..edit_start_ix].iter().rev() {
+ let distance_from_closest_edit =
+ closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
+ if distance_from_closest_edit <= 1 {
+ edit_start_ix -= 1;
+ } else {
+ break;
+ }
+ }
+
+ let mut edit_end_ix = closest_edit_ix + 1;
+ for (range, _) in &edits[edit_end_ix..] {
+ let distance_from_closest_edit =
+ range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
+ if distance_from_closest_edit <= 1 {
+ edit_end_ix += 1;
+ } else {
+ break;
+ }
+ }
+
+ Some(edit_prediction::EditPrediction {
+ id: Some(prediction.id.to_string().into()),
+ edits: edits[edit_start_ix..edit_end_ix].to_vec(),
+ edit_preview: Some(prediction.edit_preview.clone()),
+ })
+ }
+}
@@ -1,5 +1,4 @@
use anyhow::{Context as _, Result, anyhow};
-use arrayvec::ArrayVec;
use chrono::TimeDelta;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::predict_edits_v3::{self, Signature};
@@ -7,7 +6,6 @@ use cloud_llm_client::{
EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
};
use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
-use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
use edit_prediction_context::{
DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
SyntaxIndexState,
@@ -19,26 +17,26 @@ use gpui::{
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
http_client, prelude::*,
};
-use language::{
- Anchor, Buffer, DiagnosticSet, LanguageServerId, OffsetRangeExt as _, ToOffset as _, ToPoint,
- text_diff,
-};
-use language::{BufferSnapshot, EditPreview};
+use language::BufferSnapshot;
+use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use project::Project;
use release_channel::AppVersion;
-use std::borrow::Cow;
-use std::cmp;
use std::collections::{HashMap, VecDeque, hash_map};
use std::path::PathBuf;
use std::str::FromStr as _;
+use std::sync::Arc;
use std::time::{Duration, Instant};
-use std::{ops::Range, sync::Arc};
use thiserror::Error;
-use util::{ResultExt as _, some_or_debug_panic};
-use uuid::Uuid;
+use util::some_or_debug_panic;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+mod prediction;
+mod provider;
+
+use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
+pub use provider::ZetaEditPredictionProvider;
+
const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
/// Maximum number of events to track.
@@ -441,7 +439,7 @@ impl Zeta {
}
let (response, usage) = response?;
- let edits = Self::compute_edits(&response.edits, &snapshot);
+ let edits = edits_from_response(&response.edits, &snapshot);
anyhow::Ok(Some((response.request_id, edits, usage)))
}
@@ -466,7 +464,7 @@ impl Zeta {
buffer.read_with(cx, |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits: Arc<[_]> =
- interpolate(&snapshot, &new_snapshot, edits)?.into();
+ interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
})?
else {
@@ -474,7 +472,7 @@ impl Zeta {
};
Ok(Some(EditPrediction {
- id: EditPredictionId(id),
+ id: id.into(),
edits,
snapshot,
edit_preview: edit_preview_task.await,
@@ -585,62 +583,6 @@ impl Zeta {
}
}
- fn compute_edits(
- edits: &[predict_edits_v3::Edit],
- snapshot: &BufferSnapshot,
- ) -> Arc<[(Range<Anchor>, String)]> {
- edits
- .iter()
- .flat_map(|edit| {
- // TODO multi-file edits
- let old_text = snapshot.text_for_range(edit.range.clone());
-
- Self::compute_excerpt_edits(
- old_text.collect::<Cow<str>>(),
- &edit.content,
- edit.range.start,
- &snapshot,
- )
- })
- .collect::<Vec<_>>()
- .into()
- }
-
- fn compute_excerpt_edits(
- old_text: Cow<str>,
- new_text: &str,
- offset: usize,
- snapshot: &BufferSnapshot,
- ) -> impl Iterator<Item = (Range<Anchor>, String)> {
- text_diff(&old_text, new_text)
- .into_iter()
- .map(move |(mut old_range, new_text)| {
- old_range.start += offset;
- old_range.end += offset;
-
- let prefix_len = common_prefix(
- snapshot.chars_for_range(old_range.clone()),
- new_text.chars(),
- );
- old_range.start += prefix_len;
-
- let suffix_len = common_prefix(
- snapshot.reversed_chars_for_range(old_range.clone()),
- new_text[prefix_len..].chars().rev(),
- );
- old_range.end = old_range.end.saturating_sub(suffix_len);
-
- let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
- let range = if old_range.is_empty() {
- let anchor = snapshot.anchor_after(old_range.start);
- anchor..anchor
- } else {
- snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
- };
- (range, new_text)
- })
- }
-
fn gather_nearby_diagnostics(
cursor_offset: usize,
diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
@@ -750,13 +692,6 @@ impl Zeta {
}
}
-fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
- a.zip(b)
- .take_while(|(a, b)| a == b)
- .map(|(a, _)| a.len_utf8())
- .sum()
-}
-
#[derive(Error, Debug)]
#[error(
"You must update to Zed version {minimum_version} or higher to continue using edit predictions."
@@ -765,341 +700,6 @@ pub struct ZedUpdateRequiredError {
minimum_version: SemanticVersion,
}
-pub struct ZetaEditPredictionProvider {
- zeta: Entity<Zeta>,
- current_prediction: Option<CurrentEditPrediction>,
- next_pending_prediction_id: usize,
- pending_predictions: ArrayVec<PendingPrediction, 2>,
- last_request_timestamp: Instant,
-}
-
-impl ZetaEditPredictionProvider {
- pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
-
- pub fn new(
- project: Option<&Entity<Project>>,
- client: &Arc<Client>,
- user_store: &Entity<UserStore>,
- cx: &mut App,
- ) -> Self {
- let zeta = Zeta::global(client, user_store, cx);
- if let Some(project) = project {
- zeta.update(cx, |zeta, cx| {
- zeta.register_project(project, cx);
- });
- }
-
- Self {
- zeta,
- current_prediction: None,
- next_pending_prediction_id: 0,
- pending_predictions: ArrayVec::new(),
- last_request_timestamp: Instant::now(),
- }
- }
-}
-
-#[derive(Clone)]
-struct CurrentEditPrediction {
- buffer_id: EntityId,
- prediction: EditPrediction,
-}
-
-impl CurrentEditPrediction {
- fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
- if self.buffer_id != old_prediction.buffer_id {
- return true;
- }
-
- let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
- return true;
- };
- let Some(new_edits) = self.prediction.interpolate(snapshot) else {
- return false;
- };
-
- if old_edits.len() == 1 && new_edits.len() == 1 {
- let (old_range, old_text) = &old_edits[0];
- let (new_range, new_text) = &new_edits[0];
- new_range == old_range && new_text.starts_with(old_text)
- } else {
- true
- }
- }
-}
-
-#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
-pub struct EditPredictionId(Uuid);
-
-impl From<EditPredictionId> for gpui::ElementId {
- fn from(value: EditPredictionId) -> Self {
- gpui::ElementId::Uuid(value.0)
- }
-}
-
-impl std::fmt::Display for EditPredictionId {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
-}
-
-#[derive(Clone)]
-pub struct EditPrediction {
- id: EditPredictionId,
- edits: Arc<[(Range<Anchor>, String)]>,
- snapshot: BufferSnapshot,
- edit_preview: EditPreview,
-}
-
-impl EditPrediction {
- fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
- interpolate(&self.snapshot, new_snapshot, self.edits.clone())
- }
-}
-
-struct PendingPrediction {
- id: usize,
- _task: Task<()>,
-}
-
-impl EditPredictionProvider for ZetaEditPredictionProvider {
- fn name() -> &'static str {
- "zed-predict2"
- }
-
- fn display_name() -> &'static str {
- "Zed's Edit Predictions 2"
- }
-
- fn show_completions_in_menu() -> bool {
- true
- }
-
- fn show_tab_accept_marker() -> bool {
- true
- }
-
- fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
- // TODO [zeta2]
- DataCollectionState::Unsupported
- }
-
- fn toggle_data_collection(&mut self, _cx: &mut App) {
- // TODO [zeta2]
- }
-
- fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
- self.zeta.read(cx).usage(cx)
- }
-
- fn is_enabled(
- &self,
- _buffer: &Entity<language::Buffer>,
- _cursor_position: language::Anchor,
- _cx: &App,
- ) -> bool {
- true
- }
-
- fn is_refreshing(&self) -> bool {
- !self.pending_predictions.is_empty()
- }
-
- fn refresh(
- &mut self,
- project: Option<Entity<project::Project>>,
- buffer: Entity<language::Buffer>,
- cursor_position: language::Anchor,
- _debounce: bool,
- cx: &mut Context<Self>,
- ) {
- let Some(project) = project else {
- return;
- };
-
- if self
- .zeta
- .read(cx)
- .user_store
- .read_with(cx, |user_store, _cx| {
- user_store.account_too_young() || user_store.has_overdue_invoices()
- })
- {
- return;
- }
-
- if let Some(current_prediction) = self.current_prediction.as_ref() {
- let snapshot = buffer.read(cx).snapshot();
- if current_prediction
- .prediction
- .interpolate(&snapshot)
- .is_some()
- {
- return;
- }
- }
-
- let pending_prediction_id = self.next_pending_prediction_id;
- self.next_pending_prediction_id += 1;
- let last_request_timestamp = self.last_request_timestamp;
-
- let task = cx.spawn(async move |this, cx| {
- if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
- .checked_duration_since(Instant::now())
- {
- cx.background_executor().timer(timeout).await;
- }
-
- let prediction_request = this.update(cx, |this, cx| {
- this.last_request_timestamp = Instant::now();
- this.zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &buffer, cursor_position, cx)
- })
- });
-
- let prediction = match prediction_request {
- Ok(prediction_request) => {
- let prediction_request = prediction_request.await;
- prediction_request.map(|c| {
- c.map(|prediction| CurrentEditPrediction {
- buffer_id: buffer.entity_id(),
- prediction,
- })
- })
- }
- Err(error) => Err(error),
- };
-
- this.update(cx, |this, cx| {
- if this.pending_predictions[0].id == pending_prediction_id {
- this.pending_predictions.remove(0);
- } else {
- this.pending_predictions.clear();
- }
-
- let Some(new_prediction) = prediction
- .context("edit prediction failed")
- .log_err()
- .flatten()
- else {
- cx.notify();
- return;
- };
-
- if let Some(old_prediction) = this.current_prediction.as_ref() {
- let snapshot = buffer.read(cx).snapshot();
- if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
- this.current_prediction = Some(new_prediction);
- }
- } else {
- this.current_prediction = Some(new_prediction);
- }
-
- cx.notify();
- })
- .ok();
- });
-
- // We always maintain at most two pending predictions. When we already
- // have two, we replace the newest one.
- if self.pending_predictions.len() <= 1 {
- self.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- _task: task,
- });
- } else if self.pending_predictions.len() == 2 {
- self.pending_predictions.pop();
- self.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- _task: task,
- });
- }
-
- cx.notify();
- }
-
- fn cycle(
- &mut self,
- _buffer: Entity<language::Buffer>,
- _cursor_position: language::Anchor,
- _direction: Direction,
- _cx: &mut Context<Self>,
- ) {
- }
-
- fn accept(&mut self, _cx: &mut Context<Self>) {
- // TODO [zeta2] report accept
- self.current_prediction.take();
- self.pending_predictions.clear();
- }
-
- fn discard(&mut self, _cx: &mut Context<Self>) {
- self.pending_predictions.clear();
- self.current_prediction.take();
- }
-
- fn suggest(
- &mut self,
- buffer: &Entity<language::Buffer>,
- cursor_position: language::Anchor,
- cx: &mut Context<Self>,
- ) -> Option<edit_prediction::EditPrediction> {
- let CurrentEditPrediction {
- buffer_id,
- prediction,
- ..
- } = self.current_prediction.as_mut()?;
-
- // Invalidate previous prediction if it was generated for a different buffer.
- if *buffer_id != buffer.entity_id() {
- self.current_prediction.take();
- return None;
- }
-
- let buffer = buffer.read(cx);
- let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
- self.current_prediction.take();
- return None;
- };
-
- let cursor_row = cursor_position.to_point(buffer).row;
- let (closest_edit_ix, (closest_edit_range, _)) =
- edits.iter().enumerate().min_by_key(|(_, (range, _))| {
- let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
- let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
- cmp::min(distance_from_start, distance_from_end)
- })?;
-
- let mut edit_start_ix = closest_edit_ix;
- for (range, _) in edits[..edit_start_ix].iter().rev() {
- let distance_from_closest_edit =
- closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
- if distance_from_closest_edit <= 1 {
- edit_start_ix -= 1;
- } else {
- break;
- }
- }
-
- let mut edit_end_ix = closest_edit_ix + 1;
- for (range, _) in &edits[edit_end_ix..] {
- let distance_from_closest_edit =
- range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
- if distance_from_closest_edit <= 1 {
- edit_end_ix += 1;
- } else {
- break;
- }
- }
-
- Some(edit_prediction::EditPrediction {
- id: Some(prediction.id.to_string().into()),
- edits: edits[edit_start_ix..edit_end_ix].to_vec(),
- edit_preview: Some(prediction.edit_preview.clone()),
- })
- }
-}
-
fn make_cloud_request(
excerpt_path: PathBuf,
context: EditPredictionContext,
@@ -1215,238 +815,3 @@ fn add_signature(
declaration_to_signature_index.insert(declaration_id, signature_index);
Some(signature_index)
}
-
-fn interpolate(
- old_snapshot: &BufferSnapshot,
- new_snapshot: &BufferSnapshot,
- current_edits: Arc<[(Range<Anchor>, String)]>,
-) -> Option<Vec<(Range<Anchor>, String)>> {
- let mut edits = Vec::new();
-
- let mut model_edits = current_edits.iter().peekable();
- for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
- while let Some((model_old_range, _)) = model_edits.peek() {
- let model_old_range = model_old_range.to_offset(old_snapshot);
- if model_old_range.end < user_edit.old.start {
- let (model_old_range, model_new_text) = model_edits.next().unwrap();
- edits.push((model_old_range.clone(), model_new_text.clone()));
- } else {
- break;
- }
- }
-
- if let Some((model_old_range, model_new_text)) = model_edits.peek() {
- let model_old_offset_range = model_old_range.to_offset(old_snapshot);
- if user_edit.old == model_old_offset_range {
- let user_new_text = new_snapshot
- .text_for_range(user_edit.new.clone())
- .collect::<String>();
-
- if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
- if !model_suffix.is_empty() {
- let anchor = old_snapshot.anchor_after(user_edit.old.end);
- edits.push((anchor..anchor, model_suffix.to_string()));
- }
-
- model_edits.next();
- continue;
- }
- }
- }
-
- return None;
- }
-
- edits.extend(model_edits.cloned());
-
- if edits.is_empty() { None } else { Some(edits) }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use gpui::TestAppContext;
- use indoc::indoc;
-
- #[gpui::test]
- async fn test_compute_edits(cx: &mut TestAppContext) {
- let old = indoc! {r#"
- fn main() {
- let args =
- println!("{}", args[1])
- }
- "#};
-
- let new = indoc! {r#"
- fn main() {
- let args = std::env::args();
- println!("{}", args[1]);
- }
- "#};
-
- let buffer = cx.new(|cx| Buffer::local(old, cx));
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
-
- // TODO cover more cases when multi-file is supported
- let big_edits = vec![predict_edits_v3::Edit {
- path: PathBuf::from("test.txt"),
- range: 0..old.len(),
- content: new.into(),
- }];
-
- let edits = Zeta::compute_edits(&big_edits, &snapshot);
- assert_eq!(edits.len(), 2);
- assert_eq!(
- edits[0].0.to_point(&snapshot).start,
- language::Point::new(1, 14)
- );
- assert_eq!(edits[0].1, " std::env::args();");
- assert_eq!(
- edits[1].0.to_point(&snapshot).start,
- language::Point::new(2, 27)
- );
- assert_eq!(edits[1].1, ";");
- }
-
- #[gpui::test]
- async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
- let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
- let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
- to_prediction_edits(
- [(2..5, "REM".to_string()), (9..11, "".to_string())],
- &buffer,
- cx,
- )
- .into()
- });
-
- let edit_preview = cx
- .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
- .await;
-
- let prediction = EditPrediction {
- id: EditPredictionId(Uuid::new_v4()),
- edits,
- snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
- edit_preview,
- };
-
- cx.update(|cx| {
- assert_eq!(
- from_prediction_edits(
- &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
- assert_eq!(
- from_prediction_edits(
- &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.undo(cx));
- assert_eq!(
- from_prediction_edits(
- &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
- assert_eq!(
- from_prediction_edits(
- &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
- assert_eq!(
- from_prediction_edits(
- &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(4..4, "M".to_string()), (8..10, "".to_string())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
- assert_eq!(
- from_prediction_edits(
- &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(9..11, "".to_string())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
- assert_eq!(
- from_prediction_edits(
- &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(4..4, "M".to_string()), (8..10, "".to_string())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
- assert_eq!(
- from_prediction_edits(
- &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(4..4, "M".to_string())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
- assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
- })
- }
-
- fn to_prediction_edits(
- iterator: impl IntoIterator<Item = (Range<usize>, String)>,
- buffer: &Entity<Buffer>,
- cx: &App,
- ) -> Vec<(Range<Anchor>, String)> {
- let buffer = buffer.read(cx);
- iterator
- .into_iter()
- .map(|(range, text)| {
- (
- buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
- text,
- )
- })
- .collect()
- }
-
- fn from_prediction_edits(
- editor_edits: &[(Range<Anchor>, String)],
- buffer: &Entity<Buffer>,
- cx: &App,
- ) -> Vec<(Range<usize>, String)> {
- let buffer = buffer.read(cx);
- editor_edits
- .iter()
- .map(|(range, text)| {
- (
- range.start.to_offset(buffer)..range.end.to_offset(buffer),
- text.clone(),
- )
- })
- .collect()
- }
-}