Detailed changes
@@ -142,6 +142,8 @@ pub struct PredictEditsResponse {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AcceptEditPredictionBody {
pub request_id: String,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub model_version: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
@@ -160,6 +162,8 @@ pub struct EditPredictionRejection {
#[serde(default)]
pub reason: EditPredictionRejectReason,
pub was_shown: bool,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub model_version: Option<String>,
}
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
@@ -33,6 +33,8 @@ pub struct PredictEditsV3Response {
/// this range to extract the old text from its local excerpt for
/// diffing, rather than relying on its own format-derived range.
pub editable_range: Range<usize>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub model_version: Option<String>,
}
#[derive(Debug, Deserialize, Serialize)]
@@ -361,6 +361,7 @@ impl ProjectState {
prediction_id,
EditPredictionRejectReason::Canceled,
false,
+ None,
cx,
);
})
@@ -1394,7 +1395,14 @@ impl EditPredictionStore {
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
project_state.pending_predictions.clear();
if let Some(prediction) = project_state.current_prediction.take() {
- self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
+ let model_version = prediction.prediction.model_version.clone();
+ self.reject_prediction(
+ prediction.prediction.id,
+ reason,
+ prediction.was_shown,
+ model_version,
+ cx,
+ );
}
};
}
@@ -1453,6 +1461,7 @@ impl EditPredictionStore {
prediction_id: EditPredictionId,
reason: EditPredictionRejectReason,
was_shown: bool,
+ model_version: Option<String>,
cx: &App,
) {
match self.edit_prediction_model {
@@ -1467,6 +1476,7 @@ impl EditPredictionStore {
request_id: prediction_id.to_string(),
reason,
was_shown,
+ model_version,
})
.log_err();
}
@@ -1812,6 +1822,7 @@ impl EditPredictionStore {
new_prediction.prediction.id,
EditPredictionRejectReason::CurrentPreferred,
false,
+ new_prediction.prediction.model_version,
cx,
);
None
@@ -1821,7 +1832,13 @@ impl EditPredictionStore {
}
}
Err(reject_reason) => {
- this.reject_prediction(prediction_result.id, reject_reason, false, cx);
+ this.reject_prediction(
+ prediction_result.id,
+ reject_reason,
+ false,
+ None,
+ cx,
+ );
None
}
}
@@ -897,7 +897,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
&[EditPredictionRejection {
request_id: id,
reason: EditPredictionRejectReason::Empty,
- was_shown: false
+ was_shown: false,
+ model_version: None,
}]
);
}
@@ -957,7 +958,8 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
&[EditPredictionRejection {
request_id: id,
reason: EditPredictionRejectReason::InterpolatedEmpty,
- was_shown: false
+ was_shown: false,
+ model_version: None,
}]
);
}
@@ -1049,7 +1051,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
&[EditPredictionRejection {
request_id: first_id,
reason: EditPredictionRejectReason::Replaced,
- was_shown: false
+ was_shown: false,
+ model_version: None,
}]
);
}
@@ -1143,7 +1146,8 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
&[EditPredictionRejection {
request_id: second_id,
reason: EditPredictionRejectReason::CurrentPreferred,
- was_shown: false
+ was_shown: false,
+ model_version: None,
}]
);
}
@@ -1234,7 +1238,8 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
&[EditPredictionRejection {
request_id: first_id,
reason: EditPredictionRejectReason::Canceled,
- was_shown: false
+ was_shown: false,
+ model_version: None,
}]
);
}
@@ -1364,12 +1369,14 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
EditPredictionRejection {
request_id: cancelled_id,
reason: EditPredictionRejectReason::Canceled,
- was_shown: false
+ was_shown: false,
+ model_version: None,
},
EditPredictionRejection {
request_id: first_id,
reason: EditPredictionRejectReason::Replaced,
- was_shown: false
+ was_shown: false,
+ model_version: None,
}
]
);
@@ -1485,12 +1492,14 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
EditPredictionId("test-1".into()),
EditPredictionRejectReason::Discarded,
false,
+ None,
cx,
);
ep_store.reject_prediction(
EditPredictionId("test-2".into()),
EditPredictionRejectReason::Canceled,
true,
+ None,
cx,
);
});
@@ -1508,7 +1517,8 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
EditPredictionRejection {
request_id: "test-1".to_string(),
reason: EditPredictionRejectReason::Discarded,
- was_shown: false
+ was_shown: false,
+ model_version: None,
}
);
assert_eq!(
@@ -1516,7 +1526,8 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
EditPredictionRejection {
request_id: "test-2".to_string(),
reason: EditPredictionRejectReason::Canceled,
- was_shown: true
+ was_shown: true,
+ model_version: None,
}
);
@@ -1527,6 +1538,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
EditPredictionId(format!("batch-{}", i).into()),
EditPredictionRejectReason::Discarded,
false,
+ None,
cx,
);
}
@@ -1558,6 +1570,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
EditPredictionId("retry-1".into()),
EditPredictionRejectReason::Discarded,
false,
+ None,
cx,
);
});
@@ -1577,6 +1590,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
EditPredictionId("retry-2".into()),
EditPredictionRejectReason::Discarded,
false,
+ None,
cx,
);
});
@@ -1700,6 +1714,7 @@ fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> Predi
request_id: Uuid::new_v4().to_string(),
editable_range,
output: new_excerpt,
+ model_version: None,
}
}
@@ -1708,6 +1723,7 @@ fn empty_response() -> PredictEditsV3Response {
request_id: Uuid::new_v4().to_string(),
editable_range: 0..0,
output: String::new(),
+ model_version: None,
}
}
@@ -1837,6 +1853,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
+ model_version: None,
};
cx.update(|cx| {
@@ -2034,6 +2051,7 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte
request_id: Uuid::new_v4().to_string(),
output: "hello world\n".to_string(),
editable_range: 0..excerpt_length,
+ model_version: None,
};
respond_tx.send(response).unwrap();
@@ -2138,6 +2156,7 @@ async fn make_test_ep_store(
request_id: format!("request-{next_request_id}"),
editable_range: 0..req.input.cursor_excerpt.len(),
output: completion_response.lock().clone(),
+ model_version: None,
})
.unwrap()
.into(),
@@ -141,6 +141,7 @@ pub fn request_prediction(
output.buffer_snapshotted_at,
output.response_received_at,
output.inputs,
+ None,
cx,
)
.await,
@@ -218,6 +218,7 @@ impl Mercury {
buffer_snapshotted_at,
response_received_at,
inputs,
+ None,
cx,
)
.await,
@@ -41,6 +41,7 @@ impl EditPredictionResult {
buffer_snapshotted_at: Instant,
response_received_at: Instant,
inputs: ZetaPromptInput,
+ model_version: Option<String>,
cx: &mut AsyncApp,
) -> Self {
if edits.is_empty() {
@@ -79,6 +80,7 @@ impl EditPredictionResult {
buffer: edited_buffer.clone(),
buffer_snapshotted_at,
response_received_at,
+ model_version,
}),
}
}
@@ -95,6 +97,7 @@ pub struct EditPrediction {
pub buffer_snapshotted_at: Instant,
pub response_received_at: Instant,
pub inputs: zeta_prompt::ZetaPromptInput,
+ pub model_version: Option<String>,
}
impl EditPrediction {
@@ -150,6 +153,7 @@ mod tests {
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
buffer: buffer.clone(),
edit_preview,
+ model_version: None,
inputs: ZetaPromptInput {
events: vec![],
related_files: vec![],
@@ -303,6 +303,7 @@ impl SweepAi {
buffer_snapshotted_at,
response_received_at,
inputs,
+ None,
cx,
)
.await,
@@ -125,7 +125,7 @@ pub fn request_prediction_with_zeta(
log::trace!("Sending edit prediction request");
- let (request_id, output_text, usage) = if let Some(custom_settings) =
+ let (request_id, output_text, model_version, usage) = if let Some(custom_settings) =
&custom_server_settings
{
let max_tokens = custom_settings.max_output_tokens * 4;
@@ -158,7 +158,7 @@ pub fn request_prediction_with_zeta(
let request_id = EditPredictionId(request_id.into());
let output_text = zeta1::clean_zeta1_model_output(&response_text);
- (request_id, output_text, None)
+ (request_id, output_text, None, None)
} else {
let prompt = format_zeta_prompt(&prompt_input, zeta_version);
let prefill = get_prefill(&prompt_input, zeta_version);
@@ -188,7 +188,7 @@ pub fn request_prediction_with_zeta(
Some(clean_zeta2_model_output(&output, zeta_version).to_string())
};
- (request_id, output_text, None)
+ (request_id, output_text, None, None)
}
} else if let Some(config) = &raw_config {
let prompt = format_zeta_prompt(&prompt_input, config.format);
@@ -225,7 +225,7 @@ pub fn request_prediction_with_zeta(
clean_zeta2_model_output(&output, config.format).to_string()
});
- (request_id, output_text, usage)
+ (request_id, output_text, None, usage)
} else {
// Use V3 endpoint - server handles model/version selection and suffix stripping
let (response, usage) = EditPredictionStore::send_v3_request(
@@ -244,8 +244,9 @@ pub fn request_prediction_with_zeta(
Some(response.output)
};
editable_range_in_excerpt = response.editable_range;
+ let model_version = response.model_version;
- (request_id, output_text, usage)
+ (request_id, output_text, model_version, usage)
};
let received_response_at = Instant::now();
@@ -253,7 +254,7 @@ pub fn request_prediction_with_zeta(
log::trace!("Got edit prediction response");
let Some(mut output_text) = output_text else {
- return Ok((Some((request_id, None)), usage));
+ return Ok((Some((request_id, None, model_version)), usage));
};
// Client-side cursor marker processing (applies to both raw and v3 responses)
@@ -309,6 +310,7 @@ pub fn request_prediction_with_zeta(
cursor_position,
received_response_at,
)),
+ model_version,
)),
usage,
))
@@ -316,7 +318,7 @@ pub fn request_prediction_with_zeta(
});
cx.spawn(async move |this, cx| {
- let Some((id, prediction)) =
+ let Some((id, prediction, model_version)) =
EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
else {
return Ok(None);
@@ -347,6 +349,7 @@ pub fn request_prediction_with_zeta(
buffer_snapshotted_at,
received_response_at,
inputs,
+ model_version,
cx,
)
.await,
@@ -476,6 +479,7 @@ pub(crate) fn edit_prediction_accepted(
}
let request_id = current_prediction.prediction.id.to_string();
+ let model_version = current_prediction.prediction.model_version;
let require_auth = custom_accept_url.is_none();
let client = store.client.clone();
let llm_token = store.llm_token.clone();
@@ -495,6 +499,7 @@ pub(crate) fn edit_prediction_accepted(
let req = builder.uri(url.as_ref()).body(
serde_json::to_string(&AcceptEditPredictionBody {
request_id: request_id.clone(),
+ model_version: model_version.clone(),
})?
.into(),
);