diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 2d7a1aec52ae9cb007238dbd61e58597a9e81666..32a5a34d9d3b63332008a9f7df84a1990f87f17c 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -80,6 +80,8 @@ pub enum PromptFormat { Minimal, /// One-sentence instructions + FIM-like template MinimalQwen, + /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template + SeedCoder1120, } impl PromptFormat { @@ -108,6 +110,7 @@ impl std::fmt::Display for PromptFormat { PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"), PromptFormat::Minimal => write!(f, "Minimal"), PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"), + PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"), } } } diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index 48ab2097d4ca960c28f7edb498e57ded95e208f7..2ddabf750be763542bfc10b794afcb034ff08443 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -169,15 +169,18 @@ pub fn build_prompt( ) -> Result<(String, SectionLabels)> { let mut section_labels = Default::default(); + let prompt_data = PromptData { + events: request.events.clone(), + cursor_point: request.cursor_point, + cursor_path: request.excerpt_path.clone(), + included_files: request.included_files.clone(), + }; match request.prompt_format { PromptFormat::MinimalQwen => { - let prompt = MinimalQwenPrompt { - events: request.events.clone(), - cursor_point: request.cursor_point, - cursor_path: request.excerpt_path.clone(), - included_files: request.included_files.clone(), - }; - return Ok((prompt.render(), section_labels)); + return Ok((MinimalQwenPrompt.render(&prompt_data), section_labels)); + } + PromptFormat::SeedCoder1120 => { + return Ok((SeedCoder1120Prompt.render(&prompt_data), section_labels)); } _ => (), }; @@ -208,6 +211,7 @@ pub fn build_prompt( } PromptFormat::OnlySnippets => vec![], PromptFormat::MinimalQwen => unreachable!(), + PromptFormat::SeedCoder1120 => unreachable!(), }; let mut prompt = match request.prompt_format { @@ -218,6 +222,7 @@ pub fn build_prompt( PromptFormat::OnlySnippets => String::new(), PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(), PromptFormat::MinimalQwen => unreachable!(), + PromptFormat::SeedCoder1120 => unreachable!(), }; if request.events.is_empty() { @@ -328,6 +333,13 @@ pub fn build_prompt( Ok((prompt, section_labels)) } +pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams { + match prompt_format { + PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(), + _ => GenerationParams::default(), + } +} + pub fn write_codeblock<'a>( path: &Path, excerpts: impl IntoIterator, @@ -786,6 +798,7 @@ impl<'a> SyntaxBasedPrompt<'a> { } } PromptFormat::MinimalQwen => unreachable!(), + PromptFormat::SeedCoder1120 => unreachable!(), } let push_full_snippet = |output: &mut String| { @@ -896,19 +909,34 @@ fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle } } -struct MinimalQwenPrompt { +struct PromptData { events: Vec, cursor_point: Point, cursor_path: Arc, // TODO: make a common struct with cursor_point included_files: Vec, } -impl MinimalQwenPrompt { - const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n"; +#[derive(Default)] +pub struct GenerationParams { + pub temperature: Option, + pub top_p: Option, + pub stop: Option>, +} + +trait PromptFormatter { + fn render(&self, data: &PromptData) -> String; - fn render(&self) -> String { - let edit_history = self.fmt_edit_history(); - let context = self.fmt_context(); + fn generation_params() -> GenerationParams { + return GenerationParams::default(); + } +} + +struct MinimalQwenPrompt; + +impl PromptFormatter for MinimalQwenPrompt { + fn render(&self, data: &PromptData) -> String { + let edit_history = self.fmt_edit_history(data); + let context = self.fmt_context(data); format!( "{instructions}\n\n{edit_history}\n\n{context}", @@ -917,13 +945,17 @@ impl MinimalQwenPrompt { context = context ) } +} - fn fmt_edit_history(&self) -> String { - if self.events.is_empty() { +impl MinimalQwenPrompt { + const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n"; + + fn fmt_edit_history(&self, data: &PromptData) -> String { + if data.events.is_empty() { "(No edit history)\n\n".to_string() } else { let mut events_str = String::new(); - push_events(&mut events_str, &self.events); + push_events(&mut events_str, &data.events); format!( "The following are the latest edits made by the user, from earlier to later.\n\n{}", events_str @@ -931,18 +963,18 @@ impl MinimalQwenPrompt { } } - fn fmt_context(&self) -> String { + fn fmt_context(&self, data: &PromptData) -> String { let mut context = String::new(); let include_line_numbers = true; - for related_file in &self.included_files { + for related_file in &data.included_files { writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap(); - if related_file.path == self.cursor_path { + if related_file.path == data.cursor_path { write!(context, "<|fim_prefix|>").unwrap(); write_excerpts( &related_file.excerpts, - &[(self.cursor_point, "<|fim_suffix|>")], + &[(data.cursor_point, "<|fim_suffix|>")], related_file.max_row, include_line_numbers, &mut context, @@ -961,3 +993,83 @@ impl MinimalQwenPrompt { context } } + +struct SeedCoder1120Prompt; + +impl PromptFormatter for SeedCoder1120Prompt { + fn render(&self, data: &PromptData) -> String { + let edit_history = self.fmt_edit_history(data); + let context = self.fmt_context(data); + + format!( + "# Edit History:\n{edit_history}\n\n{context}", + edit_history = edit_history, + context = context + ) + } + + fn generation_params() -> GenerationParams { + GenerationParams { + temperature: Some(0.2), + top_p: Some(0.9), + stop: Some(vec!["<[end_of_sentence]>".into()]), + } + } +} + +impl SeedCoder1120Prompt { + fn fmt_edit_history(&self, data: &PromptData) -> String { + if data.events.is_empty() { + "(No edit history)\n\n".to_string() + } else { + let mut events_str = String::new(); + push_events(&mut events_str, &data.events); + events_str + } + } + + fn fmt_context(&self, data: &PromptData) -> String { + let mut context = String::new(); + let include_line_numbers = true; + + for related_file in &data.included_files { + writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap(); + + if related_file.path == data.cursor_path { + let fim_prompt = self.fmt_fim(&related_file, data.cursor_point); + context.push_str(&fim_prompt); + } else { + write_excerpts( + &related_file.excerpts, + &[], + related_file.max_row, + include_line_numbers, + &mut context, + ); + } + } + context + } + + fn fmt_fim(&self, file: &IncludedFile, cursor_point: Point) -> String { + let mut buf = String::new(); + const FIM_SUFFIX: &str = "<[fim-suffix]>"; + const FIM_PREFIX: &str = "<[fim-prefix]>"; + const FIM_MIDDLE: &str = "<[fim-middle]>"; + write!(buf, "{}", FIM_PREFIX).unwrap(); + write_excerpts( + &file.excerpts, + &[(cursor_point, FIM_SUFFIX)], + file.max_row, + true, + &mut buf, + ); + + // Swap prefix and suffix parts + let index = buf.find(FIM_SUFFIX).unwrap(); + let prefix = &buf[..index]; + let suffix = &buf[index..]; + + format!("{}{}{}", suffix, prefix, FIM_MIDDLE) + } +} diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 1cee72ce60e2fcc97d2e4f3b50f274d90a080ee9..255b294d7cc25fade197c3a50d39130bc6bb99c5 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -1562,6 +1562,8 @@ impl Zeta { } let (prompt, _) = prompt_result?; + let generation_params = + cloud_zeta2_prompt::generation_params(cloud_request.prompt_format); let request = open_ai::Request { model: EDIT_PREDICTIONS_MODEL_ID.clone(), messages: vec![open_ai::RequestMessage::User { @@ -1569,8 +1571,8 @@ impl Zeta { }], stream: false, max_completion_tokens: None, - stop: Default::default(), - temperature: 0.7, + stop: generation_params.stop.unwrap_or_default(), + temperature: generation_params.temperature.unwrap_or(0.7), tool_choice: None, parallel_tool_calls: None, tools: vec![], @@ -1636,7 +1638,9 @@ impl Zeta { // TODO: Implement parsing of multi-file diffs crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? } - PromptFormat::Minimal | PromptFormat::MinimalQwen => { + PromptFormat::Minimal + | PromptFormat::MinimalQwen + | PromptFormat::SeedCoder1120 => { if output_text.contains("--- a/\n+++ b/\nNo edits") { let edits = vec![]; (&active_snapshot, edits) diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 53f231599b7d0449b1f2a9cdef8227a7c3e6bbd5..914b141915cd3a89cd35a02bc6c9463094f0de96 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -230,6 +230,7 @@ enum PromptFormat { OldTextNewText, Minimal, MinimalQwen, + SeedCoder1120, } impl Into for PromptFormat { @@ -242,6 +243,7 @@ impl Into for PromptFormat { Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText, Self::Minimal => predict_edits_v3::PromptFormat::Minimal, Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen, + Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120, } } }