Cargo.lock 🔗
@@ -5126,7 +5126,6 @@ dependencies = [
"client",
"gpui",
"language",
- "project",
"workspace-hack",
]
Bennet Bo Fenner and Agus Zubiaga created
Release Notes:
- N/A
---------
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Cargo.lock | 1
crates/cloud_llm_client/src/predict_edits_v3.rs | 13
crates/copilot/src/copilot_completion_provider.rs | 4
crates/edit_prediction/Cargo.toml | 1
crates/edit_prediction/src/edit_prediction.rs | 24
crates/editor/src/edit_prediction_tests.rs | 11
crates/editor/src/editor.rs | 191 +++++-
crates/editor/src/editor_tests.rs | 2
crates/supermaven/Cargo.toml | 1
crates/supermaven/src/supermaven_completion_provider.rs | 4
crates/zed/src/zed/edit_prediction_registry.rs | 70 +-
crates/zeta/src/zeta.rs | 15
crates/zeta2/src/prediction.rs | 143 +++
crates/zeta2/src/provider.rs | 179 +---
crates/zeta2/src/zeta2.rs | 332 ++++++++--
crates/zeta2_tools/src/zeta2_tools.rs | 2
16 files changed, 682 insertions(+), 311 deletions(-)
@@ -5126,7 +5126,6 @@ dependencies = [
"client",
"gpui",
"language",
- "project",
"workspace-hack",
]
@@ -43,15 +43,24 @@ pub struct PredictEditsRequest {
pub prompt_format: PromptFormat,
}
-#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
+#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum PromptFormat {
- #[default]
MarkedExcerpt,
LabeledSections,
/// Prompt format intended for use via zeta_cli
OnlySnippets,
}
+impl PromptFormat {
+ pub const DEFAULT: PromptFormat = PromptFormat::LabeledSections;
+}
+
+impl Default for PromptFormat {
+ fn default() -> Self {
+ Self::DEFAULT
+ }
+}
+
impl PromptFormat {
pub fn iter() -> impl Iterator<Item = Self> {
<Self as strum::IntoEnumIterator>::iter()
@@ -3,7 +3,6 @@ use anyhow::Result;
use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
use gpui::{App, Context, Entity, EntityId, Task};
use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings};
-use project::Project;
use settings::Settings;
use std::{path::Path, time::Duration};
@@ -84,7 +83,6 @@ impl EditPredictionProvider for CopilotCompletionProvider {
fn refresh(
&mut self,
- _project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
cursor_position: language::Anchor,
debounce: bool,
@@ -249,7 +247,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
None
} else {
let position = cursor_position.bias_right(buffer);
- Some(EditPrediction {
+ Some(EditPrediction::Local {
id: None,
edits: vec![(position..position, completion_text.into())],
edit_preview: None,
@@ -15,5 +15,4 @@ path = "src/edit_prediction.rs"
client.workspace = true
gpui.workspace = true
language.workspace = true
-project.workspace = true
workspace-hack.workspace = true
@@ -3,7 +3,6 @@ use std::ops::Range;
use client::EditPredictionUsage;
use gpui::{App, Context, Entity, SharedString};
use language::Buffer;
-use project::Project;
// TODO: Find a better home for `Direction`.
//
@@ -16,11 +15,19 @@ pub enum Direction {
}
#[derive(Clone)]
-pub struct EditPrediction {
- /// The ID of the completion, if it has one.
- pub id: Option<SharedString>,
- pub edits: Vec<(Range<language::Anchor>, String)>,
- pub edit_preview: Option<language::EditPreview>,
+pub enum EditPrediction {
+ /// Edits within the buffer that requested the prediction
+ Local {
+ id: Option<SharedString>,
+ edits: Vec<(Range<language::Anchor>, String)>,
+ edit_preview: Option<language::EditPreview>,
+ },
+ /// Jump to a different file from the one that requested the prediction
+ Jump {
+ id: Option<SharedString>,
+ snapshot: language::BufferSnapshot,
+ target: language::Anchor,
+ },
}
pub enum DataCollectionState {
@@ -83,7 +90,6 @@ pub trait EditPredictionProvider: 'static + Sized {
fn is_refreshing(&self) -> bool;
fn refresh(
&mut self,
- project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
cursor_position: language::Anchor,
debounce: bool,
@@ -124,7 +130,6 @@ pub trait EditPredictionProviderHandle {
fn is_refreshing(&self, cx: &App) -> bool;
fn refresh(
&self,
- project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
cursor_position: language::Anchor,
debounce: bool,
@@ -198,14 +203,13 @@ where
fn refresh(
&self,
- project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
cursor_position: language::Anchor,
debounce: bool,
cx: &mut App,
) {
self.update(cx, |this, cx| {
- this.refresh(project, buffer, cursor_position, debounce, cx)
+ this.refresh(buffer, cursor_position, debounce, cx)
})
}
@@ -2,7 +2,6 @@ use edit_prediction::EditPredictionProvider;
use gpui::{Entity, prelude::*};
use indoc::indoc;
use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
-use project::Project;
use std::ops::Range;
use text::{Point, ToOffset};
@@ -261,7 +260,7 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui:
EditPrediction::Edit { .. } => {
// This is expected for non-Zed providers
}
- EditPrediction::Move { .. } => {
+ EditPrediction::MoveWithin { .. } | EditPrediction::MoveOutside { .. } => {
panic!(
"Non-Zed providers should not show Move predictions (jump functionality)"
);
@@ -299,7 +298,7 @@ fn assert_editor_active_move_completion(
.as_ref()
.expect("editor has no active completion");
- if let EditPrediction::Move { target, .. } = &completion_state.completion {
+ if let EditPrediction::MoveWithin { target, .. } = &completion_state.completion {
assert(editor.buffer().read(cx).snapshot(cx), *target);
} else {
panic!("expected move completion");
@@ -326,7 +325,7 @@ fn propose_edits<T: ToOffset>(
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
- provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
+ provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
id: None,
edits: edits.collect(),
edit_preview: None,
@@ -357,7 +356,7 @@ fn propose_edits_non_zed<T: ToOffset>(
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
- provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
+ provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
id: None,
edits: edits.collect(),
edit_preview: None,
@@ -418,7 +417,6 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
fn refresh(
&mut self,
- _project: Option<Entity<Project>>,
_buffer: gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
_debounce: bool,
@@ -492,7 +490,6 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
fn refresh(
&mut self,
- _project: Option<Entity<Project>>,
_buffer: gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
_debounce: bool,
@@ -638,17 +638,23 @@ enum EditPrediction {
display_mode: EditDisplayMode,
snapshot: BufferSnapshot,
},
- Move {
+ /// Move to a specific location in the active editor
+ MoveWithin {
target: Anchor,
snapshot: BufferSnapshot,
},
+ /// Move to a specific location in a different editor (not the active one)
+ MoveOutside {
+ target: language::Anchor,
+ snapshot: BufferSnapshot,
+ },
}
struct EditPredictionState {
inlay_ids: Vec<InlayId>,
completion: EditPrediction,
completion_id: Option<SharedString>,
- invalidation_range: Range<Anchor>,
+ invalidation_range: Option<Range<Anchor>>,
}
enum EditPredictionSettings {
@@ -7175,13 +7181,7 @@ impl Editor {
return None;
}
- provider.refresh(
- self.project.clone(),
- buffer,
- cursor_buffer_position,
- debounce,
- cx,
- );
+ provider.refresh(buffer, cursor_buffer_position, debounce, cx);
Some(())
}
@@ -7424,10 +7424,8 @@ impl Editor {
return;
};
- self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx);
-
match &active_edit_prediction.completion {
- EditPrediction::Move { target, .. } => {
+ EditPrediction::MoveWithin { target, .. } => {
let target = *target;
if let Some(position_map) = &self.last_position_map {
@@ -7469,7 +7467,19 @@ impl Editor {
}
}
}
+ EditPrediction::MoveOutside { snapshot, target } => {
+ if let Some(workspace) = self.workspace() {
+ Self::open_editor_at_anchor(snapshot, *target, &workspace, window, cx)
+ .detach_and_log_err(cx);
+ }
+ }
EditPrediction::Edit { edits, .. } => {
+ self.report_edit_prediction_event(
+ active_edit_prediction.completion_id.clone(),
+ true,
+ cx,
+ );
+
if let Some(provider) = self.edit_prediction_provider() {
provider.accept(cx);
}
@@ -7522,10 +7532,8 @@ impl Editor {
return;
}
- self.report_edit_prediction_event(active_edit_prediction.completion_id.clone(), true, cx);
-
match &active_edit_prediction.completion {
- EditPrediction::Move { target, .. } => {
+ EditPrediction::MoveWithin { target, .. } => {
let target = *target;
self.change_selections(
SelectionEffects::scroll(Autoscroll::newest()),
@@ -7536,7 +7544,19 @@ impl Editor {
},
);
}
+ EditPrediction::MoveOutside { snapshot, target } => {
+ if let Some(workspace) = self.workspace() {
+ Self::open_editor_at_anchor(snapshot, *target, &workspace, window, cx)
+ .detach_and_log_err(cx);
+ }
+ }
EditPrediction::Edit { edits, .. } => {
+ self.report_edit_prediction_event(
+ active_edit_prediction.completion_id.clone(),
+ true,
+ cx,
+ );
+
// Find an insertion that starts at the cursor position.
let snapshot = self.buffer.read(cx).snapshot(cx);
let cursor_offset = self.selections.newest::<usize>(cx).head();
@@ -7631,6 +7651,36 @@ impl Editor {
);
}
+ fn open_editor_at_anchor(
+ snapshot: &language::BufferSnapshot,
+ target: language::Anchor,
+ workspace: &Entity<Workspace>,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> Task<Result<()>> {
+ workspace.update(cx, |workspace, cx| {
+ let path = snapshot.file().map(|file| file.full_path(cx));
+ let Some(path) =
+ path.and_then(|path| workspace.project().read(cx).find_project_path(path, cx))
+ else {
+ return Task::ready(Err(anyhow::anyhow!("Project path not found")));
+ };
+ let target = text::ToPoint::to_point(&target, snapshot);
+ let item = workspace.open_path(path, None, true, window, cx);
+ window.spawn(cx, async move |cx| {
+ let Some(editor) = item.await?.downcast::<Editor>() else {
+ return Ok(());
+ };
+ editor
+ .update_in(cx, |editor, window, cx| {
+ editor.go_to_singleton_buffer_point(target, window, cx);
+ })
+ .ok();
+ anyhow::Ok(())
+ })
+ })
+ }
+
pub fn has_active_edit_prediction(&self) -> bool {
self.active_edit_prediction.is_some()
}
@@ -7846,7 +7896,10 @@ impl Editor {
.active_edit_prediction
.as_ref()
.is_some_and(|completion| {
- let invalidation_range = completion.invalidation_range.to_offset(&multibuffer);
+ let Some(invalidation_range) = completion.invalidation_range.as_ref() else {
+ return false;
+ };
+ let invalidation_range = invalidation_range.to_offset(&multibuffer);
let invalidation_range = invalidation_range.start..=invalidation_range.end;
!invalidation_range.contains(&offset_selection.head())
})
@@ -7882,8 +7935,31 @@ impl Editor {
}
let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?;
- let edits = edit_prediction
- .edits
+
+ let (completion_id, edits, edit_preview) = match edit_prediction {
+ edit_prediction::EditPrediction::Local {
+ id,
+ edits,
+ edit_preview,
+ } => (id, edits, edit_preview),
+ edit_prediction::EditPrediction::Jump {
+ id,
+ snapshot,
+ target,
+ } => {
+ self.stale_edit_prediction_in_menu = None;
+ self.active_edit_prediction = Some(EditPredictionState {
+ inlay_ids: vec![],
+ completion: EditPrediction::MoveOutside { snapshot, target },
+ completion_id: id,
+ invalidation_range: None,
+ });
+ cx.notify();
+ return Some(());
+ }
+ };
+
+ let edits = edits
.into_iter()
.flat_map(|(range, new_text)| {
let start = multibuffer.anchor_in_excerpt(excerpt_id, range.start)?;
@@ -7928,7 +8004,7 @@ impl Editor {
invalidation_row_range =
move_invalidation_row_range.unwrap_or(edit_start_row..edit_end_row);
let target = first_edit_start;
- EditPrediction::Move { target, snapshot }
+ EditPrediction::MoveWithin { target, snapshot }
} else {
let show_completions_in_buffer = !self.edit_prediction_visible_in_cursor_popover(true)
&& !self.edit_predictions_hidden_for_vim_mode;
@@ -7977,7 +8053,7 @@ impl Editor {
EditPrediction::Edit {
edits,
- edit_preview: edit_prediction.edit_preview,
+ edit_preview,
display_mode,
snapshot,
}
@@ -7994,8 +8070,8 @@ impl Editor {
self.active_edit_prediction = Some(EditPredictionState {
inlay_ids,
completion,
- completion_id: edit_prediction.id,
- invalidation_range,
+ completion_id,
+ invalidation_range: Some(invalidation_range),
});
cx.notify();
@@ -8581,7 +8657,7 @@ impl Editor {
}
match &active_edit_prediction.completion {
- EditPrediction::Move { target, .. } => {
+ EditPrediction::MoveWithin { target, .. } => {
let target_display_point = target.to_display_point(editor_snapshot);
if self.edit_prediction_requires_modifier() {
@@ -8666,6 +8742,28 @@ impl Editor {
window,
cx,
),
+ EditPrediction::MoveOutside { snapshot, .. } => {
+ let file_name = snapshot
+ .file()
+ .map(|file| file.file_name(cx))
+ .unwrap_or("untitled");
+ let mut element = self
+ .render_edit_prediction_line_popover(
+ format!("Jump to {file_name}"),
+ Some(IconName::ZedPredict),
+ window,
+ cx,
+ )
+ .into_any();
+
+ let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
+ let origin_x = text_bounds.size.width / 2. - size.width / 2.;
+ let origin_y = text_bounds.size.height - size.height - px(30.);
+ let origin = text_bounds.origin + gpui::Point::new(origin_x, origin_y);
+ element.prepaint_at(origin, window, cx);
+
+ Some((element, origin))
+ }
}
}
@@ -8730,13 +8828,13 @@ impl Editor {
.items_end()
.when(flag_on_right, |el| el.items_start())
.child(if flag_on_right {
- self.render_edit_prediction_line_popover("Jump", None, window, cx)?
+ self.render_edit_prediction_line_popover("Jump", None, window, cx)
.rounded_bl(px(0.))
.rounded_tl(px(0.))
.border_l_2()
.border_color(border_color)
} else {
- self.render_edit_prediction_line_popover("Jump", None, window, cx)?
+ self.render_edit_prediction_line_popover("Jump", None, window, cx)
.rounded_br(px(0.))
.rounded_tr(px(0.))
.border_r_2()
@@ -8776,7 +8874,7 @@ impl Editor {
cx: &mut App,
) -> Option<(AnyElement, gpui::Point<Pixels>)> {
let mut element = self
- .render_edit_prediction_line_popover("Scroll", Some(scroll_icon), window, cx)?
+ .render_edit_prediction_line_popover("Scroll", Some(scroll_icon), window, cx)
.into_any();
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@@ -8816,7 +8914,7 @@ impl Editor {
Some(IconName::ArrowUp),
window,
cx,
- )?
+ )
.into_any();
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@@ -8835,7 +8933,7 @@ impl Editor {
Some(IconName::ArrowDown),
window,
cx,
- )?
+ )
.into_any();
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@@ -8882,7 +8980,7 @@ impl Editor {
);
let mut element = self
- .render_edit_prediction_line_popover(label, None, window, cx)?
+ .render_edit_prediction_line_popover(label, None, window, cx)
.into_any();
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@@ -8909,7 +9007,7 @@ impl Editor {
};
element = self
- .render_edit_prediction_line_popover(label, Some(icon), window, cx)?
+ .render_edit_prediction_line_popover(label, Some(icon), window, cx)
.into_any();
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
@@ -9163,13 +9261,13 @@ impl Editor {
icon: Option<IconName>,
window: &mut Window,
cx: &App,
- ) -> Option<Stateful<Div>> {
+ ) -> Stateful<Div> {
let padding_right = if icon.is_some() { px(4.) } else { px(8.) };
let keybind = self.render_edit_prediction_accept_keybind(window, cx);
let has_keybind = keybind.is_some();
- let result = h_flex()
+ h_flex()
.id("ep-line-popover")
.py_0p5()
.pl_1()
@@ -9215,9 +9313,7 @@ impl Editor {
.mt(px(1.5))
.child(Icon::new(icon).size(IconSize::Small)),
)
- });
-
- Some(result)
+ })
}
fn edit_prediction_line_popover_bg_color(cx: &App) -> Hsla {
@@ -9281,7 +9377,7 @@ impl Editor {
.rounded_tl(px(0.))
.overflow_hidden()
.child(div().px_1p5().child(match &prediction.completion {
- EditPrediction::Move { target, snapshot } => {
+ EditPrediction::MoveWithin { target, snapshot } => {
use text::ToPoint as _;
if target.text_anchor.to_point(snapshot).row > cursor_point.row
{
@@ -9290,6 +9386,10 @@ impl Editor {
Icon::new(IconName::ZedPredictUp)
}
}
+ EditPrediction::MoveOutside { .. } => {
+ // TODO [zeta2] custom icon for external jump?
+ Icon::new(provider_icon)
+ }
EditPrediction::Edit { .. } => Icon::new(provider_icon),
}))
.child(
@@ -9472,7 +9572,7 @@ impl Editor {
.unwrap_or(true);
match &completion.completion {
- EditPrediction::Move {
+ EditPrediction::MoveWithin {
target, snapshot, ..
} => {
if !supports_jump {
@@ -9494,7 +9594,20 @@ impl Editor {
.child(Label::new("Jump to Edit")),
)
}
-
+ EditPrediction::MoveOutside { snapshot, .. } => {
+ let file_name = snapshot
+ .file()
+ .map(|file| file.file_name(cx))
+ .unwrap_or("untitled");
+ Some(
+ h_flex()
+ .px_2()
+ .gap_2()
+ .flex_1()
+ .child(Icon::new(IconName::ZedPredict))
+ .child(Label::new(format!("Jump to {file_name}"))),
+ )
+ }
EditPrediction::Edit {
edits,
edit_preview,
@@ -21418,7 +21531,7 @@ impl Editor {
{
self.hide_context_menu(window, cx);
}
- self.discard_edit_prediction(false, cx);
+ self.take_active_edit_prediction(cx);
cx.emit(EditorEvent::Blurred);
cx.notify();
}
@@ -8272,7 +8272,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext)
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
- provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
+ provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
id: None,
edits: vec![(edit_position..edit_position, "X".into())],
edit_preview: None,
@@ -22,7 +22,6 @@ gpui.workspace = true
language.workspace = true
log.workspace = true
postage.workspace = true
-project.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
@@ -4,7 +4,6 @@ use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
use futures::StreamExt as _;
use gpui::{App, Context, Entity, EntityId, Task};
use language::{Anchor, Buffer, BufferSnapshot};
-use project::Project;
use std::{
ops::{AddAssign, Range},
path::Path,
@@ -94,7 +93,7 @@ fn completion_from_diff(
edits.push((edit_range, edit_text));
}
- EditPrediction {
+ EditPrediction::Local {
id: None,
edits,
edit_preview: None,
@@ -132,7 +131,6 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
fn refresh(
&mut self,
- _project: Option<Entity<Project>>,
buffer_handle: Entity<Buffer>,
cursor_position: Anchor,
debounce: bool,
@@ -205,42 +205,48 @@ fn assign_edit_prediction_provider(
}
}
- if std::env::var("ZED_ZETA2").is_ok() {
- let zeta = zeta2::Zeta::global(client, &user_store, cx);
- let provider = cx.new(|cx| {
- zeta2::ZetaEditPredictionProvider::new(
- editor.project(),
- &client,
- &user_store,
- cx,
- )
- });
-
- if let Some(buffer) = &singleton_buffer
- && buffer.read(cx).file().is_some()
- && let Some(project) = editor.project()
- {
- zeta.update(cx, |zeta, cx| {
- zeta.register_buffer(buffer, project, cx);
+ if let Some(project) = editor.project() {
+ if std::env::var("ZED_ZETA2").is_ok() {
+ let zeta = zeta2::Zeta::global(client, &user_store, cx);
+ let provider = cx.new(|cx| {
+ zeta2::ZetaEditPredictionProvider::new(
+ project.clone(),
+ &client,
+ &user_store,
+ cx,
+ )
});
- }
- editor.set_edit_prediction_provider(Some(provider), window, cx);
- } else {
- let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
-
- if let Some(buffer) = &singleton_buffer
- && buffer.read(cx).file().is_some()
- && let Some(project) = editor.project()
- {
- zeta.update(cx, |zeta, cx| {
- zeta.register_buffer(buffer, project, cx);
+ // TODO [zeta2] handle multibuffers
+ if let Some(buffer) = &singleton_buffer
+ && buffer.read(cx).file().is_some()
+ {
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_buffer(buffer, project, cx);
+ });
+ }
+
+ editor.set_edit_prediction_provider(Some(provider), window, cx);
+ } else {
+ let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
+
+ if let Some(buffer) = &singleton_buffer
+ && buffer.read(cx).file().is_some()
+ {
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_buffer(buffer, project, cx);
+ });
+ }
+
+ let provider = cx.new(|_| {
+ zeta::ZetaEditPredictionProvider::new(
+ zeta,
+ project.clone(),
+ singleton_buffer,
+ )
});
+ editor.set_edit_prediction_provider(Some(provider), window, cx);
}
-
- let provider =
- cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer));
- editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
}
@@ -1316,12 +1316,17 @@ pub struct ZetaEditPredictionProvider {
next_pending_completion_id: usize,
current_completion: Option<CurrentEditPrediction>,
last_request_timestamp: Instant,
+ project: Entity<Project>,
}
impl ZetaEditPredictionProvider {
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
- pub fn new(zeta: Entity<Zeta>, singleton_buffer: Option<Entity<Buffer>>) -> Self {
+ pub fn new(
+ zeta: Entity<Zeta>,
+ project: Entity<Project>,
+ singleton_buffer: Option<Entity<Buffer>>,
+ ) -> Self {
Self {
zeta,
singleton_buffer,
@@ -1329,6 +1334,7 @@ impl ZetaEditPredictionProvider {
next_pending_completion_id: 0,
current_completion: None,
last_request_timestamp: Instant::now(),
+ project,
}
}
}
@@ -1394,7 +1400,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
fn refresh(
&mut self,
- project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
position: language::Anchor,
_debounce: bool,
@@ -1403,9 +1408,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
if self.zeta.read(cx).update_required {
return;
}
- let Some(project) = project else {
- return;
- };
if self
.zeta
@@ -1433,6 +1435,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
self.next_pending_completion_id += 1;
let last_request_timestamp = self.last_request_timestamp;
+ let project = self.project.clone();
let task = cx.spawn(async move |this, cx| {
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
.checked_duration_since(Instant::now())
@@ -1604,7 +1607,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
}
}
- Some(edit_prediction::EditPrediction {
+ Some(edit_prediction::EditPrediction::Local {
id: Some(completion.id.to_string().into()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
edit_preview: Some(completion.edit_preview.clone()),
@@ -1,50 +1,146 @@
-use std::{borrow::Cow, ops::Range, sync::Arc};
+use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
+use anyhow::Context as _;
use cloud_llm_client::predict_edits_v3;
-use language::{Anchor, BufferSnapshot, EditPreview, OffsetRangeExt, text_diff};
+use gpui::{App, AsyncApp, Entity};
+use language::{
+ Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
+};
+use project::Project;
+use util::ResultExt;
use uuid::Uuid;
+#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
+pub struct EditPredictionId(Uuid);
+
+impl From<EditPredictionId> for gpui::ElementId {
+ fn from(value: EditPredictionId) -> Self {
+ gpui::ElementId::Uuid(value.0)
+ }
+}
+
+impl std::fmt::Display for EditPredictionId {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
#[derive(Clone)]
pub struct EditPrediction {
pub id: EditPredictionId,
+ pub path: Arc<Path>,
pub edits: Arc<[(Range<Anchor>, String)]>,
pub snapshot: BufferSnapshot,
pub edit_preview: EditPreview,
+ // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
+ _buffer: Entity<Buffer>,
}
impl EditPrediction {
+ pub async fn from_response(
+ response: predict_edits_v3::PredictEditsResponse,
+ active_buffer_old_snapshot: &TextBufferSnapshot,
+ active_buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+ ) -> Option<Self> {
+ // TODO only allow cloud to return one path
+ let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
+ return None;
+ };
+
+ let is_same_path = active_buffer
+ .read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
+ .ok()?;
+
+ let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
+ active_buffer
+ .read_with(cx, |buffer, cx| {
+ let new_snapshot = buffer.snapshot();
+ let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
+ let edits: Arc<[_]> =
+ interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
+
+ Some((
+ active_buffer.clone(),
+ edits.clone(),
+ new_snapshot,
+ buffer.preview_edits(edits, cx),
+ ))
+ })
+ .ok()??
+ } else {
+ let buffer_handle = project
+ .update(cx, |project, cx| {
+ let project_path = project
+ .find_project_path(&path, cx)
+ .context("Failed to find project path for zeta edit")?;
+ anyhow::Ok(project.open_buffer(project_path, cx))
+ })
+ .ok()?
+ .log_err()?
+ .await
+ .context("Failed to open buffer for zeta edit")
+ .log_err()?;
+
+ buffer_handle
+ .read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot();
+ let edits = edits_from_response(&response.edits, &snapshot);
+ if edits.is_empty() {
+ return None;
+ }
+ Some((
+ buffer_handle.clone(),
+ edits.clone(),
+ snapshot,
+ buffer.preview_edits(edits, cx),
+ ))
+ })
+ .ok()??
+ };
+
+ let edit_preview = edit_preview_task.await;
+
+ Some(EditPrediction {
+ id: EditPredictionId(response.request_id),
+ path,
+ edits,
+ snapshot,
+ edit_preview,
+ _buffer: buffer,
+ })
+ }
+
pub fn interpolate(
&self,
- new_snapshot: &BufferSnapshot,
+ new_snapshot: &TextBufferSnapshot,
) -> Option<Vec<(Range<Anchor>, String)>> {
interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
}
-}
-#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
-pub struct EditPredictionId(Uuid);
-
-impl From<Uuid> for EditPredictionId {
- fn from(value: Uuid) -> Self {
- EditPredictionId(value)
+ pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
+ buffer_path_eq(buffer, &self.path, cx)
}
}
-impl From<EditPredictionId> for gpui::ElementId {
- fn from(value: EditPredictionId) -> Self {
- gpui::ElementId::Uuid(value.0)
+impl std::fmt::Debug for EditPrediction {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("EditPrediction")
+ .field("id", &self.id)
+ .field("path", &self.path)
+ .field("edits", &self.edits)
+ .finish()
}
}
-impl std::fmt::Display for EditPredictionId {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
+pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
+ buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
}
pub fn interpolate_edits(
- old_snapshot: &BufferSnapshot,
- new_snapshot: &BufferSnapshot,
+ old_snapshot: &TextBufferSnapshot,
+ new_snapshot: &TextBufferSnapshot,
current_edits: Arc<[(Range<Anchor>, String)]>,
) -> Option<Vec<(Range<Anchor>, String)>> {
let mut edits = Vec::new();
@@ -88,14 +184,13 @@ pub fn interpolate_edits(
if edits.is_empty() { None } else { Some(edits) }
}
-pub fn edits_from_response(
+fn edits_from_response(
edits: &[predict_edits_v3::Edit],
- snapshot: &BufferSnapshot,
+ snapshot: &TextBufferSnapshot,
) -> Arc<[(Range<Anchor>, String)]> {
edits
.iter()
.flat_map(|edit| {
- // TODO multi-file edits
let old_text = snapshot.text_for_range(edit.range.clone());
excerpt_edits_from_response(
@@ -113,7 +208,7 @@ fn excerpt_edits_from_response(
old_text: Cow<str>,
new_text: &str,
offset: usize,
- snapshot: &BufferSnapshot,
+ snapshot: &TextBufferSnapshot,
) -> impl Iterator<Item = (Range<Anchor>, String)> {
text_diff(&old_text, new_text)
.into_iter()
@@ -221,6 +316,8 @@ mod tests {
id: EditPredictionId(Uuid::new_v4()),
edits,
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
+ path: Path::new("test.txt").into(),
+ _buffer: buffer.clone(),
edit_preview,
};
@@ -4,76 +4,44 @@ use std::{
time::{Duration, Instant},
};
-use anyhow::Context as _;
use arrayvec::ArrayVec;
use client::{Client, UserStore};
use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
-use gpui::{App, Entity, EntityId, Task, prelude::*};
-use language::{BufferSnapshot, ToPoint as _};
+use gpui::{App, Entity, Task, prelude::*};
+use language::ToPoint as _;
use project::Project;
use util::ResultExt as _;
-use crate::{Zeta, prediction::EditPrediction};
+use crate::{BufferEditPrediction, Zeta};
pub struct ZetaEditPredictionProvider {
zeta: Entity<Zeta>,
- current_prediction: Option<CurrentEditPrediction>,
next_pending_prediction_id: usize,
pending_predictions: ArrayVec<PendingPrediction, 2>,
last_request_timestamp: Instant,
+ project: Entity<Project>,
}
impl ZetaEditPredictionProvider {
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
pub fn new(
- project: Option<&Entity<Project>>,
+ project: Entity<Project>,
client: &Arc<Client>,
user_store: &Entity<UserStore>,
cx: &mut App,
) -> Self {
let zeta = Zeta::global(client, user_store, cx);
- if let Some(project) = project {
- zeta.update(cx, |zeta, cx| {
- zeta.register_project(project, cx);
- });
- }
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_project(&project, cx);
+ });
Self {
zeta,
- current_prediction: None,
next_pending_prediction_id: 0,
pending_predictions: ArrayVec::new(),
last_request_timestamp: Instant::now(),
- }
- }
-}
-
-#[derive(Clone)]
-struct CurrentEditPrediction {
- buffer_id: EntityId,
- prediction: EditPrediction,
-}
-
-impl CurrentEditPrediction {
- fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
- if self.buffer_id != old_prediction.buffer_id {
- return true;
- }
-
- let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
- return true;
- };
- let Some(new_edits) = self.prediction.interpolate(snapshot) else {
- return false;
- };
-
- if old_edits.len() == 1 && new_edits.len() == 1 {
- let (old_range, old_text) = &old_edits[0];
- let (new_range, new_text) = &new_edits[0];
- new_range == old_range && new_text.starts_with(old_text)
- } else {
- true
+ project: project,
}
}
}
@@ -128,42 +96,31 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
fn refresh(
&mut self,
- project: Option<Entity<project::Project>>,
buffer: Entity<language::Buffer>,
cursor_position: language::Anchor,
_debounce: bool,
cx: &mut Context<Self>,
) {
- let Some(project) = project else {
- return;
- };
+ let zeta = self.zeta.read(cx);
- if self
- .zeta
- .read(cx)
- .user_store
- .read_with(cx, |user_store, _cx| {
- user_store.account_too_young() || user_store.has_overdue_invoices()
- })
- {
+ if zeta.user_store.read_with(cx, |user_store, _cx| {
+ user_store.account_too_young() || user_store.has_overdue_invoices()
+ }) {
return;
}
- if let Some(current_prediction) = self.current_prediction.as_ref() {
- let snapshot = buffer.read(cx).snapshot();
- if current_prediction
- .prediction
- .interpolate(&snapshot)
- .is_some()
- {
- return;
- }
+ if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
+ && let BufferEditPrediction::Local { prediction } = current
+ && prediction.interpolate(buffer.read(cx)).is_some()
+ {
+ return;
}
let pending_prediction_id = self.next_pending_prediction_id;
self.next_pending_prediction_id += 1;
let last_request_timestamp = self.last_request_timestamp;
+ let project = self.project.clone();
let task = cx.spawn(async move |this, cx| {
if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
.checked_duration_since(Instant::now())
@@ -171,25 +128,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
cx.background_executor().timer(timeout).await;
}
- let prediction_request = this.update(cx, |this, cx| {
+ let refresh_task = this.update(cx, |this, cx| {
this.last_request_timestamp = Instant::now();
this.zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &buffer, cursor_position, cx)
+ zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
})
});
- let prediction = match prediction_request {
- Ok(prediction_request) => {
- let prediction_request = prediction_request.await;
- prediction_request.map(|c| {
- c.map(|prediction| CurrentEditPrediction {
- buffer_id: buffer.entity_id(),
- prediction,
- })
- })
- }
- Err(error) => Err(error),
- };
+ if let Some(refresh_task) = refresh_task.ok() {
+ refresh_task.await.log_err();
+ }
this.update(cx, |this, cx| {
if this.pending_predictions[0].id == pending_prediction_id {
@@ -198,24 +146,6 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
this.pending_predictions.clear();
}
- let Some(new_prediction) = prediction
- .context("edit prediction failed")
- .log_err()
- .flatten()
- else {
- cx.notify();
- return;
- };
-
- if let Some(old_prediction) = this.current_prediction.as_ref() {
- let snapshot = buffer.read(cx).snapshot();
- if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
- this.current_prediction = Some(new_prediction);
- }
- } else {
- this.current_prediction = Some(new_prediction);
- }
-
cx.notify();
})
.ok();
@@ -248,15 +178,18 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
) {
}
- fn accept(&mut self, _cx: &mut Context<Self>) {
- // TODO [zeta2] report accept
- self.current_prediction.take();
+ fn accept(&mut self, cx: &mut Context<Self>) {
+ self.zeta.update(cx, |zeta, _cx| {
+ zeta.accept_current_prediction(&self.project);
+ });
self.pending_predictions.clear();
}
- fn discard(&mut self, _cx: &mut Context<Self>) {
+ fn discard(&mut self, cx: &mut Context<Self>) {
+ self.zeta.update(cx, |zeta, _cx| {
+ zeta.discard_current_prediction(&self.project);
+ });
self.pending_predictions.clear();
- self.current_prediction.take();
}
fn suggest(
@@ -265,36 +198,44 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
cursor_position: language::Anchor,
cx: &mut Context<Self>,
) -> Option<edit_prediction::EditPrediction> {
- let CurrentEditPrediction {
- buffer_id,
- prediction,
- ..
- } = self.current_prediction.as_mut()?;
-
- // Invalidate previous prediction if it was generated for a different buffer.
- if *buffer_id != buffer.entity_id() {
- self.current_prediction.take();
- return None;
- }
+ let prediction =
+ self.zeta
+ .read(cx)
+ .current_prediction_for_buffer(buffer, &self.project, cx)?;
+
+ let prediction = match prediction {
+ BufferEditPrediction::Local { prediction } => prediction,
+ BufferEditPrediction::Jump { prediction } => {
+ return Some(edit_prediction::EditPrediction::Jump {
+ id: Some(prediction.id.to_string().into()),
+ snapshot: prediction.snapshot.clone(),
+ target: prediction.edits.first().unwrap().0.start,
+ });
+ }
+ };
let buffer = buffer.read(cx);
- let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
- self.current_prediction.take();
+ let snapshot = buffer.snapshot();
+
+ let Some(edits) = prediction.interpolate(&snapshot) else {
+ self.zeta.update(cx, |zeta, _cx| {
+ zeta.discard_current_prediction(&self.project);
+ });
return None;
};
- let cursor_row = cursor_position.to_point(buffer).row;
+ let cursor_row = cursor_position.to_point(&snapshot).row;
let (closest_edit_ix, (closest_edit_range, _)) =
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
- let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
- let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
+ let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
+ let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
cmp::min(distance_from_start, distance_from_end)
})?;
let mut edit_start_ix = closest_edit_ix;
for (range, _) in edits[..edit_start_ix].iter().rev() {
- let distance_from_closest_edit =
- closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
+ let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
+ - range.end.to_point(&snapshot).row;
if distance_from_closest_edit <= 1 {
edit_start_ix -= 1;
} else {
@@ -305,7 +246,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
let mut edit_end_ix = closest_edit_ix + 1;
for (range, _) in &edits[edit_end_ix..] {
let distance_from_closest_edit =
- range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
+ range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
if distance_from_closest_edit <= 1 {
edit_end_ix += 1;
} else {
@@ -313,7 +254,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
}
}
- Some(edit_prediction::EditPrediction {
+ Some(edit_prediction::EditPrediction::Local {
id: Some(prediction.id.to_string().into()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
edit_preview: Some(prediction.edit_preview.clone()),
@@ -17,8 +17,8 @@ use gpui::{
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
http_client, prelude::*,
};
-use language::BufferSnapshot;
use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
+use language::{BufferSnapshot, TextBufferSnapshot};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use project::Project;
use release_channel::AppVersion;
@@ -35,7 +35,7 @@ use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_noti
mod prediction;
mod provider;
-use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
+use crate::prediction::EditPrediction;
pub use provider::ZetaEditPredictionProvider;
const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
@@ -53,7 +53,7 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
excerpt: DEFAULT_EXCERPT_OPTIONS,
max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
max_diagnostic_bytes: 2048,
- prompt_format: PromptFormat::MarkedExcerpt,
+ prompt_format: PromptFormat::DEFAULT,
};
#[derive(Clone)]
@@ -94,6 +94,47 @@ struct ZetaProject {
syntax_index: Entity<SyntaxIndex>,
events: VecDeque<Event>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
+ current_prediction: Option<CurrentEditPrediction>,
+}
+
+#[derive(Clone)]
+struct CurrentEditPrediction {
+ pub requested_by_buffer_id: EntityId,
+ pub prediction: EditPrediction,
+}
+
+impl CurrentEditPrediction {
+ fn should_replace_prediction(
+ &self,
+ old_prediction: &Self,
+ snapshot: &TextBufferSnapshot,
+ ) -> bool {
+ if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
+ return true;
+ }
+
+ let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
+ return true;
+ };
+
+ let Some(new_edits) = self.prediction.interpolate(snapshot) else {
+ return false;
+ };
+ if old_edits.len() == 1 && new_edits.len() == 1 {
+ let (old_range, old_text) = &old_edits[0];
+ let (new_range, new_text) = &new_edits[0];
+ new_range == old_range && new_text.starts_with(old_text)
+ } else {
+ true
+ }
+ }
+}
+
+/// A prediction from the perspective of a buffer.
+#[derive(Debug)]
+enum BufferEditPrediction<'a> {
+ Local { prediction: &'a EditPrediction },
+ Jump { prediction: &'a EditPrediction },
}
struct RegisteredBuffer {
@@ -204,6 +245,7 @@ impl Zeta {
syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
events: VecDeque::new(),
registered_buffers: HashMap::new(),
+ current_prediction: None,
})
}
@@ -305,7 +347,83 @@ impl Zeta {
events.push_back(event);
}
- pub fn request_prediction(
+ fn current_prediction_for_buffer(
+ &self,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ cx: &App,
+ ) -> Option<BufferEditPrediction<'_>> {
+ let project_state = self.projects.get(&project.entity_id())?;
+
+ let CurrentEditPrediction {
+ requested_by_buffer_id,
+ prediction,
+ } = project_state.current_prediction.as_ref()?;
+
+ if prediction.targets_buffer(buffer.read(cx), cx) {
+ Some(BufferEditPrediction::Local { prediction })
+ } else if *requested_by_buffer_id == buffer.entity_id() {
+ Some(BufferEditPrediction::Jump { prediction })
+ } else {
+ None
+ }
+ }
+
+ fn accept_current_prediction(&mut self, project: &Entity<Project>) {
+ if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
+ project_state.current_prediction.take();
+ };
+ // TODO report accepted
+ }
+
+ fn discard_current_prediction(&mut self, project: &Entity<Project>) {
+ if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
+ project_state.current_prediction.take();
+ };
+ }
+
+ pub fn refresh_prediction(
+ &mut self,
+ project: &Entity<Project>,
+ buffer: &Entity<Buffer>,
+ position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<()>> {
+ let request_task = self.request_prediction(project, buffer, position, cx);
+ let buffer = buffer.clone();
+ let project = project.clone();
+
+ cx.spawn(async move |this, cx| {
+ if let Some(prediction) = request_task.await? {
+ this.update(cx, |this, cx| {
+ let project_state = this
+ .projects
+ .get_mut(&project.entity_id())
+ .context("Project not found")?;
+
+ let new_prediction = CurrentEditPrediction {
+ requested_by_buffer_id: buffer.entity_id(),
+ prediction: prediction,
+ };
+
+ if project_state
+ .current_prediction
+ .as_ref()
+ .is_none_or(|old_prediction| {
+ new_prediction
+ .should_replace_prediction(&old_prediction, buffer.read(cx))
+ })
+ {
+ project_state.current_prediction = Some(new_prediction);
+ }
+ anyhow::Ok(())
+ })??;
+ }
+ Ok(())
+ })
+ }
+
+ fn request_prediction(
&mut self,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
@@ -457,74 +575,63 @@ impl Zeta {
.ok();
}
- let (response, usage) = response?;
- let edits = edits_from_response(&response.edits, &snapshot);
-
- anyhow::Ok(Some((response.request_id, edits, usage)))
+ anyhow::Ok(Some(response?))
}
});
let buffer = buffer.clone();
- cx.spawn(async move |this, cx| {
- match request_task.await {
- Ok(Some((id, edits, usage))) => {
- if let Some(usage) = usage {
- this.update(cx, |this, cx| {
- this.user_store.update(cx, |user_store, cx| {
- user_store.update_edit_prediction_usage(usage, cx);
- });
- })
- .ok();
- }
+ cx.spawn({
+ let project = project.clone();
+ async move |this, cx| {
+ match request_task.await {
+ Ok(Some((response, usage))) => {
+ if let Some(usage) = usage {
+ this.update(cx, |this, cx| {
+ this.user_store.update(cx, |user_store, cx| {
+ user_store.update_edit_prediction_usage(usage, cx);
+ });
+ })
+ .ok();
+ }
- // TODO telemetry: duration, etc
- let Some((edits, snapshot, edit_preview_task)) =
- buffer.read_with(cx, |buffer, cx| {
- let new_snapshot = buffer.snapshot();
- let edits: Arc<[_]> =
- interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
- Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
- })?
- else {
- return Ok(None);
- };
+ let prediction = EditPrediction::from_response(
+ response, &snapshot, &buffer, &project, cx,
+ )
+ .await;
- Ok(Some(EditPrediction {
- id: id.into(),
- edits,
- snapshot,
- edit_preview: edit_preview_task.await,
- }))
- }
- Ok(None) => Ok(None),
- Err(err) => {
- if err.is::<ZedUpdateRequiredError>() {
- cx.update(|cx| {
- this.update(cx, |this, _cx| {
- this.update_required = true;
+ // TODO telemetry: duration, etc
+ Ok(prediction)
+ }
+ Ok(None) => Ok(None),
+ Err(err) => {
+ if err.is::<ZedUpdateRequiredError>() {
+ cx.update(|cx| {
+ this.update(cx, |this, _cx| {
+ this.update_required = true;
+ })
+ .ok();
+
+ let error_message: SharedString = err.to_string().into();
+ show_app_notification(
+ NotificationId::unique::<ZedUpdateRequiredError>(),
+ cx,
+ move |cx| {
+ cx.new(|cx| {
+ ErrorMessagePrompt::new(error_message.clone(), cx)
+ .with_link_button(
+ "Update Zed",
+ "https://zed.dev/releases",
+ )
+ })
+ },
+ );
})
.ok();
+ }
- let error_message: SharedString = err.to_string().into();
- show_app_notification(
- NotificationId::unique::<ZedUpdateRequiredError>(),
- cx,
- move |cx| {
- cx.new(|cx| {
- ErrorMessagePrompt::new(error_message.clone(), cx)
- .with_link_button(
- "Update Zed",
- "https://zed.dev/releases",
- )
- })
- },
- );
- })
- .ok();
+ Err(err)
}
-
- Err(err)
}
}
})
@@ -859,13 +966,113 @@ mod tests {
};
use indoc::indoc;
use language::{LanguageServerId, OffsetRangeExt as _};
+ use pretty_assertions::{assert_eq, assert_matches};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
use uuid::Uuid;
- use crate::Zeta;
+ use crate::{BufferEditPrediction, Zeta};
+
+ #[gpui::test]
+ async fn test_current_state(cx: &mut TestAppContext) {
+ let (zeta, mut req_rx) = init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "1.txt": "Hello!\nHow\nBye",
+ "2.txt": "Hola!\nComo\nAdios"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_project(&project, cx);
+ });
+
+ let buffer1 = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot1.anchor_before(language::Point::new(1, 3));
+
+ // Prediction for current file
+
+ let prediction_task = zeta.update(cx, |zeta, cx| {
+ zeta.refresh_prediction(&project, &buffer1, position, cx)
+ });
+ let (_request, respond_tx) = req_rx.next().await.unwrap();
+ respond_tx
+ .send(predict_edits_v3::PredictEditsResponse {
+ request_id: Uuid::new_v4(),
+ edits: vec![predict_edits_v3::Edit {
+ path: Path::new(path!("root/1.txt")).into(),
+ range: 0..snapshot1.len(),
+ content: "Hello!\nHow are you?\nBye".into(),
+ }],
+ debug_info: None,
+ })
+ .unwrap();
+ prediction_task.await.unwrap();
+
+ zeta.read_with(cx, |zeta, cx| {
+ let prediction = zeta
+ .current_prediction_for_buffer(&buffer1, &project, cx)
+ .unwrap();
+ assert_matches!(prediction, BufferEditPrediction::Local { .. });
+ });
+
+ // Prediction for another file
+
+ let prediction_task = zeta.update(cx, |zeta, cx| {
+ zeta.refresh_prediction(&project, &buffer1, position, cx)
+ });
+ let (_request, respond_tx) = req_rx.next().await.unwrap();
+ respond_tx
+ .send(predict_edits_v3::PredictEditsResponse {
+ request_id: Uuid::new_v4(),
+ edits: vec![predict_edits_v3::Edit {
+ path: Path::new(path!("root/2.txt")).into(),
+ range: 0..snapshot1.len(),
+ content: "Hola!\nComo estas?\nAdios".into(),
+ }],
+ debug_info: None,
+ })
+ .unwrap();
+ prediction_task.await.unwrap();
+
+ zeta.read_with(cx, |zeta, cx| {
+ let prediction = zeta
+ .current_prediction_for_buffer(&buffer1, &project, cx)
+ .unwrap();
+ assert_matches!(
+ prediction,
+ BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
+ );
+ });
+
+ let buffer2 = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ zeta.read_with(cx, |zeta, cx| {
+ let prediction = zeta
+ .current_prediction_for_buffer(&buffer2, &project, cx)
+ .unwrap();
+ assert_matches!(prediction, BufferEditPrediction::Local { .. });
+ });
+ }
#[gpui::test]
async fn test_simple_request(cx: &mut TestAppContext) {
@@ -1146,6 +1353,7 @@ mod tests {
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let zeta = Zeta::global(&client, &user_store, cx);
+
(zeta, req_rx)
})
}
@@ -185,7 +185,7 @@ impl Zeta2Inspector {
cx.background_executor().timer(THROTTLE_TIME).await;
if let Some(task) = zeta
.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &buffer, position, cx)
+ zeta.refresh_prediction(&project, &buffer, position, cx)
})
.ok()
{