diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 0dd387e627a29fcd48b0523dd72990bbc05a5311..c7497fa11da3c7ec6a260aa6fe388d019e8fe24a 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -41,7 +41,7 @@ use settings::{ use std::collections::{VecDeque, hash_map}; use std::env; use text::{AnchorRangeExt, Edit}; -use workspace::Workspace; +use workspace::{AppState, Workspace}; use zeta_prompt::{ZetaFormat, ZetaPromptInput}; use std::mem; @@ -1912,6 +1912,10 @@ impl EditPredictionStore { return; } + if currently_following(&project, cx) { + return; + } + let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { return; }; @@ -2048,6 +2052,25 @@ impl EditPredictionStore { pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); } +fn currently_following(project: &Entity, cx: &App) -> bool { + let Some(app_state) = AppState::try_global(cx).and_then(|app_state| app_state.upgrade()) else { + return false; + }; + + app_state + .workspace_store + .read(cx) + .workspaces() + .filter_map(|workspace| workspace.upgrade()) + .any(|workspace| { + workspace.read(cx).project().entity_id() == project.entity_id() + && workspace + .read(cx) + .leader_for_pane(workspace.read(cx).active_pane()) + .is_some() + }) +} + fn is_ep_store_provider(provider: EditPredictionProvider) -> bool { match provider { EditPredictionProvider::Zed diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index f377f3f705f8d3e04fd4718bbfd650ae4189ba37..dc52ef6ab57428d6293cea126c695f7c659e2f53 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -8,6 +8,7 @@ use cloud_llm_client::{ EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody, predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response}, }; + use futures::{ AsyncReadExt, FutureExt, StreamExt, channel::{mpsc, oneshot}, @@ -35,11 +36,12 @@ use util::{ test::{TextRangeMarker, marked_text_ranges_by}, }; use uuid::Uuid; +use workspace::{AppState, CollaboratorId, MultiWorkspace}; use zeta_prompt::ZetaPromptInput; use crate::{ BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId, - EditPredictionStore, REJECT_REQUEST_DEBOUNCE, + EditPredictionJumpsFeatureFlag, EditPredictionStore, REJECT_REQUEST_DEBOUNCE, }; #[gpui::test] @@ -178,6 +180,172 @@ async fn test_current_state(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_diagnostics_refresh_suppressed_while_following(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + + cx.update(|cx| { + cx.update_flags( + false, + vec![EditPredictionJumpsFeatureFlag::NAME.to_string()], + ); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "1.txt": "Hello!\nHow\nBye\n", + "2.txt": "Hola!\nComo\nAdios\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let app_state = cx.update(|cx| { + let app_state = AppState::test(cx); + AppState::set_global(Arc::downgrade(&app_state), cx); + app_state + }); + + let multi_workspace = + cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + let workspace = multi_workspace + .read_with(cx, |multi_workspace, _| multi_workspace.workspace().clone()) + .unwrap(); + cx.update(|cx| { + AppState::set_global(Arc::downgrade(workspace.read(cx).app_state()), cx); + }); + let _ = app_state; + + let buffer1 = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/1.txt"), cx).unwrap(); + project.set_active_path(Some(path.clone()), cx); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot1.anchor_before(language::Point::new(1, 3)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.register_project(&project, cx); + ep_store.register_buffer(&buffer1, &project, cx); + ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx); + }); + + let (request, respond_tx) = requests.predict.next().await.unwrap(); + respond_tx + .send(model_response( + &request, + indoc! {r" + --- a/root/1.txt + +++ b/root/1.txt + @@ ... @@ + Hello! + -How + +How are you? + Bye + "}, + )) + .unwrap(); + cx.run_until_parked(); + + ep_store.update(cx, |ep_store, cx| { + ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx); + }); + + let _ = multi_workspace.update(cx, |multi_workspace, window, cx| { + multi_workspace.workspace().update(cx, |workspace, cx| { + workspace.start_following(CollaboratorId::Agent, window, cx); + }); + }); + cx.run_until_parked(); + + let diagnostic = lsp::Diagnostic { + range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: "Sentence is incomplete".to_string(), + ..Default::default() + }; + + project.update(cx, |project, cx| { + project.lsp_store().update(cx, |lsp_store, cx| { + lsp_store + .update_diagnostics( + LanguageServerId(0), + lsp::PublishDiagnosticsParams { + uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(), + diagnostics: vec![diagnostic.clone()], + version: None, + }, + None, + language::DiagnosticSourceKind::Pushed, + &[], + cx, + ) + .unwrap(); + }); + }); + + cx.run_until_parked(); + assert_no_predict_request_ready(&mut requests.predict); + + let _ = multi_workspace.update(cx, |multi_workspace, window, cx| { + multi_workspace.workspace().update(cx, |workspace, cx| { + workspace.unfollow(CollaboratorId::Agent, window, cx); + }); + }); + cx.run_until_parked(); + + project.update(cx, |project, cx| { + project.lsp_store().update(cx, |lsp_store, cx| { + lsp_store + .update_diagnostics( + LanguageServerId(0), + lsp::PublishDiagnosticsParams { + uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(), + diagnostics: vec![diagnostic], + version: None, + }, + None, + language::DiagnosticSourceKind::Pushed, + &[], + cx, + ) + .unwrap(); + }); + }); + + let (request, respond_tx) = requests.predict.next().await.unwrap(); + respond_tx + .send(model_response( + &request, + indoc! {r#" + --- a/root/2.txt + +++ b/root/2.txt + @@ ... @@ + Hola! + -Como + +Como estas? + Adios + "#}, + )) + .unwrap(); + cx.run_until_parked(); + + ep_store.update(cx, |ep_store, cx| { + let prediction = ep_store + .prediction_at(&buffer1, None, &project, cx) + .unwrap(); + assert_matches!( + prediction, + BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt")) + ); + }); +} + #[gpui::test] async fn test_simple_request(cx: &mut TestAppContext) { let (ep_store, mut requests) = init_test_with_fake_client(cx); diff --git a/crates/editor/src/edit_prediction_tests.rs b/crates/editor/src/edit_prediction_tests.rs index a997a5f86dfbd3582c0566b8e3351777e0345219..c82915c686e977178398430948f28f8178f216df 100644 --- a/crates/editor/src/edit_prediction_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -4,7 +4,13 @@ use edit_prediction_types::{ use gpui::{Entity, KeyBinding, Modifiers, prelude::*}; use indoc::indoc; use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint}; -use std::{ops::Range, sync::Arc}; +use std::{ + ops::Range, + sync::{ + Arc, + atomic::{self, AtomicUsize}, + }, +}; use text::{Point, ToOffset}; use ui::prelude::*; @@ -12,6 +18,8 @@ use crate::{ AcceptEditPrediction, EditPrediction, MenuEditPredictionsPolicy, editor_tests::init_test, test::editor_test_context::EditorTestContext, }; +use rpc::proto::PeerId; +use workspace::CollaboratorId; #[gpui::test] async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) { @@ -359,6 +367,60 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui: }); } +#[gpui::test] +async fn test_edit_prediction_refresh_suppressed_while_following(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); + assign_editor_completion_provider(provider.clone(), &mut cx); + cx.set_state("let x = ˇ;"); + + propose_edits(&provider, vec![(8..8, "42")], &mut cx); + + cx.update_editor(|editor, window, cx| { + editor.refresh_edit_prediction(false, false, window, cx); + editor.update_visible_edit_prediction(window, cx); + }); + + assert_eq!( + provider.read_with(&cx.cx, |provider, _| { + provider.refresh_count.load(atomic::Ordering::SeqCst) + }), + 1 + ); + cx.editor(|editor, _, _| { + assert!(editor.active_edit_prediction.is_some()); + }); + + cx.update_editor(|editor, window, cx| { + editor.leader_id = Some(CollaboratorId::PeerId(PeerId::default())); + editor.refresh_edit_prediction(false, false, window, cx); + }); + + assert_eq!( + provider.read_with(&cx.cx, |provider, _| { + provider.refresh_count.load(atomic::Ordering::SeqCst) + }), + 1 + ); + cx.editor(|editor, _, _| { + assert!(editor.active_edit_prediction.is_none()); + }); + + cx.update_editor(|editor, window, cx| { + editor.leader_id = None; + editor.refresh_edit_prediction(false, false, window, cx); + }); + + assert_eq!( + provider.read_with(&cx.cx, |provider, _| { + provider.refresh_count.load(atomic::Ordering::SeqCst) + }), + 2 + ); +} + #[gpui::test] async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); @@ -567,6 +629,7 @@ fn assign_editor_completion_provider_non_zed( #[derive(Default, Clone)] pub struct FakeEditPredictionDelegate { pub completion: Option, + pub refresh_count: Arc, } impl FakeEditPredictionDelegate { @@ -619,6 +682,7 @@ impl EditPredictionDelegate for FakeEditPredictionDelegate { _debounce: bool, _cx: &mut gpui::Context, ) { + self.refresh_count.fetch_add(1, atomic::Ordering::SeqCst); } fn accept(&mut self, _cx: &mut gpui::Context) {} diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 8c2e03722c345a0f093572c336029a0eaa355537..fd830c254877463da84e98d21dd39b0e644ca433 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -7804,7 +7804,11 @@ impl Editor { window: &mut Window, cx: &mut Context, ) -> Option<()> { - let provider = self.edit_prediction_provider()?; + if self.leader_id.is_some() { + self.discard_edit_prediction(EditPredictionDiscardReason::Ignored, cx); + return None; + } + let cursor = self.selections.newest_anchor().head(); let (buffer, cursor_buffer_position) = self.buffer.read(cx).text_anchor_for_position(cursor, cx)?; @@ -7829,7 +7833,8 @@ impl Editor { return None; } - provider.refresh(buffer, cursor_buffer_position, debounce, cx); + self.edit_prediction_provider()? + .refresh(buffer, cursor_buffer_position, debounce, cx); Some(()) } @@ -7954,7 +7959,7 @@ impl Editor { cx: &App, ) -> bool { maybe!({ - if self.read_only(cx) { + if self.read_only(cx) || self.leader_id.is_some() { return Some(false); } let provider = self.edit_prediction_provider()?; diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 8dd4d88e2a89cadc39e1335b4bcdc18a0a144571..d79ded2b9781252855ef424e49247fc1cabd383f 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -2253,21 +2253,21 @@ pub mod hashline { Case { name: "insert_before_first_and_after_line", original: indoc! {" - a - b - "}, + a + b + "}, model_output: indoc! {" - <|insert|> - HEAD - <|insert|>0:61 - MID - "}, + <|insert|> + HEAD + <|insert|>0:61 + MID + "}, expected: indoc! {" - HEAD - a - MID - b - "}, + HEAD + a + MID + b + "}, }, ];