zeta2: Support experimental 1120-seedcoder model (#43411)

Oleksiy Syvokon created

1. Introduce a common `PromptFormatter` trait
2. Let models define their generation params.
3. Add support for the experimental 1120-seedcoder prompt format


Release Notes:

- N/A

Change summary

crates/cloud_llm_client/src/predict_edits_v3.rs     |   3 
crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs | 152 +++++++++++++-
crates/zeta2/src/zeta2.rs                           |  10 
crates/zeta_cli/src/main.rs                         |   2 
4 files changed, 144 insertions(+), 23 deletions(-)

Detailed changes

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -80,6 +80,8 @@ pub enum PromptFormat {
     Minimal,
     /// One-sentence instructions + FIM-like template
     MinimalQwen,
+    /// No instructions, Qwen chat + Seed-Coder 1120 FIM-like template
+    SeedCoder1120,
 }
 
 impl PromptFormat {
@@ -108,6 +110,7 @@ impl std::fmt::Display for PromptFormat {
             PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
             PromptFormat::Minimal => write!(f, "Minimal"),
             PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
+            PromptFormat::SeedCoder1120 => write!(f, "Seed-Coder 1120"),
         }
     }
 }

crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs 🔗

@@ -169,15 +169,18 @@ pub fn build_prompt(
 ) -> Result<(String, SectionLabels)> {
     let mut section_labels = Default::default();
 
+    let prompt_data = PromptData {
+        events: request.events.clone(),
+        cursor_point: request.cursor_point,
+        cursor_path: request.excerpt_path.clone(),
+        included_files: request.included_files.clone(),
+    };
     match request.prompt_format {
         PromptFormat::MinimalQwen => {
-            let prompt = MinimalQwenPrompt {
-                events: request.events.clone(),
-                cursor_point: request.cursor_point,
-                cursor_path: request.excerpt_path.clone(),
-                included_files: request.included_files.clone(),
-            };
-            return Ok((prompt.render(), section_labels));
+            return Ok((MinimalQwenPrompt.render(&prompt_data), section_labels));
+        }
+        PromptFormat::SeedCoder1120 => {
+            return Ok((SeedCoder1120Prompt.render(&prompt_data), section_labels));
         }
         _ => (),
     };
@@ -208,6 +211,7 @@ pub fn build_prompt(
         }
         PromptFormat::OnlySnippets => vec![],
         PromptFormat::MinimalQwen => unreachable!(),
+        PromptFormat::SeedCoder1120 => unreachable!(),
     };
 
     let mut prompt = match request.prompt_format {
@@ -218,6 +222,7 @@ pub fn build_prompt(
         PromptFormat::OnlySnippets => String::new(),
         PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
         PromptFormat::MinimalQwen => unreachable!(),
+        PromptFormat::SeedCoder1120 => unreachable!(),
     };
 
     if request.events.is_empty() {
@@ -328,6 +333,13 @@ pub fn build_prompt(
     Ok((prompt, section_labels))
 }
 
+pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
+    match prompt_format {
+        PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(),
+        _ => GenerationParams::default(),
+    }
+}
+
 pub fn write_codeblock<'a>(
     path: &Path,
     excerpts: impl IntoIterator<Item = &'a Excerpt>,
@@ -786,6 +798,7 @@ impl<'a> SyntaxBasedPrompt<'a> {
                         }
                     }
                     PromptFormat::MinimalQwen => unreachable!(),
+                    PromptFormat::SeedCoder1120 => unreachable!(),
                 }
 
                 let push_full_snippet = |output: &mut String| {
@@ -896,19 +909,34 @@ fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle
     }
 }
 
