ep: Fix teacher's output parser for v0327 (#54416)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/format_prompt.rs | 101 +++++++++++++++++-
crates/edit_prediction_cli/src/main.rs          |  19 +--
2 files changed, 103 insertions(+), 17 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -12,10 +12,20 @@ use std::ops::Range;
 use std::sync::Arc;
 use zeta_prompt::udiff;
 use zeta_prompt::{
-    ZetaFormat, encode_patch_as_output_for_format, excerpt_range_for_format, format_zeta_prompt,
-    multi_region, output_end_marker_for_format, resolve_cursor_region,
+    ZetaFormat, encode_patch_as_output_for_format, format_zeta_prompt, multi_region,
+    output_end_marker_for_format, resolve_cursor_region,
 };
 
+fn resolved_excerpt_ranges_for_format(
+    input: &zeta_prompt::ZetaPromptInput,
+    format: ZetaFormat,
+) -> (Range<usize>, Range<usize>) {
+    let (_, editable_range_in_context, context_range, _) = resolve_cursor_region(input, format);
+    let editable_range = (context_range.start + editable_range_in_context.start)
+        ..(context_range.start + editable_range_in_context.end);
+    (editable_range, context_range)
+}
+
 pub async fn run_format_prompt(
     example: &mut Example,
     args: &FormatPromptArgs,
@@ -38,7 +48,7 @@ pub async fn run_format_prompt(
             step_progress.set_substatus("formatting teacher prompt");
 
             let (editable_range, context_range) =
-                excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges);
+                resolved_excerpt_ranges_for_format(prompt_inputs, zeta_format);
 
             let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
             example.prompt = Some(ExamplePrompt {
@@ -55,7 +65,7 @@ pub async fn run_format_prompt(
 
             let zeta_format = ZetaFormat::default();
             let (editable_range, context_range) =
-                excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges);
+                resolved_excerpt_ranges_for_format(prompt_inputs, zeta_format);
 
             let prompt =
                 TeacherMultiRegionPrompt::format_prompt(example, editable_range, context_range);
@@ -454,8 +464,7 @@ impl TeacherMultiRegionPrompt {
             .context("example is missing prompt inputs")?;
 
         let zeta_format = ZetaFormat::default();
-        let (editable_range, _) =
-            excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges);
+        let (editable_range, _) = resolved_excerpt_ranges_for_format(prompt_inputs, zeta_format);
         let excerpt = prompt_inputs.cursor_excerpt.as_ref();
         let old_editable_region = &excerpt[editable_range.clone()];
         let marker_offsets = multi_region::compute_marker_offsets(old_editable_region);
@@ -940,4 +949,84 @@ mod tests {
         assert!(parsed.0.is_empty());
         assert!(parsed.1.is_none());
     }
+
+    #[test]
+    fn test_v0327_teacher_prompt_uses_resolved_ranges() {
+        let excerpt = (0..80)
+            .map(|index| format!("line{index:02}\n"))
+            .collect::<String>();
+        let cursor_offset = excerpt.find("line40").expect("cursor line exists");
+        let prompt_inputs = zeta_prompt::ZetaPromptInput {
+            cursor_path: std::path::Path::new("src/main.rs").into(),
+            cursor_excerpt: excerpt.clone().into(),
+            cursor_offset_in_excerpt: cursor_offset,
+            excerpt_start_row: None,
+            events: Vec::new(),
+            related_files: Some(Vec::new()),
+            active_buffer_diagnostics: Vec::new(),
+            excerpt_ranges: zeta_prompt::ExcerptRanges {
+                editable_150: 0..32,
+                editable_180: 0..32,
+                editable_350: 0..32,
+                editable_512: None,
+                editable_150_context_350: 0..48,
+                editable_180_context_350: 0..48,
+                editable_350_context_150: 20..50,
+                editable_350_context_512: None,
+                editable_350_context_1024: None,
+                context_4096: None,
+                context_8192: Some(30..excerpt.len()),
+            },
+            syntax_ranges: None,
+            in_open_source_repo: false,
+            can_collect_data: false,
+            repo_url: None,
+        };
+
+        let (stored_editable_range, stored_context_range) = zeta_prompt::excerpt_range_for_format(
+            ZetaFormat::V0327SingleFile,
+            &prompt_inputs.excerpt_ranges,
+        );
+        assert!(stored_context_range.start > stored_editable_range.start);
+
+        let (editable_range, context_range) =
+            resolved_excerpt_ranges_for_format(&prompt_inputs, ZetaFormat::V0327SingleFile);
+        assert_eq!(context_range, 0..excerpt.len());
+        assert!(editable_range.start < cursor_offset);
+        assert!(editable_range.end > cursor_offset);
+
+        let prompt = TeacherPrompt::format_prompt(
+            &Example {
+                spec: edit_prediction::example_spec::ExampleSpec {
+                    name: "test".to_string(),
+                    repository_url: "https://github.com/zed-industries/zed.git".to_string(),
+                    revision: "HEAD".to_string(),
+                    tags: Vec::new(),
+                    reasoning: None,
+                    uncommitted_diff: String::new(),
+                    cursor_path: std::sync::Arc::from(std::path::Path::new("src/main.rs")),
+                    cursor_position: "0:0".to_string(),
+                    edit_history: String::new(),
+                    expected_patches: Vec::new(),
+                    rejected_patch: None,
+                    telemetry: None,
+                    human_feedback: Vec::new(),
+                    rating: None,
+                },
+                prompt_inputs: Some(prompt_inputs),
+                prompt: None,
+                predictions: Vec::new(),
+                score: Vec::new(),
+                qa: Vec::new(),
+                zed_version: None,
+                state: None,
+            },
+            editable_range,
+            context_range,
+        );
+
+        assert!(prompt.contains(TeacherPrompt::EDITABLE_REGION_START));
+        assert!(prompt.contains(TeacherPrompt::USER_CURSOR_MARKER));
+        assert!(prompt.contains("line40"));
+    }
 }

crates/edit_prediction_cli/src/main.rs 🔗

@@ -415,16 +415,13 @@ impl std::str::FromStr for PredictionProvider {
                 let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
                 Ok(PredictionProvider::Zeta2(format))
             }
-            "teacher" => parse_teacher_args(arg),
+            "teacher" => {
+                let (backend, format) = parse_teacher_args(arg)?;
+                Ok(PredictionProvider::Teacher(backend, format))
+            }
             "teacher-non-batching" | "teacher_non_batching" => {
-                let backend = arg
-                    .map(|a| a.parse())
-                    .transpose()?
-                    .unwrap_or(TeacherBackend::default());
-                Ok(PredictionProvider::TeacherNonBatching(
-                    backend,
-                    ZetaFormat::default(),
-                ))
+                let (backend, format) = parse_teacher_args(arg)?;
+                Ok(PredictionProvider::TeacherNonBatching(backend, format))
             }
             "teacher-multi-region" | "teacher_multi_region" => {
                 let backend = arg
@@ -461,7 +458,7 @@ impl std::str::FromStr for PredictionProvider {
     }
 }
 
-fn parse_teacher_args(arg: Option<&str>) -> Result<PredictionProvider, anyhow::Error> {
+fn parse_teacher_args(arg: Option<&str>) -> Result<(TeacherBackend, ZetaFormat), anyhow::Error> {
     let mut backend = TeacherBackend::default();
     let mut format = ZetaFormat::default();
 
@@ -479,7 +476,7 @@ fn parse_teacher_args(arg: Option<&str>) -> Result<PredictionProvider, anyhow::E
         }
     }
 
-    Ok(PredictionProvider::Teacher(backend, format))
+    Ok((backend, format))
 }
 
 impl Serialize for PredictionProvider {