Include optional model version with EP acceptance and rejection messages (#50262)

Max Brunsfeld created

Release Notes:

- N/A

Change summary

crates/cloud_llm_client/src/cloud_llm_client.rs     |  4 +
crates/cloud_llm_client/src/predict_edits_v3.rs     |  2 
crates/edit_prediction/src/edit_prediction.rs       | 21 +++++++
crates/edit_prediction/src/edit_prediction_tests.rs | 37 +++++++++++---
crates/edit_prediction/src/fim.rs                   |  1 
crates/edit_prediction/src/mercury.rs               |  1 
crates/edit_prediction/src/prediction.rs            |  4 +
crates/edit_prediction/src/sweep_ai.rs              |  1 
crates/edit_prediction/src/zeta.rs                  | 19 ++++--
9 files changed, 72 insertions(+), 18 deletions(-)

Detailed changes

crates/cloud_llm_client/src/cloud_llm_client.rs 🔗

@@ -142,6 +142,8 @@ pub struct PredictEditsResponse {
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct AcceptEditPredictionBody {
     pub request_id: String,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub model_version: Option<String>,
 }
 
 #[derive(Debug, Clone, Deserialize)]
@@ -160,6 +162,8 @@ pub struct EditPredictionRejection {
     #[serde(default)]
     pub reason: EditPredictionRejectReason,
     pub was_shown: bool,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub model_version: Option<String>,
 }
 
 #[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -33,6 +33,8 @@ pub struct PredictEditsV3Response {
     /// this range to extract the old text from its local excerpt for
     /// diffing, rather than relying on its own format-derived range.
     pub editable_range: Range<usize>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub model_version: Option<String>,
 }
 
 #[derive(Debug, Deserialize, Serialize)]

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -361,6 +361,7 @@ impl ProjectState {
                         prediction_id,
                         EditPredictionRejectReason::Canceled,
                         false,
+                        None,
                         cx,
                     );
                 })
@@ -1394,7 +1395,14 @@ impl EditPredictionStore {
         if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
             project_state.pending_predictions.clear();
             if let Some(prediction) = project_state.current_prediction.take() {
-                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
+                let model_version = prediction.prediction.model_version.clone();
+                self.reject_prediction(
+                    prediction.prediction.id,
+                    reason,
+                    prediction.was_shown,
+                    model_version,
+                    cx,
+                );
             }
         };
     }
@@ -1453,6 +1461,7 @@ impl EditPredictionStore {
         prediction_id: EditPredictionId,
         reason: EditPredictionRejectReason,
         was_shown: bool,
+        model_version: Option<String>,
         cx: &App,
     ) {
         match self.edit_prediction_model {
@@ -1467,6 +1476,7 @@ impl EditPredictionStore {
                             request_id: prediction_id.to_string(),
                             reason,
                             was_shown,
+                            model_version,
                         })
                         .log_err();
                 }
@@ -1812,6 +1822,7 @@ impl EditPredictionStore {
                                         new_prediction.prediction.id,
                                         EditPredictionRejectReason::CurrentPreferred,
                                         false,
+                                        new_prediction.prediction.model_version,
                                         cx,
                                     );
                                     None
@@ -1821,7 +1832,13 @@ impl EditPredictionStore {
                             }
                         }
                         Err(reject_reason) => {
-                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
+                            this.reject_prediction(
+                                prediction_result.id,
+                                reject_reason,
+                                false,
+                                None,
+                                cx,
+                            );
                             None
                         }
                     }

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -897,7 +897,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
         &[EditPredictionRejection {
             request_id: id,
             reason: EditPredictionRejectReason::Empty,
-            was_shown: false
+            was_shown: false,
+            model_version: None,
         }]
     );
 }
@@ -957,7 +958,8 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
         &[EditPredictionRejection {
             request_id: id,
             reason: EditPredictionRejectReason::InterpolatedEmpty,
-            was_shown: false
+            was_shown: false,
+            model_version: None,
         }]
     );
 }
@@ -1049,7 +1051,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
         &[EditPredictionRejection {
             request_id: first_id,
             reason: EditPredictionRejectReason::Replaced,
-            was_shown: false
+            was_shown: false,
+            model_version: None,
         }]
     );
 }