-struct MinimalQwenPrompt {
+struct PromptData {
     events: Vec<Event>,
     cursor_point: Point,
     cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
     included_files: Vec<IncludedFile>,
 }
 
-impl MinimalQwenPrompt {
-    const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n";
+#[derive(Default)]
+pub struct GenerationParams {
+    pub temperature: Option<f32>,
+    pub top_p: Option<f32>,
+    pub stop: Option<Vec<String>>,
+}
+
+trait PromptFormatter {
+    fn render(&self, data: &PromptData) -> String;
 
-    fn render(&self) -> String {
-        let edit_history = self.fmt_edit_history();
-        let context = self.fmt_context();
+    fn generation_params() -> GenerationParams {
+        return GenerationParams::default();
+    }
+}
+
+struct MinimalQwenPrompt;
+
+impl PromptFormatter for MinimalQwenPrompt {
+    fn render(&self, data: &PromptData) -> String {
+        let edit_history = self.fmt_edit_history(data);
+        let context = self.fmt_context(data);
 
         format!(
             "{instructions}\n\n{edit_history}\n\n{context}",
@@ -917,13 +945,17 @@ impl MinimalQwenPrompt {
             context = context
         )
     }
+}
 
-    fn fmt_edit_history(&self) -> String {
-        if self.events.is_empty() {
+impl MinimalQwenPrompt {
+    const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n";
+
+    fn fmt_edit_history(&self, data: &PromptData) -> String {
+        if data.events.is_empty() {
             "(No edit history)\n\n".to_string()
         } else {
             let mut events_str = String::new();
-            push_events(&mut events_str, &self.events);
+            push_events(&mut events_str, &data.events);
             format!(
                 "The following are the latest edits made by the user, from earlier to later.\n\n{}",
                 events_str
@@ -931,18 +963,18 @@ impl MinimalQwenPrompt {
         }
     }
 
-    fn fmt_context(&self) -> String {
+    fn fmt_context(&self, data: &PromptData) -> String {
         let mut context = String::new();
         let include_line_numbers = true;
 
-        for related_file in &self.included_files {
+        for related_file in &data.included_files {
             writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap();
 
-            if related_file.path == self.cursor_path {
+            if related_file.path == data.cursor_path {
                 write!(context, "<|fim_prefix|>").unwrap();
                 write_excerpts(
                     &related_file.excerpts,
-                    &[(self.cursor_point, "<|fim_suffix|>")],
+                    &[(data.cursor_point, "<|fim_suffix|>")],
                     related_file.max_row,
                     include_line_numbers,
                     &mut context,
@@ -961,3 +993,83 @@ impl MinimalQwenPrompt {
         context
     }
 }
+
+struct SeedCoder1120Prompt;
+
+impl PromptFormatter for SeedCoder1120Prompt {
+    fn render(&self, data: &PromptData) -> String {
+        let edit_history = self.fmt_edit_history(data);
+        let context = self.fmt_context(data);
+
+        format!(
+            "# Edit History:\n{edit_history}\n\n{context}",
+            edit_history = edit_history,
+            context = context
+        )
+    }
+
+    fn generation_params() -> GenerationParams {
+        GenerationParams {
+            temperature: Some(0.2),
+            top_p: Some(0.9),
+            stop: Some(vec!["<[end_of_sentence]>".into()]),
+        }
+    }
+}
+
+impl SeedCoder1120Prompt {
+    fn fmt_edit_history(&self, data: &PromptData) -> String {
+        if data.events.is_empty() {
+            "(No edit history)\n\n".to_string()
+        } else {
+            let mut events_str = String::new();
+            push_events(&mut events_str, &data.events);
+            events_str
+        }
+    }
+
+    fn fmt_context(&self, data: &PromptData) -> String {
+        let mut context = String::new();
+        let include_line_numbers = true;
+
+        for related_file in &data.included_files {
+            writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap();
+
+            if related_file.path == data.cursor_path {
+                let fim_prompt = self.fmt_fim(&related_file, data.cursor_point);
+                context.push_str(&fim_prompt);
+            } else {
+                write_excerpts(
+                    &related_file.excerpts,
+                    &[],
+                    related_file.max_row,
+                    include_line_numbers,
+                    &mut context,
+                );
+            }
+        }
+        context
+    }
+
+    fn fmt_fim(&self, file: &IncludedFile, cursor_point: Point) -> String {
+        let mut buf = String::new();
+        const FIM_SUFFIX: &str = "<[fim-suffix]>";
+        const FIM_PREFIX: &str = "<[fim-prefix]>";
+        const FIM_MIDDLE: &str = "<[fim-middle]>";
+        write!(buf, "{}", FIM_PREFIX).unwrap();
+        write_excerpts(
+            &file.excerpts,
+            &[(cursor_point, FIM_SUFFIX)],
+            file.max_row,
+            true,
+            &mut buf,
+        );
+
+        // Swap prefix and suffix parts
+        let index = buf.find(FIM_SUFFIX).unwrap();
+        let prefix = &buf[..index];
+        let suffix = &buf[index..];
+
+        format!("{}{}{}", suffix, prefix, FIM_MIDDLE)
+    }
+}

crates/zeta2/src/zeta2.rs 🔗

@@ -1562,6 +1562,8 @@ impl Zeta {
                 }
 
                 let (prompt, _) = prompt_result?;
+                let generation_params =
+                    cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
                 let request = open_ai::Request {
                     model: EDIT_PREDICTIONS_MODEL_ID.clone(),
                     messages: vec![open_ai::RequestMessage::User {
@@ -1569,8 +1571,8 @@ impl Zeta {
                     }],
                     stream: false,
                     max_completion_tokens: None,
-                    stop: Default::default(),
-                    temperature: 0.7,
+                    stop: generation_params.stop.unwrap_or_default(),
+                    temperature: generation_params.temperature.unwrap_or(0.7),
                     tool_choice: None,
                     parallel_tool_calls: None,
                     tools: vec![],
@@ -1636,7 +1638,9 @@ impl Zeta {
                         // TODO: Implement parsing of multi-file diffs
                         crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
                     }
-                    PromptFormat::Minimal | PromptFormat::MinimalQwen => {
+                    PromptFormat::Minimal
+                    | PromptFormat::MinimalQwen
+                    | PromptFormat::SeedCoder1120 => {
                         if output_text.contains("--- a/\n+++ b/\nNo edits") {
                             let edits = vec![];
                             (&active_snapshot, edits)

crates/zeta_cli/src/main.rs 🔗

@@ -230,6 +230,7 @@ enum PromptFormat {
     OldTextNewText,
     Minimal,
     MinimalQwen,
+    SeedCoder1120,
 }
 
 impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
@@ -242,6 +243,7 @@ impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
             Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
             Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
             Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
+            Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
         }
     }
 }