diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 15932706684ec47bac00048407426a617c21a23b..7b14af087fe369146379ec729f1fadef3d7602b6 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -12,10 +12,20 @@ use std::ops::Range; use std::sync::Arc; use zeta_prompt::udiff; use zeta_prompt::{ - ZetaFormat, encode_patch_as_output_for_format, excerpt_range_for_format, format_zeta_prompt, - multi_region, output_end_marker_for_format, resolve_cursor_region, + ZetaFormat, encode_patch_as_output_for_format, format_zeta_prompt, multi_region, + output_end_marker_for_format, resolve_cursor_region, }; +fn resolved_excerpt_ranges_for_format( + input: &zeta_prompt::ZetaPromptInput, + format: ZetaFormat, +) -> (Range, Range) { + let (_, editable_range_in_context, context_range, _) = resolve_cursor_region(input, format); + let editable_range = (context_range.start + editable_range_in_context.start) + ..(context_range.start + editable_range_in_context.end); + (editable_range, context_range) +} + pub async fn run_format_prompt( example: &mut Example, args: &FormatPromptArgs, @@ -38,7 +48,7 @@ pub async fn run_format_prompt( step_progress.set_substatus("formatting teacher prompt"); let (editable_range, context_range) = - excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges); + resolved_excerpt_ranges_for_format(prompt_inputs, zeta_format); let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range); example.prompt = Some(ExamplePrompt { @@ -55,7 +65,7 @@ pub async fn run_format_prompt( let zeta_format = ZetaFormat::default(); let (editable_range, context_range) = - excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges); + resolved_excerpt_ranges_for_format(prompt_inputs, zeta_format); let prompt = TeacherMultiRegionPrompt::format_prompt(example, editable_range, context_range); @@ -454,8 +464,7 @@ impl TeacherMultiRegionPrompt { .context("example is missing prompt inputs")?; let zeta_format = ZetaFormat::default(); - let (editable_range, _) = - excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges); + let (editable_range, _) = resolved_excerpt_ranges_for_format(prompt_inputs, zeta_format); let excerpt = prompt_inputs.cursor_excerpt.as_ref(); let old_editable_region = &excerpt[editable_range.clone()]; let marker_offsets = multi_region::compute_marker_offsets(old_editable_region); @@ -940,4 +949,84 @@ mod tests { assert!(parsed.0.is_empty()); assert!(parsed.1.is_none()); } + + #[test] + fn test_v0327_teacher_prompt_uses_resolved_ranges() { + let excerpt = (0..80) + .map(|index| format!("line{index:02}\n")) + .collect::(); + let cursor_offset = excerpt.find("line40").expect("cursor line exists"); + let prompt_inputs = zeta_prompt::ZetaPromptInput { + cursor_path: std::path::Path::new("src/main.rs").into(), + cursor_excerpt: excerpt.clone().into(), + cursor_offset_in_excerpt: cursor_offset, + excerpt_start_row: None, + events: Vec::new(), + related_files: Some(Vec::new()), + active_buffer_diagnostics: Vec::new(), + excerpt_ranges: zeta_prompt::ExcerptRanges { + editable_150: 0..32, + editable_180: 0..32, + editable_350: 0..32, + editable_512: None, + editable_150_context_350: 0..48, + editable_180_context_350: 0..48, + editable_350_context_150: 20..50, + editable_350_context_512: None, + editable_350_context_1024: None, + context_4096: None, + context_8192: Some(30..excerpt.len()), + }, + syntax_ranges: None, + in_open_source_repo: false, + can_collect_data: false, + repo_url: None, + }; + + let (stored_editable_range, stored_context_range) = zeta_prompt::excerpt_range_for_format( + ZetaFormat::V0327SingleFile, + &prompt_inputs.excerpt_ranges, + ); + assert!(stored_context_range.start > stored_editable_range.start); + + let (editable_range, context_range) = + resolved_excerpt_ranges_for_format(&prompt_inputs, ZetaFormat::V0327SingleFile); + assert_eq!(context_range, 0..excerpt.len()); + assert!(editable_range.start < cursor_offset); + assert!(editable_range.end > cursor_offset); + + let prompt = TeacherPrompt::format_prompt( + &Example { + spec: edit_prediction::example_spec::ExampleSpec { + name: "test".to_string(), + repository_url: "https://github.com/zed-industries/zed.git".to_string(), + revision: "HEAD".to_string(), + tags: Vec::new(), + reasoning: None, + uncommitted_diff: String::new(), + cursor_path: std::sync::Arc::from(std::path::Path::new("src/main.rs")), + cursor_position: "0:0".to_string(), + edit_history: String::new(), + expected_patches: Vec::new(), + rejected_patch: None, + telemetry: None, + human_feedback: Vec::new(), + rating: None, + }, + prompt_inputs: Some(prompt_inputs), + prompt: None, + predictions: Vec::new(), + score: Vec::new(), + qa: Vec::new(), + zed_version: None, + state: None, + }, + editable_range, + context_range, + ); + + assert!(prompt.contains(TeacherPrompt::EDITABLE_REGION_START)); + assert!(prompt.contains(TeacherPrompt::USER_CURSOR_MARKER)); + assert!(prompt.contains("line40")); + } } diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 2df477e2693c074e12668173db6a38627ca57213..b4951ae9d9f1175fd5c4a78ea4f93fa08ed4d83d 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -415,16 +415,13 @@ impl std::str::FromStr for PredictionProvider { let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default(); Ok(PredictionProvider::Zeta2(format)) } - "teacher" => parse_teacher_args(arg), + "teacher" => { + let (backend, format) = parse_teacher_args(arg)?; + Ok(PredictionProvider::Teacher(backend, format)) + } "teacher-non-batching" | "teacher_non_batching" => { - let backend = arg - .map(|a| a.parse()) - .transpose()? - .unwrap_or(TeacherBackend::default()); - Ok(PredictionProvider::TeacherNonBatching( - backend, - ZetaFormat::default(), - )) + let (backend, format) = parse_teacher_args(arg)?; + Ok(PredictionProvider::TeacherNonBatching(backend, format)) } "teacher-multi-region" | "teacher_multi_region" => { let backend = arg @@ -461,7 +458,7 @@ impl std::str::FromStr for PredictionProvider { } } -fn parse_teacher_args(arg: Option<&str>) -> Result { +fn parse_teacher_args(arg: Option<&str>) -> Result<(TeacherBackend, ZetaFormat), anyhow::Error> { let mut backend = TeacherBackend::default(); let mut format = ZetaFormat::default(); @@ -479,7 +476,7 @@ fn parse_teacher_args(arg: Option<&str>) -> Result