Detailed changes
@@ -1557,6 +1557,13 @@
"ollama": {
"api_url": "http://localhost:11434",
"model": "qwen2.5-coder:7b-base",
+ "prompt_format": "infer",
+ "max_output_tokens": 64,
+ },
+ "open_ai_compatible_api": {
+ "api_url": "",
+ "model": "",
+ "prompt_format": "infer",
"max_output_tokens": 64,
},
// Whether edit predictions are enabled when editing text threads in the agent panel.
@@ -406,6 +406,7 @@ fn update_command_palette_filter(cx: &mut App) {
EditPredictionProvider::Zed
| EditPredictionProvider::Codestral
| EditPredictionProvider::Ollama
+ | EditPredictionProvider::OpenAiCompatibleApi
| EditPredictionProvider::Sweep
| EditPredictionProvider::Mercury
| EditPredictionProvider::Experimental(_) => {
@@ -35,7 +35,9 @@ use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
use semver::Version;
use serde::de::DeserializeOwned;
-use settings::{EditPredictionProvider, Settings as _, update_settings_file};
+use settings::{
+ EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
+};
use std::collections::{VecDeque, hash_map};
use std::env;
use text::Edit;
@@ -55,6 +57,7 @@ use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_noti
pub mod cursor_excerpt;
pub mod example_spec;
+pub mod fim;
mod license_detection;
pub mod mercury;
pub mod ollama;
@@ -67,15 +70,13 @@ pub mod udiff;
mod capture_example;
mod zed_edit_prediction_delegate;
-pub mod zeta1;
-pub mod zeta2;
+pub mod zeta;
#[cfg(test)]
mod edit_prediction_tests;
use crate::license_detection::LicenseDetectionWatcher;
use crate::mercury::Mercury;
-use crate::ollama::Ollama;
use crate::onboarding_modal::ZedPredictModal;
pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
@@ -138,21 +139,19 @@ pub struct EditPredictionStore {
zeta2_raw_config: Option<Zeta2RawConfig>,
pub sweep_ai: SweepAi,
pub mercury: Mercury,
- pub ollama: Ollama,
data_collection_choice: DataCollectionChoice,
reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
shown_predictions: VecDeque<EditPrediction>,
rated_predictions: HashSet<EditPredictionId>,
}
-#[derive(Copy, Clone, Default, PartialEq, Eq)]
+#[derive(Copy, Clone, PartialEq, Eq)]
pub enum EditPredictionModel {
- #[default]
Zeta1,
Zeta2,
+ Fim { format: EditPredictionPromptFormat },
Sweep,
Mercury,
- Ollama,
}
#[derive(Clone)]
@@ -697,7 +696,6 @@ impl EditPredictionStore {
zeta2_raw_config: Self::zeta2_raw_config_from_env(),
sweep_ai: SweepAi::new(cx),
mercury: Mercury::new(cx),
- ollama: Ollama::new(),
data_collection_choice,
reject_predictions_tx: reject_tx,
@@ -727,7 +725,7 @@ impl EditPredictionStore {
self.zeta2_raw_config.as_ref()
}
- pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
+ pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
use ui::IconName;
match self.edit_prediction_model {
EditPredictionModel::Sweep => {
@@ -747,8 +745,16 @@ impl EditPredictionStore {
.with_down(IconName::ZedPredictDown)
.with_error(IconName::ZedPredictError)
}
- EditPredictionModel::Ollama => {
- edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
+ EditPredictionModel::Fim { .. } => {
+ let settings = &all_language_settings(None, cx).edit_predictions;
+ match settings.provider {
+ EditPredictionProvider::Ollama => {
+ edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
+ }
+ _ => {
+ edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
+ }
+ }
}
}
}
@@ -861,7 +867,10 @@ impl EditPredictionStore {
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
- if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) {
+ if matches!(
+ self.edit_prediction_model,
+ EditPredictionModel::Zeta2 | EditPredictionModel::Zeta1
+ ) {
self.user_store.read(cx).edit_prediction_usage()
} else {
None
@@ -1300,10 +1309,16 @@ impl EditPredictionStore {
cx,
);
}
- EditPredictionModel::Ollama => {}
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
- zeta2::edit_prediction_accepted(self, current_prediction, cx)
+ let is_cloud = !matches!(
+ all_language_settings(None, cx).edit_predictions.provider,
+ EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
+ );
+ if is_cloud {
+ zeta::edit_prediction_accepted(self, current_prediction, cx)
+ }
}
+ EditPredictionModel::Fim { .. } => {}
}
}
@@ -1438,15 +1453,20 @@ impl EditPredictionStore {
) {
match self.edit_prediction_model {
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
- self.reject_predictions_tx
- .unbounded_send(EditPredictionRejection {
- request_id: prediction_id.to_string(),
- reason,
- was_shown,
- })
- .log_err();
+ let is_cloud = !matches!(
+ all_language_settings(None, cx).edit_predictions.provider,
+ EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
+ );
+ if is_cloud {
+ self.reject_predictions_tx
+ .unbounded_send(EditPredictionRejection {
+ request_id: prediction_id.to_string(),
+ reason,
+ was_shown,
+ })
+ .log_err();
+ }
}
- EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
EditPredictionModel::Mercury => {
mercury::edit_prediction_rejected(
prediction_id,
@@ -1456,6 +1476,7 @@ impl EditPredictionStore {
cx,
);
}
+ EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
}
}
@@ -1670,9 +1691,21 @@ impl EditPredictionStore {
}
}
- let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
- let drop_on_cancel = is_ollama;
- let max_pending_predictions = if is_ollama { 1 } else { 2 };
+ let (needs_acceptance_tracking, max_pending_predictions) =
+ match all_language_settings(None, cx).edit_predictions.provider {
+ EditPredictionProvider::Zed
+ | EditPredictionProvider::Sweep
+ | EditPredictionProvider::Mercury
+ | EditPredictionProvider::Experimental(_) => (true, 2),
+ EditPredictionProvider::Ollama => (false, 1),
+ EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
+ EditPredictionProvider::None
+ | EditPredictionProvider::Copilot
+ | EditPredictionProvider::Supermaven
+ | EditPredictionProvider::Codestral => unreachable!(),
+ };
+
+ let drop_on_cancel = !needs_acceptance_tracking;
let throttle_timeout = Self::THROTTLE_TIMEOUT;
let project_state = self.get_or_init_project(&project, cx);
let pending_prediction_id = project_state.next_pending_prediction_id;
@@ -1889,22 +1922,22 @@ impl EditPredictionStore {
user_actions,
};
- let task = match &self.edit_prediction_model {
- EditPredictionModel::Zeta1 => zeta2::request_prediction_with_zeta2(
+ let task = match self.edit_prediction_model {
+ EditPredictionModel::Zeta1 => zeta::request_prediction_with_zeta(
self,
inputs,
Some(zeta_prompt::EditPredictionModelKind::Zeta1),
cx,
),
- EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
+ EditPredictionModel::Zeta2 => zeta::request_prediction_with_zeta(
self,
inputs,
Some(zeta_prompt::EditPredictionModelKind::Zeta2),
cx,
),
+ EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
- EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
};
cx.spawn(async move |this, cx| {
@@ -0,0 +1,227 @@
+use crate::{
+ EditPredictionId, EditPredictionModelInput, cursor_excerpt, prediction::EditPredictionResult,
+ zeta,
+};
+use anyhow::{Context as _, Result, anyhow};
+use gpui::{App, AppContext as _, Entity, Task};
+use language::{
+ Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, ToOffset, ToPoint as _,
+ language_settings::all_language_settings,
+};
+use settings::EditPredictionPromptFormat;
+use std::{path::Path, sync::Arc, time::Instant};
+use zeta_prompt::ZetaPromptInput;
+
+const FIM_CONTEXT_TOKENS: usize = 512;
+
+struct FimRequestOutput {
+ request_id: String,
+ edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
+ snapshot: BufferSnapshot,
+ response_received_at: Instant,
+ inputs: ZetaPromptInput,
+ buffer: Entity<Buffer>,
+ buffer_snapshotted_at: Instant,
+}
+
+pub fn request_prediction(
+ EditPredictionModelInput {
+ buffer,
+ snapshot,
+ position,
+ events,
+ ..
+ }: EditPredictionModelInput,
+ prompt_format: EditPredictionPromptFormat,
+ cx: &mut App,
+) -> Task<Result<Option<EditPredictionResult>>> {
+ let settings = &all_language_settings(None, cx).edit_predictions;
+ let provider = settings.provider;
+
+ let full_path: Arc<Path> = snapshot
+ .file()
+ .map(|file| file.full_path(cx))
+ .unwrap_or_else(|| "untitled".into())
+ .into();
+
+ let http_client = cx.http_client();
+ let cursor_point = position.to_point(&snapshot);
+ let buffer_snapshotted_at = Instant::now();
+
+ let Some(settings) = (match provider {
+ settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
+ settings::EditPredictionProvider::OpenAiCompatibleApi => {
+ settings.open_ai_compatible_api.clone()
+ }
+ _ => None,
+ }) else {
+ return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
+ };
+
+ let result = cx.background_spawn(async move {
+ let (excerpt_range, _) = cursor_excerpt::editable_and_context_ranges_for_cursor_position(
+ cursor_point,
+ &snapshot,
+ FIM_CONTEXT_TOKENS,
+ 0,
+ );
+ let excerpt_offset_range = excerpt_range.to_offset(&snapshot);
+ let cursor_offset = cursor_point.to_offset(&snapshot);
+
+ let inputs = ZetaPromptInput {
+ events,
+ related_files: Vec::new(),
+ cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
+ editable_range_in_excerpt: cursor_offset - excerpt_offset_range.start
+ ..cursor_offset - excerpt_offset_range.start,
+ cursor_path: full_path.clone(),
+ excerpt_start_row: Some(excerpt_range.start.row),
+ cursor_excerpt: snapshot
+ .text_for_range(excerpt_range)
+ .collect::<String>()
+ .into(),
+ excerpt_ranges: None,
+ preferred_model: None,
+ in_open_source_repo: false,
+ can_collect_data: false,
+ };
+
+ let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string();
+ let suffix = inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..].to_string();
+ let prompt = format_fim_prompt(prompt_format, &prefix, &suffix);
+ let stop_tokens = get_fim_stop_tokens();
+
+ let max_tokens = settings.max_output_tokens;
+ let (response_text, request_id) = zeta::send_custom_server_request(
+ provider,
+ &settings,
+ prompt,
+ max_tokens,
+ stop_tokens,
+ &http_client,
+ )
+ .await?;
+
+ let response_received_at = Instant::now();
+
+ log::debug!(
+ "fim: completion received ({:.2}s)",
+ (response_received_at - buffer_snapshotted_at).as_secs_f64()
+ );
+
+ let completion: Arc<str> = clean_fim_completion(&response_text).into();
+ let edits = if completion.is_empty() {
+ vec![]
+ } else {
+ let cursor_offset = cursor_point.to_offset(&snapshot);
+ let anchor = snapshot.anchor_after(cursor_offset);
+ vec![(anchor..anchor, completion)]
+ };
+
+ anyhow::Ok(FimRequestOutput {
+ request_id,
+ edits,
+ snapshot,
+ response_received_at,
+ inputs,
+ buffer,
+ buffer_snapshotted_at,
+ })
+ });
+
+ cx.spawn(async move |cx: &mut gpui::AsyncApp| {
+ let output = result.await.context("fim edit prediction failed")?;
+ anyhow::Ok(Some(
+ EditPredictionResult::new(
+ EditPredictionId(output.request_id.into()),
+ &output.buffer,
+ &output.snapshot,
+ output.edits.into(),
+ None,
+ output.buffer_snapshotted_at,
+ output.response_received_at,
+ output.inputs,
+ cx,
+ )
+ .await,
+ ))
+ })
+}
+
+fn format_fim_prompt(
+ prompt_format: EditPredictionPromptFormat,
+ prefix: &str,
+ suffix: &str,
+) -> String {
+ match prompt_format {
+ EditPredictionPromptFormat::CodeLlama => {
+ format!("<PRE> {prefix} <SUF>{suffix} <MID>")
+ }
+ EditPredictionPromptFormat::StarCoder => {
+ format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
+ }
+ EditPredictionPromptFormat::DeepseekCoder => {
+ format!("<ο½fimβbeginο½>{prefix}<ο½fimβholeο½>{suffix}<ο½fimβendο½>")
+ }
+ EditPredictionPromptFormat::Qwen | EditPredictionPromptFormat::CodeGemma => {
+ format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
+ }
+ EditPredictionPromptFormat::Codestral => {
+ format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
+ }
+ EditPredictionPromptFormat::Glm => {
+ format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
+ }
+ _ => {
+ format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
+ }
+ }
+}
+
+fn get_fim_stop_tokens() -> Vec<String> {
+ vec![
+ "<|endoftext|>".to_string(),
+ "<|file_separator|>".to_string(),
+ "<|fim_pad|>".to_string(),
+ "<|fim_prefix|>".to_string(),
+ "<|fim_middle|>".to_string(),
+ "<|fim_suffix|>".to_string(),
+ "<fim_prefix>".to_string(),
+ "<fim_middle>".to_string(),
+ "<fim_suffix>".to_string(),
+ "<PRE>".to_string(),
+ "<SUF>".to_string(),
+ "<MID>".to_string(),
+ "[PREFIX]".to_string(),
+ "[SUFFIX]".to_string(),
+ ]
+}
+
+fn clean_fim_completion(response: &str) -> String {
+ let mut result = response.to_string();
+
+ let end_tokens = [
+ "<|endoftext|>",
+ "<|file_separator|>",
+ "<|fim_pad|>",
+ "<|fim_prefix|>",
+ "<|fim_middle|>",
+ "<|fim_suffix|>",
+ "<fim_prefix>",
+ "<fim_middle>",
+ "<fim_suffix>",
+ "<PRE>",
+ "<SUF>",
+ "<MID>",
+ "[PREFIX]",
+ "[SUFFIX]",
+ ];
+
+ for token in &end_tokens {
+ if let Some(pos) = result.find(token) {
+ result.truncate(pos);
+ }
+ }
+
+ result
+}
@@ -1,7 +1,7 @@
use crate::{
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
- prediction::EditPredictionResult, zeta1::compute_edits,
+ prediction::EditPredictionResult, zeta::compute_edits,
};
use anyhow::{Context as _, Result};
use cloud_llm_client::EditPredictionRejectReason;
@@ -1,32 +1,16 @@
-use crate::{
- EditPredictionId, EditPredictionModelInput, cursor_excerpt,
- prediction::EditPredictionResult,
- zeta1::{
- self, MAX_CONTEXT_TOKENS as ZETA_MAX_CONTEXT_TOKENS,
- MAX_EVENT_TOKENS as ZETA_MAX_EVENT_TOKENS,
- },
-};
use anyhow::{Context as _, Result};
use futures::AsyncReadExt as _;
-use gpui::{App, AppContext as _, Entity, SharedString, Task, http_client};
-use language::{
- Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, ToOffset, ToPoint as _,
- language_settings::all_language_settings,
+use gpui::{
+ App, SharedString,
+ http_client::{self, HttpClient},
};
+use language::language_settings::OpenAiCompatibleEditPredictionSettings;
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use serde::{Deserialize, Serialize};
-use std::{path::Path, sync::Arc, time::Instant};
-use zeta_prompt::{
- ZetaPromptInput,
- zeta1::{EDITABLE_REGION_END_MARKER, format_zeta1_prompt},
-};
-
-const FIM_CONTEXT_TOKENS: usize = 512;
-
-pub struct Ollama;
+use std::sync::Arc;
#[derive(Debug, Serialize)]
-struct OllamaGenerateRequest {
+pub(crate) struct OllamaGenerateRequest {
model: String,
prompt: String,
raw: bool,
@@ -36,7 +20,7 @@ struct OllamaGenerateRequest {
}
#[derive(Debug, Serialize)]
-struct OllamaGenerateOptions {
+pub(crate) struct OllamaGenerateOptions {
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
@@ -46,9 +30,9 @@ struct OllamaGenerateOptions {
}
#[derive(Debug, Deserialize)]
-struct OllamaGenerateResponse {
- created_at: String,
- response: String,
+pub(crate) struct OllamaGenerateResponse {
+ pub created_at: String,
+ pub response: String,
}
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
@@ -79,332 +63,46 @@ pub fn fetch_models(cx: &mut App) -> Vec<SharedString> {
models
}
-/// Output from the Ollama HTTP request, containing all data needed to create the prediction result.
-struct OllamaRequestOutput {
- created_at: String,
- edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
- snapshot: BufferSnapshot,
- response_received_at: Instant,
- inputs: ZetaPromptInput,
- buffer: Entity<Buffer>,
- buffer_snapshotted_at: Instant,
-}
-
-impl Ollama {
- pub fn new() -> Self {
- Self
- }
-
- pub fn request_prediction(
- &self,
- EditPredictionModelInput {
- buffer,
- snapshot,
- position,
- events,
- ..
- }: EditPredictionModelInput,
- cx: &mut App,
- ) -> Task<Result<Option<EditPredictionResult>>> {
- let settings = &all_language_settings(None, cx).edit_predictions.ollama;
- let Some(model) = settings.model.clone() else {
- return Task::ready(Ok(None));
- };
- let api_url = settings.api_url.clone();
-
- log::debug!("Ollama: Requesting completion (model: {})", model);
-
- let full_path: Arc<Path> = snapshot
- .file()
- .map(|file| file.full_path(cx))
- .unwrap_or_else(|| "untitled".into())
- .into();
-
- let http_client = cx.http_client();
- let cursor_point = position.to_point(&snapshot);
- let buffer_snapshotted_at = Instant::now();
-
- let is_zeta = is_zeta_model(&model);
-
- // Zeta generates more tokens than FIM models. Ideally, we'd use MAX_REWRITE_TOKENS,
- // but this might be too slow for local deployments. So we make it configurable,
- // but we also have this hardcoded multiplier for now.
- let max_output_tokens = if is_zeta {
- settings.max_output_tokens * 4
- } else {
- settings.max_output_tokens
- };
-
- let result = cx.background_spawn(async move {
- let zeta_editable_region_tokens = max_output_tokens as usize;
-
- // For zeta models, use the dedicated zeta1 functions which handle their own
- // range computation with the correct token limits.
- let (prompt, stop_tokens, editable_range_override, inputs) = if is_zeta {
- let path_str = full_path.to_string_lossy();
- let input_excerpt = zeta1::excerpt_for_cursor_position(
- cursor_point,
- &path_str,
- &snapshot,
- zeta_editable_region_tokens,
- ZETA_MAX_CONTEXT_TOKENS,
- );
- let input_events = zeta1::prompt_for_events(&events, ZETA_MAX_EVENT_TOKENS);
- let prompt = format_zeta1_prompt(&input_events, &input_excerpt.prompt);
- let editable_offset_range = input_excerpt.editable_range.to_offset(&snapshot);
- let context_offset_range = input_excerpt.context_range.to_offset(&snapshot);
- let stop_tokens = get_zeta_stop_tokens();
-
- let inputs = ZetaPromptInput {
- events,
- related_files: Vec::new(),
- cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
- - context_offset_range.start,
- cursor_path: full_path.clone(),
- cursor_excerpt: snapshot
- .text_for_range(input_excerpt.context_range.clone())
- .collect::<String>()
- .into(),
- editable_range_in_excerpt: (editable_offset_range.start
- - context_offset_range.start)
- ..(editable_offset_range.end - context_offset_range.start),
- excerpt_start_row: Some(input_excerpt.context_range.start.row),
- excerpt_ranges: None,
- preferred_model: None,
- in_open_source_repo: false,
- can_collect_data: false,
- };
-
- (prompt, stop_tokens, Some(editable_offset_range), inputs)
- } else {
- let (excerpt_range, _) =
- cursor_excerpt::editable_and_context_ranges_for_cursor_position(
- cursor_point,
- &snapshot,
- FIM_CONTEXT_TOKENS,
- 0,
- );
- let excerpt_offset_range = excerpt_range.to_offset(&snapshot);
- let cursor_offset = cursor_point.to_offset(&snapshot);
-
- let inputs = ZetaPromptInput {
- events,
- related_files: Vec::new(),
- cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
- editable_range_in_excerpt: cursor_offset - excerpt_offset_range.start
- ..cursor_offset - excerpt_offset_range.start,
- cursor_path: full_path.clone(),
- excerpt_start_row: Some(excerpt_range.start.row),
- cursor_excerpt: snapshot
- .text_for_range(excerpt_range)
- .collect::<String>()
- .into(),
- excerpt_ranges: None,
- preferred_model: None,
- in_open_source_repo: false,
- can_collect_data: false,
- };
-
- let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string();
- let suffix = inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..].to_string();
- let prompt = format_fim_prompt(&model, &prefix, &suffix);
- let stop_tokens = get_fim_stop_tokens();
-
- (prompt, stop_tokens, None, inputs)
- };
-
- let request = OllamaGenerateRequest {
- model: model.clone(),
- prompt,
- raw: true,
- stream: false,
- options: Some(OllamaGenerateOptions {
- num_predict: Some(max_output_tokens),
- temperature: Some(0.2),
- stop: Some(stop_tokens),
- }),
- };
-
- let request_body = serde_json::to_string(&request)?;
- let http_request = http_client::Request::builder()
- .method(http_client::Method::POST)
- .uri(format!("{}/api/generate", api_url))
- .header("Content-Type", "application/json")
- .body(http_client::AsyncBody::from(request_body))?;
-
- let mut response = http_client.send(http_request).await?;
- let status = response.status();
-
- log::debug!("Ollama: Response status: {}", status);
-
- if !status.is_success() {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- return Err(anyhow::anyhow!("Ollama API error: {} - {}", status, body));
- }
-
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- let ollama_response: OllamaGenerateResponse =
- serde_json::from_str(&body).context("Failed to parse Ollama response")?;
-
- let response_received_at = Instant::now();
-
- log::debug!(
- "Ollama: Completion received ({:.2}s)",
- (response_received_at - buffer_snapshotted_at).as_secs_f64()
- );
-
- let edits = if is_zeta {
- let editable_range =
- editable_range_override.expect("zeta model should have editable range");
-
- log::trace!("ollama response: {}", ollama_response.response);
-
- let response = clean_zeta_completion(&ollama_response.response);
- match zeta1::parse_edits(&response, editable_range, &snapshot) {
- Ok(edits) => edits,
- Err(err) => {
- log::warn!("Ollama zeta: Failed to parse response: {}", err);
- vec![]
- }
- }
- } else {
- let completion: Arc<str> = clean_fim_completion(&ollama_response.response).into();
- if completion.is_empty() {
- vec![]
- } else {
- let cursor_offset = cursor_point.to_offset(&snapshot);
- let anchor = snapshot.anchor_after(cursor_offset);
- vec![(anchor..anchor, completion)]
- }
- };
-
- anyhow::Ok(OllamaRequestOutput {
- created_at: ollama_response.created_at,
- edits,
- snapshot,
- response_received_at,
- inputs,
- buffer,
- buffer_snapshotted_at,
- })
- });
-
- cx.spawn(async move |cx: &mut gpui::AsyncApp| {
- let output = result.await.context("Ollama edit prediction failed")?;
- anyhow::Ok(Some(
- EditPredictionResult::new(
- EditPredictionId(output.created_at.into()),
- &output.buffer,
- &output.snapshot,
- output.edits.into(),
- None,
- output.buffer_snapshotted_at,
- output.response_received_at,
- output.inputs,
- cx,
- )
- .await,
- ))
- })
- }
-}
-
-fn is_zeta_model(model: &str) -> bool {
- model.to_lowercase().contains("zeta")
-}
-
-fn get_zeta_stop_tokens() -> Vec<String> {
- vec![EDITABLE_REGION_END_MARKER.to_string(), "```".to_string()]
-}
+pub(crate) async fn make_request(
+ settings: OpenAiCompatibleEditPredictionSettings,
+ prompt: String,
+ stop_tokens: Vec<String>,
+ http_client: Arc<dyn HttpClient>,
+) -> Result<OllamaGenerateResponse> {
+ let request = OllamaGenerateRequest {
+ model: settings.model.clone(),
+ prompt,
+ raw: true,
+ stream: false,
+ options: Some(OllamaGenerateOptions {
+ num_predict: Some(settings.max_output_tokens),
+ temperature: Some(0.2),
+ stop: Some(stop_tokens),
+ }),
+ };
-fn format_fim_prompt(model: &str, prefix: &str, suffix: &str) -> String {
- let model_base = model.split(':').next().unwrap_or(model);
+ let request_body = serde_json::to_string(&request)?;
+ let http_request = http_client::Request::builder()
+ .method(http_client::Method::POST)
+ .uri(format!("{}/api/generate", settings.api_url))
+ .header("Content-Type", "application/json")
+ .body(http_client::AsyncBody::from(request_body))?;
- match model_base {
- "codellama" | "code-llama" => {
- format!("<PRE> {prefix} <SUF>{suffix} <MID>")
- }
- "starcoder" | "starcoder2" | "starcoderbase" => {
- format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
- }
- "deepseek-coder" | "deepseek-coder-v2" => {
- format!("<ο½fimβbeginο½>{prefix}<ο½fimβholeο½>{suffix}<ο½fimβendο½>")
- }
- "qwen2.5-coder" | "qwen-coder" | "qwen" => {
- format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
- }
- "codegemma" => {
- format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
- }
- "codestral" | "mistral" => {
- format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
- }
- "glm" | "glm-4" | "glm-4.5" => {
- format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
- }
- _ => {
- format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
- }
- }
-}
+ let mut response = http_client.send(http_request).await?;
+ let status = response.status();
-fn get_fim_stop_tokens() -> Vec<String> {
- vec![
- "<|endoftext|>".to_string(),
- "<|file_separator|>".to_string(),
- "<|fim_pad|>".to_string(),
- "<|fim_prefix|>".to_string(),
- "<|fim_middle|>".to_string(),
- "<|fim_suffix|>".to_string(),
- "<fim_prefix>".to_string(),
- "<fim_middle>".to_string(),
- "<fim_suffix>".to_string(),
- "<PRE>".to_string(),
- "<SUF>".to_string(),
- "<MID>".to_string(),
- "[PREFIX]".to_string(),
- "[SUFFIX]".to_string(),
- ]
-}
+ log::debug!("Ollama: Response status: {}", status);
-fn clean_zeta_completion(mut response: &str) -> &str {
- if let Some(last_newline_ix) = response.rfind('\n') {
- let last_line = &response[last_newline_ix + 1..];
- if EDITABLE_REGION_END_MARKER.starts_with(&last_line) {
- response = &response[..last_newline_ix]
- }
+ if !status.is_success() {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow::anyhow!("Ollama API error: {} - {}", status, body));
}
- response
-}
-fn clean_fim_completion(response: &str) -> String {
- let mut result = response.to_string();
-
- let end_tokens = [
- "<|endoftext|>",
- "<|file_separator|>",
- "<|fim_pad|>",
- "<|fim_prefix|>",
- "<|fim_middle|>",
- "<|fim_suffix|>",
- "<fim_prefix>",
- "<fim_middle>",
- "<fim_suffix>",
- "<PRE>",
- "<SUF>",
- "<MID>",
- "[PREFIX]",
- "[SUFFIX]",
- ];
-
- for token in &end_tokens {
- if let Some(pos) = result.find(token) {
- result.truncate(pos);
- }
- }
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
- result
+ let ollama_response: OllamaGenerateResponse =
+ serde_json::from_str(&body).context("Failed to parse Ollama response")?;
+ Ok(ollama_response)
}
@@ -9,7 +9,6 @@ use edit_prediction_types::{
use gpui::{App, Entity, prelude::*};
use language::{Buffer, ToPoint as _};
use project::Project;
-use ui::prelude::*;
use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
@@ -63,22 +62,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
}
fn icons(&self, cx: &App) -> EditPredictionIconSet {
- match self.store.read(cx).edit_prediction_model {
- EditPredictionModel::Sweep => EditPredictionIconSet::new(IconName::SweepAi)
- .with_disabled(IconName::SweepAiDisabled)
- .with_up(IconName::SweepAiUp)
- .with_down(IconName::SweepAiDown)
- .with_error(IconName::SweepAiError),
- EditPredictionModel::Mercury => EditPredictionIconSet::new(IconName::Inception),
- EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
- EditPredictionIconSet::new(IconName::ZedPredict)
- .with_disabled(IconName::ZedPredictDisabled)
- .with_up(IconName::ZedPredictUp)
- .with_down(IconName::ZedPredictDown)
- .with_error(IconName::ZedPredictError)
- }
- EditPredictionModel::Ollama => EditPredictionIconSet::new(IconName::AiOllama),
- }
+ self.store.read(cx).icons(cx)
}
fn data_collection_state(&self, cx: &App) -> DataCollectionState {
@@ -0,0 +1,614 @@
+use crate::cursor_excerpt::{compute_excerpt_ranges, excerpt_ranges_to_byte_offsets};
+use crate::prediction::EditPredictionResult;
+use crate::{
+ CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
+ EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, ollama,
+};
+use anyhow::{Context as _, Result};
+use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
+use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
+use edit_prediction_types::PredictedCursorPosition;
+use futures::AsyncReadExt as _;
+use gpui::{App, AppContext as _, Task, http_client, prelude::*};
+use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
+use language::{BufferSnapshot, OffsetRangeExt as _, ToOffset as _, ToPoint, text_diff};
+use release_channel::AppVersion;
+use text::{Anchor, Bias};
+
+use std::env;
+use std::ops::Range;
+use std::{path::Path, sync::Arc, time::Instant};
+use zeta_prompt::{
+ CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output,
+ format_zeta_prompt, get_prefill, prompt_input_contains_special_tokens,
+ zeta1::{self, EDITABLE_REGION_END_MARKER},
+};
+
+pub const MAX_CONTEXT_TOKENS: usize = 350;
+
+pub fn max_editable_tokens(format: ZetaFormat) -> usize {
+ match format {
+ ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => 150,
+ ZetaFormat::V0114180EditableRegion => 180,
+ ZetaFormat::V0120GitMergeMarkers => 180,
+ ZetaFormat::V0131GitMergeMarkersPrefix => 180,
+ ZetaFormat::V0211Prefill => 180,
+ ZetaFormat::V0211SeedCoder => 180,
+ }
+}
+
+pub fn request_prediction_with_zeta(
+ store: &mut EditPredictionStore,
+ EditPredictionModelInput {
+ buffer,
+ snapshot,
+ position,
+ related_files,
+ events,
+ debug_tx,
+ trigger,
+ project,
+ ..
+ }: EditPredictionModelInput,
+ preferred_model: Option<EditPredictionModelKind>,
+ cx: &mut Context<EditPredictionStore>,
+) -> Task<Result<Option<EditPredictionResult>>> {
+ let settings = &all_language_settings(None, cx).edit_predictions;
+ let provider = settings.provider;
+ let custom_server_settings = match provider {
+ settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
+ settings::EditPredictionProvider::OpenAiCompatibleApi => {
+ settings.open_ai_compatible_api.clone()
+ }
+ _ => None,
+ };
+
+ let http_client = cx.http_client();
+ let buffer_snapshotted_at = Instant::now();
+ let raw_config = store.zeta2_raw_config().cloned();
+
+ let excerpt_path: Arc<Path> = snapshot
+ .file()
+ .map(|file| -> Arc<Path> { file.full_path(cx).into() })
+ .unwrap_or_else(|| Arc::from(Path::new("untitled")));
+
+ let client = store.client.clone();
+ let llm_token = store.llm_token.clone();
+ let app_version = AppVersion::global(cx);
+
+ let is_open_source = snapshot
+ .file()
+ .map_or(false, |file| store.is_file_open_source(&project, file, cx))
+ && events.iter().all(|event| event.in_open_source_repo())
+ && related_files.iter().all(|file| file.in_open_source_repo);
+
+ let can_collect_data = is_open_source && store.is_data_collection_enabled(cx);
+
+ let request_task = cx.background_spawn({
+ async move {
+ let zeta_version = raw_config
+ .as_ref()
+ .map(|config| config.format)
+ .unwrap_or(ZetaFormat::default());
+
+ let cursor_offset = position.to_offset(&snapshot);
+ let (editable_offset_range, prompt_input) = zeta2_prompt_input(
+ &snapshot,
+ related_files,
+ events,
+ excerpt_path,
+ cursor_offset,
+ zeta_version,
+ preferred_model,
+ is_open_source,
+ can_collect_data,
+ );
+
+ if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
+ return Ok((None, None));
+ }
+
+ let is_zeta1 = preferred_model == Some(EditPredictionModelKind::Zeta1);
+ let excerpt_ranges = prompt_input
+ .excerpt_ranges
+ .as_ref()
+ .ok_or_else(|| anyhow::anyhow!("excerpt_ranges missing from prompt input"))?;
+
+ if let Some(debug_tx) = &debug_tx {
+ let prompt = if is_zeta1 {
+ zeta1::format_zeta1_from_input(
+ &prompt_input,
+ excerpt_ranges.editable_350.clone(),
+ excerpt_ranges.editable_350_context_150.clone(),
+ )
+ } else {
+ format_zeta_prompt(&prompt_input, zeta_version)
+ };
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionStarted(
+ EditPredictionStartedDebugEvent {
+ buffer: buffer.downgrade(),
+ prompt: Some(prompt),
+ position,
+ },
+ ))
+ .ok();
+ }
+
+ log::trace!("Sending edit prediction request");
+
+ let (request_id, output_text, usage) =
+ if let Some(custom_settings) = &custom_server_settings {
+ let max_tokens = custom_settings.max_output_tokens * 4;
+
+ if is_zeta1 {
+ let ranges = excerpt_ranges;
+ let prompt = zeta1::format_zeta1_from_input(
+ &prompt_input,
+ ranges.editable_350.clone(),
+ ranges.editable_350_context_150.clone(),
+ );
+ let stop_tokens = vec![
+ EDITABLE_REGION_END_MARKER.to_string(),
+ format!("{EDITABLE_REGION_END_MARKER}\n"),
+ format!("{EDITABLE_REGION_END_MARKER}\n\n"),
+ format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
+ ];
+
+ let (response_text, request_id) = send_custom_server_request(
+ provider,
+ custom_settings,
+ prompt,
+ max_tokens,
+ stop_tokens,
+ &http_client,
+ )
+ .await?;
+
+ let request_id = EditPredictionId(request_id.into());
+ let output_text = zeta1::clean_zeta1_model_output(&response_text);
+
+ (request_id, output_text, None)
+ } else {
+ let prompt = format_zeta_prompt(&prompt_input, zeta_version);
+ let prefill = get_prefill(&prompt_input, zeta_version);
+ let prompt = format!("{prompt}{prefill}");
+
+ let (response_text, request_id) = send_custom_server_request(
+ provider,
+ custom_settings,
+ prompt,
+ max_tokens,
+ vec![],
+ &http_client,
+ )
+ .await?;
+
+ let request_id = EditPredictionId(request_id.into());
+ let output_text = if response_text.is_empty() {
+ None
+ } else {
+ let output = format!("{prefill}{response_text}");
+ Some(clean_zeta2_model_output(&output, zeta_version).to_string())
+ };
+
+ (request_id, output_text, None)
+ }
+ } else if let Some(config) = &raw_config {
+ let prompt = format_zeta_prompt(&prompt_input, config.format);
+ let prefill = get_prefill(&prompt_input, config.format);
+ let prompt = format!("{prompt}{prefill}");
+ let request = RawCompletionRequest {
+ model: config.model_id.clone().unwrap_or_default(),
+ prompt,
+ temperature: None,
+ stop: vec![],
+ max_tokens: Some(2048),
+ environment: Some(config.format.to_string().to_lowercase()),
+ };
+
+ let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
+ request,
+ client,
+ None,
+ llm_token,
+ app_version,
+ )
+ .await?;
+
+ let request_id = EditPredictionId(response.id.clone().into());
+ let output_text = response.choices.pop().map(|choice| {
+ let response = &choice.text;
+ let output = format!("{prefill}{response}");
+ clean_zeta2_model_output(&output, config.format).to_string()
+ });
+
+ (request_id, output_text, usage)
+ } else {
+ // Use V3 endpoint - server handles model/version selection and suffix stripping
+ let (response, usage) = EditPredictionStore::send_v3_request(
+ prompt_input.clone(),
+ client,
+ llm_token,
+ app_version,
+ trigger,
+ )
+ .await?;
+
+ let request_id = EditPredictionId(response.request_id.into());
+ let output_text = if response.output.is_empty() {
+ None
+ } else {
+ Some(response.output)
+ };
+ (request_id, output_text, usage)
+ };
+
+ let received_response_at = Instant::now();
+
+ log::trace!("Got edit prediction response");
+
+ let Some(mut output_text) = output_text else {
+ return Ok((Some((request_id, None)), usage));
+ };
+
+ // Client-side cursor marker processing (applies to both raw and v3 responses)
+ let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
+ if let Some(offset) = cursor_offset_in_output {
+ log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
+ output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
+ }
+
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionFinished(
+ EditPredictionFinishedDebugEvent {
+ buffer: buffer.downgrade(),
+ position,
+ model_output: Some(output_text.clone()),
+ },
+ ))
+ .ok();
+ }
+
+ let mut old_text = snapshot
+ .text_for_range(editable_offset_range.clone())
+ .collect::<String>();
+
+ if !output_text.is_empty() && !output_text.ends_with('\n') {
+ output_text.push('\n');
+ }
+ if !old_text.is_empty() && !old_text.ends_with('\n') {
+ old_text.push('\n');
+ }
+
+ let (edits, cursor_position) = compute_edits_and_cursor_position(
+ old_text,
+ &output_text,
+ editable_offset_range.start,
+ cursor_offset_in_output,
+ &snapshot,
+ );
+
+ anyhow::Ok((
+ Some((
+ request_id,
+ Some((
+ prompt_input,
+ buffer,
+ snapshot.clone(),
+ edits,
+ cursor_position,
+ received_response_at,
+ )),
+ )),
+ usage,
+ ))
+ }
+ });
+
+ cx.spawn(async move |this, cx| {
+ let Some((id, prediction)) =
+ EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
+ else {
+ return Ok(None);
+ };
+
+ let Some((
+ inputs,
+ edited_buffer,
+ edited_buffer_snapshot,
+ edits,
+ cursor_position,
+ received_response_at,
+ )) = prediction
+ else {
+ return Ok(Some(EditPredictionResult {
+ id,
+ prediction: Err(EditPredictionRejectReason::Empty),
+ }));
+ };
+
+ Ok(Some(
+ EditPredictionResult::new(
+ id,
+ &edited_buffer,
+ &edited_buffer_snapshot,
+ edits.into(),
+ cursor_position,
+ buffer_snapshotted_at,
+ received_response_at,
+ inputs,
+ cx,
+ )
+ .await,
+ ))
+ })
+}
+
+pub fn zeta2_prompt_input(
+ snapshot: &language::BufferSnapshot,
+ related_files: Vec<zeta_prompt::RelatedFile>,
+ events: Vec<Arc<zeta_prompt::Event>>,
+ excerpt_path: Arc<Path>,
+ cursor_offset: usize,
+ zeta_format: ZetaFormat,
+ preferred_model: Option<EditPredictionModelKind>,
+ is_open_source: bool,
+ can_collect_data: bool,
+) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
+ let cursor_point = cursor_offset.to_point(snapshot);
+
+ let (full_context, range_points) = compute_excerpt_ranges(cursor_point, snapshot);
+
+ let related_files = crate::filter_redundant_excerpts(
+ related_files,
+ excerpt_path.as_ref(),
+ full_context.start.row..full_context.end.row,
+ );
+
+ let full_context_start_offset = full_context.start.to_offset(snapshot);
+ let full_context_start_row = full_context.start.row;
+
+ let excerpt_ranges =
+ excerpt_ranges_to_byte_offsets(&range_points, full_context_start_offset, snapshot);
+
+ let editable_range = match preferred_model {
+ Some(EditPredictionModelKind::Zeta1) => &range_points.editable_350,
+ _ => match zeta_format {
+ ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => &range_points.editable_150,
+ _ => &range_points.editable_180,
+ },
+ };
+
+ let editable_offset_range = editable_range.to_offset(snapshot);
+ let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
+ let editable_range_in_excerpt = (editable_offset_range.start - full_context_start_offset)
+ ..(editable_offset_range.end - full_context_start_offset);
+
+ let prompt_input = zeta_prompt::ZetaPromptInput {
+ cursor_path: excerpt_path,
+ cursor_excerpt: snapshot
+ .text_for_range(full_context)
+ .collect::<String>()
+ .into(),
+ editable_range_in_excerpt,
+ cursor_offset_in_excerpt,
+ excerpt_start_row: Some(full_context_start_row),
+ events,
+ related_files,
+ excerpt_ranges: Some(excerpt_ranges),
+ preferred_model,
+ in_open_source_repo: is_open_source,
+ can_collect_data,
+ };
+ (editable_offset_range, prompt_input)
+}
+
+pub(crate) async fn send_custom_server_request(
+ provider: settings::EditPredictionProvider,
+ settings: &OpenAiCompatibleEditPredictionSettings,
+ prompt: String,
+ max_tokens: u32,
+ stop_tokens: Vec<String>,
+ http_client: &Arc<dyn http_client::HttpClient>,
+) -> Result<(String, String)> {
+ match provider {
+ settings::EditPredictionProvider::Ollama => {
+ let response =
+ ollama::make_request(settings.clone(), prompt, stop_tokens, http_client.clone())
+ .await?;
+ Ok((response.response, response.created_at))
+ }
+ _ => {
+ let request = RawCompletionRequest {
+ model: settings.model.clone(),
+ prompt,
+ max_tokens: Some(max_tokens),
+ temperature: None,
+ stop: stop_tokens
+ .into_iter()
+ .map(std::borrow::Cow::Owned)
+ .collect(),
+ environment: None,
+ };
+
+ let request_body = serde_json::to_string(&request)?;
+ let http_request = http_client::Request::builder()
+ .method(http_client::Method::POST)
+ .uri(settings.api_url.as_ref())
+ .header("Content-Type", "application/json")
+ .body(http_client::AsyncBody::from(request_body))?;
+
+ let mut response = http_client.send(http_request).await?;
+ let status = response.status();
+
+ if !status.is_success() {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ anyhow::bail!("custom server error: {} - {}", status, body);
+ }
+
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ let parsed: RawCompletionResponse =
+ serde_json::from_str(&body).context("Failed to parse completion response")?;
+ let text = parsed
+ .choices
+ .into_iter()
+ .next()
+ .map(|choice| choice.text)
+ .unwrap_or_default();
+ Ok((text, parsed.id))
+ }
+ }
+}
+
+pub(crate) fn edit_prediction_accepted(
+ store: &EditPredictionStore,
+ current_prediction: CurrentEditPrediction,
+ cx: &App,
+) {
+ let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
+ if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
+ return;
+ }
+
+ let request_id = current_prediction.prediction.id.to_string();
+ let require_auth = custom_accept_url.is_none();
+ let client = store.client.clone();
+ let llm_token = store.llm_token.clone();
+ let app_version = AppVersion::global(cx);
+
+ cx.background_spawn(async move {
+ let url = if let Some(accept_edits_url) = custom_accept_url {
+ gpui::http_client::Url::parse(&accept_edits_url)?
+ } else {
+ client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/accept", &[])?
+ };
+
+ let response = EditPredictionStore::send_api_request::<()>(
+ move |builder| {
+ let req = builder.uri(url.as_ref()).body(
+ serde_json::to_string(&AcceptEditPredictionBody {
+ request_id: request_id.clone(),
+ })?
+ .into(),
+ );
+ Ok(req?)
+ },
+ client,
+ llm_token,
+ app_version,
+ require_auth,
+ )
+ .await;
+
+ response?;
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+}
+
+pub fn compute_edits(
+ old_text: String,
+ new_text: &str,
+ offset: usize,
+ snapshot: &BufferSnapshot,
+) -> Vec<(Range<Anchor>, Arc<str>)> {
+ compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
+}
+
+pub fn compute_edits_and_cursor_position(
+ old_text: String,
+ new_text: &str,
+ offset: usize,
+ cursor_offset_in_new_text: Option<usize>,
+ snapshot: &BufferSnapshot,
+) -> (
+ Vec<(Range<Anchor>, Arc<str>)>,
+ Option<PredictedCursorPosition>,
+) {
+ let diffs = text_diff(&old_text, new_text);
+
+ // Delta represents the cumulative change in byte count from all preceding edits.
+ // new_offset = old_offset + delta, so old_offset = new_offset - delta
+ let mut delta: isize = 0;
+ let mut cursor_position: Option<PredictedCursorPosition> = None;
+ let buffer_len = snapshot.len();
+
+ let edits = diffs
+ .iter()
+ .map(|(raw_old_range, new_text)| {
+ // Compute cursor position if it falls within or before this edit.
+ if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
+ let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
+ let edit_end_in_new = edit_start_in_new + new_text.len();
+
+ if cursor_offset < edit_start_in_new {
+ let cursor_in_old = (cursor_offset as isize - delta) as usize;
+ let buffer_offset = (offset + cursor_in_old).min(buffer_len);
+ cursor_position = Some(PredictedCursorPosition::at_anchor(
+ snapshot.anchor_after(buffer_offset),
+ ));
+ } else if cursor_offset < edit_end_in_new {
+ let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
+ let offset_within_insertion = cursor_offset - edit_start_in_new;
+ cursor_position = Some(PredictedCursorPosition::new(
+ snapshot.anchor_before(buffer_offset),
+ offset_within_insertion,
+ ));
+ }
+
+ delta += new_text.len() as isize - raw_old_range.len() as isize;
+ }
+
+ // Compute the edit with prefix/suffix trimming.
+ let mut old_range = raw_old_range.clone();
+ let old_slice = &old_text[old_range.clone()];
+
+ let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
+ let suffix_len = common_prefix(
+ old_slice[prefix_len..].chars().rev(),
+ new_text[prefix_len..].chars().rev(),
+ );
+
+ old_range.start += offset;
+ old_range.end += offset;
+ old_range.start += prefix_len;
+ old_range.end -= suffix_len;
+
+ old_range.start = old_range.start.min(buffer_len);
+ old_range.end = old_range.end.min(buffer_len);
+
+ let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
+ let range = if old_range.is_empty() {
+ let anchor = snapshot.anchor_after(old_range.start);
+ anchor..anchor
+ } else {
+ snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
+ };
+ (range, new_text)
+ })
+ .collect();
+
+ if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
+ let cursor_in_old = (cursor_offset as isize - delta) as usize;
+ let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
+ cursor_position = Some(PredictedCursorPosition::at_anchor(
+ snapshot.anchor_after(buffer_offset),
+ ));
+ }
+
+ (edits, cursor_position)
+}
+
+fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
+ a.zip(b)
+ .take_while(|(a, b)| a == b)
+ .map(|(a, _)| a.len_utf8())
+ .sum()
+}
@@ -1,451 +0,0 @@
-use std::{fmt::Write, ops::Range, sync::Arc};
-
-use crate::cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count};
-use anyhow::Result;
-use cloud_llm_client::PredictEditsBody;
-use edit_prediction_types::PredictedCursorPosition;
-use language::{Anchor, BufferSnapshot, Point, text_diff};
-use text::Bias;
-use zeta_prompt::{
- Event,
- zeta1::{
- CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
- START_OF_FILE_MARKER,
- },
-};
-
-pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
-pub(crate) const MAX_EVENT_TOKENS: usize = 500;
-
-pub(crate) fn parse_edits(
- output_excerpt: &str,
- editable_range: Range<usize>,
- snapshot: &BufferSnapshot,
-) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
- let content = output_excerpt.replace(CURSOR_MARKER, "");
-
- let start_markers = content
- .match_indices(EDITABLE_REGION_START_MARKER)
- .collect::<Vec<_>>();
- anyhow::ensure!(
- start_markers.len() <= 1,
- "expected at most one start marker, found {}",
- start_markers.len()
- );
-
- let end_markers = content
- .match_indices(EDITABLE_REGION_END_MARKER)
- .collect::<Vec<_>>();
- anyhow::ensure!(
- end_markers.len() <= 1,
- "expected at most one end marker, found {}",
- end_markers.len()
- );
-
- let sof_markers = content
- .match_indices(START_OF_FILE_MARKER)
- .collect::<Vec<_>>();
- anyhow::ensure!(
- sof_markers.len() <= 1,
- "expected at most one start-of-file marker, found {}",
- sof_markers.len()
- );
-
- let content_start = start_markers
- .first()
- .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len())
- .map(|start| {
- if content.len() > start
- && content.is_char_boundary(start)
- && content[start..].starts_with('\n')
- {
- start + 1
- } else {
- start
- }
- })
- .unwrap_or(0);
- let content_end = end_markers
- .first()
- .map(|e| {
- if e.0 > 0 && content.is_char_boundary(e.0 - 1) && content[e.0 - 1..].starts_with('\n')
- {
- e.0 - 1
- } else {
- e.0
- }
- })
- .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
-
- // min to account for content_end and content_start both accounting for the same newline in the following case:
- // <|editable_region_start|>\n<|editable_region_end|>
- let new_text = &content[content_start.min(content_end)..content_end];
-
- let old_text = snapshot
- .text_for_range(editable_range.clone())
- .collect::<String>();
-
- Ok(compute_edits(
- old_text,
- new_text,
- editable_range.start,
- snapshot,
- ))
-}
-
-pub fn compute_edits(
- old_text: String,
- new_text: &str,
- offset: usize,
- snapshot: &BufferSnapshot,
-) -> Vec<(Range<Anchor>, Arc<str>)> {
- compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
-}
-
-pub fn compute_edits_and_cursor_position(
- old_text: String,
- new_text: &str,
- offset: usize,
- cursor_offset_in_new_text: Option<usize>,
- snapshot: &BufferSnapshot,
-) -> (
- Vec<(Range<Anchor>, Arc<str>)>,
- Option<PredictedCursorPosition>,
-) {
- let diffs = text_diff(&old_text, new_text);
-
- // Delta represents the cumulative change in byte count from all preceding edits.
- // new_offset = old_offset + delta, so old_offset = new_offset - delta
- let mut delta: isize = 0;
- let mut cursor_position: Option<PredictedCursorPosition> = None;
- let buffer_len = snapshot.len();
-
- let edits = diffs
- .iter()
- .map(|(raw_old_range, new_text)| {
- // Compute cursor position if it falls within or before this edit.
- if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
- let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
- let edit_end_in_new = edit_start_in_new + new_text.len();
-
- if cursor_offset < edit_start_in_new {
- let cursor_in_old = (cursor_offset as isize - delta) as usize;
- let buffer_offset = (offset + cursor_in_old).min(buffer_len);
- cursor_position = Some(PredictedCursorPosition::at_anchor(
- snapshot.anchor_after(buffer_offset),
- ));
- } else if cursor_offset < edit_end_in_new {
- let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
- let offset_within_insertion = cursor_offset - edit_start_in_new;
- cursor_position = Some(PredictedCursorPosition::new(
- snapshot.anchor_before(buffer_offset),
- offset_within_insertion,
- ));
- }
-
- delta += new_text.len() as isize - raw_old_range.len() as isize;
- }
-
- // Compute the edit with prefix/suffix trimming.
- let mut old_range = raw_old_range.clone();
- let old_slice = &old_text[old_range.clone()];
-
- let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
- let suffix_len = common_prefix(
- old_slice[prefix_len..].chars().rev(),
- new_text[prefix_len..].chars().rev(),
- );
-
- old_range.start += offset;
- old_range.end += offset;
- old_range.start += prefix_len;
- old_range.end -= suffix_len;
-
- old_range.start = old_range.start.min(buffer_len);
- old_range.end = old_range.end.min(buffer_len);
-
- let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
- let range = if old_range.is_empty() {
- let anchor = snapshot.anchor_after(old_range.start);
- anchor..anchor
- } else {
- snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
- };
- (range, new_text)
- })
- .collect();
-
- if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
- let cursor_in_old = (cursor_offset as isize - delta) as usize;
- let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
- cursor_position = Some(PredictedCursorPosition::at_anchor(
- snapshot.anchor_after(buffer_offset),
- ));
- }
-
- (edits, cursor_position)
-}
-
-fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
- a.zip(b)
- .take_while(|(a, b)| a == b)
- .map(|(a, _)| a.len_utf8())
- .sum()
-}
-
-pub struct GatherContextOutput {
- pub body: PredictEditsBody,
- pub context_range: Range<Point>,
- pub editable_range: Range<usize>,
- pub included_events_count: usize,
-}
-
-pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
- prompt_for_events_impl(events, max_tokens).0
-}
-
-fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
- let mut result = String::new();
- for (ix, event) in events.iter().rev().enumerate() {
- let event_string = format_event(event.as_ref());
- let event_tokens = guess_token_count(event_string.len());
- if event_tokens > remaining_tokens {
- return (result, ix);
- }
-
- if !result.is_empty() {
- result.insert_str(0, "\n\n");
- }
- result.insert_str(0, &event_string);
- remaining_tokens -= event_tokens;
- }
- return (result, events.len());
-}
-
-pub fn format_event(event: &Event) -> String {
- match event {
- Event::BufferChange {
- path,
- old_path,
- diff,
- ..
- } => {
- let mut prompt = String::new();
-
- if old_path != path {
- writeln!(
- prompt,
- "User renamed {} to {}\n",
- old_path.display(),
- path.display()
- )
- .unwrap();
- }
-
- if !diff.is_empty() {
- write!(
- prompt,
- "User edited {}:\n```diff\n{}\n```",
- path.display(),
- diff
- )
- .unwrap();
- }
-
- prompt
- }
- }
-}
-
-#[derive(Debug)]
-pub struct InputExcerpt {
- pub context_range: Range<Point>,
- pub editable_range: Range<Point>,
- pub prompt: String,
-}
-
-pub fn excerpt_for_cursor_position(
- position: Point,
- path: &str,
- snapshot: &BufferSnapshot,
- editable_region_token_limit: usize,
- context_token_limit: usize,
-) -> InputExcerpt {
- let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
- position,
- snapshot,
- editable_region_token_limit,
- context_token_limit,
- );
-
- let mut prompt = String::new();
-
- writeln!(&mut prompt, "```{path}").unwrap();
- if context_range.start == Point::zero() {
- writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
- }
-
- for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
- prompt.push_str(chunk.text);
- }
-
- push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
-
- for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
- prompt.push_str(chunk.text);
- }
- write!(prompt, "\n```").unwrap();
-
- InputExcerpt {
- context_range,
- editable_range,
- prompt,
- }
-}
-
-fn push_editable_range(
- cursor_position: Point,
- snapshot: &BufferSnapshot,
- editable_range: Range<Point>,
- prompt: &mut String,
-) {
- writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
- for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
- prompt.push_str(chunk.text);
- }
- prompt.push_str(CURSOR_MARKER);
- for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
- prompt.push_str(chunk.text);
- }
- write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use gpui::{App, AppContext};
- use indoc::indoc;
- use language::Buffer;
- use text::OffsetRangeExt as _;
-
- #[gpui::test]
- fn test_excerpt_for_cursor_position(cx: &mut App) {
- let text = indoc! {r#"
- fn foo() {
- let x = 42;
- println!("Hello, world!");
- }
-
- fn bar() {
- let x = 42;
- let mut sum = 0;
- for i in 0..x {
- sum += i;
- }
- println!("Sum: {}", sum);
- return sum;
- }
-
- fn generate_random_numbers() -> Vec<i32> {
- let mut rng = rand::thread_rng();
- let mut numbers = Vec::new();
- for _ in 0..5 {
- numbers.push(rng.random_range(1..101));
- }
- numbers
- }
- "#};
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
- let snapshot = buffer.read(cx).snapshot();
-
- // The excerpt expands to syntax boundaries.
- // With 50 token editable limit, we get a region that expands to syntax nodes.
- let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
- assert_eq!(
- excerpt.prompt,
- indoc! {r#"
- ```main.rs
-
- fn bar() {
- let x = 42;
- <|editable_region_start|>
- let mut sum = 0;
- for i in 0..x {
- sum += i;
- }
- println!("Sum: {}", sum);
- r<|user_cursor_is_here|>eturn sum;
- }
-
- fn generate_random_numbers() -> Vec<i32> {
- <|editable_region_end|>
- let mut rng = rand::thread_rng();
- let mut numbers = Vec::new();
- ```"#}
- );
-
- // With smaller budget, the region expands to syntax boundaries but is tighter.
- let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
- assert_eq!(
- excerpt.prompt,
- indoc! {r#"
- ```main.rs
- fn bar() {
- let x = 42;
- let mut sum = 0;
- for i in 0..x {
- <|editable_region_start|>
- sum += i;
- }
- println!("Sum: {}", sum);
- r<|user_cursor_is_here|>eturn sum;
- }
-
- fn generate_random_numbers() -> Vec<i32> {
- <|editable_region_end|>
- let mut rng = rand::thread_rng();
- ```"#}
- );
- }
-
- #[gpui::test]
- fn test_parse_edits_empty_editable_region(cx: &mut App) {
- let text = "fn foo() {\n let x = 42;\n}\n";
- let buffer = cx.new(|cx| Buffer::local(text, cx));
- let snapshot = buffer.read(cx).snapshot();
-
- let output = "<|editable_region_start|>\n<|editable_region_end|>";
- let editable_range = 0..text.len();
- let edits = parse_edits(output, editable_range, &snapshot).unwrap();
- assert_eq!(edits.len(), 1);
- let (range, new_text) = &edits[0];
- assert_eq!(range.to_offset(&snapshot), 0..text.len(),);
- assert_eq!(new_text.as_ref(), "");
- }
-
- #[gpui::test]
- fn test_parse_edits_multibyte_char_before_end_marker(cx: &mut App) {
- let text = "// cafΓ©";
- let buffer = cx.new(|cx| Buffer::local(text, cx));
- let snapshot = buffer.read(cx).snapshot();
-
- let output = "<|editable_region_start|>\n// cafΓ©<|editable_region_end|>";
- let editable_range = 0..text.len();
-
- let edits = parse_edits(output, editable_range, &snapshot).unwrap();
- assert_eq!(edits, vec![]);
- }
-
- #[gpui::test]
- fn test_parse_edits_multibyte_char_after_start_marker(cx: &mut App) {
- let text = "Γ© is great";
- let buffer = cx.new(|cx| Buffer::local(text, cx));
- let snapshot = buffer.read(cx).snapshot();
-
- let output = "<|editable_region_start|>Γ© is great\n<|editable_region_end|>";
- let editable_range = 0..text.len();
-
- let edits = parse_edits(output, editable_range, &snapshot).unwrap();
- assert!(edits.is_empty());
- }
-}
@@ -1,367 +0,0 @@
-use crate::cursor_excerpt::{compute_excerpt_ranges, excerpt_ranges_to_byte_offsets};
-use crate::prediction::EditPredictionResult;
-use crate::zeta1::compute_edits_and_cursor_position;
-use crate::{
- CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
- EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
-};
-use anyhow::Result;
-use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
-use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
-use gpui::{App, Task, prelude::*};
-use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
-use release_channel::AppVersion;
-
-use std::env;
-use std::{path::Path, sync::Arc, time::Instant};
-use zeta_prompt::{
- CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output,
- format_zeta_prompt, get_prefill, prompt_input_contains_special_tokens,
-};
-
-pub const MAX_CONTEXT_TOKENS: usize = 350;
-
-pub fn max_editable_tokens(format: ZetaFormat) -> usize {
- match format {
- ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => 150,
- ZetaFormat::V0114180EditableRegion => 180,
- ZetaFormat::V0120GitMergeMarkers => 180,
- ZetaFormat::V0131GitMergeMarkersPrefix => 180,
- ZetaFormat::V0211Prefill => 180,
- ZetaFormat::V0211SeedCoder => 180,
- }
-}
-
-pub fn request_prediction_with_zeta2(
- store: &mut EditPredictionStore,
- EditPredictionModelInput {
- buffer,
- snapshot,
- position,
- related_files,
- events,
- debug_tx,
- trigger,
- project,
- ..
- }: EditPredictionModelInput,
- preferred_model: Option<EditPredictionModelKind>,
- cx: &mut Context<EditPredictionStore>,
-) -> Task<Result<Option<EditPredictionResult>>> {
- let buffer_snapshotted_at = Instant::now();
- let raw_config = store.zeta2_raw_config().cloned();
-
- let excerpt_path: Arc<Path> = snapshot
- .file()
- .map(|file| -> Arc<Path> { file.full_path(cx).into() })
- .unwrap_or_else(|| Arc::from(Path::new("untitled")));
-
- let client = store.client.clone();
- let llm_token = store.llm_token.clone();
- let app_version = AppVersion::global(cx);
-
- let is_open_source = snapshot
- .file()
- .map_or(false, |file| store.is_file_open_source(&project, file, cx))
- && events.iter().all(|event| event.in_open_source_repo())
- && related_files.iter().all(|file| file.in_open_source_repo);
-
- let can_collect_data = is_open_source && store.is_data_collection_enabled(cx);
-
- let request_task = cx.background_spawn({
- async move {
- let zeta_version = raw_config
- .as_ref()
- .map(|config| config.format)
- .unwrap_or(ZetaFormat::default());
-
- let cursor_offset = position.to_offset(&snapshot);
- let (editable_offset_range, prompt_input) = zeta2_prompt_input(
- &snapshot,
- related_files,
- events,
- excerpt_path,
- cursor_offset,
- zeta_version,
- preferred_model,
- is_open_source,
- can_collect_data,
- );
-
- if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
- return Ok((None, None));
- }
-
- if let Some(debug_tx) = &debug_tx {
- let prompt = format_zeta_prompt(&prompt_input, zeta_version);
- debug_tx
- .unbounded_send(DebugEvent::EditPredictionStarted(
- EditPredictionStartedDebugEvent {
- buffer: buffer.downgrade(),
- prompt: Some(prompt),
- position,
- },
- ))
- .ok();
- }
-
- log::trace!("Sending edit prediction request");
-
- let (request_id, output_text, usage) = if let Some(config) = &raw_config {
- let prompt = format_zeta_prompt(&prompt_input, config.format);
- let prefill = get_prefill(&prompt_input, config.format);
- let prompt = format!("{prompt}{prefill}");
- let request = RawCompletionRequest {
- model: config.model_id.clone().unwrap_or_default(),
- prompt,
- temperature: None,
- stop: vec![],
- max_tokens: Some(2048),
- environment: Some(config.format.to_string().to_lowercase()),
- };
-
- let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
- request,
- client,
- None,
- llm_token,
- app_version,
- )
- .await?;
-
- let request_id = EditPredictionId(response.id.clone().into());
- let output_text = response.choices.pop().map(|choice| {
- let response = &choice.text;
- let output = format!("{prefill}{response}");
- clean_zeta2_model_output(&output, config.format).to_string()
- });
-
- (request_id, output_text, usage)
- } else {
- // Use V3 endpoint - server handles model/version selection and suffix stripping
- let (response, usage) = EditPredictionStore::send_v3_request(
- prompt_input.clone(),
- client,
- llm_token,
- app_version,
- trigger,
- )
- .await?;
-
- let request_id = EditPredictionId(response.request_id.into());
- let output_text = if response.output.is_empty() {
- None
- } else {
- Some(response.output)
- };
- (request_id, output_text, usage)
- };
-
- let received_response_at = Instant::now();
-
- log::trace!("Got edit prediction response");
-
- let Some(mut output_text) = output_text else {
- return Ok((Some((request_id, None)), usage));
- };
-
- // Client-side cursor marker processing (applies to both raw and v3 responses)
- let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
- if let Some(offset) = cursor_offset_in_output {
- log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
- output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
- }
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(DebugEvent::EditPredictionFinished(
- EditPredictionFinishedDebugEvent {
- buffer: buffer.downgrade(),
- position,
- model_output: Some(output_text.clone()),
- },
- ))
- .ok();
- }
-
- let mut old_text = snapshot
- .text_for_range(editable_offset_range.clone())
- .collect::<String>();
-
- if !output_text.is_empty() && !output_text.ends_with('\n') {
- output_text.push('\n');
- }
- if !old_text.is_empty() && !old_text.ends_with('\n') {
- old_text.push('\n');
- }
-
- let (edits, cursor_position) = compute_edits_and_cursor_position(
- old_text,
- &output_text,
- editable_offset_range.start,
- cursor_offset_in_output,
- &snapshot,
- );
-
- anyhow::Ok((
- Some((
- request_id,
- Some((
- prompt_input,
- buffer,
- snapshot.clone(),
- edits,
- cursor_position,
- received_response_at,
- )),
- )),
- usage,
- ))
- }
- });
-
- cx.spawn(async move |this, cx| {
- let Some((id, prediction)) =
- EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
- else {
- return Ok(None);
- };
-
- let Some((
- inputs,
- edited_buffer,
- edited_buffer_snapshot,
- edits,
- cursor_position,
- received_response_at,
- )) = prediction
- else {
- return Ok(Some(EditPredictionResult {
- id,
- prediction: Err(EditPredictionRejectReason::Empty),
- }));
- };
-
- Ok(Some(
- EditPredictionResult::new(
- id,
- &edited_buffer,
- &edited_buffer_snapshot,
- edits.into(),
- cursor_position,
- buffer_snapshotted_at,
- received_response_at,
- inputs,
- cx,
- )
- .await,
- ))
- })
-}
-
-pub fn zeta2_prompt_input(
- snapshot: &language::BufferSnapshot,
- related_files: Vec<zeta_prompt::RelatedFile>,
- events: Vec<Arc<zeta_prompt::Event>>,
- excerpt_path: Arc<Path>,
- cursor_offset: usize,
- zeta_format: ZetaFormat,
- preferred_model: Option<EditPredictionModelKind>,
- is_open_source: bool,
- can_collect_data: bool,
-) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
- let cursor_point = cursor_offset.to_point(snapshot);
-
- let (full_context, range_points) = compute_excerpt_ranges(cursor_point, snapshot);
-
- let related_files = crate::filter_redundant_excerpts(
- related_files,
- excerpt_path.as_ref(),
- full_context.start.row..full_context.end.row,
- );
-
- let full_context_start_offset = full_context.start.to_offset(snapshot);
- let full_context_start_row = full_context.start.row;
-
- let excerpt_ranges =
- excerpt_ranges_to_byte_offsets(&range_points, full_context_start_offset, snapshot);
-
- let editable_range = match preferred_model {
- Some(EditPredictionModelKind::Zeta1) => &range_points.editable_350,
- _ => match zeta_format {
- ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => &range_points.editable_150,
- _ => &range_points.editable_180,
- },
- };
-
- let editable_offset_range = editable_range.to_offset(snapshot);
- let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
- let editable_range_in_excerpt = (editable_offset_range.start - full_context_start_offset)
- ..(editable_offset_range.end - full_context_start_offset);
-
- let prompt_input = zeta_prompt::ZetaPromptInput {
- cursor_path: excerpt_path,
- cursor_excerpt: snapshot
- .text_for_range(full_context)
- .collect::<String>()
- .into(),
- editable_range_in_excerpt,
- cursor_offset_in_excerpt,
- excerpt_start_row: Some(full_context_start_row),
- events,
- related_files,
- excerpt_ranges: Some(excerpt_ranges),
- preferred_model,
- in_open_source_repo: is_open_source,
- can_collect_data,
- };
- (editable_offset_range, prompt_input)
-}
-
-pub(crate) fn edit_prediction_accepted(
- store: &EditPredictionStore,
- current_prediction: CurrentEditPrediction,
- cx: &App,
-) {
- let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
- if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
- return;
- }
-
- let request_id = current_prediction.prediction.id.to_string();
- let require_auth = custom_accept_url.is_none();
- let client = store.client.clone();
- let llm_token = store.llm_token.clone();
- let app_version = AppVersion::global(cx);
-
- cx.background_spawn(async move {
- let url = if let Some(accept_edits_url) = custom_accept_url {
- gpui::http_client::Url::parse(&accept_edits_url)?
- } else {
- client
- .http_client()
- .build_zed_llm_url("/predict_edits/accept", &[])?
- };
-
- let response = EditPredictionStore::send_api_request::<()>(
- move |builder| {
- let req = builder.uri(url.as_ref()).body(
- serde_json::to_string(&AcceptEditPredictionBody {
- request_id: request_id.clone(),
- })?
- .into(),
- );
- Ok(req?)
- },
- client,
- llm_token,
- app_version,
- require_auth,
- )
- .await;
-
- response?;
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
-}
@@ -54,8 +54,8 @@ pub async fn run_format_prompt(
let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
cursor_point,
&snapshot,
- edit_prediction::zeta2::max_editable_tokens(ZetaFormat::default()),
- edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
+ edit_prediction::zeta::max_editable_tokens(ZetaFormat::default()),
+ edit_prediction::zeta::MAX_CONTEXT_TOKENS,
);
let editable_range = editable_range.to_offset(&snapshot);
let context_range = context_range.to_offset(&snapshot);
@@ -75,8 +75,8 @@ pub async fn run_format_prompt(
let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
cursor_point,
&snapshot,
- edit_prediction::zeta2::max_editable_tokens(version),
- edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
+ edit_prediction::zeta::max_editable_tokens(version),
+ edit_prediction::zeta::MAX_CONTEXT_TOKENS,
);
let editable_range = editable_range.to_offset(&snapshot);
let context_range = context_range.to_offset(&snapshot);
@@ -188,7 +188,6 @@ impl Render for EditPredictionButton {
.with_handle(self.popover_menu_handle.clone()),
)
}
-
EditPredictionProvider::Supermaven => {
let Some(supermaven) = Supermaven::global(cx) else {
return div();
@@ -284,7 +283,6 @@ impl Render for EditPredictionButton {
.with_handle(self.popover_menu_handle.clone()),
)
}
-
EditPredictionProvider::Codestral => {
let enabled = self.editor_enabled.unwrap_or(true);
let has_api_key = codestral::codestral_api_key(cx).is_some();
@@ -349,6 +347,36 @@ impl Render for EditPredictionButton {
.with_handle(self.popover_menu_handle.clone()),
)
}
+ EditPredictionProvider::OpenAiCompatibleApi => {
+ let enabled = self.editor_enabled.unwrap_or(true);
+ let this = cx.weak_entity();
+
+ div().child(
+ PopoverMenu::new("openai-compatible-api")
+ .menu(move |window, cx| {
+ this.update(cx, |this, cx| {
+ this.build_edit_prediction_context_menu(
+ EditPredictionProvider::OpenAiCompatibleApi,
+ window,
+ cx,
+ )
+ })
+ .ok()
+ })
+ .anchor(Corner::BottomRight)
+ .trigger(
+ IconButton::new("openai-compatible-api-icon", IconName::AiOpenAiCompat)
+ .shape(IconButtonShape::Square)
+ .when(!enabled, |this| {
+ this.indicator(Indicator::dot().color(Color::Ignored))
+ .indicator_border_color(Some(
+ cx.theme().colors().status_bar_background,
+ ))
+ }),
+ )
+ .with_handle(self.popover_menu_handle.clone()),
+ )
+ }
EditPredictionProvider::Ollama => {
let enabled = self.editor_enabled.unwrap_or(true);
let this = cx.weak_entity();
@@ -377,14 +405,9 @@ impl Render for EditPredictionButton {
}),
move |_window, cx| {
let settings = all_language_settings(None, cx);
- let tooltip_meta = match settings
- .edit_predictions
- .ollama
- .model
- .as_deref()
- {
- Some(model) if !model.trim().is_empty() => {
- format!("Powered by Ollama ({model})")
+ let tooltip_meta = match settings.edit_predictions.ollama.as_ref() {
+ Some(settings) if !settings.model.trim().is_empty() => {
+ format!("Powered by Ollama ({})", settings.model)
}
_ => {
"Ollama model not configured β configure a model before use"
@@ -1500,6 +1523,14 @@ pub fn get_available_providers(cx: &mut App) -> Vec<EditPredictionProvider> {
providers.push(EditPredictionProvider::Ollama);
}
+ if all_language_settings(None, cx)
+ .edit_predictions
+ .open_ai_compatible_api
+ .is_some()
+ {
+ providers.push(EditPredictionProvider::OpenAiCompatibleApi);
+ }
+
if edit_prediction::sweep_ai::sweep_api_token(cx)
.read(cx)
.has_key()
@@ -924,7 +924,7 @@ impl Render for RatePredictionsModal {
.flex_shrink_0()
.overflow_hidden()
.child({
- let icons = self.ep_store.read(cx).icons();
+ let icons = self.ep_store.read(cx).icons(cx);
h_flex()
.h_8()
.px_2()
@@ -12,9 +12,10 @@ use itertools::{Either, Itertools};
use settings::{DocumentFoldingRanges, DocumentSymbols, IntoGpui, SemanticTokens};
pub use settings::{
- CompletionSettingsContent, EditPredictionProvider, EditPredictionsMode, FormatOnSave,
- Formatter, FormatterList, InlayHintKind, LanguageSettingsContent, LspInsertMode,
- RewrapBehavior, ShowWhitespaceSetting, SoftWrap, WordsCompletionMode,
+ CompletionSettingsContent, EditPredictionPromptFormat, EditPredictionProvider,
+ EditPredictionsMode, FormatOnSave, Formatter, FormatterList, InlayHintKind,
+ LanguageSettingsContent, LspInsertMode, RewrapBehavior, ShowWhitespaceSetting, SoftWrap,
+ WordsCompletionMode,
};
use settings::{RegisterSetting, Settings, SettingsLocation, SettingsStore};
use shellexpand;
@@ -398,7 +399,8 @@ pub struct EditPredictionSettings {
/// Settings specific to Sweep.
pub sweep: SweepSettings,
/// Settings specific to Ollama.
- pub ollama: OllamaSettings,
+ pub ollama: Option<OpenAiCompatibleEditPredictionSettings>,
+ pub open_ai_compatible_api: Option<OpenAiCompatibleEditPredictionSettings>,
/// Whether edit predictions are enabled in the assistant panel.
/// This setting has no effect if globally disabled.
pub enabled_in_text_threads: bool,
@@ -457,13 +459,16 @@ pub struct SweepSettings {
}
#[derive(Clone, Debug, Default)]
-pub struct OllamaSettings {
+pub struct OpenAiCompatibleEditPredictionSettings {
/// Model to use for completions.
- pub model: Option<String>,
+ pub model: String,
/// Maximum tokens to generate.
pub max_output_tokens: u32,
/// Custom API URL to use for Ollama.
pub api_url: Arc<str>,
+ /// The prompt format to use for completions. When `None`, the format
+ /// will be derived from the model name at request time.
+ pub prompt_format: EditPredictionPromptFormat,
}
impl AllLanguageSettings {
@@ -700,11 +705,30 @@ impl settings::Settings for AllLanguageSettings {
privacy_mode: sweep.privacy_mode.unwrap(),
};
let ollama = edit_predictions.ollama.unwrap();
- let ollama_settings = OllamaSettings {
- model: ollama.model.map(|m| m.0),
- max_output_tokens: ollama.max_output_tokens.unwrap(),
- api_url: ollama.api_url.unwrap().into(),
- };
+ let ollama_settings = ollama
+ .model
+ .filter(|model| !model.0.is_empty())
+ .map(|model| OpenAiCompatibleEditPredictionSettings {
+ model: model.0,
+ max_output_tokens: ollama.max_output_tokens.unwrap(),
+ api_url: ollama.api_url.unwrap().into(),
+ prompt_format: ollama.prompt_format.unwrap(),
+ });
+ let openai_compatible_settings = edit_predictions.open_ai_compatible_api.unwrap();
+ let openai_compatible_settings = openai_compatible_settings
+ .model
+ .filter(|model| !model.is_empty())
+ .zip(
+ openai_compatible_settings
+ .api_url
+ .filter(|api_url| !api_url.is_empty()),
+ )
+ .map(|(model, api_url)| OpenAiCompatibleEditPredictionSettings {
+ model,
+ max_output_tokens: openai_compatible_settings.max_output_tokens.unwrap(),
+ api_url: api_url.into(),
+ prompt_format: openai_compatible_settings.prompt_format.unwrap(),
+ });
let enabled_in_text_threads = edit_predictions.enabled_in_text_threads.unwrap();
@@ -745,6 +769,7 @@ impl settings::Settings for AllLanguageSettings {
codestral: codestral_settings,
sweep: sweep_settings,
ollama: ollama_settings,
+ open_ai_compatible_api: openai_compatible_settings,
enabled_in_text_threads,
examples_dir: edit_predictions.examples_dir,
},
@@ -85,6 +85,7 @@ pub enum EditPredictionProvider {
Zed,
Codestral,
Ollama,
+ OpenAiCompatibleApi,
Sweep,
Mercury,
Experimental(&'static str),
@@ -106,6 +107,7 @@ impl<'de> Deserialize<'de> for EditPredictionProvider {
Zed,
Codestral,
Ollama,
+ OpenAiCompatibleApi,
Sweep,
Mercury,
Experimental(String),
@@ -118,6 +120,7 @@ impl<'de> Deserialize<'de> for EditPredictionProvider {
Content::Zed => EditPredictionProvider::Zed,
Content::Codestral => EditPredictionProvider::Codestral,
Content::Ollama => EditPredictionProvider::Ollama,
+ Content::OpenAiCompatibleApi => EditPredictionProvider::OpenAiCompatibleApi,
Content::Sweep => EditPredictionProvider::Sweep,
Content::Mercury => EditPredictionProvider::Mercury,
Content::Experimental(name)
@@ -146,6 +149,7 @@ impl EditPredictionProvider {
| EditPredictionProvider::Supermaven
| EditPredictionProvider::Codestral
| EditPredictionProvider::Ollama
+ | EditPredictionProvider::OpenAiCompatibleApi
| EditPredictionProvider::Sweep
| EditPredictionProvider::Mercury
| EditPredictionProvider::Experimental(_) => false,
@@ -165,6 +169,7 @@ impl EditPredictionProvider {
) => Some("Zeta2"),
EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => None,
EditPredictionProvider::Ollama => Some("Ollama"),
+ EditPredictionProvider::OpenAiCompatibleApi => Some("OpenAI-Compatible API"),
}
}
}
@@ -190,6 +195,8 @@ pub struct EditPredictionSettingsContent {
pub sweep: Option<SweepSettingsContent>,
/// Settings specific to Ollama.
pub ollama: Option<OllamaEditPredictionSettingsContent>,
+ /// Settings specific to using custom OpenAI-compatible servers for edit prediction.
+ pub open_ai_compatible_api: Option<CustomEditPredictionProviderSettingsContent>,
/// Whether edit predictions are enabled in the assistant prompt editor.
/// This has no effect if globally disabled.
pub enabled_in_text_threads: Option<bool>,
@@ -197,6 +204,56 @@ pub struct EditPredictionSettingsContent {
pub examples_dir: Option<Arc<Path>>,
}
+#[with_fallible_options]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, MergeFrom, PartialEq)]
+pub struct CustomEditPredictionProviderSettingsContent {
+ /// Api URL to use for completions.
+ ///
+ /// Default: ""
+ pub api_url: Option<String>,
+ /// The prompt format to use for completions. Set to `""` to have the format be derived from the model name.
+ ///
+ /// Default: ""
+ pub prompt_format: Option<EditPredictionPromptFormat>,
+ /// The name of the model.
+ ///
+ /// Default: ""
+ pub model: Option<String>,
+ /// Maximum tokens to generate for FIM models.
+ /// This setting does not apply to sweep models.
+ ///
+ /// Default: 256
+ pub max_output_tokens: Option<u32>,
+}
+
+#[derive(
+ Copy,
+ Clone,
+ Debug,
+ Default,
+ PartialEq,
+ Eq,
+ Serialize,
+ Deserialize,
+ JsonSchema,
+ MergeFrom,
+ strum::VariantArray,
+ strum::VariantNames,
+)]
+#[serde(rename_all = "snake_case")]
+pub enum EditPredictionPromptFormat {
+ #[default]
+ Infer,
+ Zeta,
+ CodeLlama,
+ StarCoder,
+ DeepseekCoder,
+ Qwen,
+ CodeGemma,
+ Codestral,
+ Glm,
+}
+
#[with_fallible_options]
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, MergeFrom, PartialEq)]
pub struct CopilotSettingsContent {
@@ -287,6 +344,11 @@ pub struct OllamaEditPredictionSettingsContent {
///
/// Default: "http://localhost:11434"
pub api_url: Option<String>,
+
+ /// The prompt format to use for completions. Set to `""` to have the format be derived from the model name.
+ ///
+ /// Default: ""
+ pub prompt_format: Option<EditPredictionPromptFormat>,
}
/// The mode in which edit predictions should be displayed.
@@ -64,7 +64,6 @@ pub(crate) fn render_edit_prediction_setup_page(
)
.into_any_element(),
),
- Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
Some(
render_api_key_provider(
IconName::AiMistral,
@@ -87,6 +86,8 @@ pub(crate) fn render_edit_prediction_setup_page(
)
.into_any_element(),
),
+ Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
+ Some(render_open_ai_compatible_provider(settings_window, window, cx).into_any_element()),
];
div()
@@ -420,6 +421,36 @@ fn ollama_settings() -> Box<[SettingsPageItem]> {
})),
files: USER,
}),
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "Prompt Format",
+ description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name",
+ field: Box::new(SettingField {
+ pick: |settings| {
+ settings
+ .project
+ .all_languages
+ .edit_predictions
+ .as_ref()?
+ .ollama
+ .as_ref()?
+ .prompt_format
+ .as_ref()
+ },
+ write: |settings, value| {
+ settings
+ .project
+ .all_languages
+ .edit_predictions
+ .get_or_insert_default()
+ .ollama
+ .get_or_insert_default()
+ .prompt_format = value;
+ },
+ json_path: Some("edit_predictions.ollama.prompt_format"),
+ }),
+ files: USER,
+ metadata: None,
+ }),
SettingsPageItem::SettingItem(SettingItem {
title: "Max Output Tokens",
description: "The maximum number of tokens to generate.",
@@ -453,6 +484,165 @@ fn ollama_settings() -> Box<[SettingsPageItem]> {
])
}
+fn render_open_ai_compatible_provider(
+ settings_window: &SettingsWindow,
+ window: &mut Window,
+ cx: &mut Context<SettingsWindow>,
+) -> impl IntoElement {
+ let open_ai_compatible_settings = open_ai_compatible_settings();
+ let additional_fields = settings_window
+ .render_sub_page_items_section(
+ open_ai_compatible_settings.iter().enumerate(),
+ true,
+ window,
+ cx,
+ )
+ .into_any_element();
+
+ v_flex()
+ .id("open-ai-compatible")
+ .min_w_0()
+ .pt_8()
+ .gap_1p5()
+ .child(
+ SettingsSectionHeader::new("OpenAI Compatible API")
+ .icon(IconName::AiOpenAiCompat)
+ .no_padding(true),
+ )
+ .child(div().px_neg_8().child(additional_fields))
+}
+
+fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
+ Box::new([
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "API URL",
+ description: "The base URL of your OpenAI-compatible server.",
+ field: Box::new(SettingField {
+ pick: |settings| {
+ settings
+ .project
+ .all_languages
+ .edit_predictions
+ .as_ref()?
+ .open_ai_compatible_api
+ .as_ref()?
+ .api_url
+ .as_ref()
+ },
+ write: |settings, value| {
+ settings
+ .project
+ .all_languages
+ .edit_predictions
+ .get_or_insert_default()
+ .open_ai_compatible_api
+ .get_or_insert_default()
+ .api_url = value;
+ },
+ json_path: Some("edit_predictions.open_ai_compatible_api.api_url"),
+ }),
+ metadata: Some(Box::new(SettingsFieldMetadata {
+ placeholder: Some(OLLAMA_API_URL_PLACEHOLDER),
+ ..Default::default()
+ })),
+ files: USER,
+ }),
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "Model",
+ description: "The model string to pass to the OpenAI-compatible server.",
+ field: Box::new(SettingField {
+ pick: |settings| {
+ settings
+ .project
+ .all_languages
+ .edit_predictions
+ .as_ref()?
+ .open_ai_compatible_api
+ .as_ref()?
+ .model
+ .as_ref()
+ },
+ write: |settings, value| {
+ settings
+ .project
+ .all_languages
+ .edit_predictions
+ .get_or_insert_default()
+ .open_ai_compatible_api
+ .get_or_insert_default()
+ .model = value;
+ },
+ json_path: Some("edit_predictions.open_ai_compatible_api.model"),
+ }),
+ metadata: Some(Box::new(SettingsFieldMetadata {
+ placeholder: Some(OLLAMA_MODEL_PLACEHOLDER),
+ ..Default::default()
+ })),
+ files: USER,
+ }),
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "Prompt Format",
+ description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name",
+ field: Box::new(SettingField {
+ pick: |settings| {
+ settings
+ .project
+ .all_languages
+ .edit_predictions
+ .as_ref()?
+ .open_ai_compatible_api
+ .as_ref()?
+ .prompt_format
+ .as_ref()
+ },
+ write: |settings, value| {
+ settings
+ .project
+ .all_languages
+ .edit_predictions
+ .get_or_insert_default()
+ .open_ai_compatible_api
+ .get_or_insert_default()
+ .prompt_format = value;
+ },
+ json_path: Some("edit_predictions.open_ai_compatible_api.prompt_format"),
+ }),
+ files: USER,
+ metadata: None,
+ }),
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "Max Output Tokens",
+ description: "The maximum number of tokens to generate.",
+ field: Box::new(SettingField {
+ pick: |settings| {
+ settings
+ .project
+ .all_languages
+ .edit_predictions
+ .as_ref()?
+ .open_ai_compatible_api
+ .as_ref()?
+ .max_output_tokens
+ .as_ref()
+ },
+ write: |settings, value| {
+ settings
+ .project
+ .all_languages
+ .edit_predictions
+ .get_or_insert_default()
+ .open_ai_compatible_api
+ .get_or_insert_default()
+ .max_output_tokens = value;
+ },
+ json_path: Some("edit_predictions.open_ai_compatible_api.max_output_tokens"),
+ }),
+ metadata: None,
+ files: USER,
+ }),
+ ])
+}
+
fn codestral_settings() -> Box<[SettingsPageItem]> {
Box::new([
SettingsPageItem::SettingItem(SettingItem {
@@ -505,6 +505,7 @@ fn init_renderers(cx: &mut App) {
.add_basic_renderer::<settings::AlternateScroll>(render_dropdown)
.add_basic_renderer::<settings::TerminalBlink>(render_dropdown)
.add_basic_renderer::<settings::CursorShapeContent>(render_dropdown)
+ .add_basic_renderer::<settings::EditPredictionPromptFormat>(render_dropdown)
.add_basic_renderer::<f32>(render_number_field)
.add_basic_renderer::<u32>(render_number_field)
.add_basic_renderer::<u64>(render_number_field)
@@ -2,13 +2,15 @@ use client::{Client, UserStore};
use codestral::{CodestralEditPredictionDelegate, load_codestral_api_key};
use collections::HashMap;
use copilot::CopilotEditPredictionDelegate;
-use edit_prediction::{ZedEditPredictionDelegate, Zeta2FeatureFlag};
+use edit_prediction::{EditPredictionModel, ZedEditPredictionDelegate, Zeta2FeatureFlag};
use editor::Editor;
use feature_flags::FeatureFlagAppExt;
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings};
-use settings::{EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore};
+use settings::{
+ EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, EditPredictionPromptFormat, SettingsStore,
+};
use std::{cell::RefCell, rc::Rc, sync::Arc};
use supermaven::{Supermaven, SupermavenEditPredictionDelegate};
use ui::Window;
@@ -43,10 +45,10 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
editors
.borrow_mut()
.insert(editor_handle, window.window_handle());
- let provider = all_language_settings(None, cx).edit_predictions.provider;
+ let provider_config = edit_prediction_provider_config_for_settings(cx);
assign_edit_prediction_provider(
editor,
- provider,
+ provider_config,
&client,
user_store.clone(),
window,
@@ -58,14 +60,20 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
cx.on_action(clear_edit_prediction_store_edit_history);
- let mut provider = all_language_settings(None, cx).edit_predictions.provider;
+ let mut provider_config = edit_prediction_provider_config_for_settings(cx);
cx.subscribe(&user_store, {
let editors = editors.clone();
let client = client.clone();
move |user_store, event, cx| {
if let client::user::Event::PrivateUserInfoUpdated = event {
- assign_edit_prediction_providers(&editors, provider, &client, user_store, cx);
+ assign_edit_prediction_providers(
+ &editors,
+ provider_config,
+ &client,
+ user_store,
+ cx,
+ );
}
}
})
@@ -74,19 +82,19 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
cx.observe_global::<SettingsStore>({
let user_store = user_store.clone();
move |cx| {
- let new_provider = all_language_settings(None, cx).edit_predictions.provider;
+ let new_provider_config = edit_prediction_provider_config_for_settings(cx);
- if new_provider != provider {
+ if new_provider_config != provider_config {
telemetry::event!(
"Edit Prediction Provider Changed",
- from = provider,
- to = new_provider,
+ from = provider_config.map(|config| config.name()),
+ to = new_provider_config.map(|config| config.name())
);
- provider = new_provider;
+ provider_config = new_provider_config;
assign_edit_prediction_providers(
&editors,
- provider,
+ new_provider_config,
&client,
user_store.clone(),
cx,
@@ -97,6 +105,106 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
.detach();
}
+fn edit_prediction_provider_config_for_settings(cx: &App) -> Option<EditPredictionProviderConfig> {
+ let settings = &all_language_settings(None, cx).edit_predictions;
+ let provider = settings.provider;
+ match provider {
+ EditPredictionProvider::None => None,
+ EditPredictionProvider::Copilot => Some(EditPredictionProviderConfig::Copilot),
+ EditPredictionProvider::Supermaven => Some(EditPredictionProviderConfig::Supermaven),
+ EditPredictionProvider::Zed => Some(EditPredictionProviderConfig::Zed(
+ EditPredictionModel::Zeta1,
+ )),
+ EditPredictionProvider::Codestral => Some(EditPredictionProviderConfig::Codestral),
+ EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi => {
+ let custom_settings = if provider == EditPredictionProvider::Ollama {
+ settings.ollama.as_ref()?
+ } else {
+ settings.open_ai_compatible_api.as_ref()?
+ };
+
+ let mut format = custom_settings.prompt_format;
+ if format == EditPredictionPromptFormat::Infer {
+ if let Some(inferred_format) = infer_prompt_format(&custom_settings.model) {
+ format = inferred_format;
+ } else {
+ // todo: notify user that prompt format inference failed
+ return None;
+ }
+ }
+
+ if format == EditPredictionPromptFormat::Zeta {
+ Some(EditPredictionProviderConfig::Zed(
+ EditPredictionModel::Zeta1,
+ ))
+ } else {
+ Some(EditPredictionProviderConfig::Zed(
+ EditPredictionModel::Fim { format },
+ ))
+ }
+ }
+ EditPredictionProvider::Sweep => Some(EditPredictionProviderConfig::Zed(
+ EditPredictionModel::Sweep,
+ )),
+ EditPredictionProvider::Mercury => Some(EditPredictionProviderConfig::Zed(
+ EditPredictionModel::Mercury,
+ )),
+ EditPredictionProvider::Experimental(name) => {
+ if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME
+ && cx.has_flag::<Zeta2FeatureFlag>()
+ {
+ Some(EditPredictionProviderConfig::Zed(
+ EditPredictionModel::Zeta2,
+ ))
+ } else {
+ None
+ }
+ }
+ }
+}
+
+fn infer_prompt_format(model: &str) -> Option<EditPredictionPromptFormat> {
+ let model_base = model.split(':').next().unwrap_or(model);
+
+ Some(match model_base {
+ "codellama" | "code-llama" => EditPredictionPromptFormat::CodeLlama,
+ "starcoder" | "starcoder2" | "starcoderbase" => EditPredictionPromptFormat::StarCoder,
+ "deepseek-coder" | "deepseek-coder-v2" => EditPredictionPromptFormat::DeepseekCoder,
+ "qwen2.5-coder" | "qwen-coder" | "qwen" => EditPredictionPromptFormat::Qwen,
+ "codegemma" => EditPredictionPromptFormat::CodeGemma,
+ "codestral" | "mistral" => EditPredictionPromptFormat::Codestral,
+ "glm" | "glm-4" | "glm-4.5" => EditPredictionPromptFormat::Glm,
+ _ => {
+ return None;
+ }
+ })
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+enum EditPredictionProviderConfig {
+ Copilot,
+ Supermaven,
+ Codestral,
+ Zed(EditPredictionModel),
+}
+
+impl EditPredictionProviderConfig {
+ fn name(&self) -> &'static str {
+ match self {
+ EditPredictionProviderConfig::Copilot => "Copilot",
+ EditPredictionProviderConfig::Supermaven => "Supermaven",
+ EditPredictionProviderConfig::Codestral => "Codestral",
+ EditPredictionProviderConfig::Zed(model) => match model {
+ EditPredictionModel::Zeta1 => "Zeta1",
+ EditPredictionModel::Zeta2 => "Zeta2",
+ EditPredictionModel::Fim { .. } => "FIM",
+ EditPredictionModel::Sweep => "Sweep",
+ EditPredictionModel::Mercury => "Mercury",
+ },
+ }
+ }
+}
+
fn clear_edit_prediction_store_edit_history(_: &edit_prediction::ClearHistory, cx: &mut App) {
if let Some(ep_store) = edit_prediction::EditPredictionStore::try_global(cx) {
ep_store.update(cx, |ep_store, _| ep_store.clear_history());
@@ -105,12 +213,12 @@ fn clear_edit_prediction_store_edit_history(_: &edit_prediction::ClearHistory, c
fn assign_edit_prediction_providers(
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
- provider: EditPredictionProvider,
+ provider_config: Option<EditPredictionProviderConfig>,
client: &Arc<Client>,
user_store: Entity<UserStore>,
cx: &mut App,
) {
- if provider == EditPredictionProvider::Codestral {
+ if provider_config == Some(EditPredictionProviderConfig::Codestral) {
load_codestral_api_key(cx).detach();
}
for (editor, window) in editors.borrow().iter() {
@@ -118,7 +226,7 @@ fn assign_edit_prediction_providers(
_ = editor.update(cx, |editor, cx| {
assign_edit_prediction_provider(
editor,
- provider,
+ provider_config,
client,
user_store.clone(),
window,
@@ -144,7 +252,7 @@ fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context<Ed
fn assign_edit_prediction_provider(
editor: &mut Editor,
- provider: EditPredictionProvider,
+ provider_config: Option<EditPredictionProviderConfig>,
client: &Arc<Client>,
user_store: Entity<UserStore>,
window: &mut Window,
@@ -153,11 +261,11 @@ fn assign_edit_prediction_provider(
// TODO: Do we really want to collect data only for singleton buffers?
let singleton_buffer = editor.buffer().read(cx).as_singleton();
- match provider {
- EditPredictionProvider::None => {
+ match provider_config {
+ None => {
editor.set_edit_prediction_provider::<ZedEditPredictionDelegate>(None, window, cx);
}
- EditPredictionProvider::Copilot => {
+ Some(EditPredictionProviderConfig::Copilot) => {
let ep_store = edit_prediction::EditPredictionStore::global(client, &user_store, cx);
let Some(project) = editor.project().cloned() else {
return;
@@ -177,53 +285,22 @@ fn assign_edit_prediction_provider(
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
- EditPredictionProvider::Supermaven => {
+ Some(EditPredictionProviderConfig::Supermaven) => {
if let Some(supermaven) = Supermaven::global(cx) {
let provider = cx.new(|_| SupermavenEditPredictionDelegate::new(supermaven));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
- EditPredictionProvider::Codestral => {
+ Some(EditPredictionProviderConfig::Codestral) => {
let http_client = client.http_client();
let provider = cx.new(|_| CodestralEditPredictionDelegate::new(http_client));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
- value @ (EditPredictionProvider::Experimental(_)
- | EditPredictionProvider::Zed
- | EditPredictionProvider::Ollama
- | EditPredictionProvider::Sweep
- | EditPredictionProvider::Mercury) => {
+ Some(EditPredictionProviderConfig::Zed(model)) => {
let ep_store = edit_prediction::EditPredictionStore::global(client, &user_store, cx);
if let Some(project) = editor.project() {
let has_model = ep_store.update(cx, |ep_store, cx| {
- let model = match value {
- EditPredictionProvider::Sweep => {
- edit_prediction::EditPredictionModel::Sweep
- }
- EditPredictionProvider::Mercury => {
- edit_prediction::EditPredictionModel::Mercury
- }
- EditPredictionProvider::Ollama => {
- if !edit_prediction::ollama::is_available(cx) {
- return false;
- }
- edit_prediction::EditPredictionModel::Ollama
- }
- EditPredictionProvider::Experimental(name)
- if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME
- && cx.has_flag::<Zeta2FeatureFlag>() =>
- {
- edit_prediction::EditPredictionModel::Zeta2
- }
- EditPredictionProvider::Zed
- if user_store.read(cx).current_user().is_some() =>
- {
- edit_prediction::EditPredictionModel::Zeta1
- }
- _ => return false,
- };
-
ep_store.set_edit_prediction_model(model);
if let Some(buffer) = &singleton_buffer {
ep_store.register_buffer(buffer, project, cx);
@@ -18,8 +18,8 @@ Once signed in, predictions appear as you type.
You can confirm that Zeta is properly configured either by verifying whether you have the following code in your settings file:
```json [settings]
-"features": {
- "edit_prediction_provider": "zed"
+"edit_predictions": {
+ "provider": "zed"
},
```
@@ -350,8 +350,8 @@ After adding your API key, Sweep will appear in the provider dropdown in the sta
```json [settings]
{
- "features": {
- "edit_prediction_provider": "sweep"
+ "edit_predictions": {
+ "provider": "sweep"
}
}
```
@@ -400,6 +400,16 @@ After adding your API key, Codestral will appear in the provider dropdown in the
}
```
+### Self-Hosted OpenAI-compatible servers
+
+To configure Zed to use an arbitrary server for edit predictions:
+
+1. Open the Settings Editor (`Cmd+,` on macOS, `Ctrl+,` on Linux/Windows)
+2. Search for "Edit Predictions" and click **Configure Providers**
+3. Find the "OpenAI-compatible API" section and enter the URL and model name. You can also select a prompt format that Zed should use. Zed currently supports several FIM prompt formats, as well as Zed's own Zeta prompt format. If you do not select a prompt format, Zed will attempt to infer it from the model name.
+
+The URL must accept requests according to OpenAI's [Completions API](https://developers.openai.com/api/reference/resources/completions/methods/create)
+
## See also
- [Agent Panel](./agent-panel.md): Agentic editing with file read/write and terminal access