diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index e61cafa6adced14a36a032bf03a77e14be7cb0a2..7f835dfbdf1d564538efc2596771daa20bef9d36 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -33,16 +33,16 @@ use gpui::{ }; use heapless::Vec as ArrayVec; use language::{ - Anchor, Buffer, BufferSnapshot, EditPredictionsMode, EditPreview, File, OffsetRangeExt, Point, - TextBufferSnapshot, ToOffset, ToPoint, language_settings::all_language_settings, + Anchor, Buffer, BufferSnapshot, EditPredictionPromptFormat, EditPredictionsMode, EditPreview, + File, OffsetRangeExt, Point, TextBufferSnapshot, ToOffset, ToPoint, + language_settings::all_language_settings, }; use project::{DisableAiSettings, Project, ProjectPath, WorktreeId}; use release_channel::AppVersion; use semver::Version; use serde::de::DeserializeOwned; use settings::{ - EditPredictionDataCollectionChoice, EditPredictionPromptFormat, EditPredictionProvider, - Settings as _, update_settings_file, + EditPredictionDataCollectionChoice, EditPredictionProvider, Settings as _, update_settings_file, }; use std::collections::{VecDeque, hash_map}; use std::env; diff --git a/crates/edit_prediction/src/fim.rs b/crates/edit_prediction/src/fim.rs index 44a5b2541fbd4c65b15f929ee9fb5a5bd7fe929b..301ca7fb468f96240e5aab04481c93638307b456 100644 --- a/crates/edit_prediction/src/fim.rs +++ b/crates/edit_prediction/src/fim.rs @@ -6,10 +6,9 @@ use crate::{ use anyhow::{Context as _, Result, anyhow}; use gpui::{App, AppContext as _, Entity, Task}; use language::{ - Anchor, Buffer, BufferSnapshot, ToOffset, ToPoint as _, + Anchor, Buffer, BufferSnapshot, EditPredictionPromptFormat, ToOffset, ToPoint as _, language_settings::all_language_settings, }; -use settings::EditPredictionPromptFormat; use std::{path::Path, sync::Arc, time::Instant}; use zeta_prompt::{ZetaPromptInput, compute_editable_and_context_ranges}; diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index 33b347be1757678e733bb56ec9becf904a1fa791..a5637ca3cec6d03f4af51415a94edaa179def01f 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -12,11 +12,10 @@ use cloud_llm_client::{ use edit_prediction_types::PredictedCursorPosition; use gpui::{App, AppContext as _, Entity, Task, TaskExt, WeakEntity, prelude::*}; use language::{ - Buffer, BufferSnapshot, DiagnosticSeverity, OffsetRangeExt as _, ToOffset as _, - language_settings::all_language_settings, text_diff, + Buffer, BufferSnapshot, DiagnosticSeverity, EditPredictionPromptFormat, OffsetRangeExt as _, + ToOffset as _, ZetaVersion, language_settings::all_language_settings, text_diff, }; use release_channel::AppVersion; -use settings::EditPredictionPromptFormat; use text::{Anchor, Bias, Point}; use ui::SharedString; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; @@ -101,10 +100,30 @@ pub fn request_prediction_with_zeta( let request_task = cx.background_spawn({ async move { - let zeta_version = raw_config + let local_zeta_version = custom_server_settings + .as_ref() + .and_then(|settings| match settings.prompt_format { + EditPredictionPromptFormat::Zeta(version) => Some(version), + EditPredictionPromptFormat::Infer => { + match settings.model.to_ascii_lowercase().as_str() { + "zeta" | "zeta1" => Some(ZetaVersion::Zeta1), + "zeta2" => Some(ZetaVersion::Zeta2), + "zeta2.1" => Some(ZetaVersion::Zeta2_1), + _ => None, + } + } + _ => None, + }) + .unwrap_or_default(); + let zeta_format = raw_config .as_ref() .map(|config| config.format) - .unwrap_or(ZetaFormat::default()); + .or(match local_zeta_version { + ZetaVersion::Zeta1 => None, + ZetaVersion::Zeta2 => Some(ZetaFormat::V0211SeedCoder), + ZetaVersion::Zeta2_1 => Some(ZetaFormat::V0318SeedMultiRegions), + }) + .unwrap_or_default(); let cursor_offset = position.to_offset(&snapshot); let (full_context_offset_range, prompt_input) = zeta2_prompt_input( @@ -119,7 +138,7 @@ pub fn request_prediction_with_zeta( repo_url, ); - let formatted_prompt = format_zeta_prompt(&prompt_input, zeta_version); + let formatted_prompt = format_zeta_prompt(&prompt_input, zeta_format); if let Some(debug_tx) = &debug_tx { debug_tx @@ -139,8 +158,8 @@ pub fn request_prediction_with_zeta( (if let Some(custom_settings) = &custom_server_settings { let max_tokens = custom_settings.max_output_tokens * 4; - Some(match custom_settings.prompt_format { - EditPredictionPromptFormat::Zeta => { + Some(match local_zeta_version { + ZetaVersion::Zeta1 => { let ranges = &prompt_input.excerpt_ranges; let editable_range_in_excerpt = ranges.editable_350.clone(); let prompt = zeta1::format_zeta1_from_input( @@ -176,11 +195,11 @@ pub fn request_prediction_with_zeta( (request_id, parsed_output, None, None) } - EditPredictionPromptFormat::Zeta2 => { + ZetaVersion::Zeta2 | ZetaVersion::Zeta2_1 => { let Some(prompt) = formatted_prompt.clone() else { return Ok((None, None)); }; - let prefill = get_prefill(&prompt_input, zeta_version); + let prefill = get_prefill(&prompt_input, zeta_format); let prompt = format!("{prompt}{prefill}"); let (response_text, request_id) = send_custom_server_request( @@ -188,7 +207,7 @@ pub fn request_prediction_with_zeta( custom_settings, prompt, max_tokens, - stop_tokens_for_format(zeta_version) + stop_tokens_for_format(zeta_format) .iter() .map(|token| token.to_string()) .collect(), @@ -204,14 +223,13 @@ pub fn request_prediction_with_zeta( let output = format!("{prefill}{response_text}"); Some(parse_zeta2_model_output( &output, - zeta_version, + zeta_format, &prompt_input, )?) }; (request_id, output_text, None, None) } - _ => anyhow::bail!("unsupported prompt format"), }) } else if let Some(config) = &raw_config { let Some(prompt) = format_zeta_prompt(&prompt_input, config.format) else { diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index a6e05fb586c37a2feaf29aacbd80df3a1af3d822..cb19b5e6dbbeeea7fe63b3e00874a0d7a1d45f62 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -25,7 +25,10 @@ mod toolchain; #[cfg(test)] pub mod buffer_tests; -pub use crate::language_settings::{AutoIndentMode, EditPredictionsMode, IndentGuideSettings}; +pub use crate::language_settings::{ + AutoIndentMode, EditPredictionPromptFormat, EditPredictionsMode, IndentGuideSettings, + ZetaVersion, +}; use anyhow::{Context as _, Result}; use async_trait::async_trait; use collections::{HashMap, HashSet}; diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index 701b363d9eb9ca3d0a7e446081dbe12084513e05..3d90d8d06e65c50aaecdb4018100cebcc0f9e720 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -17,7 +17,7 @@ use settings::{DocumentFoldingRanges, DocumentSymbols, IntoGpui, SemanticTokens} pub use settings::{ AutoIndentMode, CompletionSettingsContent, EditPredictionDataCollectionChoice, - EditPredictionPromptFormat, EditPredictionProvider, EditPredictionsMode, FormatOnSave, + EditPredictionPromptFormatContent, EditPredictionProvider, EditPredictionsMode, FormatOnSave, Formatter, FormatterList, InlayHintKind, LanguageSettingsContent, LineEndingSetting, LspInsertMode, RewrapBehavior, ShowWhitespaceSetting, SoftWrap, WordsCompletionMode, }; @@ -540,6 +540,46 @@ pub struct OpenAiCompatibleEditPredictionSettings { pub prompt_format: EditPredictionPromptFormat, } +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] +pub enum EditPredictionPromptFormat { + #[default] + Infer, + Zeta(ZetaVersion), + CodeLlama, + StarCoder, + DeepseekCoder, + Qwen, + CodeGemma, + Codestral, + Glm, +} + +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] +pub enum ZetaVersion { + Zeta1, + Zeta2, + #[default] // NOTE: make latest version default when adding + Zeta2_1, +} + +impl From for EditPredictionPromptFormat { + fn from(value: EditPredictionPromptFormatContent) -> Self { + match value { + EditPredictionPromptFormatContent::Infer => Self::Infer, + EditPredictionPromptFormatContent::Zeta => Self::Zeta(ZetaVersion::Zeta1), + EditPredictionPromptFormatContent::Zeta2 => Self::Zeta(ZetaVersion::Zeta2), + EditPredictionPromptFormatContent::Zeta2_1 => Self::Zeta(ZetaVersion::Zeta2_1), + EditPredictionPromptFormatContent::CodeLlama => Self::CodeLlama, + EditPredictionPromptFormatContent::StarCoder => Self::StarCoder, + EditPredictionPromptFormatContent::DeepseekCoder => Self::DeepseekCoder, + EditPredictionPromptFormatContent::Qwen => Self::Qwen, + EditPredictionPromptFormatContent::CodeGemma => Self::CodeGemma, + EditPredictionPromptFormatContent::Codestral => Self::Codestral, + EditPredictionPromptFormatContent::Glm => Self::Glm, + } + } +} + impl AllLanguageSettings { /// Returns the [`LanguageSettings`] for the language with the specified name. pub fn language<'a>( @@ -816,7 +856,7 @@ impl settings::Settings for AllLanguageSettings { model: model.0, max_output_tokens: ollama.max_output_tokens.unwrap(), api_url: ollama.api_url.unwrap().into(), - prompt_format: ollama.prompt_format.unwrap(), + prompt_format: ollama.prompt_format.unwrap().into(), }); let openai_compatible_settings = edit_predictions.open_ai_compatible_api.unwrap(); let openai_compatible_settings = openai_compatible_settings @@ -831,7 +871,7 @@ impl settings::Settings for AllLanguageSettings { model, max_output_tokens: openai_compatible_settings.max_output_tokens.unwrap(), api_url: api_url.into(), - prompt_format: openai_compatible_settings.prompt_format.unwrap(), + prompt_format: openai_compatible_settings.prompt_format.unwrap().into(), }); let mut file_types: FxHashMap, (GlobSet, Vec)> = FxHashMap::default(); diff --git a/crates/settings_content/src/language.rs b/crates/settings_content/src/language.rs index d3f0e6a4195bae10b23100fdf584467ff5f7d3b8..081406a6846b816f58fe0a1b3e3353b7a40170fd 100644 --- a/crates/settings_content/src/language.rs +++ b/crates/settings_content/src/language.rs @@ -159,7 +159,7 @@ pub struct CustomEditPredictionProviderSettingsContent { /// The prompt format to use for completions. Set to `""` to have the format be derived from the model name. /// /// Default: "" - pub prompt_format: Option, + pub prompt_format: Option, /// The name of the model. /// /// Default: "" @@ -185,11 +185,12 @@ pub struct CustomEditPredictionProviderSettingsContent { strum::VariantNames, )] #[serde(rename_all = "snake_case")] -pub enum EditPredictionPromptFormat { +pub enum EditPredictionPromptFormatContent { #[default] Infer, Zeta, Zeta2, + Zeta2_1, CodeLlama, StarCoder, DeepseekCoder, @@ -280,7 +281,7 @@ pub struct OllamaEditPredictionSettingsContent { /// The prompt format to use for completions. Set to `""` to have the format be derived from the model name. /// /// Default: "" - pub prompt_format: Option, + pub prompt_format: Option, } /// Controls whether Zed collects training data when using Zed's Edit Predictions. diff --git a/crates/settings_ui/src/settings_ui.rs b/crates/settings_ui/src/settings_ui.rs index 7b88c2affe99aebd43bdd8993403ee53833fc715..a5c36671ea01ee4cfbc5c27b85f8f786b3198713 100644 --- a/crates/settings_ui/src/settings_ui.rs +++ b/crates/settings_ui/src/settings_ui.rs @@ -504,7 +504,7 @@ fn init_renderers(cx: &mut App) { .add_basic_renderer::(render_dropdown) .add_basic_renderer::(render_dropdown) .add_basic_renderer::(render_dropdown) - .add_basic_renderer::(render_dropdown) + .add_basic_renderer::(render_dropdown) .add_basic_renderer::(render_dropdown) .add_basic_renderer::(render_editable_number_field) .add_basic_renderer::(render_editable_number_field) diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 5e41024589df1b96b83562648be4073d2e380659..f0968bf9efe0e1795ba193e5f35be73204ef1f33 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -5,9 +5,14 @@ use copilot::CopilotEditPredictionDelegate; use edit_prediction::{EditPredictionModel, ZedEditPredictionDelegate}; use editor::Editor; use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; -use language::language_settings::{EditPredictionProvider, all_language_settings}; - -use settings::{EditPredictionPromptFormat, SettingsStore}; +use language::{ + ZetaVersion, + language_settings::{ + EditPredictionPromptFormat, EditPredictionProvider, all_language_settings, + }, +}; + +use settings::SettingsStore; use std::{cell::RefCell, rc::Rc, sync::Arc}; use ui::Window; @@ -132,10 +137,7 @@ fn edit_prediction_provider_config_for_settings(cx: &App) -> Option Option { let model_base = model.split(':').next().unwrap_or(model); Some(match model_base { + "zeta2" => EditPredictionPromptFormat::Zeta(ZetaVersion::Zeta2), + "zeta2.1" => EditPredictionPromptFormat::Zeta(ZetaVersion::Zeta2_1), "codellama" | "code-llama" => EditPredictionPromptFormat::CodeLlama, "starcoder" | "starcoder2" | "starcoderbase" => EditPredictionPromptFormat::StarCoder, "deepseek-coder" | "deepseek-coder-v2" => EditPredictionPromptFormat::DeepseekCoder, diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 7bc37ff698c86b567f069fb82fab7b624ca718f3..7bee9fc9b0935201b9d92707b5579af88e1e0fd3 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -99,6 +99,7 @@ pub enum ZetaFormat { #[default] V0131GitMergeMarkersPrefix, V0211Prefill, + #[serde(alias = "Zeta2")] V0211SeedCoder, V0331SeedCoderModelPy, v0226Hashline, @@ -111,6 +112,7 @@ pub enum ZetaFormat { /// V0316, but marker numbers are relative to the cursor block (e.g. -1, -0, +1). V0317SeedMultiRegions, /// V0316 with larger block sizes. + #[serde(alias = "Zeta2.1")] V0318SeedMultiRegions, /// V0318-style markers over the full available current file excerpt with no related files. V0327SingleFile, diff --git a/docs/src/ai/edit-prediction.md b/docs/src/ai/edit-prediction.md index 865693036c2a999e48a72bcef9b536706084db56..1f5b3e8adcee4402d920fac3d66cc0df42c11ed8 100644 --- a/docs/src/ai/edit-prediction.md +++ b/docs/src/ai/edit-prediction.md @@ -286,11 +286,29 @@ After adding your API key, Codestral will appear in the provider dropdown in the } ``` -### Self-Hosted OpenAI-compatible servers +### Local and self-hosted models -You can use any self-hosted server that implements the OpenAI completion API format. This works with vLLM, llama.cpp server, LocalAI, and other compatible servers. +You can use local or self-hosted edit prediction models through Ollama or any server that implements the OpenAI completion API format. This works with Ollama, vLLM, llama.cpp server, LocalAI, and other compatible servers. -#### Configuration +#### Ollama + +Set `ollama` as your provider and configure the local model: + +```json [settings] +{ + "edit_predictions": { + "provider": "ollama", + "ollama": { + "api_url": "http://localhost:11434", + "model": "qwen2.5-coder:7b-base", + "prompt_format": "infer", + "max_output_tokens": 512 + } + } +} +``` + +#### OpenAI-compatible servers Set `open_ai_compatible_api` as your provider and configure the API endpoint: @@ -302,7 +320,7 @@ Set `open_ai_compatible_api` as your provider and configure the API endpoint: "api_url": "http://localhost:8080/v1/completions", "model": "deepseek-coder-6.7b-base", "prompt_format": "deepseek_coder", - "max_output_tokens": 64 + "max_output_tokens": 512 } } } @@ -310,15 +328,55 @@ Set `open_ai_compatible_api` as your provider and configure the API endpoint: The `prompt_format` setting controls how code context is formatted for the model. Use `"infer"` to detect the format from the model name, or specify one explicitly: +- `zeta` - Zeta 1 format +- `zeta2` - Zeta 2 format +- `zeta2_1` - Zeta 2.1 format - `code_llama` - CodeLlama format: `
 prefix  suffix `
 - `star_coder` - StarCoder format: `prefixsuffix`
 - `deepseek_coder` - DeepSeek format with special unicode markers
 - `qwen` - Qwen/CodeGemma format: `<|fim_prefix|>prefix<|fim_suffix|>suffix<|fim_middle|>`
+- `code_gemma` - CodeGemma format: `<|fim_prefix|>prefix<|fim_suffix|>suffix<|fim_middle|>`
 - `codestral` - Codestral format: `[SUFFIX]suffix[PREFIX]prefix`
 - `glm` - GLM-4 format with code markers
 - `infer` - Auto-detect from model name (default)
 
-Your server must implement the OpenAI `/v1/completions` endpoint. Edit predictions will send POST requests with this format:
+With `"prompt_format": "infer"`, Zed automatically uses Zeta 2 format for models named `zeta2` and Zeta 2.1 format for models named `zeta2.1`.
+
+For example, to use Zeta 2 with Ollama:
+
+```json [settings]
+{
+  "edit_predictions": {
+    "provider": "ollama",
+    "ollama": {
+      "api_url": "http://localhost:11434",
+      "model": "zeta2",
+      "prompt_format": "infer",
+      "max_output_tokens": 512
+    }
+  }
+}
+```
+
+To use Zeta 2.1 with an OpenAI-compatible server:
+
+```json [settings]
+{
+  "edit_predictions": {
+    "provider": "open_ai_compatible_api",
+    "open_ai_compatible_api": {
+      "api_url": "http://localhost:8080/v1/completions",
+      "model": "zeta2.1",
+      "prompt_format": "infer",
+      "max_output_tokens": 512
+    }
+  }
+}
+```
+
+You can also set `"prompt_format": "zeta2"` or `"prompt_format": "zeta2_1"` explicitly when the model name does not match.
+
+Your OpenAI-compatible server must implement the OpenAI `/v1/completions` endpoint. Edit predictions will send POST requests with this format:
 
 ```json
 {