diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index d61cba71922582b98bdc64444bc8227c0043fa2e..16ce0d659905528b2e63d50671914d4d6b12a1f4 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -353,7 +353,7 @@ impl ProjectState { drop(pending_prediction.task); } else { cx.spawn(async move |this, cx| { - let Some(prediction_id) = pending_prediction.task.await else { + let Some((prediction_id, model_version)) = pending_prediction.task.await else { return; }; @@ -362,7 +362,7 @@ impl ProjectState { prediction_id, EditPredictionRejectReason::Canceled, false, - None, + model_version, None, cx, ); @@ -460,7 +460,7 @@ pub enum DiagnosticSearchScope { #[derive(Debug)] struct PendingPrediction { id: usize, - task: Task>, + task: Task)>>, /// If true, the task is dropped immediately on cancel (cancelling the HTTP request). /// If false, the task is awaited to completion so rejection can be reported. drop_on_cancel: bool, @@ -2039,6 +2039,7 @@ impl EditPredictionStore { EditPredictionResult { id: prediction_result.id, prediction: Err(EditPredictionRejectReason::CurrentPreferred), + model_version: prediction_result.model_version, e2e_latency: prediction_result.e2e_latency, } }, @@ -2213,9 +2214,9 @@ impl EditPredictionStore { } let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten(); - let new_prediction_id = new_prediction_result + let new_prediction_metadata = new_prediction_result .as_ref() - .map(|(prediction, _)| prediction.id.clone()); + .map(|(prediction, _)| (prediction.id.clone(), prediction.model_version.clone())); // When a prediction completes, remove it from the pending list, and cancel // any pending predictions that were enqueued before it. @@ -2271,7 +2272,7 @@ impl EditPredictionStore { prediction_result.id, reject_reason, false, - None, + prediction_result.model_version, Some(prediction_result.e2e_latency), cx, ); @@ -2303,7 +2304,7 @@ impl EditPredictionStore { }) .ok(); - new_prediction_id + new_prediction_metadata }); if project_state.pending_predictions.len() < max_pending_predictions { diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 7a0c5f57992e1996be715c8220aaeb4e398f5601..0a8cd1b066adad88e639129bb932d4eca8b690dc 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1374,7 +1374,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) { }); let (request, respond_tx) = requests.predict.next().await.unwrap(); - let response = model_response(&request, ""); + let mut response = model_response(&request, ""); + response.model_version = Some("zeta2:test-empty".to_string()); let id = response.request_id.clone(); respond_tx.send(response).unwrap(); @@ -1397,7 +1398,7 @@ async fn test_empty_prediction(cx: &mut TestAppContext) { request_id: id, reason: EditPredictionRejectReason::Empty, was_shown: false, - model_version: None, + model_version: Some("zeta2:test-empty".to_string()), e2e_latency_ms: Some(0), }] ); @@ -1436,7 +1437,8 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) { buffer.set_text("Hello!\nHow are you?\nBye", cx); }); - let response = model_response(&request, SIMPLE_DIFF); + let mut response = model_response(&request, SIMPLE_DIFF); + response.model_version = Some("zeta2:test-interpolated-empty".to_string()); let id = response.request_id.clone(); respond_tx.send(response).unwrap(); @@ -1459,7 +1461,7 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) { request_id: id, reason: EditPredictionRejectReason::InterpolatedEmpty, was_shown: false, - model_version: None, + model_version: Some("zeta2:test-interpolated-empty".to_string()), e2e_latency_ms: Some(0), }] ); @@ -1611,7 +1613,7 @@ async fn test_current_preferred(cx: &mut TestAppContext) { let (request, respond_tx) = requests.predict.next().await.unwrap(); // worse than current prediction - let second_response = model_response( + let mut second_response = model_response( &request, indoc! { r" --- a/root/foo.md @@ -1623,6 +1625,7 @@ async fn test_current_preferred(cx: &mut TestAppContext) { Bye "}, ); + second_response.model_version = Some("zeta2:test-current-preferred".to_string()); let second_id = second_response.request_id.clone(); respond_tx.send(second_response).unwrap(); @@ -1649,7 +1652,7 @@ async fn test_current_preferred(cx: &mut TestAppContext) { request_id: second_id, reason: EditPredictionRejectReason::CurrentPreferred, was_shown: false, - model_version: None, + model_version: Some("zeta2:test-current-preferred".to_string()), e2e_latency_ms: Some(0), }] ); @@ -1713,7 +1716,8 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { ); }); - let first_response = model_response(&request1, SIMPLE_DIFF); + let mut first_response = model_response(&request1, SIMPLE_DIFF); + first_response.model_version = Some("zeta2:test-canceled".to_string()); let first_id = first_response.request_id.clone(); respond_first.send(first_response).unwrap(); @@ -1742,7 +1746,7 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { request_id: first_id, reason: EditPredictionRejectReason::Canceled, was_shown: false, - model_version: None, + model_version: Some("zeta2:test-canceled".to_string()), e2e_latency_ms: None, }] ); @@ -1826,7 +1830,8 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { ); }); - let cancelled_response = model_response(&request2, SIMPLE_DIFF); + let mut cancelled_response = model_response(&request2, SIMPLE_DIFF); + cancelled_response.model_version = Some("zeta2:test-canceled-second".to_string()); let cancelled_id = cancelled_response.request_id.clone(); respond_second.send(cancelled_response).unwrap(); @@ -1874,7 +1879,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { request_id: cancelled_id, reason: EditPredictionRejectReason::Canceled, was_shown: false, - model_version: None, + model_version: Some("zeta2:test-canceled-second".to_string()), e2e_latency_ms: None, }, EditPredictionRejection { diff --git a/crates/edit_prediction/src/prediction.rs b/crates/edit_prediction/src/prediction.rs index eb45832d4cccad52ae950c6d6ff685092270aaa0..b115ad795b12cb0c54f876df264f769c570510d4 100644 --- a/crates/edit_prediction/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -25,6 +25,7 @@ impl std::fmt::Display for EditPredictionId { pub struct EditPredictionResult { pub id: EditPredictionId, pub prediction: Result, + pub model_version: Option, pub e2e_latency: std::time::Duration, } @@ -43,8 +44,9 @@ impl EditPredictionResult { if edits.is_empty() { return Self { id, - e2e_latency, prediction: Err(EditPredictionRejectReason::Empty), + model_version, + e2e_latency, }; } @@ -59,8 +61,9 @@ impl EditPredictionResult { else { return Self { id, - e2e_latency, prediction: Err(EditPredictionRejectReason::InterpolatedEmpty), + model_version, + e2e_latency, }; }; @@ -68,7 +71,6 @@ impl EditPredictionResult { Self { id: id.clone(), - e2e_latency, prediction: Ok(EditPrediction { id, edits, @@ -77,8 +79,10 @@ impl EditPredictionResult { edit_preview, inputs, buffer: edited_buffer.clone(), - model_version, + model_version: model_version.clone(), }), + model_version, + e2e_latency, } } } diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index 60157781eb769553d8f4c266df3e05d75d335ec3..7b12453353478d766e9e54093ecfb4c62d235564 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -102,7 +102,6 @@ pub fn request_prediction_with_zeta( edits: Vec<(Range, Arc)>, cursor_position: Option, editable_range_in_buffer: Range, - model_version: Option, } let request_task = cx.background_spawn({ @@ -305,7 +304,7 @@ pub fn request_prediction_with_zeta( cursor_offset_in_new_editable_region: cursor_offset_in_output, }) = output else { - return Ok((Some((request_id, None)), None)); + return Ok((Some((request_id, None, model_version)), None)); }; let editable_range_in_buffer = editable_range_in_excerpt.start @@ -343,26 +342,23 @@ pub fn request_prediction_with_zeta( &snapshot, ); - anyhow::Ok(( - Some(( - request_id, - Some(Prediction { - prompt_input, - buffer, - snapshot: snapshot.clone(), - edits, - cursor_position, - editable_range_in_buffer, - model_version, - }), - )), - usage, - )) + let prediction = Some(Prediction { + prompt_input, + buffer, + snapshot: snapshot.clone(), + edits, + cursor_position, + editable_range_in_buffer, + }); + + anyhow::Ok((Some((request_id, prediction, model_version)), usage)) } }); cx.spawn(async move |this, cx| { - let Some((id, prediction)) = handle_api_response(&this, request_task.await, cx)? else { + let Some((id, prediction, model_version)) = + handle_api_response(&this, request_task.await, cx)? + else { return Ok(None); }; let request_duration = cx.background_executor().now() - request_start; @@ -374,13 +370,14 @@ pub fn request_prediction_with_zeta( edits, cursor_position, editable_range_in_buffer, - model_version, + .. }) = prediction else { return Ok(Some(EditPredictionResult { id, - e2e_latency: request_duration, prediction: Err(EditPredictionRejectReason::Empty), + model_version, + e2e_latency: request_duration, })); };