ollama_edit_prediction_delegate.rs

  1use anyhow::{Context as _, Result};
  2use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
  3use edit_prediction_types::{EditPrediction, EditPredictionDelegate};
  4use futures::AsyncReadExt;
  5use gpui::{App, Context, Entity, Task};
  6use http_client::HttpClient;
  7use language::{
  8    Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint, language_settings::all_language_settings,
  9};
 10use language_model::{LanguageModelProviderId, LanguageModelRegistry};
 11use serde::{Deserialize, Serialize};
 12use std::{
 13    ops::Range,
 14    sync::Arc,
 15    time::{Duration, Instant},
 16};
 17use text::ToOffset;
 18
 19use crate::{OLLAMA_API_URL, get_models};
 20
 21pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
 22
 23const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 24    max_bytes: 1050,
 25    min_bytes: 525,
 26    target_before_cursor_over_total_bytes: 0.66,
 27};
 28
 29pub const RECOMMENDED_EDIT_PREDICTION_MODELS: [&str; 4] = [
 30    "qwen2.5-coder:3b-base",
 31    "qwen2.5-coder:3b",
 32    "qwen2.5-coder:7b-base",
 33    "qwen2.5-coder:7b",
 34];
 35
 36#[derive(Clone)]
 37struct CurrentCompletion {
 38    snapshot: BufferSnapshot,
 39    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 40    edit_preview: EditPreview,
 41}
 42
 43impl CurrentCompletion {
 44    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 45        edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 46    }
 47}
 48
 49pub struct OllamaEditPredictionDelegate {
 50    http_client: Arc<dyn HttpClient>,
 51    pending_request: Option<Task<Result<()>>>,
 52    current_completion: Option<CurrentCompletion>,
 53}
 54
 55impl OllamaEditPredictionDelegate {
 56    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
 57        Self {
 58            http_client,
 59            pending_request: None,
 60            current_completion: None,
 61        }
 62    }
 63
 64    pub fn is_available(cx: &App) -> bool {
 65        let ollama_provider_id = LanguageModelProviderId::new("ollama");
 66        LanguageModelRegistry::read_global(cx)
 67            .provider(&ollama_provider_id)
 68            .is_some_and(|provider| provider.is_authenticated(cx))
 69    }
 70
 71    async fn fetch_completion(
 72        http_client: Arc<dyn HttpClient>,
 73        prompt: String,
 74        suffix: String,
 75        model: String,
 76        api_url: String,
 77    ) -> Result<String> {
 78        let start_time = Instant::now();
 79
 80        log::debug!("Ollama: Requesting completion (model: {})", model);
 81
 82        let fim_prompt = format_fim_prompt(&model, &prompt, &suffix);
 83
 84        let request = OllamaGenerateRequest {
 85            model,
 86            prompt: fim_prompt,
 87            raw: true,
 88            stream: false,
 89            options: Some(OllamaGenerateOptions {
 90                num_predict: Some(64),
 91                temperature: Some(0.2),
 92                stop: Some(get_stop_tokens()),
 93            }),
 94        };
 95
 96        let request_body = serde_json::to_string(&request)?;
 97
 98        log::debug!("Ollama: Sending FIM request");
 99
100        let http_request = http_client::Request::builder()
101            .method(http_client::Method::POST)
102            .uri(format!("{}/api/generate", api_url))
103            .header("Content-Type", "application/json")
104            .body(http_client::AsyncBody::from(request_body))?;
105
106        let mut response = http_client.send(http_request).await?;
107        let status = response.status();
108
109        log::debug!("Ollama: Response status: {}", status);
110
111        if !status.is_success() {
112            let mut body = String::new();
113            response.body_mut().read_to_string(&mut body).await?;
114            return Err(anyhow::anyhow!("Ollama API error: {} - {}", status, body));
115        }
116
117        let mut body = String::new();
118        response.body_mut().read_to_string(&mut body).await?;
119
120        let ollama_response: OllamaGenerateResponse =
121            serde_json::from_str(&body).context("Failed to parse Ollama response")?;
122
123        let elapsed = start_time.elapsed();
124
125        log::debug!(
126            "Ollama: Completion received ({:.2}s)",
127            elapsed.as_secs_f64()
128        );
129
130        let completion = clean_completion(&ollama_response.response);
131        Ok(completion)
132    }
133}
134
135impl EditPredictionDelegate for OllamaEditPredictionDelegate {
136    fn name() -> &'static str {
137        "ollama"
138    }
139
140    fn display_name() -> &'static str {
141        "Ollama"
142    }
143
144    fn show_predictions_in_menu() -> bool {
145        true
146    }
147
148    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
149        Self::is_available(cx)
150    }
151
152    fn is_refreshing(&self, _cx: &App) -> bool {
153        self.pending_request.is_some()
154    }
155
156    fn refresh(
157        &mut self,
158        buffer: Entity<Buffer>,
159        cursor_position: Anchor,
160        debounce: bool,
161        cx: &mut Context<Self>,
162    ) {
163        log::debug!("Ollama: Refresh called (debounce: {})", debounce);
164
165        let snapshot = buffer.read(cx).snapshot();
166
167        if let Some(current_completion) = self.current_completion.as_ref() {
168            if current_completion.interpolate(&snapshot).is_some() {
169                return;
170            }
171        }
172
173        let http_client = self.http_client.clone();
174
175        let settings = all_language_settings(None, cx);
176        let configured_model = settings.edit_predictions.ollama.model.clone();
177        let api_url = settings
178            .edit_predictions
179            .ollama
180            .api_url
181            .clone()
182            .unwrap_or_else(|| OLLAMA_API_URL.to_string());
183
184        self.pending_request = Some(cx.spawn(async move |this, cx| {
185            if debounce {
186                log::debug!("Ollama: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
187                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
188            }
189
190            let model = if let Some(model) = configured_model
191                .as_deref()
192                .map(str::trim)
193                .filter(|model| !model.is_empty())
194            {
195                model.to_string()
196            } else {
197                let local_models = get_models(http_client.as_ref(), &api_url, None).await?;
198                let available_model_names = local_models.iter().map(|model| model.name.as_str());
199
200                match pick_recommended_edit_prediction_model(available_model_names) {
201                    Some(recommended) => recommended.to_string(),
202                    None => {
203                        log::debug!(
204                            "Ollama: No model configured and no recommended local model found; skipping edit prediction"
205                        );
206                        this.update(cx, |this, cx| {
207                            this.pending_request = None;
208                            cx.notify();
209                        })?;
210                        return Ok(());
211                    }
212                }
213            };
214
215            let cursor_offset = cursor_position.to_offset(&snapshot);
216            let cursor_point = cursor_offset.to_point(&snapshot);
217            let excerpt = EditPredictionExcerpt::select_from_buffer(
218                cursor_point,
219                &snapshot,
220                &EXCERPT_OPTIONS,
221            )
222            .context("Line containing cursor doesn't fit in excerpt max bytes")?;
223
224            let excerpt_text = excerpt.text(&snapshot);
225            let cursor_within_excerpt = cursor_offset
226                .saturating_sub(excerpt.range.start)
227                .min(excerpt_text.body.len());
228            let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
229            let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
230
231            let completion_text =
232                match Self::fetch_completion(http_client, prompt, suffix, model, api_url).await {
233                    Ok(completion) => completion,
234                    Err(e) => {
235                        log::error!("Ollama: Failed to fetch completion: {}", e);
236                        this.update(cx, |this, cx| {
237                            this.pending_request = None;
238                            cx.notify();
239                        })?;
240                        return Err(e);
241                    }
242                };
243
244            if completion_text.trim().is_empty() {
245                log::debug!("Ollama: Completion was empty after trimming; ignoring");
246                this.update(cx, |this, cx| {
247                    this.pending_request = None;
248                    cx.notify();
249                })?;
250                return Ok(());
251            }
252
253            let edits: Arc<[(Range<Anchor>, Arc<str>)]> = buffer.read_with(cx, |buffer, _cx| {
254                // Clamp the requested offset to the current buffer snapshot length.
255                //
256                // `anchor_after` ultimately asserts that the offset is within the rope bounds
257                // (in debug builds), and our `cursor_position` may be stale vs. the snapshot
258                // we used to compute `cursor_offset`.
259                let snapshot = buffer.snapshot();
260                let clamped_cursor_offset = cursor_offset.min(snapshot.len());
261
262                // Use anchor_after (Right bias) so the cursor stays before the completion text,
263                // not at the end of it. This matches how Copilot handles edit predictions.
264                let position = buffer.anchor_after(clamped_cursor_offset);
265                vec![(position..position, completion_text.into())].into()
266            })?;
267            let edit_preview = buffer
268                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
269                .await;
270
271            this.update(cx, |this, cx| {
272                this.current_completion = Some(CurrentCompletion {
273                    snapshot,
274                    edits,
275                    edit_preview,
276                });
277                this.pending_request = None;
278                cx.notify();
279            })?;
280
281            Ok(())
282        }));
283    }
284
285    fn accept(&mut self, _cx: &mut Context<Self>) {
286        log::debug!("Ollama: Completion accepted");
287        self.pending_request = None;
288        self.current_completion = None;
289    }
290
291    fn discard(&mut self, _cx: &mut Context<Self>) {
292        log::debug!("Ollama: Completion discarded");
293        self.pending_request = None;
294        self.current_completion = None;
295    }
296
297    fn suggest(
298        &mut self,
299        buffer: &Entity<Buffer>,
300        _cursor_position: Anchor,
301        cx: &mut Context<Self>,
302    ) -> Option<EditPrediction> {
303        let current_completion = self.current_completion.as_ref()?;
304        let buffer = buffer.read(cx);
305        let edits = current_completion.interpolate(&buffer.snapshot())?;
306        if edits.is_empty() {
307            return None;
308        }
309        Some(EditPrediction::Local {
310            id: None,
311            edits,
312            edit_preview: Some(current_completion.edit_preview.clone()),
313        })
314    }
315}
316
317fn format_fim_prompt(model: &str, prefix: &str, suffix: &str) -> String {
318    let model_base = model.split(':').next().unwrap_or(model);
319
320    match model_base {
321        "codellama" | "code-llama" => {
322            format!("<PRE> {prefix} <SUF>{suffix} <MID>")
323        }
324        "starcoder" | "starcoder2" | "starcoderbase" => {
325            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
326        }
327        "deepseek-coder" | "deepseek-coder-v2" => {
328            // DeepSeek uses special Unicode characters for FIM tokens
329            format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
330        }
331        "qwen2.5-coder" | "qwen-coder" | "qwen" => {
332            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
333        }
334        "codegemma" => {
335            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
336        }
337        "codestral" | "mistral" => {
338            format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
339        }
340        "glm" | "glm-4" | "glm-4.5" => {
341            format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
342        }
343        _ => {
344            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
345        }
346    }
347}
348
349fn get_stop_tokens() -> Vec<String> {
350    vec![
351        "<|endoftext|>".to_string(),
352        "<|file_separator|>".to_string(),
353        "<|fim_pad|>".to_string(),
354        "<|fim_prefix|>".to_string(),
355        "<|fim_middle|>".to_string(),
356        "<|fim_suffix|>".to_string(),
357        "<fim_prefix>".to_string(),
358        "<fim_middle>".to_string(),
359        "<fim_suffix>".to_string(),
360        "<PRE>".to_string(),
361        "<SUF>".to_string(),
362        "<MID>".to_string(),
363        "[PREFIX]".to_string(),
364        "[SUFFIX]".to_string(),
365    ]
366}
367
368fn clean_completion(response: &str) -> String {
369    let mut result = response.to_string();
370
371    let end_tokens = [
372        "<|endoftext|>",
373        "<|file_separator|>",
374        "<|fim_pad|>",
375        "<|fim_prefix|>",
376        "<|fim_middle|>",
377        "<|fim_suffix|>",
378        "<fim_prefix>",
379        "<fim_middle>",
380        "<fim_suffix>",
381        "<PRE>",
382        "<SUF>",
383        "<MID>",
384        "[PREFIX]",
385        "[SUFFIX]",
386    ];
387
388    for token in &end_tokens {
389        if let Some(pos) = result.find(token) {
390            result.truncate(pos);
391        }
392    }
393
394    result
395}
396
397#[derive(Debug, Serialize)]
398struct OllamaGenerateRequest {
399    model: String,
400    prompt: String,
401    raw: bool,
402    stream: bool,
403    #[serde(skip_serializing_if = "Option::is_none")]
404    options: Option<OllamaGenerateOptions>,
405}
406
407#[derive(Debug, Serialize)]
408struct OllamaGenerateOptions {
409    #[serde(skip_serializing_if = "Option::is_none")]
410    num_predict: Option<u32>,
411    #[serde(skip_serializing_if = "Option::is_none")]
412    temperature: Option<f32>,
413    #[serde(skip_serializing_if = "Option::is_none")]
414    stop: Option<Vec<String>>,
415}
416
417#[derive(Debug, Deserialize)]
418struct OllamaGenerateResponse {
419    response: String,
420}
421pub fn pick_recommended_edit_prediction_model<'a>(
422    available_models: impl IntoIterator<Item = &'a str>,
423) -> Option<&'static str> {
424    let available: std::collections::HashSet<&str> = available_models.into_iter().collect();
425
426    RECOMMENDED_EDIT_PREDICTION_MODELS
427        .into_iter()
428        .find(|recommended| available.contains(recommended))
429}