Detailed changes
@@ -76,6 +76,8 @@ pub enum PromptFormat {
OldTextNewText,
/// Prompt format intended for use via zeta_cli
OnlySnippets,
+ /// One-sentence instructions used in fine-tuned models
+ Minimal,
}
impl PromptFormat {
@@ -102,6 +104,7 @@ impl std::fmt::Display for PromptFormat {
PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
+ PromptFormat::Minimal => write!(f, "Minimal"),
}
}
}
@@ -86,6 +86,13 @@ const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#"
"#};
+const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
+ You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
+
+ # Edit History:
+
+ "#};
+
const UNIFIED_DIFF_REMINDER: &str = indoc! {"
---
@@ -100,6 +107,14 @@ const UNIFIED_DIFF_REMINDER: &str = indoc! {"
to uniquely identify it amongst all excerpts of code provided.
"};
+const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
+ ---
+
+ Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
+ Do not include the cursor marker in your output.
+ If you're editing multiple files, be sure to reflect filename in the hunk's header.
+ "};
+
const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#"
# Instructions
@@ -171,6 +186,7 @@ pub fn build_prompt(
],
PromptFormat::LabeledSections
| PromptFormat::NumLinesUniDiff
+ | PromptFormat::Minimal
| PromptFormat::OldTextNewText => {
vec![(request.cursor_point, CURSOR_MARKER)]
}
@@ -183,28 +199,47 @@ pub fn build_prompt(
PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(),
PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
PromptFormat::OnlySnippets => String::new(),
+ PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
};
if request.events.is_empty() {
prompt.push_str("(No edit history)\n\n");
} else {
- prompt.push_str("Here are the latest edits made by the user, from earlier to later.\n\n");
+ let edit_preamble = if request.prompt_format == PromptFormat::Minimal {
+ "The following are the latest edits made by the user, from earlier to later.\n\n"
+ } else {
+ "Here are the latest edits made by the user, from earlier to later.\n\n"
+ };
+ prompt.push_str(edit_preamble);
push_events(&mut prompt, &request.events);
}
- prompt.push_str(indoc! {"
- # Code Excerpts
-
- The cursor marker <|user_cursor|> indicates the current user cursor position.
- The file is in current state, edits from edit history have been applied.
- "});
-
- if request.prompt_format == PromptFormat::NumLinesUniDiff {
- prompt.push_str(indoc! {"
+ let excerpts_preamble = match request.prompt_format {
+ PromptFormat::Minimal => indoc! {"
+ # Part of the file under the cursor:
+
+ (The cursor marker <|user_cursor|> indicates the current user cursor position.
+ The file is in current state, edits from edit history has been applied.
+ We only show part of the file around the cursor.
+ You can only edit exactly this part of the file.
+ We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
+ "},
+ PromptFormat::NumLinesUniDiff => indoc! {"
+ # Code Excerpts
+
+ The cursor marker <|user_cursor|> indicates the current user cursor position.
+ The file is in current state, edits from edit history have been applied.
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
- "});
- }
+ "},
+ _ => indoc! {"
+ # Code Excerpts
+ The cursor marker <|user_cursor|> indicates the current user cursor position.
+ The file is in current state, edits from edit history have been applied.
+ "},
+ };
+
+ prompt.push_str(excerpts_preamble);
prompt.push('\n');
let mut section_labels = Default::default();
@@ -217,19 +252,38 @@ pub fn build_prompt(
anyhow::bail!("PromptFormat::LabeledSections cannot be used with ContextMode::Llm");
}
+ let include_line_numbers = matches!(
+ request.prompt_format,
+ PromptFormat::NumLinesUniDiff | PromptFormat::Minimal
+ );
for related_file in &request.included_files {
- write_codeblock(
- &related_file.path,
- &related_file.excerpts,
- if related_file.path == request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- related_file.max_row,
- request.prompt_format == PromptFormat::NumLinesUniDiff,
- &mut prompt,
- );
+ if request.prompt_format == PromptFormat::Minimal {
+ write_codeblock_with_filename(
+ &related_file.path,
+ &related_file.excerpts,
+ if related_file.path == request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ related_file.max_row,
+ include_line_numbers,
+ &mut prompt,
+ );
+ } else {
+ write_codeblock(
+ &related_file.path,
+ &related_file.excerpts,
+ if related_file.path == request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ related_file.max_row,
+ include_line_numbers,
+ &mut prompt,
+ );
+ }
}
}
@@ -240,6 +294,9 @@ pub fn build_prompt(
PromptFormat::OldTextNewText => {
prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
}
+ PromptFormat::Minimal => {
+ prompt.push_str(MINIMAL_PROMPT_REMINDER);
+ }
_ => {}
}
@@ -255,6 +312,27 @@ pub fn write_codeblock<'a>(
output: &'a mut String,
) {
writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
+
+ write_excerpts(
+ excerpts,
+ sorted_insertions,
+ file_line_count,
+ include_line_numbers,
+ output,
+ );
+ write!(output, "`````\n\n").unwrap();
+}
+
+fn write_codeblock_with_filename<'a>(
+ path: &Path,
+ excerpts: impl IntoIterator<Item = &'a Excerpt>,
+ sorted_insertions: &[(Point, &str)],
+ file_line_count: Line,
+ include_line_numbers: bool,
+ output: &'a mut String,
+) {
+ writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap();
+
write_excerpts(
excerpts,
sorted_insertions,
@@ -666,6 +744,7 @@ impl<'a> SyntaxBasedPrompt<'a> {
PromptFormat::MarkedExcerpt
| PromptFormat::OnlySnippets
| PromptFormat::OldTextNewText
+ | PromptFormat::Minimal
| PromptFormat::NumLinesUniDiff => {
if range.start.0 > 0 && !skipped_last_snippet {
output.push_str("…\n");
@@ -571,10 +571,15 @@ mod tests {
expected_output: &str,
cx: &mut TestAppContext,
) {
- let results =
- run_retrieval_searches(vec![query], project.clone(), None, &mut cx.to_async())
- .await
- .unwrap();
+ let results = run_retrieval_searches(
+ vec![query],
+ project.clone(),
+ #[cfg(feature = "eval-support")]
+ None,
+ &mut cx.to_async(),
+ )
+ .await
+ .unwrap();
let mut results = results.into_iter().collect::<Vec<_>>();
results.sort_by_key(|results| {
@@ -49,7 +49,7 @@ pub async fn parse_diff<'a>(
DiffEvent::FileEnd { renamed_to } => {
let (buffer, _) = edited_buffer
.take()
- .expect("Got a FileEnd event before an Hunk event");
+ .context("Got a FileEnd event before an Hunk event")?;
if renamed_to.is_some() {
anyhow::bail!("edit predictions cannot rename files");
@@ -133,7 +133,7 @@ pub async fn apply_diff<'a>(
DiffEvent::FileEnd { renamed_to } => {
let (buffer, _) = current_file
.take()
- .expect("Got a FileEnd event before an Hunk event");
+ .context("Got a FileEnd event before an Hunk event")?;
if let Some(renamed_to) = renamed_to {
project
@@ -91,13 +91,22 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
static USE_OLLAMA: LazyLock<bool> =
LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
-static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
- env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
+static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
+ env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
"qwen3-coder:30b".to_string()
} else {
"yqvev8r3".to_string()
})
});
+static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
+ match env::var("ZED_ZETA2_MODEL").as_deref() {
+ Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
+ Ok(model) => model,
+ Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
+ Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
+ }
+ .to_string()
+});
static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
if *USE_OLLAMA {
@@ -826,7 +835,7 @@ impl Zeta {
} else {
included_files.push((
active_buffer.clone(),
- active_snapshot,
+ active_snapshot.clone(),
excerpt_path.clone(),
vec![excerpt_anchor_range],
));
@@ -940,7 +949,7 @@ impl Zeta {
let (prompt, _) = prompt_result?;
let request = open_ai::Request {
- model: MODEL_ID.clone(),
+ model: EDIT_PREDICTIONS_MODEL_ID.clone(),
messages: vec![open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(prompt),
}],
@@ -1010,8 +1019,17 @@ impl Zeta {
let (edited_buffer_snapshot, edits) = match options.prompt_format {
PromptFormat::NumLinesUniDiff => {
+ // TODO: Implement parsing of multi-file diffs
crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
}
+ PromptFormat::Minimal => {
+ if output_text.contains("--- a/\n+++ b/\nNo edits") {
+ let edits = vec![];
+ (&active_snapshot, edits)
+ } else {
+ crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
+ }
+ }
PromptFormat::OldTextNewText => {
crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
.await?
@@ -1363,7 +1381,7 @@ impl Zeta {
let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
let request = open_ai::Request {
- model: MODEL_ID.clone(),
+ model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
messages: vec![open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(prompt),
}],
@@ -54,7 +54,6 @@ pub async fn run_evaluate(
let tasks = zetas.into_iter().enumerate().map(|(repetition_ix, zeta)| {
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
-
let example = example.clone();
let project = project.clone();
@@ -208,7 +207,7 @@ fn write_eval_result(
"## Actual edit prediction:\n\n```diff\n{}\n```\n",
compare_diffs(&predictions.diff, &example.example.expected_patch)
)?;
- writeln!(out, "{}", evaluation_result)?;
+ writeln!(out, "{:#}", evaluation_result)?;
anyhow::Ok(())
}
@@ -304,6 +303,16 @@ False Negatives : {}",
impl std::fmt::Display for EvaluationResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ if f.alternate() {
+ self.fmt_table(f)
+ } else {
+ self.fmt_markdown(f)
+ }
+ }
+}
+
+impl EvaluationResult {
+ fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
r#"
@@ -317,6 +326,38 @@ impl std::fmt::Display for EvaluationResult {
self.edit_prediction.to_markdown()
)
}
+
+ fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ writeln!(f, "### Scores\n")?;
+ writeln!(
+ f,
+ " TP FP FN Precision Recall F1"
+ )?;
+ writeln!(
+ f,
+ "──────────────────────────────────────────────────────────────────"
+ )?;
+ writeln!(
+ f,
+ "Context Retrieval {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+ self.context.true_positives,
+ self.context.false_positives,
+ self.context.false_negatives,
+ self.context.precision() * 100.0,
+ self.context.recall() * 100.0,
+ self.context.f1_score() * 100.0
+ )?;
+ writeln!(
+ f,
+ "Edit Prediction {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+ self.edit_prediction.true_positives,
+ self.edit_prediction.false_positives,
+ self.edit_prediction.false_negatives,
+ self.edit_prediction.precision() * 100.0,
+ self.edit_prediction.recall() * 100.0,
+ self.edit_prediction.f1_score() * 100.0
+ )
+ }
}
pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResult {
@@ -175,6 +175,7 @@ enum PromptFormat {
#[default]
NumberedLines,
OldTextNewText,
+ Minimal,
}
impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
@@ -185,6 +186,7 @@ impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
+ Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
}
}
}
@@ -126,7 +126,7 @@ pub async fn zeta2_predict(
example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
}
fs::create_dir_all(&example_run_dir)?;
- if LATEST_EXAMPLE_RUN_DIR.exists() {
+ if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
}