WIP: Initial Ollama edit prediction provider implementation

Oleksiy Syvokon created

Change summary

Cargo.lock                                              |   7 
assets/settings/default.json                            |   4 
crates/agent_ui/src/agent_ui.rs                         |   1 
crates/edit_prediction_ui/src/edit_prediction_button.rs |  52 +
crates/language/src/language_settings.rs                |  17 
crates/ollama/Cargo.toml                                |   6 
crates/ollama/src/ollama.rs                             |   4 
crates/ollama/src/ollama_edit_prediction_delegate.rs    | 376 +++++++++++
crates/settings/src/settings_content/language.rs        |  19 
crates/zed/Cargo.toml                                   |   1 
crates/zed/src/zed/edit_prediction_registry.rs          |   6 
11 files changed, 493 insertions(+)

Detailed changes

Cargo.lock 🔗

@@ -10880,12 +10880,18 @@ name = "ollama"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "edit_prediction_context",
+ "edit_prediction_types",
  "futures 0.3.31",
+ "gpui",
  "http_client",
+ "language",
+ "log",
  "schemars",
  "serde",
  "serde_json",
  "settings",
+ "text",
 ]
 
 [[package]]
@@ -20697,6 +20703,7 @@ dependencies = [
  "nc",
  "node_runtime",
  "notifications",
+ "ollama",
  "onboarding",
  "outline",
  "outline_panel",

assets/settings/default.json 🔗

@@ -1422,6 +1422,10 @@
       "model": "codestral-latest",
       "max_tokens": 150,
     },
+    "ollama": {
+      "api_url": "http://localhost:11434",
+      "model": "qwen2.5-coder:3b-base",
+    },
     // Whether edit predictions are enabled when editing text threads in the agent panel.
     // This setting has no effect if globally disabled.
     "enabled_in_text_threads": true,

crates/agent_ui/src/agent_ui.rs 🔗

@@ -321,6 +321,7 @@ fn update_command_palette_filter(cx: &mut App) {
                 }
                 EditPredictionProvider::Zed
                 | EditPredictionProvider::Codestral
+                | EditPredictionProvider::Ollama
                 | EditPredictionProvider::Experimental(_) => {
                     filter.show_namespace("edit_prediction");
                     filter.hide_namespace("copilot");

crates/edit_prediction_ui/src/edit_prediction_button.rs 🔗

@@ -293,6 +293,40 @@ impl Render for EditPredictionButton {
                         .with_handle(self.popover_menu_handle.clone()),
                 )
             }
+            EditPredictionProvider::Ollama => {
+                let enabled = self.editor_enabled.unwrap_or(true);
+                let this = cx.weak_entity();
+
+                let tooltip_meta = "Powered by Ollama";
+
+                div().child(
+                    PopoverMenu::new("ollama")
+                        .menu(move |window, cx| {
+                            this.update(cx, |this, cx| this.build_ollama_context_menu(window, cx))
+                                .ok()
+                        })
+                        .anchor(Corner::BottomRight)
+                        .trigger_with_tooltip(
+                            IconButton::new("ollama-icon", IconName::ZedPredict)
+                                .shape(IconButtonShape::Square)
+                                .when(!enabled, |this| {
+                                    this.indicator(Indicator::dot().color(Color::Ignored))
+                                        .indicator_border_color(Some(
+                                            cx.theme().colors().status_bar_background,
+                                        ))
+                                }),
+                            move |_window, cx| {
+                                Tooltip::with_meta(
+                                    "Edit Prediction",
+                                    Some(&ToggleMenu),
+                                    tooltip_meta,
+                                    cx,
+                                )
+                            },
+                        )
+                        .with_handle(self.popover_menu_handle.clone()),
+                )
+            }
             provider @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
                 let enabled = self.editor_enabled.unwrap_or(true);
 
@@ -547,6 +581,9 @@ impl EditPredictionButton {
             providers.push(EditPredictionProvider::Codestral);
         }
 
+        // Ollama is always available as it runs locally
+        providers.push(EditPredictionProvider::Ollama);
+
         if cx.has_flag::<SweepFeatureFlag>()
             && edit_prediction::sweep_ai::sweep_api_token(cx)
                 .read(cx)
@@ -595,6 +632,7 @@ impl EditPredictionButton {
                     EditPredictionProvider::Copilot => "GitHub Copilot",
                     EditPredictionProvider::Supermaven => "Supermaven",
                     EditPredictionProvider::Codestral => "Codestral",
+                    EditPredictionProvider::Ollama => "Ollama",
                     EditPredictionProvider::Experimental(
                         EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
                     ) => "Sweep",
@@ -985,6 +1023,20 @@ impl EditPredictionButton {
         })
     }
 
