ep: Ensure predictions are not refreshed while following (#51489)

Ben Kunkle created

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/edit_prediction.rs       |  25 ++
crates/edit_prediction/src/edit_prediction_tests.rs | 170 ++++++++++++++
crates/editor/src/edit_prediction_tests.rs          |  66 +++++
crates/editor/src/editor.rs                         |  11 
crates/zeta_prompt/src/zeta_prompt.rs               |  26 +-
5 files changed, 279 insertions(+), 19 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -41,7 +41,7 @@ use settings::{
 use std::collections::{VecDeque, hash_map};
 use std::env;
 use text::{AnchorRangeExt, Edit};
-use workspace::Workspace;
+use workspace::{AppState, Workspace};
 use zeta_prompt::{ZetaFormat, ZetaPromptInput};
 
 use std::mem;
@@ -1912,6 +1912,10 @@ impl EditPredictionStore {
             return;
         }
 
+        if currently_following(&project, cx) {
+            return;
+        }
+
         let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
             return;
         };
@@ -2048,6 +2052,25 @@ impl EditPredictionStore {
     pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 }
 
+fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
+    let Some(app_state) = AppState::try_global(cx).and_then(|app_state| app_state.upgrade()) else {
+        return false;
+    };
+
+    app_state
+        .workspace_store
+        .read(cx)
+        .workspaces()
+        .filter_map(|workspace| workspace.upgrade())
+        .any(|workspace| {
+            workspace.read(cx).project().entity_id() == project.entity_id()
+                && workspace
+                    .read(cx)
+                    .leader_for_pane(workspace.read(cx).active_pane())
+                    .is_some()
+        })
+}
+
 fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
     match provider {
         EditPredictionProvider::Zed

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -8,6 +8,7 @@ use cloud_llm_client::{
     EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
     predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
 };
+
 use futures::{
     AsyncReadExt, FutureExt, StreamExt,
     channel::{mpsc, oneshot},
@@ -35,11 +36,12 @@ use util::{
     test::{TextRangeMarker, marked_text_ranges_by},
 };
 use uuid::Uuid;
+use workspace::{AppState, CollaboratorId, MultiWorkspace};
 use zeta_prompt::ZetaPromptInput;
 
 use crate::{
     BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
-    EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
+    EditPredictionJumpsFeatureFlag, EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
 };
 
 #[gpui::test]
@@ -178,6 +180,172 @@ async fn test_current_state(cx: &mut TestAppContext) {
     });
 }
 
