ep: Don't throw away model version in reject events (#54159)

Ben Kunkle created

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes #ISSUE

Release Notes:

- N/A or Added/Fixed/Improved ...

Change summary

crates/edit_prediction/src/edit_prediction.rs       | 15 +++--
crates/edit_prediction/src/edit_prediction_tests.rs | 25 ++++++----
crates/edit_prediction/src/prediction.rs            | 12 +++-
crates/edit_prediction/src/zeta.rs                  | 37 ++++++--------
4 files changed, 48 insertions(+), 41 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -353,7 +353,7 @@ impl ProjectState {
             drop(pending_prediction.task);
         } else {
             cx.spawn(async move |this, cx| {
-                let Some(prediction_id) = pending_prediction.task.await else {
+                let Some((prediction_id, model_version)) = pending_prediction.task.await else {
                     return;
                 };
 
@@ -362,7 +362,7 @@ impl ProjectState {
                         prediction_id,
                         EditPredictionRejectReason::Canceled,
                         false,
-                        None,
+                        model_version,
                         None,
                         cx,
                     );
@@ -460,7 +460,7 @@ pub enum DiagnosticSearchScope {
 #[derive(Debug)]
 struct PendingPrediction {
     id: usize,
-    task: Task<Option<EditPredictionId>>,
+    task: Task<Option<(EditPredictionId, Option<String>)>>,
     /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
     /// If false, the task is awaited to completion so rejection can be reported.
     drop_on_cancel: bool,
@@ -2039,6 +2039,7 @@ impl EditPredictionStore {
                                 EditPredictionResult {
                                     id: prediction_result.id,
                                     prediction: Err(EditPredictionRejectReason::CurrentPreferred),
+                                    model_version: prediction_result.model_version,
                                     e2e_latency: prediction_result.e2e_latency,
                                 }
                             },
@@ -2213,9 +2214,9 @@ impl EditPredictionStore {
             }
 
             let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
-            let new_prediction_id = new_prediction_result
+            let new_prediction_metadata = new_prediction_result
                 .as_ref()
-                .map(|(prediction, _)| prediction.id.clone());
+                .map(|(prediction, _)| (prediction.id.clone(), prediction.model_version.clone()));
 
             // When a prediction completes, remove it from the pending list, and cancel
             // any pending predictions that were enqueued before it.
@@ -2271,7 +2272,7 @@ impl EditPredictionStore {
                                 prediction_result.id,
                                 reject_reason,
                                 false,
-                                None,
+                                prediction_result.model_version,
                                 Some(prediction_result.e2e_latency),
                                 cx,
                             );
@@ -2303,7 +2304,7 @@ impl EditPredictionStore {
             })
             .ok();
 
-            new_prediction_id
+            new_prediction_metadata
         });
 
         if project_state.pending_predictions.len() < max_pending_predictions {

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -1374,7 +1374,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
     });
 
     let (request, respond_tx) = requests.predict.next().await.unwrap();
-    let response = model_response(&request, "");
+    let mut response = model_response(&request, "");
+    response.model_version = Some("zeta2:test-empty".to_string());
     let id = response.request_id.clone();
     respond_tx.send(response).unwrap();
 
@@ -1397,7 +1398,7 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
             request_id: id,
             reason: EditPredictionRejectReason::Empty,
             was_shown: false,
-            model_version: None,
+            model_version: Some("zeta2:test-empty".to_string()),
             e2e_latency_ms: Some(0),
         }]
     );
@@ -1436,7 +1437,8 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
         buffer.set_text("Hello!\nHow are you?\nBye", cx);
     });
 
-    let response = model_response(&request, SIMPLE_DIFF);
+    let mut response = model_response(&request, SIMPLE_DIFF);
+    response.model_version = Some("zeta2:test-interpolated-empty".to_string());
     let id = response.request_id.clone();
     respond_tx.send(response).unwrap();
 
@@ -1459,7 +1461,7 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
             request_id: id,
             reason: EditPredictionRejectReason::InterpolatedEmpty,
             was_shown: false,
-            model_version: None,
+            model_version: Some("zeta2:test-interpolated-empty".to_string()),
             e2e_latency_ms: Some(0),
         }]
     );
@@ -1611,7 +1613,7 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
 
     let (request, respond_tx) = requests.predict.next().await.unwrap();
     // worse than current prediction
-    let second_response = model_response(
+    let mut second_response = model_response(
         &request,
         indoc! { r"
             --- a/root/foo.md
@@ -1623,6 +1625,7 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
              Bye
         "},
     );
+    second_response.model_version = Some("zeta2:test-current-preferred".to_string());
     let second_id = second_response.request_id.clone();
     respond_tx.send(second_response).unwrap();
 
@@ -1649,7 +1652,7 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
             request_id: second_id,
             reason: EditPredictionRejectReason::CurrentPreferred,
             was_shown: false,
-            model_version: None,
+            model_version: Some("zeta2:test-current-preferred".to_string()),
             e2e_latency_ms: Some(0),
         }]
     );