+    fn build_ollama_context_menu(
+        &self,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) -> Entity<ContextMenu> {
+        ContextMenu::build(window, cx, |menu, window, cx| {
+            let menu = self.build_language_settings_menu(menu, window, cx);
+            let menu =
+                self.add_provider_switching_section(menu, EditPredictionProvider::Ollama, cx);
+
+            menu
+        })
+    }
+
     fn build_edit_prediction_context_menu(
         &self,
         provider: EditPredictionProvider,

crates/language/src/language_settings.rs 🔗

@@ -385,6 +385,8 @@ pub struct EditPredictionSettings {
     pub copilot: CopilotSettings,
     /// Settings specific to Codestral.
     pub codestral: CodestralSettings,
+    /// Settings specific to Ollama.
+    pub ollama: OllamaSettings,
     /// Whether edit predictions are enabled in the assistant panel.
     /// This setting has no effect if globally disabled.
     pub enabled_in_text_threads: bool,
@@ -430,6 +432,14 @@ pub struct CodestralSettings {
     pub api_url: Option<String>,
 }
 
+#[derive(Clone, Debug, Default)]
+pub struct OllamaSettings {
+    /// Model to use for completions.
+    pub model: Option<String>,
+    /// Custom API URL to use for Ollama.
+    pub api_url: Option<String>,
+}
+
 impl AllLanguageSettings {
     /// Returns the [`LanguageSettings`] for the language with the specified name.
     pub fn language<'a>(
@@ -654,6 +664,12 @@ impl settings::Settings for AllLanguageSettings {
             api_url: codestral.api_url,
         };
 
+        let ollama = edit_predictions.ollama.unwrap();
+        let ollama_settings = OllamaSettings {
+            model: ollama.model,
+            api_url: ollama.api_url,
+        };
+
         let enabled_in_text_threads = edit_predictions.enabled_in_text_threads.unwrap();
 
         let mut file_types: FxHashMap<Arc<str>, (GlobSet, Vec<String>)> = FxHashMap::default();
@@ -692,6 +708,7 @@ impl settings::Settings for AllLanguageSettings {
                 mode: edit_predictions_mode,
                 copilot: copilot_settings,
                 codestral: codestral_settings,
+                ollama: ollama_settings,
                 enabled_in_text_threads,
             },
             defaults: default_language_settings,

crates/ollama/Cargo.toml 🔗

@@ -17,9 +17,15 @@ schemars = ["dep:schemars"]
 
 [dependencies]
 anyhow.workspace = true
+edit_prediction_context.workspace = true
+edit_prediction_types.workspace = true
 futures.workspace = true
+gpui.workspace = true
 http_client.workspace = true
+language.workspace = true
+log.workspace = true
 schemars = { workspace = true, optional = true }
 serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true
+text.workspace = true

crates/ollama/src/ollama.rs 🔗

@@ -1,3 +1,7 @@
+mod ollama_edit_prediction_delegate;
+
+pub use ollama_edit_prediction_delegate::OllamaEditPredictionDelegate;
+
 use anyhow::{Context as _, Result};
 use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
 use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};

crates/ollama/src/ollama_edit_prediction_delegate.rs 🔗

@@ -0,0 +1,376 @@
+use anyhow::{Context as _, Result};
+use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
+use edit_prediction_types::{EditPrediction, EditPredictionDelegate};
+use futures::AsyncReadExt;
+use gpui::{App, Context, Entity, Task};
+use http_client::HttpClient;
+use language::{
+    Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint, language_settings::all_language_settings,
+};
+use serde::{Deserialize, Serialize};
+use std::{
+    ops::Range,
+    sync::Arc,
+    time::{Duration, Instant},
+};
+use text::ToOffset;
+
+use crate::OLLAMA_API_URL;
+
+pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
+
+const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
+    max_bytes: 1050,
+    min_bytes: 525,
+    target_before_cursor_over_total_bytes: 0.66,
+};
+
+#[derive(Clone)]
+struct CurrentCompletion {
+    snapshot: BufferSnapshot,
+    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
+    edit_preview: EditPreview,
+}
+
+impl CurrentCompletion {
+    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
+        edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
+    }
+}
+
+pub struct OllamaEditPredictionDelegate {
+    http_client: Arc<dyn HttpClient>,
+    pending_request: Option<Task<Result<()>>>,
+    current_completion: Option<CurrentCompletion>,
+}
+
+impl OllamaEditPredictionDelegate {
+    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
+        Self {
+            http_client,
+            pending_request: None,
+            current_completion: None,
+        }
+    }
+
+    async fn fetch_completion(
+        http_client: Arc<dyn HttpClient>,
+        prompt: String,
+        suffix: String,
+        model: String,
+        api_url: String,
+    ) -> Result<String> {
+        let start_time = Instant::now();
+
+        log::debug!("Ollama: Requesting completion (model: {})", model);
+
+        let fim_prompt = format_fim_prompt(&model, &prompt, &suffix);
+
+        let request = OllamaGenerateRequest {
+            model,
+            prompt: fim_prompt,
+            raw: true,
+            stream: false,
+            options: Some(OllamaGenerateOptions {
+                num_predict: Some(256),
+                temperature: Some(0.2),
+                stop: Some(get_stop_tokens()),
+            }),
+        };
+
+        let request_body = serde_json::to_string(&request)?;
+
+        log::debug!("Ollama: Sending FIM request");
+
+        let http_request = http_client::Request::builder()
+            .method(http_client::Method::POST)
+            .uri(format!("{}/api/generate", api_url))
+            .header("Content-Type", "application/json")
+            .body(http_client::AsyncBody::from(request_body))?;
+
+        let mut response = http_client.send(http_request).await?;
+        let status = response.status();
+
+        log::debug!("Ollama: Response status: {}", status);
+
+        if !status.is_success() {
+            let mut body = String::new();
+            response.body_mut().read_to_string(&mut body).await?;
+            return Err(anyhow::anyhow!("Ollama API error: {} - {}", status, body));
+        }
+
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        let ollama_response: OllamaGenerateResponse =
+            serde_json::from_str(&body).context("Failed to parse Ollama response")?;
+
+        let elapsed = start_time.elapsed();
+
+        log::debug!(
+            "Ollama: Completion received ({:.2}s)",
+            elapsed.as_secs_f64()
+        );
+
+        let completion = clean_completion(&ollama_response.response);
+        Ok(completion)
+    }
+}
+
+impl EditPredictionDelegate for OllamaEditPredictionDelegate {
+    fn name() -> &'static str {
+        "ollama"
+    }
+
+    fn display_name() -> &'static str {
+        "Ollama"
+    }
+
+    fn show_predictions_in_menu() -> bool {
+        true
+    }
+
+    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, _cx: &App) -> bool {
+        true
+    }
+
+    fn is_refreshing(&self, _cx: &App) -> bool {
+        self.pending_request.is_some()
+    }
+
+    fn refresh(
+        &mut self,
+        buffer: Entity<Buffer>,
+        cursor_position: Anchor,
+        debounce: bool,
+        cx: &mut Context<Self>,
+    ) {
+        log::debug!("Ollama: Refresh called (debounce: {})", debounce);
+
+        let snapshot = buffer.read(cx).snapshot();
+
+        if let Some(current_completion) = self.current_completion.as_ref() {
+            if current_completion.interpolate(&snapshot).is_some() {
+                return;
+            }
+        }
+
+        let http_client = self.http_client.clone();
+
+        let settings = all_language_settings(None, cx);
+        let model = settings
+            .edit_predictions
+            .ollama
+            .model
+            .clone()
+            .unwrap_or_else(|| "qwen2.5-coder:1.5b".to_string());
+        let api_url = settings
+            .edit_predictions
+            .ollama
+            .api_url
+            .clone()
+            .unwrap_or_else(|| OLLAMA_API_URL.to_string());
+
+        self.pending_request = Some(cx.spawn(async move |this, cx| {
+            if debounce {
+                log::debug!("Ollama: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
+                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
+            }
+
+            let cursor_offset = cursor_position.to_offset(&snapshot);
+            let cursor_point = cursor_offset.to_point(&snapshot);
+            let excerpt = EditPredictionExcerpt::select_from_buffer(
+                cursor_point,
+                &snapshot,
+                &EXCERPT_OPTIONS,
+            )
+            .context("Line containing cursor doesn't fit in excerpt max bytes")?;
+
+            let excerpt_text = excerpt.text(&snapshot);
+            let cursor_within_excerpt = cursor_offset
+                .saturating_sub(excerpt.range.start)
+                .min(excerpt_text.body.len());
+            let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
+            let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
+
+            let completion_text =
+                match Self::fetch_completion(http_client, prompt, suffix, model, api_url).await {
+                    Ok(completion) => completion,
+                    Err(e) => {
+                        log::error!("Ollama: Failed to fetch completion: {}", e);
+                        this.update(cx, |this, cx| {
+                            this.pending_request = None;
+                            cx.notify();
+                        })?;
+                        return Err(e);
+                    }
+                };
+
+            if completion_text.trim().is_empty() {
+                log::debug!("Ollama: Completion was empty after trimming; ignoring");
+                this.update(cx, |this, cx| {
+                    this.pending_request = None;
+                    cx.notify();
+                })?;
+                return Ok(());
+            }
+
+            let edits: Arc<[(Range<Anchor>, Arc<str>)]> = buffer.read_with(cx, |buffer, _cx| {
+                // Use anchor_after (Right bias) so the cursor stays before the completion text,
+                // not at the end of it. This matches how Copilot handles edit predictions.
+                let position = buffer.anchor_after(cursor_offset);
+                vec![(position..position, completion_text.into())].into()
+            })?;
+            let edit_preview = buffer
+                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
+                .await;
+
+            this.update(cx, |this, cx| {
+                this.current_completion = Some(CurrentCompletion {
+                    snapshot,
+                    edits,
+                    edit_preview,
+                });
+                this.pending_request = None;
+                cx.notify();
+            })?;
+
+            Ok(())
+        }));
+    }
+
+    fn accept(&mut self, _cx: &mut Context<Self>) {
+        log::debug!("Ollama: Completion accepted");
+        self.pending_request = None;
+        self.current_completion = None;
+    }
+
+    fn discard(&mut self, _cx: &mut Context<Self>) {
+        log::debug!("Ollama: Completion discarded");
+        self.pending_request = None;
+        self.current_completion = None;
+    }
+
+    fn suggest(
+        &mut self,
+        buffer: &Entity<Buffer>,
+        _cursor_position: Anchor,
+        cx: &mut Context<Self>,
+    ) -> Option<EditPrediction> {
+        let current_completion = self.current_completion.as_ref()?;
+        let buffer = buffer.read(cx);
+        let edits = current_completion.interpolate(&buffer.snapshot())?;
+        if edits.is_empty() {
+            return None;
+        }
+        Some(EditPrediction::Local {
+            id: None,
+            edits,
+            edit_preview: Some(current_completion.edit_preview.clone()),
+        })
+    }
+}
+
+fn format_fim_prompt(model: &str, prefix: &str, suffix: &str) -> String {
+    let model_base = model.split(':').next().unwrap_or(model);
+
+    match model_base {
+        "codellama" | "code-llama" => {
+            format!("<PRE> {prefix} <SUF>{suffix} <MID>")
+        }
+        "starcoder" | "starcoder2" | "starcoderbase" => {
+            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
+        }
+        "deepseek-coder" | "deepseek-coder-v2" => {
+            // DeepSeek uses special Unicode characters for FIM tokens
+            format!(
+                "<\u{ff5c}fim\u{2581}begin\u{ff5c}>{prefix}<\u{ff5c}fim\u{2581}hole\u{ff5c}>{suffix}<\u{ff5c}fim\u{2581}end\u{ff5c}>"
+            )
+        }
+        "qwen2.5-coder" | "qwen-coder" => {
+            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
+        }
+        "codegemma" => {
+            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
+        }
+        "codestral" | "mistral" => {
+            format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
+        }
+        _ => {
+            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
+        }
+    }
+}
+
+fn get_stop_tokens() -> Vec<String> {
+    vec![
+        "<|endoftext|>".to_string(),
+        "<|file_separator|>".to_string(),
+        "<|fim_pad|>".to_string(),
+        "<|fim_prefix|>".to_string(),
+        "<|fim_middle|>".to_string(),
+        "<|fim_suffix|>".to_string(),
+        "<fim_prefix>".to_string(),
+        "<fim_middle>".to_string(),
+        "<fim_suffix>".to_string(),
+        "<PRE>".to_string(),
+        "<SUF>".to_string(),
+        "<MID>".to_string(),
+        "[PREFIX]".to_string(),
+        "[SUFFIX]".to_string(),
+    ]
+}
+
+fn clean_completion(response: &str) -> String {
+    let mut result = response.to_string();
+
+    let end_tokens = [
+        "<|endoftext|>",
+        "<|file_separator|>",
+        "<|fim_pad|>",
+        "<|fim_prefix|>",
+        "<|fim_middle|>",
+        "<|fim_suffix|>",
+        "<fim_prefix>",
+        "<fim_middle>",
+        "<fim_suffix>",
+        "<PRE>",
+        "<SUF>",
+        "<MID>",
+        "[PREFIX]",
+        "[SUFFIX]",
+    ];
+
+    for token in &end_tokens {
+        if let Some(pos) = result.find(token) {
+            result.truncate(pos);
+        }
+    }
+
+    result
+}
+
+#[derive(Debug, Serialize)]
+struct OllamaGenerateRequest {
+    model: String,
+    prompt: String,
+    raw: bool,
+    stream: bool,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    options: Option<OllamaGenerateOptions>,
+}
+
+#[derive(Debug, Serialize)]
+struct OllamaGenerateOptions {
+    #[serde(skip_serializing_if = "Option::is_none")]
+    num_predict: Option<u32>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    temperature: Option<f32>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    stop: Option<Vec<String>>,
+}
+
+#[derive(Debug, Deserialize)]
+struct OllamaGenerateResponse {
+    response: String,
+}

