diff --git a/crates/edit_prediction/src/fim.rs b/crates/edit_prediction/src/fim.rs index 02053aae7154acdfa22a01a4f84d6b732a9ca696..79df739e60bc28ba5c6b9f53699dcf398fc8310e 100644 --- a/crates/edit_prediction/src/fim.rs +++ b/crates/edit_prediction/src/fim.rs @@ -73,7 +73,7 @@ pub fn request_prediction( let inputs = ZetaPromptInput { events, - related_files: Vec::new(), + related_files: Some(Vec::new()), cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start, cursor_path: full_path.clone(), excerpt_start_row: Some(excerpt_range.start.row), diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index cbb4e027253bb4d69b684c0668ff0da60f4e6aaf..0d63005feb18acb9a434ff107811080a7bcf1f12 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -91,7 +91,7 @@ impl Mercury { let inputs = zeta_prompt::ZetaPromptInput { events, - related_files, + related_files: Some(related_files), cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_offset_range.start, cursor_path: full_path.clone(), @@ -260,7 +260,7 @@ fn build_prompt(inputs: &ZetaPromptInput) -> String { &mut prompt, RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END, |prompt| { - for related_file in inputs.related_files.iter() { + for related_file in inputs.related_files.as_deref().unwrap_or_default().iter() { for related_excerpt in &related_file.excerpts { push_delimited( prompt, diff --git a/crates/edit_prediction/src/prediction.rs b/crates/edit_prediction/src/prediction.rs index 263409043b397e2df1ac32514a0ce76656fbefe1..1c281453b93d0ab7c601f575b290c46fe63d2eae 100644 --- a/crates/edit_prediction/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -156,7 +156,7 @@ mod tests { model_version: None, inputs: ZetaPromptInput { events: vec![], - related_files: vec![], + related_files: Some(vec![]), cursor_path: Path::new("path.txt").into(), cursor_offset_in_excerpt: 0, cursor_excerpt: "".into(), diff --git a/crates/edit_prediction/src/sweep_ai.rs b/crates/edit_prediction/src/sweep_ai.rs index d8ce180801aa8902bfff79044cabaae7570ed05f..ff5128e56e49191f308a574d5502f8139db9bc3f 100644 --- a/crates/edit_prediction/src/sweep_ai.rs +++ b/crates/edit_prediction/src/sweep_ai.rs @@ -212,7 +212,7 @@ impl SweepAi { let ep_inputs = zeta_prompt::ZetaPromptInput { events: inputs.events, - related_files: inputs.related_files.clone(), + related_files: Some(inputs.related_files.clone()), cursor_path: full_path.clone(), cursor_excerpt: request_body.file_contents.clone().into(), cursor_offset_in_excerpt: request_body.cursor_position, diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index 1217cbd5ba6f8ecd5b13aa1eec3b1a88bf26dbc2..1a4d0b445a8c3d5876eb48646a0a1622a8b725a2 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -509,7 +509,7 @@ pub fn zeta2_prompt_input( cursor_offset_in_excerpt, excerpt_start_row: Some(full_context_start_row), events, - related_files, + related_files: Some(related_files), excerpt_ranges, experiment: preferred_experiment, in_open_source_repo: is_open_source, diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index fe7dff5935aed035e803b1451c8c06df8f79b810..324c297ba4c75d10a24b53c7961bd35e1f42e2cd 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -259,7 +259,10 @@ impl TeacherPrompt { } pub fn format_context(example: &Example) -> String { - let related_files = example.prompt_inputs.as_ref().map(|pi| &pi.related_files); + let related_files = example + .prompt_inputs + .as_ref() + .and_then(|pi| pi.related_files.as_deref()); let Some(related_files) = related_files else { return "(No context)".to_string(); }; diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs index df458770519be5accd72f33a56893bb13c9b88a9..f7e27ca432baacd38c468e5b4c6f97b62cb8ee3e 100644 --- a/crates/edit_prediction_cli/src/load_project.rs +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -71,8 +71,7 @@ pub async fn run_load_project( let existing_related_files = example .prompt_inputs .take() - .map(|inputs| inputs.related_files) - .unwrap_or_default(); + .and_then(|inputs| inputs.related_files); let (prompt_inputs, language_name) = buffer.read_with(&cx, |buffer, _cx| { let snapshot = buffer.snapshot(); diff --git a/crates/edit_prediction_cli/src/retrieve_context.rs b/crates/edit_prediction_cli/src/retrieve_context.rs index a5fb00b39a67a15a7afcced897b4d109f1f3406f..971bdf24d3e8cd1d8184a9009903cec25d3000d1 100644 --- a/crates/edit_prediction_cli/src/retrieve_context.rs +++ b/crates/edit_prediction_cli/src/retrieve_context.rs @@ -20,18 +20,12 @@ pub async fn run_context_retrieval( example_progress: &ExampleProgress, mut cx: AsyncApp, ) -> anyhow::Result<()> { - if example.prompt_inputs.is_some() { - if example.spec.repository_url.is_empty() { - return Ok(()); - } - - if example - .prompt_inputs - .as_ref() - .is_some_and(|inputs| !inputs.related_files.is_empty()) - { - return Ok(()); - } + if example + .prompt_inputs + .as_ref() + .is_some_and(|inputs| inputs.related_files.is_some()) + { + return Ok(()); } run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?; @@ -72,7 +66,7 @@ pub async fn run_context_retrieval( step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal); if let Some(prompt_inputs) = example.prompt_inputs.as_mut() { - prompt_inputs.related_files = context_files; + prompt_inputs.related_files = Some(context_files); } Ok(()) } diff --git a/crates/edit_prediction_cli/src/reversal_tracking.rs b/crates/edit_prediction_cli/src/reversal_tracking.rs index cb955dbdf7dd2375395e8c0ecd52df849e33fb38..398ae24309bbb9368bb7947c94ad4f481c03ab9e 100644 --- a/crates/edit_prediction_cli/src/reversal_tracking.rs +++ b/crates/edit_prediction_cli/src/reversal_tracking.rs @@ -668,7 +668,7 @@ mod tests { cursor_offset_in_excerpt: 0, excerpt_start_row, events, - related_files: Vec::new(), + related_files: Some(Vec::new()), excerpt_ranges: ExcerptRanges { editable_150: 0..content.len(), editable_180: 0..content.len(), diff --git a/crates/edit_prediction_ui/src/rate_prediction_modal.rs b/crates/edit_prediction_ui/src/rate_prediction_modal.rs index d07dbe9bad72c2252ee2e33c8a014778d1331e96..1c4328d8a1d301b7cc01aa520c166bda4b40e32d 100644 --- a/crates/edit_prediction_ui/src/rate_prediction_modal.rs +++ b/crates/edit_prediction_ui/src/rate_prediction_modal.rs @@ -402,7 +402,13 @@ impl RatePredictionsModal { write!(&mut formatted_inputs, "## Related files\n\n").unwrap(); - for included_file in prediction.inputs.related_files.iter() { + for included_file in prediction + .inputs + .related_files + .as_deref() + .unwrap_or_default() + .iter() + { write!( &mut formatted_inputs, "### {}\n\n", diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 3f7839305bd840f32a3f27182b0c5d02c1166099..774ac7cb9baebb943c9223645aae8d16cd730998 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -51,7 +51,8 @@ pub struct ZetaPromptInput { #[serde(default, skip_serializing_if = "Option::is_none")] pub excerpt_start_row: Option, pub events: Vec>, - pub related_files: Vec, + #[serde(default)] + pub related_files: Option>, /// These ranges let the server select model-appropriate subsets. pub excerpt_ranges: ExcerptRanges, /// The name of the edit prediction model experiment to use. @@ -350,17 +351,19 @@ pub fn format_prompt_with_budget_for_format( resolve_cursor_region(input, format); let path = &*input.cursor_path; + let empty_files = Vec::new(); + let input_related_files = input.related_files.as_deref().unwrap_or(&empty_files); let related_files = if let Some(cursor_excerpt_start_row) = input.excerpt_start_row { let relative_row_range = offset_range_to_row_range(&input.cursor_excerpt, context_range); let row_range = relative_row_range.start + cursor_excerpt_start_row ..relative_row_range.end + cursor_excerpt_start_row; &filter_redundant_excerpts( - input.related_files.clone(), + input_related_files.to_vec(), input.cursor_path.as_ref(), row_range, ) } else { - &input.related_files + input_related_files }; match format { @@ -3863,7 +3866,7 @@ mod tests { cursor_offset_in_excerpt: cursor_offset, excerpt_start_row: None, events: events.into_iter().map(Arc::new).collect(), - related_files, + related_files: Some(related_files), excerpt_ranges: ExcerptRanges { editable_150: editable_range.clone(), editable_180: editable_range.clone(), @@ -3892,7 +3895,7 @@ mod tests { cursor_offset_in_excerpt: cursor_offset, excerpt_start_row: None, events: vec![], - related_files: vec![], + related_files: Some(vec![]), excerpt_ranges: ExcerptRanges { editable_150: editable_range.clone(), editable_180: editable_range.clone(), @@ -4475,7 +4478,7 @@ mod tests { cursor_offset_in_excerpt: 30, excerpt_start_row: Some(0), events: vec![Arc::new(make_event("other.rs", "-old\n+new\n"))], - related_files: vec![], + related_files: Some(vec![]), excerpt_ranges: ExcerptRanges { editable_150: 15..41, editable_180: 15..41, @@ -4538,7 +4541,7 @@ mod tests { cursor_offset_in_excerpt: 15, excerpt_start_row: Some(10), events: vec![], - related_files: vec![], + related_files: Some(vec![]), excerpt_ranges: ExcerptRanges { editable_150: 0..28, editable_180: 0..28, @@ -4596,7 +4599,7 @@ mod tests { cursor_offset_in_excerpt: 25, excerpt_start_row: Some(0), events: vec![], - related_files: vec![], + related_files: Some(vec![]), excerpt_ranges: ExcerptRanges { editable_150: editable_range.clone(), editable_180: editable_range.clone(),