diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 98ca0748934d663d204c64544af8a3e83fcd704d..e17a92387e68b5cf6e0993ec91f382f6c14cc765 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -76,6 +76,8 @@ pub enum PromptFormat { OldTextNewText, /// Prompt format intended for use via zeta_cli OnlySnippets, + /// One-sentence instructions used in fine-tuned models + Minimal, } impl PromptFormat { @@ -102,6 +104,7 @@ impl std::fmt::Display for PromptFormat { PromptFormat::OnlySnippets => write!(f, "Only Snippets"), PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"), PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"), + PromptFormat::Minimal => write!(f, "Minimal"), } } } diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index 3f0bd476c50b9e6f92a9f457af15899fcb33b8ed..89c7536f88e1c0bdcce7b67fb2f2704052b5a677 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -86,6 +86,13 @@ const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#" "#}; +const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#" + You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase. + + # Edit History: + + "#}; + const UNIFIED_DIFF_REMINDER: &str = indoc! {" --- @@ -100,6 +107,14 @@ const UNIFIED_DIFF_REMINDER: &str = indoc! {" to uniquely identify it amongst all excerpts of code provided. "}; +const MINIMAL_PROMPT_REMINDER: &str = indoc! {" + --- + + Please analyze the edit history and the files, then provide the unified diff for your predicted edits. + Do not include the cursor marker in your output. + If you're editing multiple files, be sure to reflect filename in the hunk's header. + "}; + const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#" # Instructions @@ -171,6 +186,7 @@ pub fn build_prompt( ], PromptFormat::LabeledSections | PromptFormat::NumLinesUniDiff + | PromptFormat::Minimal | PromptFormat::OldTextNewText => { vec![(request.cursor_point, CURSOR_MARKER)] } @@ -183,28 +199,47 @@ pub fn build_prompt( PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(), PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(), PromptFormat::OnlySnippets => String::new(), + PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(), }; if request.events.is_empty() { prompt.push_str("(No edit history)\n\n"); } else { - prompt.push_str("Here are the latest edits made by the user, from earlier to later.\n\n"); + let edit_preamble = if request.prompt_format == PromptFormat::Minimal { + "The following are the latest edits made by the user, from earlier to later.\n\n" + } else { + "Here are the latest edits made by the user, from earlier to later.\n\n" + }; + prompt.push_str(edit_preamble); push_events(&mut prompt, &request.events); } - prompt.push_str(indoc! {" - # Code Excerpts - - The cursor marker <|user_cursor|> indicates the current user cursor position. - The file is in current state, edits from edit history have been applied. - "}); - - if request.prompt_format == PromptFormat::NumLinesUniDiff { - prompt.push_str(indoc! {" + let excerpts_preamble = match request.prompt_format { + PromptFormat::Minimal => indoc! {" + # Part of the file under the cursor: + + (The cursor marker <|user_cursor|> indicates the current user cursor position. + The file is in current state, edits from edit history has been applied. + We only show part of the file around the cursor. + You can only edit exactly this part of the file. + We prepend line numbers (e.g., `123|`); they are not part of the file.) + "}, + PromptFormat::NumLinesUniDiff => indoc! {" + # Code Excerpts + + The cursor marker <|user_cursor|> indicates the current user cursor position. + The file is in current state, edits from edit history have been applied. We prepend line numbers (e.g., `123|`); they are not part of the file. - "}); - } + "}, + _ => indoc! {" + # Code Excerpts + The cursor marker <|user_cursor|> indicates the current user cursor position. + The file is in current state, edits from edit history have been applied. + "}, + }; + + prompt.push_str(excerpts_preamble); prompt.push('\n'); let mut section_labels = Default::default(); @@ -217,19 +252,38 @@ pub fn build_prompt( anyhow::bail!("PromptFormat::LabeledSections cannot be used with ContextMode::Llm"); } + let include_line_numbers = matches!( + request.prompt_format, + PromptFormat::NumLinesUniDiff | PromptFormat::Minimal + ); for related_file in &request.included_files { - write_codeblock( - &related_file.path, - &related_file.excerpts, - if related_file.path == request.excerpt_path { - &insertions - } else { - &[] - }, - related_file.max_row, - request.prompt_format == PromptFormat::NumLinesUniDiff, - &mut prompt, - ); + if request.prompt_format == PromptFormat::Minimal { + write_codeblock_with_filename( + &related_file.path, + &related_file.excerpts, + if related_file.path == request.excerpt_path { + &insertions + } else { + &[] + }, + related_file.max_row, + include_line_numbers, + &mut prompt, + ); + } else { + write_codeblock( + &related_file.path, + &related_file.excerpts, + if related_file.path == request.excerpt_path { + &insertions + } else { + &[] + }, + related_file.max_row, + include_line_numbers, + &mut prompt, + ); + } } } @@ -240,6 +294,9 @@ pub fn build_prompt( PromptFormat::OldTextNewText => { prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER); } + PromptFormat::Minimal => { + prompt.push_str(MINIMAL_PROMPT_REMINDER); + } _ => {} } @@ -255,6 +312,27 @@ pub fn write_codeblock<'a>( output: &'a mut String, ) { writeln!(output, "`````{}", DiffPathFmt(path)).unwrap(); + + write_excerpts( + excerpts, + sorted_insertions, + file_line_count, + include_line_numbers, + output, + ); + write!(output, "`````\n\n").unwrap(); +} + +fn write_codeblock_with_filename<'a>( + path: &Path, + excerpts: impl IntoIterator, + sorted_insertions: &[(Point, &str)], + file_line_count: Line, + include_line_numbers: bool, + output: &'a mut String, +) { + writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap(); + write_excerpts( excerpts, sorted_insertions, @@ -666,6 +744,7 @@ impl<'a> SyntaxBasedPrompt<'a> { PromptFormat::MarkedExcerpt | PromptFormat::OnlySnippets | PromptFormat::OldTextNewText + | PromptFormat::Minimal | PromptFormat::NumLinesUniDiff => { if range.start.0 > 0 && !skipped_last_snippet { output.push_str("…\n"); diff --git a/crates/zeta2/src/retrieval_search.rs b/crates/zeta2/src/retrieval_search.rs index 76501fb1e5c73a22ff8eebc5c29d117d45389beb..d642c2edaa1fbc897b3c74b0b5c8b1fb71227e84 100644 --- a/crates/zeta2/src/retrieval_search.rs +++ b/crates/zeta2/src/retrieval_search.rs @@ -571,10 +571,15 @@ mod tests { expected_output: &str, cx: &mut TestAppContext, ) { - let results = - run_retrieval_searches(vec![query], project.clone(), None, &mut cx.to_async()) - .await - .unwrap(); + let results = run_retrieval_searches( + vec![query], + project.clone(), + #[cfg(feature = "eval-support")] + None, + &mut cx.to_async(), + ) + .await + .unwrap(); let mut results = results.into_iter().collect::>(); results.sort_by_key(|results| { diff --git a/crates/zeta2/src/udiff.rs b/crates/zeta2/src/udiff.rs index d765a64345f839b9314632444d209fa79e9ca5ce..d565fab1b0c2bbf1e27fe183df1c95e27cac871d 100644 --- a/crates/zeta2/src/udiff.rs +++ b/crates/zeta2/src/udiff.rs @@ -49,7 +49,7 @@ pub async fn parse_diff<'a>( DiffEvent::FileEnd { renamed_to } => { let (buffer, _) = edited_buffer .take() - .expect("Got a FileEnd event before an Hunk event"); + .context("Got a FileEnd event before an Hunk event")?; if renamed_to.is_some() { anyhow::bail!("edit predictions cannot rename files"); @@ -133,7 +133,7 @@ pub async fn apply_diff<'a>( DiffEvent::FileEnd { renamed_to } => { let (buffer, _) = current_file .take() - .expect("Got a FileEnd event before an Hunk event"); + .context("Got a FileEnd event before an Hunk event")?; if let Some(renamed_to) = renamed_to { project diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index b32e902b71a1b4a20e5f935eea854ecf115ae0f1..7322cb4b6e6882ad2f3597abb505224cc24dbd5e 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -91,13 +91,22 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { static USE_OLLAMA: LazyLock = LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty())); -static MODEL_ID: LazyLock = LazyLock::new(|| { - env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA { +static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock = LazyLock::new(|| { + env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA { "qwen3-coder:30b".to_string() } else { "yqvev8r3".to_string() }) }); +static EDIT_PREDICTIONS_MODEL_ID: LazyLock = LazyLock::new(|| { + match env::var("ZED_ZETA2_MODEL").as_deref() { + Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten + Ok(model) => model, + Err(_) if *USE_OLLAMA => "qwen3-coder:30b", + Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten + } + .to_string() +}); static PREDICT_EDITS_URL: LazyLock> = LazyLock::new(|| { env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| { if *USE_OLLAMA { @@ -826,7 +835,7 @@ impl Zeta { } else { included_files.push(( active_buffer.clone(), - active_snapshot, + active_snapshot.clone(), excerpt_path.clone(), vec![excerpt_anchor_range], )); @@ -940,7 +949,7 @@ impl Zeta { let (prompt, _) = prompt_result?; let request = open_ai::Request { - model: MODEL_ID.clone(), + model: EDIT_PREDICTIONS_MODEL_ID.clone(), messages: vec![open_ai::RequestMessage::User { content: open_ai::MessageContent::Plain(prompt), }], @@ -1010,8 +1019,17 @@ impl Zeta { let (edited_buffer_snapshot, edits) = match options.prompt_format { PromptFormat::NumLinesUniDiff => { + // TODO: Implement parsing of multi-file diffs crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? } + PromptFormat::Minimal => { + if output_text.contains("--- a/\n+++ b/\nNo edits") { + let edits = vec![]; + (&active_snapshot, edits) + } else { + crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? + } + } PromptFormat::OldTextNewText => { crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context) .await? @@ -1363,7 +1381,7 @@ impl Zeta { let (tool_schema, tool_description) = TOOL_SCHEMA.clone(); let request = open_ai::Request { - model: MODEL_ID.clone(), + model: CONTEXT_RETRIEVAL_MODEL_ID.clone(), messages: vec![open_ai::RequestMessage::User { content: open_ai::MessageContent::Plain(prompt), }], diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index 4f8e984a7de36a96c4e8ad3ac7e5d9e9bfda244b..d255d1a56102d836cc18ce4df10586edad0ca957 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -54,7 +54,6 @@ pub async fn run_evaluate( let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| { let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16); - let example = example.clone(); let project = project.clone(); @@ -208,7 +207,7 @@ fn write_eval_result( "## Actual edit prediction:\n\n```diff\n{}\n```\n", compare_diffs(&predictions.diff, &example.example.expected_patch) )?; - writeln!(out, "{}", evaluation_result)?; + writeln!(out, "{:#}", evaluation_result)?; anyhow::Ok(()) } @@ -304,6 +303,16 @@ False Negatives : {}", impl std::fmt::Display for EvaluationResult { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if f.alternate() { + self.fmt_table(f) + } else { + self.fmt_markdown(f) + } + } +} + +impl EvaluationResult { + fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, r#" @@ -317,6 +326,38 @@ impl std::fmt::Display for EvaluationResult { self.edit_prediction.to_markdown() ) } + + fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "### Scores\n")?; + writeln!( + f, + " TP FP FN Precision Recall F1" + )?; + writeln!( + f, + "──────────────────────────────────────────────────────────────────" + )?; + writeln!( + f, + "Context Retrieval {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}", + self.context.true_positives, + self.context.false_positives, + self.context.false_negatives, + self.context.precision() * 100.0, + self.context.recall() * 100.0, + self.context.f1_score() * 100.0 + )?; + writeln!( + f, + "Edit Prediction {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}", + self.edit_prediction.true_positives, + self.edit_prediction.false_positives, + self.edit_prediction.false_negatives, + self.edit_prediction.precision() * 100.0, + self.edit_prediction.recall() * 100.0, + self.edit_prediction.f1_score() * 100.0 + ) + } } pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult { diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index f75b4a7e25020395f24d2638af88d4ba8b390e77..7305d3bb2479452e0b8a54392a0a84cbea1be426 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -175,6 +175,7 @@ enum PromptFormat { #[default] NumberedLines, OldTextNewText, + Minimal, } impl Into for PromptFormat { @@ -185,6 +186,7 @@ impl Into for PromptFormat { Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets, Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff, Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText, + Self::Minimal => predict_edits_v3::PromptFormat::Minimal, } } } diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index 0cfc7421547b1b00bc552f157ae22b2a8afad541..1f419fd09a87d1270d73bc90fe4b312cbaf0b4a4 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -126,7 +126,7 @@ pub async fn zeta2_predict( example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix)); } fs::create_dir_all(&example_run_dir)?; - if LATEST_EXAMPLE_RUN_DIR.exists() { + if LATEST_EXAMPLE_RUN_DIR.is_symlink() { fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?; }