crates/settings/src/settings_content/language.rs 🔗

@@ -76,6 +76,7 @@ pub enum EditPredictionProvider {
     Supermaven,
     Zed,
     Codestral,
+    Ollama,
     Experimental(&'static str),
 }
 
@@ -96,6 +97,7 @@ impl<'de> Deserialize<'de> for EditPredictionProvider {
             Supermaven,
             Zed,
             Codestral,
+            Ollama,
             Experimental(String),
         }
 
@@ -105,6 +107,7 @@ impl<'de> Deserialize<'de> for EditPredictionProvider {
             Content::Supermaven => EditPredictionProvider::Supermaven,
             Content::Zed => EditPredictionProvider::Zed,
             Content::Codestral => EditPredictionProvider::Codestral,
+            Content::Ollama => EditPredictionProvider::Ollama,
             Content::Experimental(name)
                 if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME =>
             {
@@ -144,6 +147,7 @@ impl EditPredictionProvider {
             | EditPredictionProvider::Copilot
             | EditPredictionProvider::Supermaven
             | EditPredictionProvider::Codestral
+            | EditPredictionProvider::Ollama
             | EditPredictionProvider::Experimental(_) => false,
         }
     }
@@ -164,6 +168,8 @@ pub struct EditPredictionSettingsContent {
     pub copilot: Option<CopilotSettingsContent>,
     /// Settings specific to Codestral.
     pub codestral: Option<CodestralSettingsContent>,
+    /// Settings specific to Ollama.
+    pub ollama: Option<OllamaEditPredictionSettingsContent>,
     /// Whether edit predictions are enabled in the assistant prompt editor.
     /// This has no effect if globally disabled.
     pub enabled_in_text_threads: Option<bool>,
@@ -203,6 +209,19 @@ pub struct CodestralSettingsContent {
     pub api_url: Option<String>,
 }
 
+#[with_fallible_options]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, MergeFrom, PartialEq)]
+pub struct OllamaEditPredictionSettingsContent {
+    /// Model to use for completions.
+    ///
+    /// Default: "qwen2.5-coder:1.5b"
+    pub model: Option<String>,
+    /// Api URL to use for completions.
+    ///
+    /// Default: "http://localhost:11434"
+    pub api_url: Option<String>,
+}
+
 /// The mode in which edit predictions should be displayed.
 #[derive(
     Copy,

crates/zed/Cargo.toml 🔗

@@ -42,6 +42,7 @@ cli.workspace = true
 client.workspace = true
 codestral.workspace = true
 collab_ui.workspace = true
+ollama.workspace = true
 collections.workspace = true
 command_palette.workspace = true
 component.workspace = true

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -8,6 +8,7 @@ use feature_flags::FeatureFlagAppExt;
 use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
 use language::language_settings::{EditPredictionProvider, all_language_settings};
 use language_models::MistralLanguageModelProvider;
+use ollama::OllamaEditPredictionDelegate;
 use settings::{
     EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
     EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
@@ -186,6 +187,11 @@ fn assign_edit_prediction_provider(
             let provider = cx.new(|_| CodestralEditPredictionDelegate::new(http_client));
             editor.set_edit_prediction_provider(Some(provider), window, cx);
         }
+        EditPredictionProvider::Ollama => {
+            let http_client = client.http_client();
+            let provider = cx.new(|_| OllamaEditPredictionDelegate::new(http_client));
+            editor.set_edit_prediction_provider(Some(provider), window, cx);
+        }
         value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
             let ep_store = edit_prediction::EditPredictionStore::global(client, &user_store, cx);