+#[gpui::test]
+async fn test_diagnostics_refresh_suppressed_while_following(cx: &mut TestAppContext) {
+    let (ep_store, mut requests) = init_test_with_fake_client(cx);
+
+    cx.update(|cx| {
+        cx.update_flags(
+            false,
+            vec![EditPredictionJumpsFeatureFlag::NAME.to_string()],
+        );
+    });
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        "/root",
+        json!({
+            "1.txt": "Hello!\nHow\nBye\n",
+            "2.txt": "Hola!\nComo\nAdios\n"
+        }),
+    )
+    .await;
+    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+    let app_state = cx.update(|cx| {
+        let app_state = AppState::test(cx);
+        AppState::set_global(Arc::downgrade(&app_state), cx);
+        app_state
+    });
+
+    let multi_workspace =
+        cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+    let workspace = multi_workspace
+        .read_with(cx, |multi_workspace, _| multi_workspace.workspace().clone())
+        .unwrap();
+    cx.update(|cx| {
+        AppState::set_global(Arc::downgrade(workspace.read(cx).app_state()), cx);
+    });
+    let _ = app_state;
+
+    let buffer1 = project
+        .update(cx, |project, cx| {
+            let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
+            project.set_active_path(Some(path.clone()), cx);
+            project.open_buffer(path, cx)
+        })
+        .await
+        .unwrap();
+    let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
+    let position = snapshot1.anchor_before(language::Point::new(1, 3));
+
+    ep_store.update(cx, |ep_store, cx| {
+        ep_store.register_project(&project, cx);
+        ep_store.register_buffer(&buffer1, &project, cx);
+        ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx);
+    });
+
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
+    respond_tx
+        .send(model_response(
+            &request,
+            indoc! {r"
+                --- a/root/1.txt
+                +++ b/root/1.txt
+                @@ ... @@
+                 Hello!
+                -How
+                +How are you?
+                 Bye
+            "},
+        ))
+        .unwrap();
+    cx.run_until_parked();
+
+    ep_store.update(cx, |ep_store, cx| {
+        ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
+    });
+
+    let _ = multi_workspace.update(cx, |multi_workspace, window, cx| {
+        multi_workspace.workspace().update(cx, |workspace, cx| {
+            workspace.start_following(CollaboratorId::Agent, window, cx);
+        });
+    });
+    cx.run_until_parked();
+
+    let diagnostic = lsp::Diagnostic {
+        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
+        severity: Some(lsp::DiagnosticSeverity::ERROR),
+        message: "Sentence is incomplete".to_string(),
+        ..Default::default()
+    };
+
+    project.update(cx, |project, cx| {
+        project.lsp_store().update(cx, |lsp_store, cx| {
+            lsp_store
+                .update_diagnostics(
+                    LanguageServerId(0),
+                    lsp::PublishDiagnosticsParams {
+                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
+                        diagnostics: vec![diagnostic.clone()],
+                        version: None,
+                    },
+                    None,
+                    language::DiagnosticSourceKind::Pushed,
+                    &[],
+                    cx,
+                )
+                .unwrap();
+        });
+    });
+
+    cx.run_until_parked();
+    assert_no_predict_request_ready(&mut requests.predict);
+
+    let _ = multi_workspace.update(cx, |multi_workspace, window, cx| {
+        multi_workspace.workspace().update(cx, |workspace, cx| {
+            workspace.unfollow(CollaboratorId::Agent, window, cx);
+        });
+    });
+    cx.run_until_parked();
+
+    project.update(cx, |project, cx| {
+        project.lsp_store().update(cx, |lsp_store, cx| {
+            lsp_store
+                .update_diagnostics(
+                    LanguageServerId(0),
+                    lsp::PublishDiagnosticsParams {
+                        uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
+                        diagnostics: vec![diagnostic],
+                        version: None,
+                    },
+                    None,
+                    language::DiagnosticSourceKind::Pushed,
+                    &[],
+                    cx,
+                )
+                .unwrap();
+        });
+    });
+
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
+    respond_tx
+        .send(model_response(
+            &request,
+            indoc! {r#"
+                --- a/root/2.txt
+                +++ b/root/2.txt
+                @@ ... @@
+                 Hola!
+                -Como
+                +Como estas?
+                 Adios
+            "#},
+        ))
+        .unwrap();
+    cx.run_until_parked();
+
+    ep_store.update(cx, |ep_store, cx| {
+        let prediction = ep_store
+            .prediction_at(&buffer1, None, &project, cx)
+            .unwrap();
+        assert_matches!(
+            prediction,
+            BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
+        );
+    });
+}
+
 #[gpui::test]
 async fn test_simple_request(cx: &mut TestAppContext) {
     let (ep_store, mut requests) = init_test_with_fake_client(cx);

crates/editor/src/edit_prediction_tests.rs 🔗

@@ -4,7 +4,13 @@ use edit_prediction_types::{
 use gpui::{Entity, KeyBinding, Modifiers, prelude::*};
 use indoc::indoc;
 use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
-use std::{ops::Range, sync::Arc};
+use std::{
+    ops::Range,
+    sync::{
+        Arc,
+        atomic::{self, AtomicUsize},
+    },
+};
 use text::{Point, ToOffset};
 use ui::prelude::*;
 
@@ -12,6 +18,8 @@ use crate::{
     AcceptEditPrediction, EditPrediction, MenuEditPredictionsPolicy, editor_tests::init_test,
     test::editor_test_context::EditorTestContext,
 };
+use rpc::proto::PeerId;
+use workspace::CollaboratorId;
 
 #[gpui::test]
 async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) {
@@ -359,6 +367,60 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui:
     });
 }
 
