diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 4e5ef2c9134618d1cc67780dc24611fa5d459a71..2c0407b56b2b5c7baffc1094dfaf1cd4baf6393f 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -76,6 +76,7 @@ pub mod zeta; #[cfg(test)] mod edit_prediction_tests; +use crate::example_spec::ExampleSpec; use crate::license_detection::LicenseDetectionWatcher; use crate::mercury::Mercury; use crate::onboarding_modal::ZedPredictModal; @@ -498,6 +499,7 @@ impl std::ops::Deref for BufferEditPrediction<'_> { struct PendingSettledPrediction { request_id: EditPredictionId, editable_anchor_range: Range, + example: Option, enqueued_at: Instant, last_edit_at: Instant, } @@ -1572,6 +1574,7 @@ impl EditPredictionStore { EDIT_PREDICTION_SETTLED_EVENT, request_id = pending_prediction.request_id.0.clone(), settled_editable_region, + example = pending_prediction.example.take(), ); return false; @@ -1600,22 +1603,25 @@ impl EditPredictionStore { edited_buffer: &Entity, edited_buffer_snapshot: &BufferSnapshot, editable_offset_range: Range, + example: Option, cx: &mut Context, ) { - let project_state = self.get_or_init_project(project, cx); + let this = &mut *self; + let project_state = this.get_or_init_project(project, cx); if let Some(buffer) = project_state .registered_buffers .get_mut(&edited_buffer.entity_id()) { let now = cx.background_executor().now(); buffer.pending_predictions.push(PendingSettledPrediction { - request_id, + request_id: request_id, editable_anchor_range: edited_buffer_snapshot .anchor_range_around(editable_offset_range), + example, enqueued_at: now, last_edit_at: now, }); - self.settled_predictions_tx.unbounded_send(now).ok(); + this.settled_predictions_tx.unbounded_send(now).ok(); } } @@ -2226,14 +2232,16 @@ impl EditPredictionStore { && self.is_data_collection_enabled(cx) && matches!(self.edit_prediction_model, EditPredictionModel::Zeta); + let recent_paths = project_state.recent_paths.clone(); + let inputs = EditPredictionModelInput { project: project.clone(), - buffer: active_buffer.clone(), - snapshot: snapshot, + buffer: active_buffer, + snapshot, position, events, related_files, - recent_paths: project_state.recent_paths.clone(), + recent_paths, trigger, diagnostic_search_range: diagnostic_search_range, debug_tx, @@ -2242,21 +2250,12 @@ impl EditPredictionStore { is_open_source, }; - if can_collect_data && rand::random_ratio(1, 1000) { - if let Some(task) = capture_example( - project.clone(), - active_buffer, - position, - stored_events, - false, - cx, - ) { - task.detach(); - } - } + let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events); let task = match self.edit_prediction_model { - EditPredictionModel::Zeta => zeta::request_prediction_with_zeta(self, inputs, cx), + EditPredictionModel::Zeta => { + zeta::request_prediction_with_zeta(self, inputs, capture_data, cx) + } EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx), EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx), EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx), diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 594bfd482052950a5b3835798f83d5905573711c..bbad3c104e6f84f30c7906ba310df132ee66191e 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -2650,8 +2650,8 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) { .await .unwrap(); - let settled_events: Arc>> = - Arc::new(Mutex::new(Vec::new())); + type SettledEventRecord = (EditPredictionId, String); + let settled_events: Arc>> = Arc::new(Mutex::new(Vec::new())); ep_store.update(cx, |ep_store, cx| { ep_store.register_buffer(&buffer, &project, cx); @@ -2674,13 +2674,15 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) { // Region A: first 10 lines of the buffer. let editable_region_a = 0..snapshot_a.point_to_offset(Point::new(10, 0)); + ep_store.update(cx, |ep_store, cx| { ep_store.enqueue_settled_prediction( EditPredictionId("prediction-a".into()), &project, &buffer, &snapshot_a, - editable_region_a, + editable_region_a.clone(), + None, cx, ); }); @@ -2735,13 +2737,15 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) { let snapshot_b2 = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let editable_region_b = line_20_offset..snapshot_b2.point_to_offset(Point::new(25, 0)); + ep_store.update(cx, |ep_store, cx| { ep_store.enqueue_settled_prediction( EditPredictionId("prediction-b".into()), &project, &buffer, &snapshot_b2, - editable_region_b, + editable_region_b.clone(), + None, cx, ); }); @@ -2767,7 +2771,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) { assert_eq!( events.len(), 1, - "only prediction A should have settled, got: {events:?}" + "prediction and capture_sample for A should have settled, got: {events:?}" ); assert_eq!(events[0].0, EditPredictionId("prediction-a".into())); } @@ -2784,7 +2788,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) { assert_eq!( events.len(), 2, - "both predictions should have settled, got: {events:?}" + "both prediction and capture_sample settled events should be emitted for each request, got: {events:?}" ); assert_eq!(events[1].0, EditPredictionId("prediction-b".into())); } diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index 3397d31276efcc7e1d68336f87ccf3e035f51f3a..3d111bfd9394a90e87a70e24ae96eb69a58afe91 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -2,7 +2,7 @@ use crate::cursor_excerpt::compute_excerpt_ranges; use crate::prediction::EditPredictionResult; use crate::{ CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, - EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, + EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent, }; use anyhow::Result; use cloud_llm_client::predict_edits_v3::RawCompletionRequest; @@ -41,6 +41,7 @@ pub fn request_prediction_with_zeta( is_open_source, .. }: EditPredictionModelInput, + capture_data: Option>, cx: &mut Context, ) -> Task>> { let settings = &all_language_settings(None, cx).edit_predictions; @@ -364,17 +365,44 @@ pub fn request_prediction_with_zeta( }; if can_collect_data { - this.update(cx, |this, cx| { - this.enqueue_settled_prediction( - id.clone(), - &project, - &edited_buffer, - &edited_buffer_snapshot, - editable_range_in_buffer, - cx, - ); + let weak_this = this.clone(); + let id = id.clone(); + let edited_buffer = edited_buffer.clone(); + let edited_buffer_snapshot = edited_buffer_snapshot.clone(); + let example_task = capture_data.and_then(|stored_events| { + cx.update(|cx| { + crate::capture_example( + project.clone(), + edited_buffer.clone(), + position, + stored_events, + false, + cx, + ) + }) + }); + cx.spawn(async move |cx| { + let example_spec = if let Some(task) = example_task { + task.await.ok() + } else { + None + }; + + weak_this + .update(cx, |this, cx| { + this.enqueue_settled_prediction( + id.clone(), + &project, + &edited_buffer, + &edited_buffer_snapshot, + editable_range_in_buffer, + example_spec, + cx, + ); + }) + .ok(); }) - .ok(); + .detach(); } Ok(Some(