Fix zeta2 prompt format selection (#55338)

Ben Kunkle created

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [ ] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes https://github.com/zed-industries/zed/issues/52585

Release Notes:

- Fixed local zeta2 edit predictions using the wrong prompt format.

Change summary

crates/edit_prediction/src/edit_prediction.rs  |  8 +-
crates/edit_prediction/src/fim.rs              |  3 
crates/edit_prediction/src/zeta.rs             | 44 +++++++++---
crates/language/src/language.rs                |  5 +
crates/language/src/language_settings.rs       | 46 ++++++++++++
crates/settings_content/src/language.rs        |  7 +
crates/settings_ui/src/settings_ui.rs          |  2 
crates/zed/src/zed/edit_prediction_registry.rs | 18 +++--
crates/zeta_prompt/src/zeta_prompt.rs          |  2 
docs/src/ai/edit-prediction.md                 | 68 ++++++++++++++++++-
10 files changed, 164 insertions(+), 39 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -33,16 +33,16 @@ use gpui::{
 };
 use heapless::Vec as ArrayVec;
 use language::{
-    Anchor, Buffer, BufferSnapshot, EditPredictionsMode, EditPreview, File, OffsetRangeExt, Point,
-    TextBufferSnapshot, ToOffset, ToPoint, language_settings::all_language_settings,
+    Anchor, Buffer, BufferSnapshot, EditPredictionPromptFormat, EditPredictionsMode, EditPreview,
+    File, OffsetRangeExt, Point, TextBufferSnapshot, ToOffset, ToPoint,
+    language_settings::all_language_settings,
 };
 use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
 use release_channel::AppVersion;
 use semver::Version;
 use serde::de::DeserializeOwned;
 use settings::{
-    EditPredictionDataCollectionChoice, EditPredictionPromptFormat, EditPredictionProvider,
-    Settings as _, update_settings_file,
+    EditPredictionDataCollectionChoice, EditPredictionProvider, Settings as _, update_settings_file,
 };
 use std::collections::{VecDeque, hash_map};
 use std::env;

crates/edit_prediction/src/fim.rs 🔗

@@ -6,10 +6,9 @@ use crate::{
 use anyhow::{Context as _, Result, anyhow};
 use gpui::{App, AppContext as _, Entity, Task};
 use language::{
-    Anchor, Buffer, BufferSnapshot, ToOffset, ToPoint as _,
+    Anchor, Buffer, BufferSnapshot, EditPredictionPromptFormat, ToOffset, ToPoint as _,
     language_settings::all_language_settings,
 };
-use settings::EditPredictionPromptFormat;
 use std::{path::Path, sync::Arc, time::Instant};
 use zeta_prompt::{ZetaPromptInput, compute_editable_and_context_ranges};
 

crates/edit_prediction/src/zeta.rs 🔗

@@ -12,11 +12,10 @@ use cloud_llm_client::{
 use edit_prediction_types::PredictedCursorPosition;
 use gpui::{App, AppContext as _, Entity, Task, TaskExt, WeakEntity, prelude::*};
 use language::{
-    Buffer, BufferSnapshot, DiagnosticSeverity, OffsetRangeExt as _, ToOffset as _,
-    language_settings::all_language_settings, text_diff,
+    Buffer, BufferSnapshot, DiagnosticSeverity, EditPredictionPromptFormat, OffsetRangeExt as _,
+    ToOffset as _, ZetaVersion, language_settings::all_language_settings, text_diff,
 };
 use release_channel::AppVersion;
-use settings::EditPredictionPromptFormat;
 use text::{Anchor, Bias, Point};
 use ui::SharedString;
 use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
@@ -101,10 +100,30 @@ pub fn request_prediction_with_zeta(
 
     let request_task = cx.background_spawn({
         async move {
-            let zeta_version = raw_config
+            let local_zeta_version = custom_server_settings
+                .as_ref()
+                .and_then(|settings| match settings.prompt_format {
+                    EditPredictionPromptFormat::Zeta(version) => Some(version),
+                    EditPredictionPromptFormat::Infer => {
+                        match settings.model.to_ascii_lowercase().as_str() {
+                            "zeta" | "zeta1" => Some(ZetaVersion::Zeta1),
+                            "zeta2" => Some(ZetaVersion::Zeta2),
+                            "zeta2.1" => Some(ZetaVersion::Zeta2_1),
+                            _ => None,
+                        }
+                    }
+                    _ => None,
+                })
+                .unwrap_or_default();
+            let zeta_format = raw_config
                 .as_ref()
                 .map(|config| config.format)
-                .unwrap_or(ZetaFormat::default());
+                .or(match local_zeta_version {
+                    ZetaVersion::Zeta1 => None,
+                    ZetaVersion::Zeta2 => Some(ZetaFormat::V0211SeedCoder),
+                    ZetaVersion::Zeta2_1 => Some(ZetaFormat::V0318SeedMultiRegions),
+                })
+                .unwrap_or_default();
 
             let cursor_offset = position.to_offset(&snapshot);
             let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
@@ -119,7 +138,7 @@ pub fn request_prediction_with_zeta(
                 repo_url,
             );
 
-            let formatted_prompt = format_zeta_prompt(&prompt_input, zeta_version);
+            let formatted_prompt = format_zeta_prompt(&prompt_input, zeta_format);
 
             if let Some(debug_tx) = &debug_tx {
                 debug_tx
@@ -139,8 +158,8 @@ pub fn request_prediction_with_zeta(
                 (if let Some(custom_settings) = &custom_server_settings {
                     let max_tokens = custom_settings.max_output_tokens * 4;
 
-                    Some(match custom_settings.prompt_format {
-                        EditPredictionPromptFormat::Zeta => {
+                    Some(match local_zeta_version {
+                        ZetaVersion::Zeta1 => {
                             let ranges = &prompt_input.excerpt_ranges;
                             let editable_range_in_excerpt = ranges.editable_350.clone();
                             let prompt = zeta1::format_zeta1_from_input(
@@ -176,11 +195,11 @@ pub fn request_prediction_with_zeta(
 
                             (request_id, parsed_output, None, None)
                         }
-                        EditPredictionPromptFormat::Zeta2 => {
+                        ZetaVersion::Zeta2 | ZetaVersion::Zeta2_1 => {
                             let Some(prompt) = formatted_prompt.clone() else {
                                 return Ok((None, None));
                             };
-                            let prefill = get_prefill(&prompt_input, zeta_version);
+                            let prefill = get_prefill(&prompt_input, zeta_format);
                             let prompt = format!("{prompt}{prefill}");
 
                             let (response_text, request_id) = send_custom_server_request(
@@ -188,7 +207,7 @@ pub fn request_prediction_with_zeta(
                                 custom_settings,
                                 prompt,
                                 max_tokens,
-                                stop_tokens_for_format(zeta_version)
+                                stop_tokens_for_format(zeta_format)
                                     .iter()
                                     .map(|token| token.to_string())
                                     .collect(),
@@ -204,14 +223,13 @@ pub fn request_prediction_with_zeta(
                                 let output = format!("{prefill}{response_text}");
                                 Some(parse_zeta2_model_output(
                                     &output,
-                                    zeta_version,
+                                    zeta_format,
                                     &prompt_input,
                                 )?)
                             };
 
                             (request_id, output_text, None, None)
                         }
-                        _ => anyhow::bail!("unsupported prompt format"),
                     })
                 } else if let Some(config) = &raw_config {
                     let Some(prompt) = format_zeta_prompt(&prompt_input, config.format) else {

crates/language/src/language.rs 🔗

@@ -25,7 +25,10 @@ mod toolchain;
 #[cfg(test)]
 pub mod buffer_tests;
 
-pub use crate::language_settings::{AutoIndentMode, EditPredictionsMode, IndentGuideSettings};
+pub use crate::language_settings::{
+    AutoIndentMode, EditPredictionPromptFormat, EditPredictionsMode, IndentGuideSettings,
+    ZetaVersion,
+};
 use anyhow::{Context as _, Result};
 use async_trait::async_trait;
 use collections::{HashMap, HashSet};

crates/language/src/language_settings.rs 🔗

@@ -17,7 +17,7 @@ use settings::{DocumentFoldingRanges, DocumentSymbols, IntoGpui, SemanticTokens}
 
 pub use settings::{
     AutoIndentMode, CompletionSettingsContent, EditPredictionDataCollectionChoice,
-    EditPredictionPromptFormat, EditPredictionProvider, EditPredictionsMode, FormatOnSave,
+    EditPredictionPromptFormatContent, EditPredictionProvider, EditPredictionsMode, FormatOnSave,
     Formatter, FormatterList, InlayHintKind, LanguageSettingsContent, LineEndingSetting,
     LspInsertMode, RewrapBehavior, ShowWhitespaceSetting, SoftWrap, WordsCompletionMode,
 };
@@ -540,6 +540,46 @@ pub struct OpenAiCompatibleEditPredictionSettings {
     pub prompt_format: EditPredictionPromptFormat,
 }
 
+#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
+pub enum EditPredictionPromptFormat {
+    #[default]
+    Infer,
+    Zeta(ZetaVersion),
+    CodeLlama,
+    StarCoder,
+    DeepseekCoder,
+    Qwen,
+    CodeGemma,
+    Codestral,
+    Glm,
+}
+
+#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
+pub enum ZetaVersion {
+    Zeta1,
+    Zeta2,
+    #[default] // NOTE: make latest version default when adding
+    Zeta2_1,
+}
+
+impl From<EditPredictionPromptFormatContent> for EditPredictionPromptFormat {
+    fn from(value: EditPredictionPromptFormatContent) -> Self {
+        match value {
+            EditPredictionPromptFormatContent::Infer => Self::Infer,
+            EditPredictionPromptFormatContent::Zeta => Self::Zeta(ZetaVersion::Zeta1),
+            EditPredictionPromptFormatContent::Zeta2 => Self::Zeta(ZetaVersion::Zeta2),
+            EditPredictionPromptFormatContent::Zeta2_1 => Self::Zeta(ZetaVersion::Zeta2_1),
+            EditPredictionPromptFormatContent::CodeLlama => Self::CodeLlama,
+            EditPredictionPromptFormatContent::StarCoder => Self::StarCoder,
+            EditPredictionPromptFormatContent::DeepseekCoder => Self::DeepseekCoder,
+            EditPredictionPromptFormatContent::Qwen => Self::Qwen,
+            EditPredictionPromptFormatContent::CodeGemma => Self::CodeGemma,
+            EditPredictionPromptFormatContent::Codestral => Self::Codestral,
+            EditPredictionPromptFormatContent::Glm => Self::Glm,
+        }
+    }
+}
+
 impl AllLanguageSettings {
     /// Returns the [`LanguageSettings`] for the language with the specified name.
     pub fn language<'a>(
@@ -816,7 +856,7 @@ impl settings::Settings for AllLanguageSettings {
                 model: model.0,
                 max_output_tokens: ollama.max_output_tokens.unwrap(),
                 api_url: ollama.api_url.unwrap().into(),
-                prompt_format: ollama.prompt_format.unwrap(),
+                prompt_format: ollama.prompt_format.unwrap().into(),
             });
         let openai_compatible_settings = edit_predictions.open_ai_compatible_api.unwrap();
         let openai_compatible_settings = openai_compatible_settings
@@ -831,7 +871,7 @@ impl settings::Settings for AllLanguageSettings {
                 model,
                 max_output_tokens: openai_compatible_settings.max_output_tokens.unwrap(),
                 api_url: api_url.into(),
-                prompt_format: openai_compatible_settings.prompt_format.unwrap(),
+                prompt_format: openai_compatible_settings.prompt_format.unwrap().into(),
             });
 
         let mut file_types: FxHashMap<Arc<str>, (GlobSet, Vec<String>)> = FxHashMap::default();

crates/settings_content/src/language.rs 🔗

@@ -159,7 +159,7 @@ pub struct CustomEditPredictionProviderSettingsContent {
     /// The prompt format to use for completions. Set to `""` to have the format be derived from the model name.
     ///
     /// Default: ""
-    pub prompt_format: Option<EditPredictionPromptFormat>,
+    pub prompt_format: Option<EditPredictionPromptFormatContent>,
     /// The name of the model.
     ///
     /// Default: ""
@@ -185,11 +185,12 @@ pub struct CustomEditPredictionProviderSettingsContent {
     strum::VariantNames,
 )]
 #[serde(rename_all = "snake_case")]
-pub enum EditPredictionPromptFormat {
+pub enum EditPredictionPromptFormatContent {
     #[default]
     Infer,
     Zeta,
     Zeta2,
+    Zeta2_1,
     CodeLlama,
     StarCoder,
     DeepseekCoder,
@@ -280,7 +281,7 @@ pub struct OllamaEditPredictionSettingsContent {
     /// The prompt format to use for completions. Set to `""` to have the format be derived from the model name.
     ///
     /// Default: ""
-    pub prompt_format: Option<EditPredictionPromptFormat>,
+    pub prompt_format: Option<EditPredictionPromptFormatContent>,
 }
 
 /// Controls whether Zed collects training data when using Zed's Edit Predictions.

crates/settings_ui/src/settings_ui.rs 🔗

@@ -504,7 +504,7 @@ fn init_renderers(cx: &mut App) {
         .add_basic_renderer::<settings::AlternateScroll>(render_dropdown)
         .add_basic_renderer::<settings::TerminalBlink>(render_dropdown)
         .add_basic_renderer::<settings::CursorShapeContent>(render_dropdown)
-        .add_basic_renderer::<settings::EditPredictionPromptFormat>(render_dropdown)
+        .add_basic_renderer::<settings::EditPredictionPromptFormatContent>(render_dropdown)
         .add_basic_renderer::<settings::EditPredictionDataCollectionChoice>(render_dropdown)
         .add_basic_renderer::<f32>(render_editable_number_field)
         .add_basic_renderer::<u32>(render_editable_number_field)

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -5,9 +5,14 @@ use copilot::CopilotEditPredictionDelegate;
 use edit_prediction::{EditPredictionModel, ZedEditPredictionDelegate};
 use editor::Editor;
 use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
-use language::language_settings::{EditPredictionProvider, all_language_settings};
-
-use settings::{EditPredictionPromptFormat, SettingsStore};
+use language::{
+    ZetaVersion,
+    language_settings::{
+        EditPredictionPromptFormat, EditPredictionProvider, all_language_settings,
+    },
+};
+
+use settings::SettingsStore;
 use std::{cell::RefCell, rc::Rc, sync::Arc};
 use ui::Window;
 
@@ -132,10 +137,7 @@ fn edit_prediction_provider_config_for_settings(cx: &App) -> Option<EditPredicti
                 }
             }
 
-            if matches!(
-                format,
-                EditPredictionPromptFormat::Zeta | EditPredictionPromptFormat::Zeta2
-            ) {
+            if matches!(format, EditPredictionPromptFormat::Zeta(_)) {
                 Some(EditPredictionProviderConfig::Zed(EditPredictionModel::Zeta))
             } else {
                 Some(EditPredictionProviderConfig::Zed(
@@ -154,6 +156,8 @@ fn infer_prompt_format(model: &str) -> Option<EditPredictionPromptFormat> {
     let model_base = model.split(':').next().unwrap_or(model);
 
     Some(match model_base {
+        "zeta2" => EditPredictionPromptFormat::Zeta(ZetaVersion::Zeta2),
+        "zeta2.1" => EditPredictionPromptFormat::Zeta(ZetaVersion::Zeta2_1),
         "codellama" | "code-llama" => EditPredictionPromptFormat::CodeLlama,
         "starcoder" | "starcoder2" | "starcoderbase" => EditPredictionPromptFormat::StarCoder,
         "deepseek-coder" | "deepseek-coder-v2" => EditPredictionPromptFormat::DeepseekCoder,

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -99,6 +99,7 @@ pub enum ZetaFormat {
     #[default]
     V0131GitMergeMarkersPrefix,
     V0211Prefill,
+    #[serde(alias = "Zeta2")]
     V0211SeedCoder,
     V0331SeedCoderModelPy,
     v0226Hashline,
@@ -111,6 +112,7 @@ pub enum ZetaFormat {
     /// V0316, but marker numbers are relative to the cursor block (e.g. -1, -0, +1).
     V0317SeedMultiRegions,
     /// V0316 with larger block sizes.
+    #[serde(alias = "Zeta2.1")]
     V0318SeedMultiRegions,
     /// V0318-style markers over the full available current file excerpt with no related files.
     V0327SingleFile,

docs/src/ai/edit-prediction.md 🔗

@@ -286,11 +286,29 @@ After adding your API key, Codestral will appear in the provider dropdown in the
 }
 ```
 
-### Self-Hosted OpenAI-compatible servers
+### Local and self-hosted models
 
-You can use any self-hosted server that implements the OpenAI completion API format. This works with vLLM, llama.cpp server, LocalAI, and other compatible servers.
+You can use local or self-hosted edit prediction models through Ollama or any server that implements the OpenAI completion API format. This works with Ollama, vLLM, llama.cpp server, LocalAI, and other compatible servers.
 
-#### Configuration
+#### Ollama
+
+Set `ollama` as your provider and configure the local model:
+
+```json [settings]
+{
+  "edit_predictions": {
+    "provider": "ollama",
+    "ollama": {
+      "api_url": "http://localhost:11434",
+      "model": "qwen2.5-coder:7b-base",
+      "prompt_format": "infer",
+      "max_output_tokens": 512
+    }
+  }
+}
+```
+
+#### OpenAI-compatible servers
 
 Set `open_ai_compatible_api` as your provider and configure the API endpoint:
 
@@ -302,7 +320,7 @@ Set `open_ai_compatible_api` as your provider and configure the API endpoint:
       "api_url": "http://localhost:8080/v1/completions",
       "model": "deepseek-coder-6.7b-base",
       "prompt_format": "deepseek_coder",
-      "max_output_tokens": 64
+      "max_output_tokens": 512
     }
   }
 }
@@ -310,15 +328,55 @@ Set `open_ai_compatible_api` as your provider and configure the API endpoint:
 
 The `prompt_format` setting controls how code context is formatted for the model. Use `"infer"` to detect the format from the model name, or specify one explicitly:
 
+- `zeta` - Zeta 1 format
+- `zeta2` - Zeta 2 format
+- `zeta2_1` - Zeta 2.1 format
 - `code_llama` - CodeLlama format: `<PRE> prefix <SUF> suffix <MID>`
 - `star_coder` - StarCoder format: `<fim_prefix>prefix<fim_suffix>suffix<fim_middle>`
 - `deepseek_coder` - DeepSeek format with special unicode markers
 - `qwen` - Qwen/CodeGemma format: `<|fim_prefix|>prefix<|fim_suffix|>suffix<|fim_middle|>`
+- `code_gemma` - CodeGemma format: `<|fim_prefix|>prefix<|fim_suffix|>suffix<|fim_middle|>`
 - `codestral` - Codestral format: `[SUFFIX]suffix[PREFIX]prefix`
 - `glm` - GLM-4 format with code markers
 - `infer` - Auto-detect from model name (default)
 
-Your server must implement the OpenAI `/v1/completions` endpoint. Edit predictions will send POST requests with this format:
+With `"prompt_format": "infer"`, Zed automatically uses Zeta 2 format for models named `zeta2` and Zeta 2.1 format for models named `zeta2.1`.
+
+For example, to use Zeta 2 with Ollama:
+
+```json [settings]
+{
+  "edit_predictions": {
+    "provider": "ollama",
+    "ollama": {
+      "api_url": "http://localhost:11434",
+      "model": "zeta2",
+      "prompt_format": "infer",
+      "max_output_tokens": 512
+    }
+  }
+}
+```
+
+To use Zeta 2.1 with an OpenAI-compatible server:
+
+```json [settings]
+{
+  "edit_predictions": {
+    "provider": "open_ai_compatible_api",
+    "open_ai_compatible_api": {
+      "api_url": "http://localhost:8080/v1/completions",
+      "model": "zeta2.1",
+      "prompt_format": "infer",
+      "max_output_tokens": 512
+    }
+  }
+}
+```
+
+You can also set `"prompt_format": "zeta2"` or `"prompt_format": "zeta2_1"` explicitly when the model name does not match.
+
+Your OpenAI-compatible server must implement the OpenAI `/v1/completions` endpoint. Edit predictions will send POST requests with this format:
 
 ```json
 {