ollama.rs

  1use crate::{
  2    EditPredictionId, EditPredictionModelInput, cursor_excerpt,
  3    prediction::EditPredictionResult,
  4    zeta1::{
  5        self, MAX_CONTEXT_TOKENS as ZETA_MAX_CONTEXT_TOKENS,
  6        MAX_EVENT_TOKENS as ZETA_MAX_EVENT_TOKENS,
  7    },
  8};
  9use anyhow::{Context as _, Result};
 10use futures::AsyncReadExt as _;
 11use gpui::{App, AppContext as _, Entity, SharedString, Task, http_client};
 12use language::{
 13    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, ToOffset, ToPoint as _,
 14    language_settings::all_language_settings,
 15};
 16use language_model::{LanguageModelProviderId, LanguageModelRegistry};
 17use serde::{Deserialize, Serialize};
 18use std::{path::Path, sync::Arc, time::Instant};
 19use zeta_prompt::{
 20    ZetaPromptInput,
 21    zeta1::{EDITABLE_REGION_END_MARKER, format_zeta1_prompt},
 22};
 23
 24const FIM_CONTEXT_TOKENS: usize = 512;
 25
 26pub struct Ollama;
 27
 28#[derive(Debug, Serialize)]
 29struct OllamaGenerateRequest {
 30    model: String,
 31    prompt: String,
 32    raw: bool,
 33    stream: bool,
 34    #[serde(skip_serializing_if = "Option::is_none")]
 35    options: Option<OllamaGenerateOptions>,
 36}
 37
 38#[derive(Debug, Serialize)]
 39struct OllamaGenerateOptions {
 40    #[serde(skip_serializing_if = "Option::is_none")]
 41    num_predict: Option<u32>,
 42    #[serde(skip_serializing_if = "Option::is_none")]
 43    temperature: Option<f32>,
 44    #[serde(skip_serializing_if = "Option::is_none")]
 45    stop: Option<Vec<String>>,
 46}
 47
 48#[derive(Debug, Deserialize)]
 49struct OllamaGenerateResponse {
 50    created_at: String,
 51    response: String,
 52}
 53
 54const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
 55
 56pub fn is_available(cx: &App) -> bool {
 57    LanguageModelRegistry::read_global(cx)
 58        .provider(&PROVIDER_ID)
 59        .is_some_and(|provider| provider.is_authenticated(cx))
 60}
 61
 62pub fn ensure_authenticated(cx: &mut App) {
 63    if let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&PROVIDER_ID) {
 64        provider.authenticate(cx).detach_and_log_err(cx);
 65    }
 66}
 67
 68pub fn fetch_models(cx: &mut App) -> Vec<SharedString> {
 69    let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&PROVIDER_ID) else {
 70        return Vec::new();
 71    };
 72    provider.authenticate(cx).detach_and_log_err(cx);
 73    let mut models: Vec<SharedString> = provider
 74        .provided_models(cx)
 75        .into_iter()
 76        .map(|model| SharedString::from(model.id().0.to_string()))
 77        .collect();
 78    models.sort();
 79    models
 80}
 81
 82/// Output from the Ollama HTTP request, containing all data needed to create the prediction result.
 83struct OllamaRequestOutput {
 84    created_at: String,
 85    edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
 86    snapshot: BufferSnapshot,
 87    response_received_at: Instant,
 88    inputs: ZetaPromptInput,
 89    buffer: Entity<Buffer>,
 90    buffer_snapshotted_at: Instant,
 91}
 92
 93impl Ollama {
 94    pub fn new() -> Self {
 95        Self
 96    }
 97
 98    pub fn request_prediction(
 99        &self,
100        EditPredictionModelInput {
101            buffer,
102            snapshot,
103            position,
104            events,
105            ..
106        }: EditPredictionModelInput,
107        cx: &mut App,
108    ) -> Task<Result<Option<EditPredictionResult>>> {
109        let settings = &all_language_settings(None, cx).edit_predictions.ollama;
110        let Some(model) = settings.model.clone() else {
111            return Task::ready(Ok(None));
112        };
113        let api_url = settings.api_url.clone();
114
115        log::debug!("Ollama: Requesting completion (model: {})", model);
116
117        let full_path: Arc<Path> = snapshot
118            .file()
119            .map(|file| file.full_path(cx))
120            .unwrap_or_else(|| "untitled".into())
121            .into();
122
123        let http_client = cx.http_client();
124        let cursor_point = position.to_point(&snapshot);
125        let buffer_snapshotted_at = Instant::now();
126
127        let is_zeta = is_zeta_model(&model);
128
129        // Zeta generates more tokens than FIM models. Ideally, we'd use MAX_REWRITE_TOKENS,
130        // but this might be too slow for local deployments. So we make it configurable,
131        // but we also have this hardcoded multiplier for now.
132        let max_output_tokens = if is_zeta {
133            settings.max_output_tokens * 4
134        } else {
135            settings.max_output_tokens
136        };
137
138        let result = cx.background_spawn(async move {
139            let zeta_editable_region_tokens = max_output_tokens as usize;
140
141            // For zeta models, use the dedicated zeta1 functions which handle their own
142            // range computation with the correct token limits.
143            let (prompt, stop_tokens, editable_range_override, inputs) = if is_zeta {
144                let path_str = full_path.to_string_lossy();
145                let input_excerpt = zeta1::excerpt_for_cursor_position(
146                    cursor_point,
147                    &path_str,
148                    &snapshot,
149                    zeta_editable_region_tokens,
150                    ZETA_MAX_CONTEXT_TOKENS,
151                );
152                let input_events = zeta1::prompt_for_events(&events, ZETA_MAX_EVENT_TOKENS);
153                let prompt = format_zeta1_prompt(&input_events, &input_excerpt.prompt);
154                let editable_offset_range = input_excerpt.editable_range.to_offset(&snapshot);
155                let context_offset_range = input_excerpt.context_range.to_offset(&snapshot);
156                let stop_tokens = get_zeta_stop_tokens();
157
158                let inputs = ZetaPromptInput {
159                    events,
160                    related_files: Vec::new(),
161                    cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
162                        - context_offset_range.start,
163                    cursor_path: full_path.clone(),
164                    cursor_excerpt: snapshot
165                        .text_for_range(input_excerpt.context_range.clone())
166                        .collect::<String>()
167                        .into(),
168                    editable_range_in_excerpt: (editable_offset_range.start
169                        - context_offset_range.start)
170                        ..(editable_offset_range.end - context_offset_range.start),
171                    excerpt_start_row: Some(input_excerpt.context_range.start.row),
172                    excerpt_ranges: None,
173                    preferred_model: None,
174                    in_open_source_repo: false,
175                };
176
177                (prompt, stop_tokens, Some(editable_offset_range), inputs)
178            } else {
179                let (excerpt_range, _) =
180                    cursor_excerpt::editable_and_context_ranges_for_cursor_position(
181                        cursor_point,
182                        &snapshot,
183                        FIM_CONTEXT_TOKENS,
184                        0,
185                    );
186                let excerpt_offset_range = excerpt_range.to_offset(&snapshot);
187                let cursor_offset = cursor_point.to_offset(&snapshot);
188
189                let inputs = ZetaPromptInput {
190                    events,
191                    related_files: Vec::new(),
192                    cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
193                    editable_range_in_excerpt: cursor_offset - excerpt_offset_range.start
194                        ..cursor_offset - excerpt_offset_range.start,
195                    cursor_path: full_path.clone(),
196                    excerpt_start_row: Some(excerpt_range.start.row),
197                    cursor_excerpt: snapshot
198                        .text_for_range(excerpt_range)
199                        .collect::<String>()
200                        .into(),
201                    excerpt_ranges: None,
202                    preferred_model: None,
203                    in_open_source_repo: false,
204                };
205
206                let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string();
207                let suffix = inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..].to_string();
208                let prompt = format_fim_prompt(&model, &prefix, &suffix);
209                let stop_tokens = get_fim_stop_tokens();
210
211                (prompt, stop_tokens, None, inputs)
212            };
213
214            let request = OllamaGenerateRequest {
215                model: model.clone(),
216                prompt,
217                raw: true,
218                stream: false,
219                options: Some(OllamaGenerateOptions {
220                    num_predict: Some(max_output_tokens),
221                    temperature: Some(0.2),
222                    stop: Some(stop_tokens),
223                }),
224            };
225
226            let request_body = serde_json::to_string(&request)?;
227            let http_request = http_client::Request::builder()
228                .method(http_client::Method::POST)
229                .uri(format!("{}/api/generate", api_url))
230                .header("Content-Type", "application/json")
231                .body(http_client::AsyncBody::from(request_body))?;
232
233            let mut response = http_client.send(http_request).await?;
234            let status = response.status();
235
236            log::debug!("Ollama: Response status: {}", status);
237
238            if !status.is_success() {
239                let mut body = String::new();
240                response.body_mut().read_to_string(&mut body).await?;
241                return Err(anyhow::anyhow!("Ollama API error: {} - {}", status, body));
242            }
243
244            let mut body = String::new();
245            response.body_mut().read_to_string(&mut body).await?;
246
247            let ollama_response: OllamaGenerateResponse =
248                serde_json::from_str(&body).context("Failed to parse Ollama response")?;
249
250            let response_received_at = Instant::now();
251
252            log::debug!(
253                "Ollama: Completion received ({:.2}s)",
254                (response_received_at - buffer_snapshotted_at).as_secs_f64()
255            );
256
257            let edits = if is_zeta {
258                let editable_range =
259                    editable_range_override.expect("zeta model should have editable range");
260
261                log::trace!("ollama response: {}", ollama_response.response);
262
263                let response = clean_zeta_completion(&ollama_response.response);
264                match zeta1::parse_edits(&response, editable_range, &snapshot) {
265                    Ok(edits) => edits,
266                    Err(err) => {
267                        log::warn!("Ollama zeta: Failed to parse response: {}", err);
268                        vec![]
269                    }
270                }
271            } else {
272                let completion: Arc<str> = clean_fim_completion(&ollama_response.response).into();
273                if completion.is_empty() {
274                    vec![]
275                } else {
276                    let cursor_offset = cursor_point.to_offset(&snapshot);
277                    let anchor = snapshot.anchor_after(cursor_offset);
278                    vec![(anchor..anchor, completion)]
279                }
280            };
281
282            anyhow::Ok(OllamaRequestOutput {
283                created_at: ollama_response.created_at,
284                edits,
285                snapshot,
286                response_received_at,
287                inputs,
288                buffer,
289                buffer_snapshotted_at,
290            })
291        });
292
293        cx.spawn(async move |cx: &mut gpui::AsyncApp| {
294            let output = result.await.context("Ollama edit prediction failed")?;
295            anyhow::Ok(Some(
296                EditPredictionResult::new(
297                    EditPredictionId(output.created_at.into()),
298                    &output.buffer,
299                    &output.snapshot,
300                    output.edits.into(),
301                    None,
302                    output.buffer_snapshotted_at,
303                    output.response_received_at,
304                    output.inputs,
305                    cx,
306                )
307                .await,
308            ))
309        })
310    }
311}
312
313fn is_zeta_model(model: &str) -> bool {
314    model.to_lowercase().contains("zeta")
315}
316
317fn get_zeta_stop_tokens() -> Vec<String> {
318    vec![EDITABLE_REGION_END_MARKER.to_string(), "```".to_string()]
319}
320
321fn format_fim_prompt(model: &str, prefix: &str, suffix: &str) -> String {
322    let model_base = model.split(':').next().unwrap_or(model);
323
324    match model_base {
325        "codellama" | "code-llama" => {
326            format!("<PRE> {prefix} <SUF>{suffix} <MID>")
327        }
328        "starcoder" | "starcoder2" | "starcoderbase" => {
329            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
330        }
331        "deepseek-coder" | "deepseek-coder-v2" => {
332            format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
333        }
334        "qwen2.5-coder" | "qwen-coder" | "qwen" => {
335            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
336        }
337        "codegemma" => {
338            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
339        }
340        "codestral" | "mistral" => {
341            format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
342        }
343        "glm" | "glm-4" | "glm-4.5" => {
344            format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
345        }
346        _ => {
347            format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
348        }
349    }
350}
351
352fn get_fim_stop_tokens() -> Vec<String> {
353    vec![
354        "<|endoftext|>".to_string(),
355        "<|file_separator|>".to_string(),
356        "<|fim_pad|>".to_string(),
357        "<|fim_prefix|>".to_string(),
358        "<|fim_middle|>".to_string(),
359        "<|fim_suffix|>".to_string(),
360        "<fim_prefix>".to_string(),
361        "<fim_middle>".to_string(),
362        "<fim_suffix>".to_string(),
363        "<PRE>".to_string(),
364        "<SUF>".to_string(),
365        "<MID>".to_string(),
366        "[PREFIX]".to_string(),
367        "[SUFFIX]".to_string(),
368    ]
369}
370
371fn clean_zeta_completion(mut response: &str) -> &str {
372    if let Some(last_newline_ix) = response.rfind('\n') {
373        let last_line = &response[last_newline_ix + 1..];
374        if EDITABLE_REGION_END_MARKER.starts_with(&last_line) {
375            response = &response[..last_newline_ix]
376        }
377    }
378    response
379}
380
381fn clean_fim_completion(response: &str) -> String {
382    let mut result = response.to_string();
383
384    let end_tokens = [
385        "<|endoftext|>",
386        "<|file_separator|>",
387        "<|fim_pad|>",
388        "<|fim_prefix|>",
389        "<|fim_middle|>",
390        "<|fim_suffix|>",
391        "<fim_prefix>",
392        "<fim_middle>",
393        "<fim_suffix>",
394        "<PRE>",
395        "<SUF>",
396        "<MID>",
397        "[PREFIX]",
398        "[SUFFIX]",
399    ];
400
401    for token in &end_tokens {
402        if let Some(pos) = result.find(token) {
403            result.truncate(pos);
404        }
405    }
406
407    result
408}