diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 20e4d49bb3e42e0e9ce92e61bb0dfa377d9c2ad6..93f4270b48caf0a339694613ef0c486ecf3eac54 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -142,6 +142,8 @@ pub struct PredictEditsResponse { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AcceptEditPredictionBody { pub request_id: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model_version: Option, } #[derive(Debug, Clone, Deserialize)] @@ -160,6 +162,8 @@ pub struct EditPredictionRejection { #[serde(default)] pub reason: EditPredictionRejectReason, pub was_shown: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model_version: Option, } #[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index d0b53ca18e8c74ec2588bff14c5130e3381f9444..5002c1a770ec1955d2a96c97098867f20f9bd05d 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -33,6 +33,8 @@ pub struct PredictEditsV3Response { /// this range to extract the old text from its local excerpt for /// diffing, rather than relying on its own format-derived range. pub editable_range: Range, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model_version: Option, } #[derive(Debug, Deserialize, Serialize)] diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 78f42db2120b45f04dbf83c5e706a42163ee8067..836b4a477f62e2da6674568d0a7a1ccfc2b603cf 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -361,6 +361,7 @@ impl ProjectState { prediction_id, EditPredictionRejectReason::Canceled, false, + None, cx, ); }) @@ -1394,7 +1395,14 @@ impl EditPredictionStore { if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { project_state.pending_predictions.clear(); if let Some(prediction) = project_state.current_prediction.take() { - self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx); + let model_version = prediction.prediction.model_version.clone(); + self.reject_prediction( + prediction.prediction.id, + reason, + prediction.was_shown, + model_version, + cx, + ); } }; } @@ -1453,6 +1461,7 @@ impl EditPredictionStore { prediction_id: EditPredictionId, reason: EditPredictionRejectReason, was_shown: bool, + model_version: Option, cx: &App, ) { match self.edit_prediction_model { @@ -1467,6 +1476,7 @@ impl EditPredictionStore { request_id: prediction_id.to_string(), reason, was_shown, + model_version, }) .log_err(); } @@ -1812,6 +1822,7 @@ impl EditPredictionStore { new_prediction.prediction.id, EditPredictionRejectReason::CurrentPreferred, false, + new_prediction.prediction.model_version, cx, ); None @@ -1821,7 +1832,13 @@ impl EditPredictionStore { } } Err(reject_reason) => { - this.reject_prediction(prediction_result.id, reject_reason, false, cx); + this.reject_prediction( + prediction_result.id, + reject_reason, + false, + None, + cx, + ); None } } diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index b0468e3c5610b8f618631be6707c74c4eaa451e5..abe522494fc8962a995313ffb1a57b8672c22ca4 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -897,7 +897,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) { &[EditPredictionRejection { request_id: id, reason: EditPredictionRejectReason::Empty, - was_shown: false + was_shown: false, + model_version: None, }] ); } @@ -957,7 +958,8 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) { &[EditPredictionRejection { request_id: id, reason: EditPredictionRejectReason::InterpolatedEmpty, - was_shown: false + was_shown: false, + model_version: None, }] ); } @@ -1049,7 +1051,8 @@ async fn test_replace_current(cx: &mut TestAppContext) { &[EditPredictionRejection { request_id: first_id, reason: EditPredictionRejectReason::Replaced, - was_shown: false + was_shown: false, + model_version: None, }] ); } @@ -1143,7 +1146,8 @@ async fn test_current_preferred(cx: &mut TestAppContext) { &[EditPredictionRejection { request_id: second_id, reason: EditPredictionRejectReason::CurrentPreferred, - was_shown: false + was_shown: false, + model_version: None, }] ); } @@ -1234,7 +1238,8 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { &[EditPredictionRejection { request_id: first_id, reason: EditPredictionRejectReason::Canceled, - was_shown: false + was_shown: false, + model_version: None, }] ); } @@ -1364,12 +1369,14 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { EditPredictionRejection { request_id: cancelled_id, reason: EditPredictionRejectReason::Canceled, - was_shown: false + was_shown: false, + model_version: None, }, EditPredictionRejection { request_id: first_id, reason: EditPredictionRejectReason::Replaced, - was_shown: false + was_shown: false, + model_version: None, } ] ); @@ -1485,12 +1492,14 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) { EditPredictionId("test-1".into()), EditPredictionRejectReason::Discarded, false, + None, cx, ); ep_store.reject_prediction( EditPredictionId("test-2".into()), EditPredictionRejectReason::Canceled, true, + None, cx, ); }); @@ -1508,7 +1517,8 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) { EditPredictionRejection { request_id: "test-1".to_string(), reason: EditPredictionRejectReason::Discarded, - was_shown: false + was_shown: false, + model_version: None, } ); assert_eq!( @@ -1516,7 +1526,8 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) { EditPredictionRejection { request_id: "test-2".to_string(), reason: EditPredictionRejectReason::Canceled, - was_shown: true + was_shown: true, + model_version: None, } ); @@ -1527,6 +1538,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) { EditPredictionId(format!("batch-{}", i).into()), EditPredictionRejectReason::Discarded, false, + None, cx, ); } @@ -1558,6 +1570,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) { EditPredictionId("retry-1".into()), EditPredictionRejectReason::Discarded, false, + None, cx, ); }); @@ -1577,6 +1590,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) { EditPredictionId("retry-2".into()), EditPredictionRejectReason::Discarded, false, + None, cx, ); }); @@ -1700,6 +1714,7 @@ fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> Predi request_id: Uuid::new_v4().to_string(), editable_range, output: new_excerpt, + model_version: None, } } @@ -1708,6 +1723,7 @@ fn empty_response() -> PredictEditsV3Response { request_id: Uuid::new_v4().to_string(), editable_range: 0..0, output: String::new(), + model_version: None, } } @@ -1837,6 +1853,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { }, buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), + model_version: None, }; cx.update(|cx| { @@ -2034,6 +2051,7 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte request_id: Uuid::new_v4().to_string(), output: "hello world\n".to_string(), editable_range: 0..excerpt_length, + model_version: None, }; respond_tx.send(response).unwrap(); @@ -2138,6 +2156,7 @@ async fn make_test_ep_store( request_id: format!("request-{next_request_id}"), editable_range: 0..req.input.cursor_excerpt.len(), output: completion_response.lock().clone(), + model_version: None, }) .unwrap() .into(), diff --git a/crates/edit_prediction/src/fim.rs b/crates/edit_prediction/src/fim.rs index 7ba6c6bef77c5b2229d1b3a4072e8070e5c4a6f1..dda008133d3726f5e7ba32ec05c770878d16585f 100644 --- a/crates/edit_prediction/src/fim.rs +++ b/crates/edit_prediction/src/fim.rs @@ -141,6 +141,7 @@ pub fn request_prediction( output.buffer_snapshotted_at, output.response_received_at, output.inputs, + None, cx, ) .await, diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index 4187881639d8c363582f7a2c7603f2bb51e09fa7..f3adba55e620e77ffd7bb12b0e950fd4d3f011fc 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -218,6 +218,7 @@ impl Mercury { buffer_snapshotted_at, response_received_at, inputs, + None, cx, ) .await, diff --git a/crates/edit_prediction/src/prediction.rs b/crates/edit_prediction/src/prediction.rs index 750b1a435ae4a7a281ef41973e1f6d0d2158445e..9c17f29fe29bc711f6750cf6fe24586067bfc619 100644 --- a/crates/edit_prediction/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -41,6 +41,7 @@ impl EditPredictionResult { buffer_snapshotted_at: Instant, response_received_at: Instant, inputs: ZetaPromptInput, + model_version: Option, cx: &mut AsyncApp, ) -> Self { if edits.is_empty() { @@ -79,6 +80,7 @@ impl EditPredictionResult { buffer: edited_buffer.clone(), buffer_snapshotted_at, response_received_at, + model_version, }), } } @@ -95,6 +97,7 @@ pub struct EditPrediction { pub buffer_snapshotted_at: Instant, pub response_received_at: Instant, pub inputs: zeta_prompt::ZetaPromptInput, + pub model_version: Option, } impl EditPrediction { @@ -150,6 +153,7 @@ mod tests { snapshot: cx.read(|cx| buffer.read(cx).snapshot()), buffer: buffer.clone(), edit_preview, + model_version: None, inputs: ZetaPromptInput { events: vec![], related_files: vec![], diff --git a/crates/edit_prediction/src/sweep_ai.rs b/crates/edit_prediction/src/sweep_ai.rs index 1253916487894d757c74293c21f4ace1c681cd11..5a9fcf0e6ce7bfa5476d6c48245068994178f7bc 100644 --- a/crates/edit_prediction/src/sweep_ai.rs +++ b/crates/edit_prediction/src/sweep_ai.rs @@ -303,6 +303,7 @@ impl SweepAi { buffer_snapshotted_at, response_received_at, inputs, + None, cx, ) .await, diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index 41877d10d6e3ede2ad6055e7580400075533a265..25f9900dcba4a8f29f7e1268560bcbb40ded9778 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -125,7 +125,7 @@ pub fn request_prediction_with_zeta( log::trace!("Sending edit prediction request"); - let (request_id, output_text, usage) = if let Some(custom_settings) = + let (request_id, output_text, model_version, usage) = if let Some(custom_settings) = &custom_server_settings { let max_tokens = custom_settings.max_output_tokens * 4; @@ -158,7 +158,7 @@ pub fn request_prediction_with_zeta( let request_id = EditPredictionId(request_id.into()); let output_text = zeta1::clean_zeta1_model_output(&response_text); - (request_id, output_text, None) + (request_id, output_text, None, None) } else { let prompt = format_zeta_prompt(&prompt_input, zeta_version); let prefill = get_prefill(&prompt_input, zeta_version); @@ -188,7 +188,7 @@ pub fn request_prediction_with_zeta( Some(clean_zeta2_model_output(&output, zeta_version).to_string()) }; - (request_id, output_text, None) + (request_id, output_text, None, None) } } else if let Some(config) = &raw_config { let prompt = format_zeta_prompt(&prompt_input, config.format); @@ -225,7 +225,7 @@ pub fn request_prediction_with_zeta( clean_zeta2_model_output(&output, config.format).to_string() }); - (request_id, output_text, usage) + (request_id, output_text, None, usage) } else { // Use V3 endpoint - server handles model/version selection and suffix stripping let (response, usage) = EditPredictionStore::send_v3_request( @@ -244,8 +244,9 @@ pub fn request_prediction_with_zeta( Some(response.output) }; editable_range_in_excerpt = response.editable_range; + let model_version = response.model_version; - (request_id, output_text, usage) + (request_id, output_text, model_version, usage) }; let received_response_at = Instant::now(); @@ -253,7 +254,7 @@ pub fn request_prediction_with_zeta( log::trace!("Got edit prediction response"); let Some(mut output_text) = output_text else { - return Ok((Some((request_id, None)), usage)); + return Ok((Some((request_id, None, model_version)), usage)); }; // Client-side cursor marker processing (applies to both raw and v3 responses) @@ -311,6 +312,7 @@ pub fn request_prediction_with_zeta( full_context_offset_range, editable_range_in_buffer, )), + model_version, )), usage, )) @@ -318,7 +320,7 @@ pub fn request_prediction_with_zeta( }); cx.spawn(async move |this, cx| { - let Some((id, prediction)) = + let Some((id, prediction, model_version)) = EditPredictionStore::handle_api_response(&this, request_task.await, cx)? else { return Ok(None); @@ -392,6 +394,7 @@ pub fn request_prediction_with_zeta( buffer_snapshotted_at, received_response_at, inputs, + model_version, cx, ) .await, @@ -521,6 +524,7 @@ pub(crate) fn edit_prediction_accepted( } let request_id = current_prediction.prediction.id.to_string(); + let model_version = current_prediction.prediction.model_version; let require_auth = custom_accept_url.is_none(); let client = store.client.clone(); let llm_token = store.llm_token.clone(); @@ -540,6 +544,7 @@ pub(crate) fn edit_prediction_accepted( let req = builder.uri(url.as_ref()).body( serde_json::to_string(&AcceptEditPredictionBody { request_id: request_id.clone(), + model_version: model_version.clone(), })? .into(), );