@@ -12,10 +12,20 @@ use std::ops::Range;
use std::sync::Arc;
use zeta_prompt::udiff;
use zeta_prompt::{
- ZetaFormat, encode_patch_as_output_for_format, excerpt_range_for_format, format_zeta_prompt,
- multi_region, output_end_marker_for_format, resolve_cursor_region,
+ ZetaFormat, encode_patch_as_output_for_format, format_zeta_prompt, multi_region,
+ output_end_marker_for_format, resolve_cursor_region,
};
+fn resolved_excerpt_ranges_for_format(
+ input: &zeta_prompt::ZetaPromptInput,
+ format: ZetaFormat,
+) -> (Range<usize>, Range<usize>) {
+ let (_, editable_range_in_context, context_range, _) = resolve_cursor_region(input, format);
+ let editable_range = (context_range.start + editable_range_in_context.start)
+ ..(context_range.start + editable_range_in_context.end);
+ (editable_range, context_range)
+}
+
pub async fn run_format_prompt(
example: &mut Example,
args: &FormatPromptArgs,
@@ -38,7 +48,7 @@ pub async fn run_format_prompt(
step_progress.set_substatus("formatting teacher prompt");
let (editable_range, context_range) =
- excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges);
+ resolved_excerpt_ranges_for_format(prompt_inputs, zeta_format);
let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
example.prompt = Some(ExamplePrompt {
@@ -55,7 +65,7 @@ pub async fn run_format_prompt(
let zeta_format = ZetaFormat::default();
let (editable_range, context_range) =
- excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges);
+ resolved_excerpt_ranges_for_format(prompt_inputs, zeta_format);
let prompt =
TeacherMultiRegionPrompt::format_prompt(example, editable_range, context_range);
@@ -454,8 +464,7 @@ impl TeacherMultiRegionPrompt {
.context("example is missing prompt inputs")?;
let zeta_format = ZetaFormat::default();
- let (editable_range, _) =
- excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges);
+ let (editable_range, _) = resolved_excerpt_ranges_for_format(prompt_inputs, zeta_format);
let excerpt = prompt_inputs.cursor_excerpt.as_ref();
let old_editable_region = &excerpt[editable_range.clone()];
let marker_offsets = multi_region::compute_marker_offsets(old_editable_region);
@@ -940,4 +949,84 @@ mod tests {
assert!(parsed.0.is_empty());
assert!(parsed.1.is_none());
}
+
+ #[test]
+ fn test_v0327_teacher_prompt_uses_resolved_ranges() {
+ let excerpt = (0..80)
+ .map(|index| format!("line{index:02}\n"))
+ .collect::<String>();
+ let cursor_offset = excerpt.find("line40").expect("cursor line exists");
+ let prompt_inputs = zeta_prompt::ZetaPromptInput {
+ cursor_path: std::path::Path::new("src/main.rs").into(),
+ cursor_excerpt: excerpt.clone().into(),
+ cursor_offset_in_excerpt: cursor_offset,
+ excerpt_start_row: None,
+ events: Vec::new(),
+ related_files: Some(Vec::new()),
+ active_buffer_diagnostics: Vec::new(),
+ excerpt_ranges: zeta_prompt::ExcerptRanges {
+ editable_150: 0..32,
+ editable_180: 0..32,
+ editable_350: 0..32,
+ editable_512: None,
+ editable_150_context_350: 0..48,
+ editable_180_context_350: 0..48,
+ editable_350_context_150: 20..50,
+ editable_350_context_512: None,
+ editable_350_context_1024: None,
+ context_4096: None,
+ context_8192: Some(30..excerpt.len()),
+ },
+ syntax_ranges: None,
+ in_open_source_repo: false,
+ can_collect_data: false,
+ repo_url: None,
+ };
+
+ let (stored_editable_range, stored_context_range) = zeta_prompt::excerpt_range_for_format(
+ ZetaFormat::V0327SingleFile,
+ &prompt_inputs.excerpt_ranges,
+ );
+ assert!(stored_context_range.start > stored_editable_range.start);
+
+ let (editable_range, context_range) =
+ resolved_excerpt_ranges_for_format(&prompt_inputs, ZetaFormat::V0327SingleFile);
+ assert_eq!(context_range, 0..excerpt.len());
+ assert!(editable_range.start < cursor_offset);
+ assert!(editable_range.end > cursor_offset);
+
+ let prompt = TeacherPrompt::format_prompt(
+ &Example {
+ spec: edit_prediction::example_spec::ExampleSpec {
+ name: "test".to_string(),
+ repository_url: "https://github.com/zed-industries/zed.git".to_string(),
+ revision: "HEAD".to_string(),
+ tags: Vec::new(),
+ reasoning: None,
+ uncommitted_diff: String::new(),
+ cursor_path: std::sync::Arc::from(std::path::Path::new("src/main.rs")),
+ cursor_position: "0:0".to_string(),
+ edit_history: String::new(),
+ expected_patches: Vec::new(),
+ rejected_patch: None,
+ telemetry: None,
+ human_feedback: Vec::new(),
+ rating: None,
+ },
+ prompt_inputs: Some(prompt_inputs),
+ prompt: None,
+ predictions: Vec::new(),
+ score: Vec::new(),
+ qa: Vec::new(),
+ zed_version: None,
+ state: None,
+ },
+ editable_range,
+ context_range,
+ );
+
+ assert!(prompt.contains(TeacherPrompt::EDITABLE_REGION_START));
+ assert!(prompt.contains(TeacherPrompt::USER_CURSOR_MARKER));
+ assert!(prompt.contains("line40"));
+ }
}
@@ -415,16 +415,13 @@ impl std::str::FromStr for PredictionProvider {
let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
Ok(PredictionProvider::Zeta2(format))
}
- "teacher" => parse_teacher_args(arg),
+ "teacher" => {
+ let (backend, format) = parse_teacher_args(arg)?;
+ Ok(PredictionProvider::Teacher(backend, format))
+ }
"teacher-non-batching" | "teacher_non_batching" => {
- let backend = arg
- .map(|a| a.parse())
- .transpose()?
- .unwrap_or(TeacherBackend::default());
- Ok(PredictionProvider::TeacherNonBatching(
- backend,
- ZetaFormat::default(),
- ))
+ let (backend, format) = parse_teacher_args(arg)?;
+ Ok(PredictionProvider::TeacherNonBatching(backend, format))
}
"teacher-multi-region" | "teacher_multi_region" => {
let backend = arg
@@ -461,7 +458,7 @@ impl std::str::FromStr for PredictionProvider {
}
}
-fn parse_teacher_args(arg: Option<&str>) -> Result<PredictionProvider, anyhow::Error> {
+fn parse_teacher_args(arg: Option<&str>) -> Result<(TeacherBackend, ZetaFormat), anyhow::Error> {
let mut backend = TeacherBackend::default();
let mut format = ZetaFormat::default();
@@ -479,7 +476,7 @@ fn parse_teacher_args(arg: Option<&str>) -> Result<PredictionProvider, anyhow::E
}
}
- Ok(PredictionProvider::Teacher(backend, format))
+ Ok((backend, format))
}
impl Serialize for PredictionProvider {