@@ -1143,7 +1146,8 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
         &[EditPredictionRejection {
             request_id: second_id,
             reason: EditPredictionRejectReason::CurrentPreferred,
-            was_shown: false
+            was_shown: false,
+            model_version: None,
         }]
     );
 }
@@ -1234,7 +1238,8 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
         &[EditPredictionRejection {
             request_id: first_id,
             reason: EditPredictionRejectReason::Canceled,
-            was_shown: false
+            was_shown: false,
+            model_version: None,
         }]
     );
 }
@@ -1364,12 +1369,14 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
             EditPredictionRejection {
                 request_id: cancelled_id,
                 reason: EditPredictionRejectReason::Canceled,
-                was_shown: false
+                was_shown: false,
+                model_version: None,
             },
             EditPredictionRejection {
                 request_id: first_id,
                 reason: EditPredictionRejectReason::Replaced,
-                was_shown: false
+                was_shown: false,
+                model_version: None,
             }
         ]
     );
@@ -1485,12 +1492,14 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
             EditPredictionId("test-1".into()),
             EditPredictionRejectReason::Discarded,
             false,
+            None,
             cx,
         );
         ep_store.reject_prediction(
             EditPredictionId("test-2".into()),
             EditPredictionRejectReason::Canceled,
             true,
+            None,
             cx,
         );
     });
@@ -1508,7 +1517,8 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
         EditPredictionRejection {
             request_id: "test-1".to_string(),
             reason: EditPredictionRejectReason::Discarded,
-            was_shown: false
+            was_shown: false,
+            model_version: None,
         }
     );
     assert_eq!(
@@ -1516,7 +1526,8 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
         EditPredictionRejection {
             request_id: "test-2".to_string(),
             reason: EditPredictionRejectReason::Canceled,
-            was_shown: true
+            was_shown: true,
+            model_version: None,
         }
     );
 
@@ -1527,6 +1538,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
                 EditPredictionId(format!("batch-{}", i).into()),
                 EditPredictionRejectReason::Discarded,
                 false,
+                None,
                 cx,
             );
         }
@@ -1558,6 +1570,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
             EditPredictionId("retry-1".into()),
             EditPredictionRejectReason::Discarded,
             false,
+            None,
             cx,
         );
     });
@@ -1577,6 +1590,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
             EditPredictionId("retry-2".into()),
             EditPredictionRejectReason::Discarded,
             false,
+            None,
             cx,
         );
     });
@@ -1700,6 +1714,7 @@ fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> Predi
         request_id: Uuid::new_v4().to_string(),
         editable_range,
         output: new_excerpt,
+        model_version: None,
     }
 }
 
@@ -1708,6 +1723,7 @@ fn empty_response() -> PredictEditsV3Response {
         request_id: Uuid::new_v4().to_string(),
         editable_range: 0..0,
         output: String::new(),
+        model_version: None,
     }
 }
 
@@ -1837,6 +1853,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
         },
         buffer_snapshotted_at: Instant::now(),
         response_received_at: Instant::now(),
+        model_version: None,
     };
 
     cx.update(|cx| {
@@ -2034,6 +2051,7 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte
         request_id: Uuid::new_v4().to_string(),
         output: "hello world\n".to_string(),
         editable_range: 0..excerpt_length,
+        model_version: None,
     };
     respond_tx.send(response).unwrap();
 
@@ -2138,6 +2156,7 @@ async fn make_test_ep_store(
                                     request_id: format!("request-{next_request_id}"),
                                     editable_range: 0..req.input.cursor_excerpt.len(),
                                     output: completion_response.lock().clone(),
+                                    model_version: None,
                                 })
                                 .unwrap()
                                 .into(),

crates/edit_prediction/src/prediction.rs 🔗

@@ -41,6 +41,7 @@ impl EditPredictionResult {
         buffer_snapshotted_at: Instant,
         response_received_at: Instant,
         inputs: ZetaPromptInput,
+        model_version: Option<String>,
         cx: &mut AsyncApp,
     ) -> Self {
         if edits.is_empty() {
@@ -79,6 +80,7 @@ impl EditPredictionResult {
                 buffer: edited_buffer.clone(),
                 buffer_snapshotted_at,
                 response_received_at,
+                model_version,
             }),
         }
     }
