zeta2: Use editable range returned by cloud for prediction diffs (#50029)

Ben Kunkle and Max created

Closes #ISSUE

Before you mark this PR as ready for review, make sure that you have:
- [ ] Added a solid test coverage and/or screenshots from doing manual
testing
- [ ] Done a self-review taking into account security and performance
aspects
- [ ] Aligned any UI changes with the [UI
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Co-authored-by: Max <max@zed.dev>

Change summary

crates/cloud_llm_client/src/predict_edits_v3.rs     |   6 
crates/edit_prediction/src/edit_prediction_tests.rs |  33 +
crates/edit_prediction/src/zeta.rs                  | 219 ++++++++------
3 files changed, 151 insertions(+), 107 deletions(-)

Detailed changes

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -1,6 +1,7 @@
 use crate::PredictEditsRequestTrigger;
 use serde::{Deserialize, Serialize};
 use std::borrow::Cow;
+use std::ops::Range;
 
 #[derive(Debug, Deserialize, Serialize)]
 pub struct RawCompletionRequest {
@@ -27,6 +28,11 @@ pub struct PredictEditsV3Request {
 pub struct PredictEditsV3Response {
     pub request_id: String,
     pub output: String,
+    /// The editable region byte range within `cursor_excerpt` that the
+    /// server used for this request. When present, the client should use
+    /// this range to extract the old text from its local excerpt for
+    /// diffing, rather than relying on its own format-derived range.
+    pub editable_range: Range<usize>,
 }
 
 #[derive(Debug, Deserialize, Serialize)]

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -1687,12 +1687,18 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
 
 // Generate a model response that would apply the given diff to the active file.
 fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> PredictEditsV3Response {
-    let excerpt =
-        request.input.cursor_excerpt[request.input.editable_range_in_excerpt.clone()].to_string();
+    let editable_range = request
+        .input
+        .excerpt_ranges
+        .as_ref()
+        .map(|r| zeta_prompt::excerpt_range_for_format(Default::default(), r).1)
+        .unwrap_or(request.input.editable_range_in_excerpt.clone());
+    let excerpt = request.input.cursor_excerpt[editable_range.clone()].to_string();
     let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
 
     PredictEditsV3Response {
         request_id: Uuid::new_v4().to_string(),
+        editable_range,
         output: new_excerpt,
     }
 }
@@ -1700,6 +1706,7 @@ fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> Predi
 fn empty_response() -> PredictEditsV3Response {
     PredictEditsV3Response {
         request_id: Uuid::new_v4().to_string(),
+        editable_range: 0..0,
         output: String::new(),
     }
 }
@@ -2018,13 +2025,15 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte
         ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
     });
 
-    let (_request, respond_tx) = requests.predict.next().await.unwrap();
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
 
     // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
     // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
+    let excerpt_length = request.input.cursor_excerpt.len();
     let response = PredictEditsV3Response {
         request_id: Uuid::new_v4().to_string(),
         output: "hello world\n".to_string(),
+        editable_range: 0..excerpt_length,
     };
     respond_tx.send(response).unwrap();
 
@@ -2099,9 +2108,12 @@ async fn make_test_ep_store(
         let mut next_request_id = 0;
         move |req| {
             let completion_response = completion_response.clone();
+            let method = req.method().clone();
+            let uri = req.uri().path().to_string();
+            let mut body = req.into_body();
             async move {
-                match (req.method(), req.uri().path()) {
-                    (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
+                match (method, uri.as_str()) {
+                    (Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
                         .status(200)
                         .body(
                             serde_json::to_string(&CreateLlmTokenResponse {
@@ -2111,13 +2123,20 @@ async fn make_test_ep_store(
                             .into(),
                         )
                         .unwrap()),
-                    (&Method::POST, "/predict_edits/v3") => {
+                    (Method::POST, "/predict_edits/v3") => {
+                        let mut buf = Vec::new();
+                        body.read_to_end(&mut buf).await.ok();
+                        let decompressed = zstd::decode_all(&buf[..]).unwrap();
+                        let req: PredictEditsV3Request =
+                            serde_json::from_slice(&decompressed).unwrap();
+
                         next_request_id += 1;
                         Ok(http_client::Response::builder()
                             .status(200)
                             .body(
                                 serde_json::to_string(&PredictEditsV3Response {
                                     request_id: format!("request-{next_request_id}"),
+                                    editable_range: 0..req.input.cursor_excerpt.len(),
                                     output: completion_response.lock().clone(),
                                 })
                                 .unwrap()
@@ -2127,7 +2146,7 @@ async fn make_test_ep_store(
                     }
                     _ => Ok(http_client::Response::builder()
                         .status(404)
-                        .body("Not Found".into())
+                        .body("Not Found".to_string().into())
                         .unwrap()),
                 }
             }

crates/edit_prediction/src/zeta.rs 🔗

@@ -79,7 +79,8 @@ pub fn request_prediction_with_zeta(
                 .unwrap_or(ZetaFormat::default());
 
             let cursor_offset = position.to_offset(&snapshot);
-            let (editable_offset_range, prompt_input) = zeta2_prompt_input(
+            let editable_range_in_excerpt: Range<usize>;
+            let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
                 &snapshot,
                 related_files,
                 events,
@@ -124,113 +125,129 @@ pub fn request_prediction_with_zeta(
 
             log::trace!("Sending edit prediction request");
 
-            let (request_id, output_text, usage) =
-                if let Some(custom_settings) = &custom_server_settings {
-                    let max_tokens = custom_settings.max_output_tokens * 4;
-
-                    if is_zeta1 {
-                        let ranges = excerpt_ranges;
-                        let prompt = zeta1::format_zeta1_from_input(
-                            &prompt_input,
-                            ranges.editable_350.clone(),
-                            ranges.editable_350_context_150.clone(),
-                        );
-                        let stop_tokens = vec![
-                            EDITABLE_REGION_END_MARKER.to_string(),
-                            format!("{EDITABLE_REGION_END_MARKER}\n"),
-                            format!("{EDITABLE_REGION_END_MARKER}\n\n"),
-                            format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
-                        ];
-
-                        let (response_text, request_id) = send_custom_server_request(
-                            provider,
-                            custom_settings,
-                            prompt,
-                            max_tokens,
-                            stop_tokens,
-                            &http_client,
-                        )
-                        .await?;
-
-                        let request_id = EditPredictionId(request_id.into());
-                        let output_text = zeta1::clean_zeta1_model_output(&response_text);
-
-                        (request_id, output_text, None)
-                    } else {
-                        let prompt = format_zeta_prompt(&prompt_input, zeta_version);
-                        let prefill = get_prefill(&prompt_input, zeta_version);
-                        let prompt = format!("{prompt}{prefill}");
-
-                        let (response_text, request_id) = send_custom_server_request(
-                            provider,
-                            custom_settings,
-                            prompt,
-                            max_tokens,
-                            vec![],
-                            &http_client,
-                        )
-                        .await?;
-
-                        let request_id = EditPredictionId(request_id.into());
-                        let output_text = if response_text.is_empty() {
-                            None
-                        } else {
-                            let output = format!("{prefill}{response_text}");
-                            Some(clean_zeta2_model_output(&output, zeta_version).to_string())
-                        };
-
-                        (request_id, output_text, None)
-                    }
-                } else if let Some(config) = &raw_config {
-                    let prompt = format_zeta_prompt(&prompt_input, config.format);
-                    let prefill = get_prefill(&prompt_input, config.format);
-                    let prompt = format!("{prompt}{prefill}");
-                    let request = RawCompletionRequest {
-                        model: config.model_id.clone().unwrap_or_default(),
-                        prompt,
-                        temperature: None,
-                        stop: vec![],
-                        max_tokens: Some(2048),
-                        environment: Some(config.format.to_string().to_lowercase()),
-                    };
+            let (request_id, output_text, usage) = if let Some(custom_settings) =
+                &custom_server_settings
+            {
+                let max_tokens = custom_settings.max_output_tokens * 4;
 
-                    let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
-                        request,
-                        client,
-                        None,
-                        llm_token,
-                        app_version,
+                if is_zeta1 {
+                    let ranges = excerpt_ranges;
+                    let prompt = zeta1::format_zeta1_from_input(
+                        &prompt_input,
+                        ranges.editable_350.clone(),
+                        ranges.editable_350_context_150.clone(),
+                    );
+                    editable_range_in_excerpt = ranges.editable_350.clone();
+                    let stop_tokens = vec![
+                        EDITABLE_REGION_END_MARKER.to_string(),
+                        format!("{EDITABLE_REGION_END_MARKER}\n"),
+                        format!("{EDITABLE_REGION_END_MARKER}\n\n"),
+                        format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
+                    ];
+
+                    let (response_text, request_id) = send_custom_server_request(
+                        provider,
+                        custom_settings,
+                        prompt,
+                        max_tokens,
+                        stop_tokens,
+                        &http_client,
                     )
                     .await?;
 
-                    let request_id = EditPredictionId(response.id.clone().into());
-                    let output_text = response.choices.pop().map(|choice| {
-                        let response = &choice.text;
-                        let output = format!("{prefill}{response}");
-                        clean_zeta2_model_output(&output, config.format).to_string()
-                    });
+                    let request_id = EditPredictionId(request_id.into());
+                    let output_text = zeta1::clean_zeta1_model_output(&response_text);
 
-                    (request_id, output_text, usage)
+                    (request_id, output_text, None)
                 } else {
-                    // Use V3 endpoint - server handles model/version selection and suffix stripping
-                    let (response, usage) = EditPredictionStore::send_v3_request(
-                        prompt_input.clone(),
-                        client,
-                        llm_token,
-                        app_version,
-                        trigger,
+                    let prompt = format_zeta_prompt(&prompt_input, zeta_version);
+                    let prefill = get_prefill(&prompt_input, zeta_version);
+                    let prompt = format!("{prompt}{prefill}");
+
+                    editable_range_in_excerpt = prompt_input
+                        .excerpt_ranges
+                        .as_ref()
+                        .map(|ranges| zeta_prompt::excerpt_range_for_format(zeta_version, ranges).0)
+                        .unwrap_or(prompt_input.editable_range_in_excerpt.clone());
+
+                    let (response_text, request_id) = send_custom_server_request(
+                        provider,
+                        custom_settings,
+                        prompt,
+                        max_tokens,
+                        vec![],
+                        &http_client,
                     )
                     .await?;
 
-                    let request_id = EditPredictionId(response.request_id.into());
-                    let output_text = if response.output.is_empty() {
+                    let request_id = EditPredictionId(request_id.into());
+                    let output_text = if response_text.is_empty() {
                         None
                     } else {
-                        Some(response.output)
+                        let output = format!("{prefill}{response_text}");
+                        Some(clean_zeta2_model_output(&output, zeta_version).to_string())
                     };
-                    (request_id, output_text, usage)
+
+                    (request_id, output_text, None)
+                }
+            } else if let Some(config) = &raw_config {
+                let prompt = format_zeta_prompt(&prompt_input, config.format);
+                let prefill = get_prefill(&prompt_input, config.format);
+                let prompt = format!("{prompt}{prefill}");
+                let request = RawCompletionRequest {
+                    model: config.model_id.clone().unwrap_or_default(),
+                    prompt,
+                    temperature: None,
+                    stop: vec![],
+                    max_tokens: Some(2048),
+                    environment: Some(config.format.to_string().to_lowercase()),
                 };
 
+                editable_range_in_excerpt = prompt_input
+                    .excerpt_ranges
+                    .as_ref()
+                    .map(|ranges| zeta_prompt::excerpt_range_for_format(config.format, ranges).1)
+                    .unwrap_or(prompt_input.editable_range_in_excerpt.clone());
+
+                let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
+                    request,
+                    client,
+                    None,
+                    llm_token,
+                    app_version,
+                )
+                .await?;
+
+                let request_id = EditPredictionId(response.id.clone().into());
+                let output_text = response.choices.pop().map(|choice| {
+                    let response = &choice.text;
+                    let output = format!("{prefill}{response}");
+                    clean_zeta2_model_output(&output, config.format).to_string()
+                });
+
+                (request_id, output_text, usage)
+            } else {
+                // Use V3 endpoint - server handles model/version selection and suffix stripping
+                let (response, usage) = EditPredictionStore::send_v3_request(
+                    prompt_input.clone(),
+                    client,
+                    llm_token,
+                    app_version,
+                    trigger,
+                )
+                .await?;
+
+                let request_id = EditPredictionId(response.request_id.into());
+                let output_text = if response.output.is_empty() {
+                    None
+                } else {
+                    Some(response.output)
+                };
+                editable_range_in_excerpt = response.editable_range;
+
+                (request_id, output_text, usage)
+            };
+
             let received_response_at = Instant::now();
 
             log::trace!("Got edit prediction response");
@@ -258,8 +275,12 @@ pub fn request_prediction_with_zeta(
                     .ok();
             }
 
+            let editable_range_in_buffer = editable_range_in_excerpt.start
+                + full_context_offset_range.start
+                ..editable_range_in_excerpt.end + full_context_offset_range.start;
+
             let mut old_text = snapshot
-                .text_for_range(editable_offset_range.clone())
+                .text_for_range(editable_range_in_buffer.clone())
                 .collect::<String>();
 
             if !output_text.is_empty() && !output_text.ends_with('\n') {
@@ -272,7 +293,7 @@ pub fn request_prediction_with_zeta(
             let (edits, cursor_position) = compute_edits_and_cursor_position(
                 old_text,
                 &output_text,
-                editable_offset_range.start,
+                editable_range_in_buffer.start,
                 cursor_offset_in_output,
                 &snapshot,
             );
@@ -343,7 +364,7 @@ pub fn zeta2_prompt_input(
     preferred_model: Option<EditPredictionModelKind>,
     is_open_source: bool,
     can_collect_data: bool,
-) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
+) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
     let cursor_point = cursor_offset.to_point(snapshot);
 
     let (full_context, full_context_offset_range, excerpt_ranges) =
@@ -362,8 +383,6 @@ pub fn zeta2_prompt_input(
         Some(EditPredictionModelKind::Zeta1) => excerpt_ranges.editable_350.clone(),
         _ => zeta_prompt::excerpt_range_for_format(zeta_format, &excerpt_ranges).0,
     };
-    let absolute_editable_range = full_context_start_offset + editable_offset_range.start
-        ..full_context_start_offset + editable_offset_range.end;
 
     let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
 
@@ -383,7 +402,7 @@ pub fn zeta2_prompt_input(
         in_open_source_repo: is_open_source,
         can_collect_data,
     };
-    (absolute_editable_range, prompt_input)
+    (full_context_offset_range, prompt_input)
 }
 
 pub(crate) async fn send_custom_server_request(