+#[gpui::test]
+async fn test_edit_prediction_refresh_suppressed_while_following(cx: &mut gpui::TestAppContext) {
+    init_test(cx, |_| {});
+
+    let mut cx = EditorTestContext::new(cx).await;
+    let provider = cx.new(|_| FakeEditPredictionDelegate::default());
+    assign_editor_completion_provider(provider.clone(), &mut cx);
+    cx.set_state("let x = ˇ;");
+
+    propose_edits(&provider, vec![(8..8, "42")], &mut cx);
+
+    cx.update_editor(|editor, window, cx| {
+        editor.refresh_edit_prediction(false, false, window, cx);
+        editor.update_visible_edit_prediction(window, cx);
+    });
+
+    assert_eq!(
+        provider.read_with(&cx.cx, |provider, _| {
+            provider.refresh_count.load(atomic::Ordering::SeqCst)
+        }),
+        1
+    );
+    cx.editor(|editor, _, _| {
+        assert!(editor.active_edit_prediction.is_some());
+    });
+
+    cx.update_editor(|editor, window, cx| {
+        editor.leader_id = Some(CollaboratorId::PeerId(PeerId::default()));
+        editor.refresh_edit_prediction(false, false, window, cx);
+    });
+
+    assert_eq!(
+        provider.read_with(&cx.cx, |provider, _| {
+            provider.refresh_count.load(atomic::Ordering::SeqCst)
+        }),
+        1
+    );
+    cx.editor(|editor, _, _| {
+        assert!(editor.active_edit_prediction.is_none());
+    });
+
+    cx.update_editor(|editor, window, cx| {
+        editor.leader_id = None;
+        editor.refresh_edit_prediction(false, false, window, cx);
+    });
+
+    assert_eq!(
+        provider.read_with(&cx.cx, |provider, _| {
+            provider.refresh_count.load(atomic::Ordering::SeqCst)
+        }),
+        2
+    );
+}
+
 #[gpui::test]
 async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestAppContext) {
     init_test(cx, |_| {});
@@ -567,6 +629,7 @@ fn assign_editor_completion_provider_non_zed(
 #[derive(Default, Clone)]
 pub struct FakeEditPredictionDelegate {
     pub completion: Option<edit_prediction_types::EditPrediction>,
+    pub refresh_count: Arc<AtomicUsize>,
 }
 
 impl FakeEditPredictionDelegate {
@@ -619,6 +682,7 @@ impl EditPredictionDelegate for FakeEditPredictionDelegate {
         _debounce: bool,
         _cx: &mut gpui::Context<Self>,
     ) {
+        self.refresh_count.fetch_add(1, atomic::Ordering::SeqCst);
     }
 
     fn accept(&mut self, _cx: &mut gpui::Context<Self>) {}

crates/editor/src/editor.rs 🔗

@@ -7804,7 +7804,11 @@ impl Editor {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) -> Option<()> {
-        let provider = self.edit_prediction_provider()?;
+        if self.leader_id.is_some() {
+            self.discard_edit_prediction(EditPredictionDiscardReason::Ignored, cx);
+            return None;
+        }
+
         let cursor = self.selections.newest_anchor().head();
         let (buffer, cursor_buffer_position) =
             self.buffer.read(cx).text_anchor_for_position(cursor, cx)?;
@@ -7829,7 +7833,8 @@ impl Editor {
             return None;
         }
 
-        provider.refresh(buffer, cursor_buffer_position, debounce, cx);
+        self.edit_prediction_provider()?
+            .refresh(buffer, cursor_buffer_position, debounce, cx);
         Some(())
     }
 
@@ -7954,7 +7959,7 @@ impl Editor {
         cx: &App,
     ) -> bool {
         maybe!({
-            if self.read_only(cx) {
+            if self.read_only(cx) || self.leader_id.is_some() {
                 return Some(false);
             }
             let provider = self.edit_prediction_provider()?;

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -2253,21 +2253,21 @@ pub mod hashline {
                 Case {
                     name: "insert_before_first_and_after_line",
                     original: indoc! {"
-                    a
-                    b
-                "},
+                        a
+                        b
+                    "},
                     model_output: indoc! {"
-                    <|insert|>
-                    HEAD
-                    <|insert|>0:61
-                    MID
-                "},
+                        <|insert|>
+                        HEAD
+                        <|insert|>0:61
+                        MID
+                    "},
                     expected: indoc! {"
-                    HEAD
-                    a
-                    MID
-                    b
-                "},
+                        HEAD
+                        a
+                        MID
+                        b
+                    "},
                 },
             ];