diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index e17a92387e68b5cf6e0993ec91f382f6c14cc765..2d7a1aec52ae9cb007238dbd61e58597a9e81666 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -78,6 +78,8 @@ pub enum PromptFormat { OnlySnippets, /// One-sentence instructions used in fine-tuned models Minimal, + /// One-sentence instructions + FIM-like template + MinimalQwen, } impl PromptFormat { @@ -105,6 +107,7 @@ impl std::fmt::Display for PromptFormat { PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"), PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"), PromptFormat::Minimal => write!(f, "Minimal"), + PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"), } } } diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index c84ba24ae3485f837278f61e1eeb8b40eb276840..48ab2097d4ca960c28f7edb498e57ded95e208f7 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -3,7 +3,8 @@ pub mod retrieval_prompt; use anyhow::{Context as _, Result, anyhow}; use cloud_llm_client::predict_edits_v3::{ - self, DiffPathFmt, Excerpt, Line, Point, PromptFormat, ReferencedDeclaration, + self, DiffPathFmt, Event, Excerpt, IncludedFile, Line, Point, PromptFormat, + ReferencedDeclaration, }; use indoc::indoc; use ordered_float::OrderedFloat; @@ -166,6 +167,21 @@ const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#" pub fn build_prompt( request: &predict_edits_v3::PredictEditsRequest, ) -> Result<(String, SectionLabels)> { + let mut section_labels = Default::default(); + + match request.prompt_format { + PromptFormat::MinimalQwen => { + let prompt = MinimalQwenPrompt { + events: request.events.clone(), + cursor_point: request.cursor_point, + cursor_path: request.excerpt_path.clone(), + included_files: request.included_files.clone(), + }; + return Ok((prompt.render(), section_labels)); + } + _ => (), + }; + let mut insertions = match request.prompt_format { PromptFormat::MarkedExcerpt => vec![ ( @@ -191,6 +207,7 @@ pub fn build_prompt( vec![(request.cursor_point, CURSOR_MARKER)] } PromptFormat::OnlySnippets => vec![], + PromptFormat::MinimalQwen => unreachable!(), }; let mut prompt = match request.prompt_format { @@ -200,6 +217,7 @@ pub fn build_prompt( PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(), PromptFormat::OnlySnippets => String::new(), PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(), + PromptFormat::MinimalQwen => unreachable!(), }; if request.events.is_empty() { @@ -251,8 +269,6 @@ pub fn build_prompt( prompt.push_str(excerpts_preamble); prompt.push('\n'); - let mut section_labels = Default::default(); - if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() { let syntax_based_prompt = SyntaxBasedPrompt::populate(request)?; section_labels = syntax_based_prompt.write(&mut insertions, &mut prompt)?; @@ -769,6 +785,7 @@ impl<'a> SyntaxBasedPrompt<'a> { writeln!(output, "<|section_{}|>", section_index).ok(); } } + PromptFormat::MinimalQwen => unreachable!(), } let push_full_snippet = |output: &mut String| { @@ -878,3 +895,69 @@ fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle DeclarationStyle::Declaration => declaration.text.len(), } } + +struct MinimalQwenPrompt { + events: Vec, + cursor_point: Point, + cursor_path: Arc, // TODO: make a common struct with cursor_point + included_files: Vec, +} + +impl MinimalQwenPrompt { + const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n"; + + fn render(&self) -> String { + let edit_history = self.fmt_edit_history(); + let context = self.fmt_context(); + + format!( + "{instructions}\n\n{edit_history}\n\n{context}", + instructions = MinimalQwenPrompt::INSTRUCTIONS, + edit_history = edit_history, + context = context + ) + } + + fn fmt_edit_history(&self) -> String { + if self.events.is_empty() { + "(No edit history)\n\n".to_string() + } else { + let mut events_str = String::new(); + push_events(&mut events_str, &self.events); + format!( + "The following are the latest edits made by the user, from earlier to later.\n\n{}", + events_str + ) + } + } + + fn fmt_context(&self) -> String { + let mut context = String::new(); + let include_line_numbers = true; + + for related_file in &self.included_files { + writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap(); + + if related_file.path == self.cursor_path { + write!(context, "<|fim_prefix|>").unwrap(); + write_excerpts( + &related_file.excerpts, + &[(self.cursor_point, "<|fim_suffix|>")], + related_file.max_row, + include_line_numbers, + &mut context, + ); + writeln!(context, "<|fim_middle|>").unwrap(); + } else { + write_excerpts( + &related_file.excerpts, + &[], + related_file.max_row, + include_line_numbers, + &mut context, + ); + } + } + context + } +} diff --git a/crates/zeta2/src/udiff.rs b/crates/zeta2/src/udiff.rs index d565fab1b0c2bbf1e27fe183df1c95e27cac871d..5ae029c6c16c2c6b6d0c2451cc961e8399a64a8f 100644 --- a/crates/zeta2/src/udiff.rs +++ b/crates/zeta2/src/udiff.rs @@ -391,10 +391,12 @@ impl<'a> DiffLine<'a> { return Some(Self::HunkHeader(None)); } - let (start_line_old, header) = header.strip_prefix('-')?.split_once(',')?; - let mut parts = header.split_ascii_whitespace(); - let count_old = parts.next()?; - let (start_line_new, count_new) = parts.next()?.strip_prefix('+')?.split_once(',')?; + let mut tokens = header.split_whitespace(); + let old_range = tokens.next()?.strip_prefix('-')?; + let new_range = tokens.next()?.strip_prefix('+')?; + + let (start_line_old, count_old) = old_range.split_once(',').unwrap_or((old_range, "1")); + let (start_line_new, count_new) = new_range.split_once(',').unwrap_or((new_range, "1")); Some(Self::HunkHeader(Some(HunkLocation { start_line_old: start_line_old.parse::().ok()?.saturating_sub(1), diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 1521fbd9291c7a69cc56152d193734f41cf0451e..881a7254f876e1b2df636513480115bf36489a24 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -1015,7 +1015,7 @@ impl Zeta { // TODO: Implement parsing of multi-file diffs crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? } - PromptFormat::Minimal => { + PromptFormat::Minimal | PromptFormat::MinimalQwen => { if output_text.contains("--- a/\n+++ b/\nNo edits") { let edits = vec![]; (&active_snapshot, edits) diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 7305d3bb2479452e0b8a54392a0a84cbea1be426..517deb6ec7482ca2712a347531b24eca5ed16796 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -176,6 +176,7 @@ enum PromptFormat { NumberedLines, OldTextNewText, Minimal, + MinimalQwen, } impl Into for PromptFormat { @@ -187,6 +188,7 @@ impl Into for PromptFormat { Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff, Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText, Self::Minimal => predict_edits_v3::PromptFormat::Minimal, + Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen, } } }