@@ -310,7 +310,8 @@ struct ProjectState {
next_pending_prediction_id: usize,
pending_predictions: ArrayVec<PendingPrediction, 2>,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
- last_prediction_refresh: Option<(EntityId, Instant)>,
+ last_edit_prediction_refresh: Option<(EntityId, Instant)>,
+ last_jump_prediction_refresh: Option<(EntityId, Instant)>,
cancelled_predictions: HashSet<usize>,
context: Entity<RelatedExcerptStore>,
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
@@ -444,6 +445,14 @@ impl PredictionRequestedBy {
}
}
+const DIAGNOSTIC_LINES_RANGE: u32 = 20;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub enum DiagnosticSearchScope {
+ Local,
+ Global,
+}
+
#[derive(Debug)]
struct PendingPrediction {
id: usize,
@@ -899,7 +908,8 @@ impl EditPredictionStore {
cancelled_predictions: HashSet::default(),
pending_predictions: ArrayVec::new(),
next_pending_prediction_id: 0,
- last_prediction_refresh: None,
+ last_edit_prediction_refresh: None,
+ last_jump_prediction_refresh: None,
license_detection_watchers: HashMap::default(),
user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
_subscriptions: [
@@ -1014,7 +1024,11 @@ impl EditPredictionStore {
}
project::Event::DiagnosticsUpdated { .. } => {
if cx.has_flag::<Zeta2FeatureFlag>() {
- self.refresh_prediction_from_diagnostics(project, cx);
+ self.refresh_prediction_from_diagnostics(
+ project,
+ DiagnosticSearchScope::Global,
+ cx,
+ );
}
}
_ => (),
@@ -1458,38 +1472,45 @@ impl EditPredictionStore {
position: language::Anchor,
cx: &mut Context<Self>,
) {
- self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
- let Some(request_task) = this
- .update(cx, |this, cx| {
- this.request_prediction(
- &project,
- &buffer,
- position,
- PredictEditsRequestTrigger::Other,
- cx,
- )
- })
- .log_err()
- else {
- return Task::ready(anyhow::Ok(None));
- };
-
- cx.spawn(async move |_cx| {
- request_task.await.map(|prediction_result| {
- prediction_result.map(|prediction_result| {
- (
- prediction_result,
- PredictionRequestedBy::Buffer(buffer.entity_id()),
+ self.queue_prediction_refresh(
+ project.clone(),
+ PredictEditsRequestTrigger::Other,
+ buffer.entity_id(),
+ cx,
+ move |this, cx| {
+ let Some(request_task) = this
+ .update(cx, |this, cx| {
+ this.request_prediction(
+ &project,
+ &buffer,
+ position,
+ PredictEditsRequestTrigger::Other,
+ cx,
)
})
+ .log_err()
+ else {
+ return Task::ready(anyhow::Ok(None));
+ };
+
+ cx.spawn(async move |_cx| {
+ request_task.await.map(|prediction_result| {
+ prediction_result.map(|prediction_result| {
+ (
+ prediction_result,
+ PredictionRequestedBy::Buffer(buffer.entity_id()),
+ )
+ })
+ })
})
- })
- })
+ },
+ )
}
pub fn refresh_prediction_from_diagnostics(
&mut self,
project: Entity<Project>,
+ scope: DiagnosticSearchScope,
cx: &mut Context<Self>,
) {
let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
@@ -1499,79 +1520,96 @@ impl EditPredictionStore {
// Prefer predictions from buffer
if project_state.current_prediction.is_some() {
return;
- };
-
- self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
- let Some((active_buffer, snapshot, cursor_point)) = this
- .read_with(cx, |this, cx| {
- let project_state = this.projects.get(&project.entity_id())?;
- let (buffer, position) = project_state.active_buffer(&project, cx)?;
- let snapshot = buffer.read(cx).snapshot();
-
- if !Self::predictions_enabled_at(&snapshot, position, cx) {
- return None;
- }
+ }
- let cursor_point = position
- .map(|pos| pos.to_point(&snapshot))
- .unwrap_or_default();
+ self.queue_prediction_refresh(
+ project.clone(),
+ PredictEditsRequestTrigger::Diagnostics,
+ project.entity_id(),
+ cx,
+ move |this, cx| {
+ let Some((active_buffer, snapshot, cursor_point)) = this
+ .read_with(cx, |this, cx| {
+ let project_state = this.projects.get(&project.entity_id())?;
+ let (buffer, position) = project_state.active_buffer(&project, cx)?;
+ let snapshot = buffer.read(cx).snapshot();
+
+ if !Self::predictions_enabled_at(&snapshot, position, cx) {
+ return None;
+ }
- Some((buffer, snapshot, cursor_point))
- })
- .log_err()
- .flatten()
- else {
- return Task::ready(anyhow::Ok(None));
- };
+ let cursor_point = position
+ .map(|pos| pos.to_point(&snapshot))
+ .unwrap_or_default();
- cx.spawn(async move |cx| {
- let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
- active_buffer,
- &snapshot,
- Default::default(),
- cursor_point,
- &project,
- cx,
- )
- .await?
+ Some((buffer, snapshot, cursor_point))
+ })
+ .log_err()
+ .flatten()
else {
- return anyhow::Ok(None);
+ return Task::ready(anyhow::Ok(None));
};
- let Some(prediction_result) = this
- .update(cx, |this, cx| {
- this.request_prediction(
- &project,
- &jump_buffer,
- jump_position,
- PredictEditsRequestTrigger::Diagnostics,
- cx,
- )
- })?
+ cx.spawn(async move |cx| {
+ let diagnostic_search_range = match scope {
+ DiagnosticSearchScope::Local => {
+ let diagnostic_search_start =
+ cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
+ let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
+ Point::new(diagnostic_search_start, 0)
+ ..Point::new(diagnostic_search_end, 0)
+ }
+ DiagnosticSearchScope::Global => Default::default(),
+ };
+
+ let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
+ active_buffer,
+ &snapshot,
+ diagnostic_search_range,
+ cursor_point,
+ &project,
+ cx,
+ )
.await?
- else {
- return anyhow::Ok(None);
- };
+ else {
+ return anyhow::Ok(None);
+ };
- this.update(cx, |this, cx| {
- Some((
- if this
- .get_or_init_project(&project, cx)
- .current_prediction
- .is_none()
- {
- prediction_result
- } else {
- EditPredictionResult {
- id: prediction_result.id,
- prediction: Err(EditPredictionRejectReason::CurrentPreferred),
- }
- },
- PredictionRequestedBy::DiagnosticsUpdate,
- ))
+ let Some(prediction_result) = this
+ .update(cx, |this, cx| {
+ this.request_prediction(
+ &project,
+ &jump_buffer,
+ jump_position,
+ PredictEditsRequestTrigger::Diagnostics,
+ cx,
+ )
+ })?
+ .await?
+ else {
+ return anyhow::Ok(None);
+ };
+
+ this.update(cx, |this, cx| {
+ Some((
+ if this
+ .get_or_init_project(&project, cx)
+ .current_prediction
+ .is_none()
+ {
+ prediction_result
+ } else {
+ EditPredictionResult {
+ id: prediction_result.id,
+ prediction: Err(EditPredictionRejectReason::CurrentPreferred),
+ }
+ },
+ PredictionRequestedBy::DiagnosticsUpdate,
+ ))
+ })
})
- })
- });
+ },
+ );
}
fn predictions_enabled_at(
@@ -1605,14 +1643,12 @@ impl EditPredictionStore {
true
}
- #[cfg(not(test))]
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
- #[cfg(test)]
- pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
fn queue_prediction_refresh(
&mut self,
project: Entity<Project>,
+ request_trigger: PredictEditsRequestTrigger,
throttle_entity: EntityId,
cx: &mut Context<Self>,
do_refresh: impl FnOnce(
@@ -1622,20 +1658,34 @@ impl EditPredictionStore {
-> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
+ 'static,
) {
+ fn select_throttle(
+ project_state: &mut ProjectState,
+ request_trigger: PredictEditsRequestTrigger,
+ ) -> &mut Option<(EntityId, Instant)> {
+ match request_trigger {
+ PredictEditsRequestTrigger::Diagnostics => {
+ &mut project_state.last_jump_prediction_refresh
+ }
+ _ => &mut project_state.last_edit_prediction_refresh,
+ }
+ }
+
let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
let drop_on_cancel = is_ollama;
let max_pending_predictions = if is_ollama { 1 } else { 2 };
+ let throttle_timeout = Self::THROTTLE_TIMEOUT;
let project_state = self.get_or_init_project(&project, cx);
let pending_prediction_id = project_state.next_pending_prediction_id;
project_state.next_pending_prediction_id += 1;
- let last_request = project_state.last_prediction_refresh;
+ let last_request = *select_throttle(project_state, request_trigger);
let task = cx.spawn(async move |this, cx| {
- if let Some((last_entity, last_timestamp)) = last_request
- && throttle_entity == last_entity
- && let Some(timeout) =
- (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
- {
+ if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
+ if throttle_entity != last_entity {
+ return None;
+ }
+ (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
+ }) {
cx.background_executor().timer(timeout).await;
}
@@ -1644,11 +1694,12 @@ impl EditPredictionStore {
let mut is_cancelled = true;
this.update(cx, |this, cx| {
let project_state = this.get_or_init_project(&project, cx);
- if !project_state
+ let was_cancelled = project_state
.cancelled_predictions
- .remove(&pending_prediction_id)
- {
- project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
+ .remove(&pending_prediction_id);
+ if !was_cancelled {
+ let new_refresh = (throttle_entity, Instant::now());
+ *select_throttle(project_state, request_trigger) = Some(new_refresh);
is_cancelled = false;
}
})
@@ -1785,8 +1836,6 @@ impl EditPredictionStore {
allow_jump: bool,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPredictionResult>>> {
- const DIAGNOSTIC_LINES_RANGE: u32 = 20;
-
self.get_or_init_project(&project, cx);
let project_state = self.projects.get(&project.entity_id()).unwrap();
let stored_events = project_state.events(cx);
@@ -1828,14 +1877,14 @@ impl EditPredictionStore {
let inputs = EditPredictionModelInput {
project: project.clone(),
- buffer: active_buffer.clone(),
- snapshot: snapshot.clone(),
+ buffer: active_buffer,
+ snapshot: snapshot,
position,
events,
related_files,
recent_paths: project_state.recent_paths.clone(),
trigger,
- diagnostic_search_range: diagnostic_search_range.clone(),
+ diagnostic_search_range: diagnostic_search_range,
debug_tx,
user_actions,
};
@@ -1861,33 +1910,14 @@ impl EditPredictionStore {
cx.spawn(async move |this, cx| {
let prediction = task.await?;
- if prediction.is_none() && allow_jump {
- let cursor_point = position.to_point(&snapshot);
- if has_events
- && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
- active_buffer.clone(),
- &snapshot,
- diagnostic_search_range,
- cursor_point,
- &project,
+ if prediction.is_none() && allow_jump && has_events {
+ this.update(cx, |this, cx| {
+ this.refresh_prediction_from_diagnostics(
+ project,
+ DiagnosticSearchScope::Local,
cx,
- )
- .await?
- {
- return this
- .update(cx, |this, cx| {
- this.request_prediction_internal(
- project,
- jump_buffer,
- jump_position,
- trigger,
- false,
- cx,
- )
- })?
- .await;
- }
-
+ );
+ })?;
return anyhow::Ok(None);
}
@@ -8,7 +8,7 @@ use cloud_llm_client::{
predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
};
use futures::{
- AsyncReadExt, StreamExt,
+ AsyncReadExt, FutureExt, StreamExt,
channel::{mpsc, oneshot},
};
use gpui::App;
@@ -1375,6 +1375,107 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
}
+#[gpui::test]
+async fn test_jump_and_edit_throttles_are_independent(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n",
+ "bar.md": "Hola!\nComo\nAdios\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.set_active_path(Some(path.clone()), cx);
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_project(&project, cx);
+ ep_store.register_buffer(&buffer, &project, cx);
+ });
+
+ // First edit request - no prior edit, so not throttled.
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+ let (_edit_request, edit_response_tx) = requests.predict.next().await.unwrap();
+ edit_response_tx.send(empty_response()).unwrap();
+ cx.run_until_parked();
+
+ let diagnostic = lsp::Diagnostic {
+ range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
+ severity: Some(lsp::DiagnosticSeverity::ERROR),
+ message: "Sentence is incomplete".to_string(),
+ ..Default::default()
+ };
+
+ // First jump request triggered by diagnostic event on buffer - no prior jump, so not throttled (independent from edit).
+ project.update(cx, |project, cx| {
+ project.lsp_store().update(cx, |lsp_store, cx| {
+ lsp_store
+ .update_diagnostics(
+ LanguageServerId(0),
+ lsp::PublishDiagnosticsParams {
+ uri: lsp::Uri::from_file_path(path!("/root/bar.md")).unwrap(),
+ diagnostics: vec![diagnostic],
+ version: None,
+ },
+ None,
+ language::DiagnosticSourceKind::Pushed,
+ &[],
+ cx,
+ )
+ .unwrap();
+ });
+ });
+ let (_jump_request, jump_response_tx) = requests.predict.next().await.unwrap();
+ jump_response_tx.send(empty_response()).unwrap();
+ cx.run_until_parked();
+
+ // Second edit request - should be throttled by the first edit.
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+ assert_no_predict_request_ready(&mut requests.predict);
+
+ // Second jump request - should be throttled by the first jump.
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_diagnostics(
+ project.clone(),
+ DiagnosticSearchScope::Global,
+ cx,
+ );
+ });
+ assert_no_predict_request_ready(&mut requests.predict);
+
+ // Wait for both throttles to expire.
+ cx.background_executor
+ .advance_clock(EditPredictionStore::THROTTLE_TIMEOUT);
+ cx.background_executor.run_until_parked();
+ cx.run_until_parked();
+
+ // Both requests should now go through.
+ let (_request_1, response_tx_1) = requests.predict.next().await.unwrap();
+ response_tx_1.send(empty_response()).unwrap();
+ cx.run_until_parked();
+
+ let (_request_2, response_tx_2) = requests.predict.next().await.unwrap();
+ response_tx_2.send(empty_response()).unwrap();
+ cx.run_until_parked();
+}
+
#[gpui::test]
async fn test_rejections_flushing(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
@@ -1596,10 +1697,28 @@ fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> Predi
}
}
+fn empty_response() -> PredictEditsV3Response {
+ PredictEditsV3Response {
+ request_id: Uuid::new_v4().to_string(),
+ output: String::new(),
+ }
+}
+
fn prompt_from_request(request: &PredictEditsV3Request) -> String {
zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
}
+fn assert_no_predict_request_ready(
+ requests: &mut mpsc::UnboundedReceiver<(
+ PredictEditsV3Request,
+ oneshot::Sender<PredictEditsV3Response>,
+ )>,
+) {
+ if requests.next().now_or_never().flatten().is_some() {
+ panic!("Unexpected prediction request while throttled.");
+ }
+}
+
struct RequestChannels {
predict: mpsc::UnboundedReceiver<(
PredictEditsV3Request,