Seperate throttles for jump and edit based predictions (#49499)

Ben Kunkle and Zed Zippy created

Closes #ISSUE

Before you mark this PR as ready for review, make sure that you have:
- [ ] Added a solid test coverage and/or screenshots from doing manual
testing
- [ ] Done a self-review taking into account security and performance
aspects
- [ ] Aligned any UI changes with the [UI
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)

Release Notes:

- N/A *or* Added/Fixed/Improved ...

---------

Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com>

Change summary

crates/edit_prediction/src/edit_prediction.rs       | 300 ++++++++------
crates/edit_prediction/src/edit_prediction_tests.rs | 121 ++++++
2 files changed, 285 insertions(+), 136 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -310,7 +310,8 @@ struct ProjectState {
     next_pending_prediction_id: usize,
     pending_predictions: ArrayVec<PendingPrediction, 2>,
     debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
-    last_prediction_refresh: Option<(EntityId, Instant)>,
+    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
+    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
     cancelled_predictions: HashSet<usize>,
     context: Entity<RelatedExcerptStore>,
     license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
@@ -444,6 +445,14 @@ impl PredictionRequestedBy {
     }
 }
 
+const DIAGNOSTIC_LINES_RANGE: u32 = 20;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub enum DiagnosticSearchScope {
+    Local,
+    Global,
+}
+
 #[derive(Debug)]
 struct PendingPrediction {
     id: usize,
@@ -899,7 +908,8 @@ impl EditPredictionStore {
                 cancelled_predictions: HashSet::default(),
                 pending_predictions: ArrayVec::new(),
                 next_pending_prediction_id: 0,
-                last_prediction_refresh: None,
+                last_edit_prediction_refresh: None,
+                last_jump_prediction_refresh: None,
                 license_detection_watchers: HashMap::default(),
                 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
                 _subscriptions: [
@@ -1014,7 +1024,11 @@ impl EditPredictionStore {
             }
             project::Event::DiagnosticsUpdated { .. } => {
                 if cx.has_flag::<Zeta2FeatureFlag>() {
-                    self.refresh_prediction_from_diagnostics(project, cx);
+                    self.refresh_prediction_from_diagnostics(
+                        project,
+                        DiagnosticSearchScope::Global,
+                        cx,
+                    );
                 }
             }
             _ => (),
@@ -1458,38 +1472,45 @@ impl EditPredictionStore {
         position: language::Anchor,
         cx: &mut Context<Self>,
     ) {
-        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
-            let Some(request_task) = this
-                .update(cx, |this, cx| {
-                    this.request_prediction(
-                        &project,
-                        &buffer,
-                        position,
-                        PredictEditsRequestTrigger::Other,
-                        cx,
-                    )
-                })
-                .log_err()
-            else {
-                return Task::ready(anyhow::Ok(None));
-            };
-
-            cx.spawn(async move |_cx| {
-                request_task.await.map(|prediction_result| {
-                    prediction_result.map(|prediction_result| {
-                        (
-                            prediction_result,
-                            PredictionRequestedBy::Buffer(buffer.entity_id()),
+        self.queue_prediction_refresh(
+            project.clone(),
+            PredictEditsRequestTrigger::Other,
+            buffer.entity_id(),
+            cx,
+            move |this, cx| {
+                let Some(request_task) = this
+                    .update(cx, |this, cx| {
+                        this.request_prediction(
+                            &project,
+                            &buffer,
+                            position,
+                            PredictEditsRequestTrigger::Other,
+                            cx,
                         )
                     })
+                    .log_err()
+                else {
+                    return Task::ready(anyhow::Ok(None));
+                };
+
+                cx.spawn(async move |_cx| {
+                    request_task.await.map(|prediction_result| {
+                        prediction_result.map(|prediction_result| {
+                            (
+                                prediction_result,
+                                PredictionRequestedBy::Buffer(buffer.entity_id()),
+                            )
+                        })
+                    })
                 })
-            })
-        })
+            },
+        )
     }
 
     pub fn refresh_prediction_from_diagnostics(
         &mut self,
         project: Entity<Project>,
+        scope: DiagnosticSearchScope,
         cx: &mut Context<Self>,
     ) {
         let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
@@ -1499,79 +1520,96 @@ impl EditPredictionStore {
         // Prefer predictions from buffer
         if project_state.current_prediction.is_some() {
             return;
-        };
-
-        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
-            let Some((active_buffer, snapshot, cursor_point)) = this
-                .read_with(cx, |this, cx| {
-                    let project_state = this.projects.get(&project.entity_id())?;
-                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
-                    let snapshot = buffer.read(cx).snapshot();
-
-                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
-                        return None;
-                    }
+        }
 
-                    let cursor_point = position
-                        .map(|pos| pos.to_point(&snapshot))
-                        .unwrap_or_default();
+        self.queue_prediction_refresh(
+            project.clone(),
+            PredictEditsRequestTrigger::Diagnostics,
+            project.entity_id(),
+            cx,
+            move |this, cx| {
+                let Some((active_buffer, snapshot, cursor_point)) = this
+                    .read_with(cx, |this, cx| {
+                        let project_state = this.projects.get(&project.entity_id())?;
+                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
+                        let snapshot = buffer.read(cx).snapshot();
+
+                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
+                            return None;
+                        }
 
-                    Some((buffer, snapshot, cursor_point))
-                })
-                .log_err()
-                .flatten()
-            else {
-                return Task::ready(anyhow::Ok(None));
-            };
+                        let cursor_point = position
+                            .map(|pos| pos.to_point(&snapshot))
+                            .unwrap_or_default();
 
-            cx.spawn(async move |cx| {
-                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
-                    active_buffer,
-                    &snapshot,
-                    Default::default(),
-                    cursor_point,
-                    &project,
-                    cx,
-                )
-                .await?
+                        Some((buffer, snapshot, cursor_point))
+                    })
+                    .log_err()
+                    .flatten()
                 else {
-                    return anyhow::Ok(None);
+                    return Task::ready(anyhow::Ok(None));
                 };
 
-                let Some(prediction_result) = this
-                    .update(cx, |this, cx| {
-                        this.request_prediction(
-                            &project,
-                            &jump_buffer,
-                            jump_position,
-                            PredictEditsRequestTrigger::Diagnostics,
-                            cx,
-                        )
-                    })?
+                cx.spawn(async move |cx| {
+                    let diagnostic_search_range = match scope {
+                        DiagnosticSearchScope::Local => {
+                            let diagnostic_search_start =
+                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
+                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
+                            Point::new(diagnostic_search_start, 0)
+                                ..Point::new(diagnostic_search_end, 0)
+                        }
+                        DiagnosticSearchScope::Global => Default::default(),
+                    };
+
+                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
+                        active_buffer,
+                        &snapshot,
+                        diagnostic_search_range,
+                        cursor_point,
+                        &project,
+                        cx,
+                    )
                     .await?
-                else {
-                    return anyhow::Ok(None);
-                };
+                    else {
+                        return anyhow::Ok(None);
+                    };
 
-                this.update(cx, |this, cx| {
-                    Some((
-                        if this
-                            .get_or_init_project(&project, cx)
-                            .current_prediction
-                            .is_none()
-                        {
-                            prediction_result
-                        } else {
-                            EditPredictionResult {
-                                id: prediction_result.id,
-                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
-                            }
-                        },
-                        PredictionRequestedBy::DiagnosticsUpdate,
-                    ))
+                    let Some(prediction_result) = this
+                        .update(cx, |this, cx| {
+                            this.request_prediction(
+                                &project,
+                                &jump_buffer,
+                                jump_position,
+                                PredictEditsRequestTrigger::Diagnostics,
+                                cx,
+                            )
+                        })?
+                        .await?
+                    else {
+                        return anyhow::Ok(None);
+                    };
+
+                    this.update(cx, |this, cx| {
+                        Some((
+                            if this
+                                .get_or_init_project(&project, cx)
+                                .current_prediction
+                                .is_none()
+                            {
+                                prediction_result
+                            } else {
+                                EditPredictionResult {
+                                    id: prediction_result.id,
+                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
+                                }
+                            },
+                            PredictionRequestedBy::DiagnosticsUpdate,
+                        ))
+                    })
                 })
-            })
-        });
+            },
+        );
     }
 
     fn predictions_enabled_at(
@@ -1605,14 +1643,12 @@ impl EditPredictionStore {
         true
     }
 
-    #[cfg(not(test))]
     pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
-    #[cfg(test)]
-    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
 
     fn queue_prediction_refresh(
         &mut self,
         project: Entity<Project>,
+        request_trigger: PredictEditsRequestTrigger,
         throttle_entity: EntityId,
         cx: &mut Context<Self>,
         do_refresh: impl FnOnce(
@@ -1622,20 +1658,34 @@ impl EditPredictionStore {
             -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
         + 'static,
     ) {
+        fn select_throttle(
+            project_state: &mut ProjectState,
+            request_trigger: PredictEditsRequestTrigger,
+        ) -> &mut Option<(EntityId, Instant)> {
+            match request_trigger {
+                PredictEditsRequestTrigger::Diagnostics => {
+                    &mut project_state.last_jump_prediction_refresh
+                }
+                _ => &mut project_state.last_edit_prediction_refresh,
+            }
+        }
+
         let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
         let drop_on_cancel = is_ollama;
         let max_pending_predictions = if is_ollama { 1 } else { 2 };
+        let throttle_timeout = Self::THROTTLE_TIMEOUT;
         let project_state = self.get_or_init_project(&project, cx);
         let pending_prediction_id = project_state.next_pending_prediction_id;
         project_state.next_pending_prediction_id += 1;
-        let last_request = project_state.last_prediction_refresh;
+        let last_request = *select_throttle(project_state, request_trigger);
 
         let task = cx.spawn(async move |this, cx| {
-            if let Some((last_entity, last_timestamp)) = last_request
-                && throttle_entity == last_entity
-                && let Some(timeout) =
-                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
-            {
+            if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
+                if throttle_entity != last_entity {
+                    return None;
+                }
+                (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
+            }) {
                 cx.background_executor().timer(timeout).await;
             }
 
@@ -1644,11 +1694,12 @@ impl EditPredictionStore {
             let mut is_cancelled = true;
             this.update(cx, |this, cx| {
                 let project_state = this.get_or_init_project(&project, cx);
-                if !project_state
+                let was_cancelled = project_state
                     .cancelled_predictions
-                    .remove(&pending_prediction_id)
-                {
-                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
+                    .remove(&pending_prediction_id);
+                if !was_cancelled {
+                    let new_refresh = (throttle_entity, Instant::now());
+                    *select_throttle(project_state, request_trigger) = Some(new_refresh);
                     is_cancelled = false;
                 }
             })
@@ -1785,8 +1836,6 @@ impl EditPredictionStore {
         allow_jump: bool,
         cx: &mut Context<Self>,
     ) -> Task<Result<Option<EditPredictionResult>>> {
-        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
-
         self.get_or_init_project(&project, cx);
         let project_state = self.projects.get(&project.entity_id()).unwrap();
         let stored_events = project_state.events(cx);
@@ -1828,14 +1877,14 @@ impl EditPredictionStore {
 
         let inputs = EditPredictionModelInput {
             project: project.clone(),
-            buffer: active_buffer.clone(),
-            snapshot: snapshot.clone(),
+            buffer: active_buffer,
+            snapshot: snapshot,
             position,
             events,
             related_files,
             recent_paths: project_state.recent_paths.clone(),
             trigger,
-            diagnostic_search_range: diagnostic_search_range.clone(),
+            diagnostic_search_range: diagnostic_search_range,
             debug_tx,
             user_actions,
         };
@@ -1861,33 +1910,14 @@ impl EditPredictionStore {
         cx.spawn(async move |this, cx| {
             let prediction = task.await?;
 
-            if prediction.is_none() && allow_jump {
-                let cursor_point = position.to_point(&snapshot);
-                if has_events
-                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
-                        active_buffer.clone(),
-                        &snapshot,
-                        diagnostic_search_range,
-                        cursor_point,
-                        &project,
+            if prediction.is_none() && allow_jump && has_events {
+                this.update(cx, |this, cx| {
+                    this.refresh_prediction_from_diagnostics(
+                        project,
+                        DiagnosticSearchScope::Local,
                         cx,
-                    )
-                    .await?
-                {
-                    return this
-                        .update(cx, |this, cx| {
-                            this.request_prediction_internal(
-                                project,
-                                jump_buffer,
-                                jump_position,
-                                trigger,
-                                false,
-                                cx,
-                            )
-                        })?
-                        .await;
-                }
-
+                    );
+                })?;
                 return anyhow::Ok(None);
             }
 

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -8,7 +8,7 @@ use cloud_llm_client::{
     predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
 };
 use futures::{
-    AsyncReadExt, StreamExt,
+    AsyncReadExt, FutureExt, StreamExt,
     channel::{mpsc, oneshot},
 };
 use gpui::App;
@@ -1375,6 +1375,107 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
     );
 }
 
+#[gpui::test]
+async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
+    let (ep_store, mut requests) = init_test_with_fake_client(cx);
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        "/root",
+        json!({
+            "foo.md":  "Hello!\nHow\nBye\n",
+            "bar.md": "Hola!\nComo\nAdios\n"
+        }),
+    )
+    .await;
+    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+    let buffer = project
+        .update(cx, |project, cx| {
+            let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+            project.set_active_path(Some(path.clone()), cx);
+            project.open_buffer(path, cx)
+        })
+        .await
+        .unwrap();
+    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+    let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+    ep_store.update(cx, |ep_store, cx| {
+        ep_store.register_project(&project, cx);
+        ep_store.register_buffer(&buffer, &project, cx);
+    });
+
+    // First edit request - no prior edit, so not throttled.
+    ep_store.update(cx, |ep_store, cx| {
+        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+    });
+    let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
+    edit_response_tx.send(empty_response()).unwrap();
+    cx.run_until_parked();
+
+    let diagnostic = lsp::Diagnostic {
+        range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
+        severity: Some(lsp::DiagnosticSeverity::ERROR),
+        message: "Sentence is incomplete".to_string(),
+        ..Default::default()
+    };
+
+    // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
+    project.update(cx, |project, cx| {
+        project.lsp_store().update(cx, |lsp_store, cx| {
+            lsp_store
+                .update_diagnostics(
+                    LanguageServerId(0),
+                    lsp::PublishDiagnosticsParams {
+                        uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
+                        diagnostics: vec![diagnostic],
+                        version: None,
+                    },
+                    None,
+                    language::DiagnosticSourceKind::Pushed,
+                    &[],
+                    cx,
+                )
+                .unwrap();
+        });
+    });
+    let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
+    jump_response_tx.send(empty_response()).unwrap();
+    cx.run_until_parked();
+
+    // Second edit request - should be throttled by the first edit.
+    ep_store.update(cx, |ep_store, cx| {
+        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+    });
+    assert_no_predict_request_ready(&mut requests.predict);
+
+    // Second jump request - should be throttled by the first jump.
+    ep_store.update(cx, |ep_store, cx| {
+        ep_store.refresh_prediction_from_diagnostics(
+            project.clone(),
+            DiagnosticSearchScope::Global,
+            cx,
+        );
+    });
+    assert_no_predict_request_ready(&mut requests.predict);
+
+    // Wait for both throttles to expire.
+    cx.background_executor
+        .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
+    cx.background_executor.run_until_parked();
+    cx.run_until_parked();
+
+    // Both requests should now go through.
+    let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
+    response_tx_1.send(empty_response()).unwrap();
+    cx.run_until_parked();
+
+    let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
+    response_tx_2.send(empty_response()).unwrap();
+    cx.run_until_parked();
+}
+
 #[gpui::test]
 async fn test_rejections_flushing(cx: &mut TestAppContext) {
     let (ep_store, mut requests) = init_test_with_fake_client(cx);
@@ -1596,10 +1697,28 @@ fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> Predi
     }
 }
 
+fn empty_response() -> PredictEditsV3Response {
+    PredictEditsV3Response {
+        request_id: Uuid::new_v4().to_string(),
+        output: String::new(),
+    }
+}
+
 fn prompt_from_request(request: &PredictEditsV3Request) -> String {
     zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
 }
 
+fn assert_no_predict_request_ready(
+    requests: &mut mpsc::UnboundedReceiver<(
+        PredictEditsV3Request,
+        oneshot::Sender<PredictEditsV3Response>,
+    )>,
+) {
+    if requests.next().now_or_never().flatten().is_some() {
+        panic!("Unexpected prediction request while throttled.");
+    }
+}
+
 struct RequestChannels {
     predict: mpsc::UnboundedReceiver<(
         PredictEditsV3Request,