@@ -1713,7 +1716,8 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
         );
     });
 
-    let first_response = model_response(&request1, SIMPLE_DIFF);
+    let mut first_response = model_response(&request1, SIMPLE_DIFF);
+    first_response.model_version = Some("zeta2:test-canceled".to_string());
     let first_id = first_response.request_id.clone();
     respond_first.send(first_response).unwrap();
 
@@ -1742,7 +1746,7 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
             request_id: first_id,
             reason: EditPredictionRejectReason::Canceled,
             was_shown: false,
-            model_version: None,
+            model_version: Some("zeta2:test-canceled".to_string()),
             e2e_latency_ms: None,
         }]
     );
@@ -1826,7 +1830,8 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
         );
     });
 
-    let cancelled_response = model_response(&request2, SIMPLE_DIFF);
+    let mut cancelled_response = model_response(&request2, SIMPLE_DIFF);
+    cancelled_response.model_version = Some("zeta2:test-canceled-second".to_string());
     let cancelled_id = cancelled_response.request_id.clone();
     respond_second.send(cancelled_response).unwrap();
 
@@ -1874,7 +1879,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
                 request_id: cancelled_id,
                 reason: EditPredictionRejectReason::Canceled,
                 was_shown: false,
-                model_version: None,
+                model_version: Some("zeta2:test-canceled-second".to_string()),
                 e2e_latency_ms: None,
             },
             EditPredictionRejection {

crates/edit_prediction/src/prediction.rs 🔗

@@ -25,6 +25,7 @@ impl std::fmt::Display for EditPredictionId {
 pub struct EditPredictionResult {
     pub id: EditPredictionId,
     pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
+    pub model_version: Option<String>,
     pub e2e_latency: std::time::Duration,
 }
 
@@ -43,8 +44,9 @@ impl EditPredictionResult {
         if edits.is_empty() {
             return Self {
                 id,
-                e2e_latency,
                 prediction: Err(EditPredictionRejectReason::Empty),
+                model_version,
+                e2e_latency,
             };
         }
 
@@ -59,8 +61,9 @@ impl EditPredictionResult {
         else {
             return Self {
                 id,
-                e2e_latency,
                 prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
+                model_version,
+                e2e_latency,
             };
         };
 
@@ -68,7 +71,6 @@ impl EditPredictionResult {
 
         Self {
             id: id.clone(),
-            e2e_latency,
             prediction: Ok(EditPrediction {
                 id,
                 edits,
@@ -77,8 +79,10 @@ impl EditPredictionResult {
                 edit_preview,
                 inputs,
                 buffer: edited_buffer.clone(),
-                model_version,
+                model_version: model_version.clone(),
             }),
+            model_version,
+            e2e_latency,
         }
     }
 }

crates/edit_prediction/src/zeta.rs 🔗

@@ -102,7 +102,6 @@ pub fn request_prediction_with_zeta(
         edits: Vec<(Range<Anchor>, Arc<str>)>,
         cursor_position: Option<PredictedCursorPosition>,
         editable_range_in_buffer: Range<usize>,
-        model_version: Option<String>,
     }
 
     let request_task = cx.background_spawn({
@@ -305,7 +304,7 @@ pub fn request_prediction_with_zeta(
                 cursor_offset_in_new_editable_region: cursor_offset_in_output,
             }) = output
             else {
-                return Ok((Some((request_id, None)), None));
+                return Ok((Some((request_id, None, model_version)), None));
             };
 
             let editable_range_in_buffer = editable_range_in_excerpt.start
@@ -343,26 +342,23 @@ pub fn request_prediction_with_zeta(
                 &snapshot,
             );
 
-            anyhow::Ok((
-                Some((
-                    request_id,
-                    Some(Prediction {
-                        prompt_input,
-                        buffer,
-                        snapshot: snapshot.clone(),
-                        edits,
-                        cursor_position,
-                        editable_range_in_buffer,
-                        model_version,
-                    }),
-                )),
-                usage,
-            ))
+            let prediction = Some(Prediction {
+                prompt_input,
+                buffer,
+                snapshot: snapshot.clone(),
+                edits,
+                cursor_position,
+                editable_range_in_buffer,
+            });
+
+            anyhow::Ok((Some((request_id, prediction, model_version)), usage))
         }
     });
 
     cx.spawn(async move |this, cx| {
-        let Some((id, prediction)) = handle_api_response(&this, request_task.await, cx)? else {
+        let Some((id, prediction, model_version)) =
+            handle_api_response(&this, request_task.await, cx)?
+        else {
             return Ok(None);
         };
         let request_duration = cx.background_executor().now() - request_start;
@@ -374,13 +370,14 @@ pub fn request_prediction_with_zeta(
             edits,
             cursor_position,
             editable_range_in_buffer,
-            model_version,
+            ..
         }) = prediction
         else {
             return Ok(Some(EditPredictionResult {
                 id,
-                e2e_latency: request_duration,
                 prediction: Err(EditPredictionRejectReason::Empty),
+                model_version,
+                e2e_latency: request_duration,
             }));
         };