diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 9e7772ab7450cb47785d034b39d9c7c642b931c2..d0b53ca18e8c74ec2588bff14c5130e3381f9444 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -1,6 +1,7 @@ use crate::PredictEditsRequestTrigger; use serde::{Deserialize, Serialize}; use std::borrow::Cow; +use std::ops::Range; #[derive(Debug, Deserialize, Serialize)] pub struct RawCompletionRequest { @@ -27,6 +28,11 @@ pub struct PredictEditsV3Request { pub struct PredictEditsV3Response { pub request_id: String, pub output: String, + /// The editable region byte range within `cursor_excerpt` that the + /// server used for this request. When present, the client should use + /// this range to extract the old text from its local excerpt for + /// diffing, rather than relying on its own format-derived range. + pub editable_range: Range, } #[derive(Debug, Deserialize, Serialize)] diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index eb76e0fd05182a1b9048bcf36f1bcebe8e808ef2..b0468e3c5610b8f618631be6707c74c4eaa451e5 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1687,12 +1687,18 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) { // Generate a model response that would apply the given diff to the active file. fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response { - let excerpt = - request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string(); + let editable_range = request + .input + .excerpt_ranges + .as_ref() + .map(|r| zeta_prompt::excerpt_range_for_format(Default::default(), r).1) + .unwrap_or(request.input.editable_range_in_excerpt.clone()); + let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string(); let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap(); PredictEditsV3Response { request_id: Uuid::new_v4().to_string(), + editable_range, output: new_excerpt, } } @@ -1700,6 +1706,7 @@ fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> Predi fn empty_response() -> PredictEditsV3Response { PredictEditsV3Response { request_id: Uuid::new_v4().to_string(), + editable_range: 0..0, output: String::new(), } } @@ -2018,13 +2025,15 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); - let (_request, respond_tx) = requests.predict.next().await.unwrap(); + let (request, respond_tx) = requests.predict.next().await.unwrap(); // Model returns output WITH a trailing newline, even though the buffer doesn't have one. // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted. + let excerpt_length = request.input.cursor_excerpt.len(); let response = PredictEditsV3Response { request_id: Uuid::new_v4().to_string(), output: "hello world\n".to_string(), + editable_range: 0..excerpt_length, }; respond_tx.send(response).unwrap(); @@ -2099,9 +2108,12 @@ async fn make_test_ep_store( let mut next_request_id = 0; move |req| { let completion_response = completion_response.clone(); + let method = req.method().clone(); + let uri = req.uri().path().to_string(); + let mut body = req.into_body(); async move { - match (req.method(), req.uri().path()) { - (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + match (method, uri.as_str()) { + (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() .status(200) .body( serde_json::to_string(&CreateLlmTokenResponse { @@ -2111,13 +2123,20 @@ async fn make_test_ep_store( .into(), ) .unwrap()), - (&Method::POST, "/predict_edits/v3") => { + (Method::POST, "/predict_edits/v3") => { + let mut buf = Vec::new(); + body.read_to_end(&mut buf).await.ok(); + let decompressed = zstd::decode_all(&buf[..]).unwrap(); + let req: PredictEditsV3Request = + serde_json::from_slice(&decompressed).unwrap(); + next_request_id += 1; Ok(http_client::Response::builder() .status(200) .body( serde_json::to_string(&PredictEditsV3Response { request_id: format!("request-{next_request_id}"), + editable_range: 0..req.input.cursor_excerpt.len(), output: completion_response.lock().clone(), }) .unwrap() @@ -2127,7 +2146,7 @@ async fn make_test_ep_store( } _ => Ok(http_client::Response::builder() .status(404) - .body("Not Found".into()) + .body("Not Found".to_string().into()) .unwrap()), } } diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index 658071c9ccfbdf64a9a1ebead7724774cd5cc40e..f6d6eaf689eabd417c432b0879fdf7c1cec47139 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -79,7 +79,8 @@ pub fn request_prediction_with_zeta( .unwrap_or(ZetaFormat::default()); let cursor_offset = position.to_offset(&snapshot); - let (editable_offset_range, prompt_input) = zeta2_prompt_input( + let editable_range_in_excerpt: Range; + let (full_context_offset_range, prompt_input) = zeta2_prompt_input( &snapshot, related_files, events, @@ -124,113 +125,129 @@ pub fn request_prediction_with_zeta( log::trace!("Sending edit prediction request"); - let (request_id, output_text, usage) = - if let Some(custom_settings) = &custom_server_settings { - let max_tokens = custom_settings.max_output_tokens * 4; - - if is_zeta1 { - let ranges = excerpt_ranges; - let prompt = zeta1::format_zeta1_from_input( - &prompt_input, - ranges.editable_350.clone(), - ranges.editable_350_context_150.clone(), - ); - let stop_tokens = vec![ - EDITABLE_REGION_END_MARKER.to_string(), - format!("{EDITABLE_REGION_END_MARKER}\n"), - format!("{EDITABLE_REGION_END_MARKER}\n\n"), - format!("{EDITABLE_REGION_END_MARKER}\n\n\n"), - ]; - - let (response_text, request_id) = send_custom_server_request( - provider, - custom_settings, - prompt, - max_tokens, - stop_tokens, - &http_client, - ) - .await?; - - let request_id = EditPredictionId(request_id.into()); - let output_text = zeta1::clean_zeta1_model_output(&response_text); - - (request_id, output_text, None) - } else { - let prompt = format_zeta_prompt(&prompt_input, zeta_version); - let prefill = get_prefill(&prompt_input, zeta_version); - let prompt = format!("{prompt}{prefill}"); - - let (response_text, request_id) = send_custom_server_request( - provider, - custom_settings, - prompt, - max_tokens, - vec![], - &http_client, - ) - .await?; - - let request_id = EditPredictionId(request_id.into()); - let output_text = if response_text.is_empty() { - None - } else { - let output = format!("{prefill}{response_text}"); - Some(clean_zeta2_model_output(&output, zeta_version).to_string()) - }; - - (request_id, output_text, None) - } - } else if let Some(config) = &raw_config { - let prompt = format_zeta_prompt(&prompt_input, config.format); - let prefill = get_prefill(&prompt_input, config.format); - let prompt = format!("{prompt}{prefill}"); - let request = RawCompletionRequest { - model: config.model_id.clone().unwrap_or_default(), - prompt, - temperature: None, - stop: vec![], - max_tokens: Some(2048), - environment: Some(config.format.to_string().to_lowercase()), - }; + let (request_id, output_text, usage) = if let Some(custom_settings) = + &custom_server_settings + { + let max_tokens = custom_settings.max_output_tokens * 4; - let (mut response, usage) = EditPredictionStore::send_raw_llm_request( - request, - client, - None, - llm_token, - app_version, + if is_zeta1 { + let ranges = excerpt_ranges; + let prompt = zeta1::format_zeta1_from_input( + &prompt_input, + ranges.editable_350.clone(), + ranges.editable_350_context_150.clone(), + ); + editable_range_in_excerpt = ranges.editable_350.clone(); + let stop_tokens = vec![ + EDITABLE_REGION_END_MARKER.to_string(), + format!("{EDITABLE_REGION_END_MARKER}\n"), + format!("{EDITABLE_REGION_END_MARKER}\n\n"), + format!("{EDITABLE_REGION_END_MARKER}\n\n\n"), + ]; + + let (response_text, request_id) = send_custom_server_request( + provider, + custom_settings, + prompt, + max_tokens, + stop_tokens, + &http_client, ) .await?; - let request_id = EditPredictionId(response.id.clone().into()); - let output_text = response.choices.pop().map(|choice| { - let response = &choice.text; - let output = format!("{prefill}{response}"); - clean_zeta2_model_output(&output, config.format).to_string() - }); + let request_id = EditPredictionId(request_id.into()); + let output_text = zeta1::clean_zeta1_model_output(&response_text); - (request_id, output_text, usage) + (request_id, output_text, None) } else { - // Use V3 endpoint - server handles model/version selection and suffix stripping - let (response, usage) = EditPredictionStore::send_v3_request( - prompt_input.clone(), - client, - llm_token, - app_version, - trigger, + let prompt = format_zeta_prompt(&prompt_input, zeta_version); + let prefill = get_prefill(&prompt_input, zeta_version); + let prompt = format!("{prompt}{prefill}"); + + editable_range_in_excerpt = prompt_input + .excerpt_ranges + .as_ref() + .map(|ranges| zeta_prompt::excerpt_range_for_format(zeta_version, ranges).0) + .unwrap_or(prompt_input.editable_range_in_excerpt.clone()); + + let (response_text, request_id) = send_custom_server_request( + provider, + custom_settings, + prompt, + max_tokens, + vec![], + &http_client, ) .await?; - let request_id = EditPredictionId(response.request_id.into()); - let output_text = if response.output.is_empty() { + let request_id = EditPredictionId(request_id.into()); + let output_text = if response_text.is_empty() { None } else { - Some(response.output) + let output = format!("{prefill}{response_text}"); + Some(clean_zeta2_model_output(&output, zeta_version).to_string()) }; - (request_id, output_text, usage) + + (request_id, output_text, None) + } + } else if let Some(config) = &raw_config { + let prompt = format_zeta_prompt(&prompt_input, config.format); + let prefill = get_prefill(&prompt_input, config.format); + let prompt = format!("{prompt}{prefill}"); + let request = RawCompletionRequest { + model: config.model_id.clone().unwrap_or_default(), + prompt, + temperature: None, + stop: vec![], + max_tokens: Some(2048), + environment: Some(config.format.to_string().to_lowercase()), }; + editable_range_in_excerpt = prompt_input + .excerpt_ranges + .as_ref() + .map(|ranges| zeta_prompt::excerpt_range_for_format(config.format, ranges).1) + .unwrap_or(prompt_input.editable_range_in_excerpt.clone()); + + let (mut response, usage) = EditPredictionStore::send_raw_llm_request( + request, + client, + None, + llm_token, + app_version, + ) + .await?; + + let request_id = EditPredictionId(response.id.clone().into()); + let output_text = response.choices.pop().map(|choice| { + let response = &choice.text; + let output = format!("{prefill}{response}"); + clean_zeta2_model_output(&output, config.format).to_string() + }); + + (request_id, output_text, usage) + } else { + // Use V3 endpoint - server handles model/version selection and suffix stripping + let (response, usage) = EditPredictionStore::send_v3_request( + prompt_input.clone(), + client, + llm_token, + app_version, + trigger, + ) + .await?; + + let request_id = EditPredictionId(response.request_id.into()); + let output_text = if response.output.is_empty() { + None + } else { + Some(response.output) + }; + editable_range_in_excerpt = response.editable_range; + + (request_id, output_text, usage) + }; + let received_response_at = Instant::now(); log::trace!("Got edit prediction response"); @@ -258,8 +275,12 @@ pub fn request_prediction_with_zeta( .ok(); } + let editable_range_in_buffer = editable_range_in_excerpt.start + + full_context_offset_range.start + ..editable_range_in_excerpt.end + full_context_offset_range.start; + let mut old_text = snapshot - .text_for_range(editable_offset_range.clone()) + .text_for_range(editable_range_in_buffer.clone()) .collect::(); if !output_text.is_empty() && !output_text.ends_with('\n') { @@ -272,7 +293,7 @@ pub fn request_prediction_with_zeta( let (edits, cursor_position) = compute_edits_and_cursor_position( old_text, &output_text, - editable_offset_range.start, + editable_range_in_buffer.start, cursor_offset_in_output, &snapshot, ); @@ -343,7 +364,7 @@ pub fn zeta2_prompt_input( preferred_model: Option, is_open_source: bool, can_collect_data: bool, -) -> (std::ops::Range, zeta_prompt::ZetaPromptInput) { +) -> (Range, zeta_prompt::ZetaPromptInput) { let cursor_point = cursor_offset.to_point(snapshot); let (full_context, full_context_offset_range, excerpt_ranges) = @@ -362,8 +383,6 @@ pub fn zeta2_prompt_input( Some(EditPredictionModelKind::Zeta1) => excerpt_ranges.editable_350.clone(), _ => zeta_prompt::excerpt_range_for_format(zeta_format, &excerpt_ranges).0, }; - let absolute_editable_range = full_context_start_offset + editable_offset_range.start - ..full_context_start_offset + editable_offset_range.end; let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset; @@ -383,7 +402,7 @@ pub fn zeta2_prompt_input( in_open_source_repo: is_open_source, can_collect_data, }; - (absolute_editable_range, prompt_input) + (full_context_offset_range, prompt_input) } pub(crate) async fn send_custom_server_request(