@@ -95,6 +97,7 @@ pub struct EditPrediction {
     pub buffer_snapshotted_at: Instant,
     pub response_received_at: Instant,
     pub inputs: zeta_prompt::ZetaPromptInput,
+    pub model_version: Option<String>,
 }
 
 impl EditPrediction {
@@ -150,6 +153,7 @@ mod tests {
             snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
             buffer: buffer.clone(),
             edit_preview,
+            model_version: None,
             inputs: ZetaPromptInput {
                 events: vec![],
                 related_files: vec![],

crates/edit_prediction/src/zeta.rs 🔗

@@ -125,7 +125,7 @@ pub fn request_prediction_with_zeta(
 
             log::trace!("Sending edit prediction request");
 
-            let (request_id, output_text, usage) = if let Some(custom_settings) =
+            let (request_id, output_text, model_version, usage) = if let Some(custom_settings) =
                 &custom_server_settings
             {
                 let max_tokens = custom_settings.max_output_tokens * 4;
@@ -158,7 +158,7 @@ pub fn request_prediction_with_zeta(
                     let request_id = EditPredictionId(request_id.into());
                     let output_text = zeta1::clean_zeta1_model_output(&response_text);
 
-                    (request_id, output_text, None)
+                    (request_id, output_text, None, None)
                 } else {
                     let prompt = format_zeta_prompt(&prompt_input, zeta_version);
                     let prefill = get_prefill(&prompt_input, zeta_version);
@@ -188,7 +188,7 @@ pub fn request_prediction_with_zeta(
                         Some(clean_zeta2_model_output(&output, zeta_version).to_string())
                     };
 
-                    (request_id, output_text, None)
+                    (request_id, output_text, None, None)
                 }
             } else if let Some(config) = &raw_config {
                 let prompt = format_zeta_prompt(&prompt_input, config.format);
@@ -225,7 +225,7 @@ pub fn request_prediction_with_zeta(
                     clean_zeta2_model_output(&output, config.format).to_string()
                 });
 
-                (request_id, output_text, usage)
+                (request_id, output_text, None, usage)
             } else {
                 // Use V3 endpoint - server handles model/version selection and suffix stripping
                 let (response, usage) = EditPredictionStore::send_v3_request(
@@ -244,8 +244,9 @@ pub fn request_prediction_with_zeta(
                     Some(response.output)
                 };
                 editable_range_in_excerpt = response.editable_range;
+                let model_version = response.model_version;
 
-                (request_id, output_text, usage)
+                (request_id, output_text, model_version, usage)
             };
 
             let received_response_at = Instant::now();
@@ -253,7 +254,7 @@ pub fn request_prediction_with_zeta(
             log::trace!("Got edit prediction response");
 
             let Some(mut output_text) = output_text else {
-                return Ok((Some((request_id, None)), usage));
+                return Ok((Some((request_id, None, model_version)), usage));
             };
 
             // Client-side cursor marker processing (applies to both raw and v3 responses)
@@ -309,6 +310,7 @@ pub fn request_prediction_with_zeta(
                         cursor_position,
                         received_response_at,
                     )),
+                    model_version,
                 )),
                 usage,
             ))
@@ -316,7 +318,7 @@ pub fn request_prediction_with_zeta(
     });
 
     cx.spawn(async move |this, cx| {
-        let Some((id, prediction)) =
+        let Some((id, prediction, model_version)) =
             EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
         else {
             return Ok(None);
@@ -347,6 +349,7 @@ pub fn request_prediction_with_zeta(
                 buffer_snapshotted_at,
                 received_response_at,
                 inputs,
+                model_version,
                 cx,
             )
             .await,
@@ -476,6 +479,7 @@ pub(crate) fn edit_prediction_accepted(
     }
 
     let request_id = current_prediction.prediction.id.to_string();
+    let model_version = current_prediction.prediction.model_version;
     let require_auth = custom_accept_url.is_none();
     let client = store.client.clone();
     let llm_token = store.llm_token.clone();
@@ -495,6 +499,7 @@ pub(crate) fn edit_prediction_accepted(
                 let req = builder.uri(url.as_ref()).body(
                     serde_json::to_string(&AcceptEditPredictionBody {
                         request_id: request_id.clone(),
+                        model_version: model_version.clone(),
                     })?
                     .into(),
                 );