1use crate::{
2 EditPredictionId, EditPredictionModelInput, cursor_excerpt,
3 prediction::EditPredictionResult,
4 zeta1::{
5 self, MAX_CONTEXT_TOKENS as ZETA_MAX_CONTEXT_TOKENS,
6 MAX_EVENT_TOKENS as ZETA_MAX_EVENT_TOKENS,
7 },
8};
9use anyhow::{Context as _, Result};
10use futures::AsyncReadExt as _;
11use gpui::{App, AppContext as _, Entity, SharedString, Task, http_client};
12use language::{
13 Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, ToOffset, ToPoint as _,
14 language_settings::all_language_settings,
15};
16use language_model::{LanguageModelProviderId, LanguageModelRegistry};
17use serde::{Deserialize, Serialize};
18use std::{path::Path, sync::Arc, time::Instant};
19use zeta_prompt::{
20 ZetaPromptInput,
21 zeta1::{EDITABLE_REGION_END_MARKER, format_zeta1_prompt},
22};
23
24const FIM_CONTEXT_TOKENS: usize = 512;
25
26pub struct Ollama;
27
28#[derive(Debug, Serialize)]
29struct OllamaGenerateRequest {
30 model: String,
31 prompt: String,
32 raw: bool,
33 stream: bool,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 options: Option<OllamaGenerateOptions>,
36}
37
38#[derive(Debug, Serialize)]
39struct OllamaGenerateOptions {
40 #[serde(skip_serializing_if = "Option::is_none")]
41 num_predict: Option<u32>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 temperature: Option<f32>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 stop: Option<Vec<String>>,
46}
47
48#[derive(Debug, Deserialize)]
49struct OllamaGenerateResponse {
50 created_at: String,
51 response: String,
52}
53
54const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
55
56pub fn is_available(cx: &App) -> bool {
57 LanguageModelRegistry::read_global(cx)
58 .provider(&PROVIDER_ID)
59 .is_some_and(|provider| provider.is_authenticated(cx))
60}
61
62pub fn ensure_authenticated(cx: &mut App) {
63 if let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&PROVIDER_ID) {
64 provider.authenticate(cx).detach_and_log_err(cx);
65 }
66}
67
68pub fn fetch_models(cx: &mut App) -> Vec<SharedString> {
69 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&PROVIDER_ID) else {
70 return Vec::new();
71 };
72 provider.authenticate(cx).detach_and_log_err(cx);
73 let mut models: Vec<SharedString> = provider
74 .provided_models(cx)
75 .into_iter()
76 .map(|model| SharedString::from(model.id().0.to_string()))
77 .collect();
78 models.sort();
79 models
80}
81
82/// Output from the Ollama HTTP request, containing all data needed to create the prediction result.
83struct OllamaRequestOutput {
84 created_at: String,
85 edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
86 snapshot: BufferSnapshot,
87 response_received_at: Instant,
88 inputs: ZetaPromptInput,
89 buffer: Entity<Buffer>,
90 buffer_snapshotted_at: Instant,
91}
92
93impl Ollama {
94 pub fn new() -> Self {
95 Self
96 }
97
98 pub fn request_prediction(
99 &self,
100 EditPredictionModelInput {
101 buffer,
102 snapshot,
103 position,
104 events,
105 ..
106 }: EditPredictionModelInput,
107 cx: &mut App,
108 ) -> Task<Result<Option<EditPredictionResult>>> {
109 let settings = &all_language_settings(None, cx).edit_predictions.ollama;
110 let Some(model) = settings.model.clone() else {
111 return Task::ready(Ok(None));
112 };
113 let api_url = settings.api_url.clone();
114
115 log::debug!("Ollama: Requesting completion (model: {})", model);
116
117 let full_path: Arc<Path> = snapshot
118 .file()
119 .map(|file| file.full_path(cx))
120 .unwrap_or_else(|| "untitled".into())
121 .into();
122
123 let http_client = cx.http_client();
124 let cursor_point = position.to_point(&snapshot);
125 let buffer_snapshotted_at = Instant::now();
126
127 let is_zeta = is_zeta_model(&model);
128
129 // Zeta generates more tokens than FIM models. Ideally, we'd use MAX_REWRITE_TOKENS,
130 // but this might be too slow for local deployments. So we make it configurable,
131 // but we also have this hardcoded multiplier for now.
132 let max_output_tokens = if is_zeta {
133 settings.max_output_tokens * 4
134 } else {
135 settings.max_output_tokens
136 };
137
138 let result = cx.background_spawn(async move {
139 let zeta_editable_region_tokens = max_output_tokens as usize;
140
141 // For zeta models, use the dedicated zeta1 functions which handle their own
142 // range computation with the correct token limits.
143 let (prompt, stop_tokens, editable_range_override, inputs) = if is_zeta {
144 let path_str = full_path.to_string_lossy();
145 let input_excerpt = zeta1::excerpt_for_cursor_position(
146 cursor_point,
147 &path_str,
148 &snapshot,
149 zeta_editable_region_tokens,
150 ZETA_MAX_CONTEXT_TOKENS,
151 );
152 let input_events = zeta1::prompt_for_events(&events, ZETA_MAX_EVENT_TOKENS);
153 let prompt = format_zeta1_prompt(&input_events, &input_excerpt.prompt);
154 let editable_offset_range = input_excerpt.editable_range.to_offset(&snapshot);
155 let context_offset_range = input_excerpt.context_range.to_offset(&snapshot);
156 let stop_tokens = get_zeta_stop_tokens();
157
158 let inputs = ZetaPromptInput {
159 events,
160 related_files: Vec::new(),
161 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
162 - context_offset_range.start,
163 cursor_path: full_path.clone(),
164 cursor_excerpt: snapshot
165 .text_for_range(input_excerpt.context_range.clone())
166 .collect::<String>()
167 .into(),
168 editable_range_in_excerpt: (editable_offset_range.start
169 - context_offset_range.start)
170 ..(editable_offset_range.end - context_offset_range.start),
171 excerpt_start_row: Some(input_excerpt.context_range.start.row),
172 excerpt_ranges: None,
173 preferred_model: None,
174 in_open_source_repo: false,
175 can_collect_data: false,
176 };
177
178 (prompt, stop_tokens, Some(editable_offset_range), inputs)
179 } else {
180 let (excerpt_range, _) =
181 cursor_excerpt::editable_and_context_ranges_for_cursor_position(
182 cursor_point,
183 &snapshot,
184 FIM_CONTEXT_TOKENS,
185 0,
186 );
187 let excerpt_offset_range = excerpt_range.to_offset(&snapshot);
188 let cursor_offset = cursor_point.to_offset(&snapshot);
189
190 let inputs = ZetaPromptInput {
191 events,
192 related_files: Vec::new(),
193 cursor_offset_in_excerpt: cursor_offset - excerpt_offset_range.start,
194 editable_range_in_excerpt: cursor_offset - excerpt_offset_range.start
195 ..cursor_offset - excerpt_offset_range.start,
196 cursor_path: full_path.clone(),
197 excerpt_start_row: Some(excerpt_range.start.row),
198 cursor_excerpt: snapshot
199 .text_for_range(excerpt_range)
200 .collect::<String>()
201 .into(),
202 excerpt_ranges: None,
203 preferred_model: None,
204 in_open_source_repo: false,
205 can_collect_data: false,
206 };
207
208 let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string();
209 let suffix = inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..].to_string();
210 let prompt = format_fim_prompt(&model, &prefix, &suffix);
211 let stop_tokens = get_fim_stop_tokens();
212
213 (prompt, stop_tokens, None, inputs)
214 };
215
216 let request = OllamaGenerateRequest {
217 model: model.clone(),
218 prompt,
219 raw: true,
220 stream: false,
221 options: Some(OllamaGenerateOptions {
222 num_predict: Some(max_output_tokens),
223 temperature: Some(0.2),
224 stop: Some(stop_tokens),
225 }),
226 };
227
228 let request_body = serde_json::to_string(&request)?;
229 let http_request = http_client::Request::builder()
230 .method(http_client::Method::POST)
231 .uri(format!("{}/api/generate", api_url))
232 .header("Content-Type", "application/json")
233 .body(http_client::AsyncBody::from(request_body))?;
234
235 let mut response = http_client.send(http_request).await?;
236 let status = response.status();
237
238 log::debug!("Ollama: Response status: {}", status);
239
240 if !status.is_success() {
241 let mut body = String::new();
242 response.body_mut().read_to_string(&mut body).await?;
243 return Err(anyhow::anyhow!("Ollama API error: {} - {}", status, body));
244 }
245
246 let mut body = String::new();
247 response.body_mut().read_to_string(&mut body).await?;
248
249 let ollama_response: OllamaGenerateResponse =
250 serde_json::from_str(&body).context("Failed to parse Ollama response")?;
251
252 let response_received_at = Instant::now();
253
254 log::debug!(
255 "Ollama: Completion received ({:.2}s)",
256 (response_received_at - buffer_snapshotted_at).as_secs_f64()
257 );
258
259 let edits = if is_zeta {
260 let editable_range =
261 editable_range_override.expect("zeta model should have editable range");
262
263 log::trace!("ollama response: {}", ollama_response.response);
264
265 let response = clean_zeta_completion(&ollama_response.response);
266 match zeta1::parse_edits(&response, editable_range, &snapshot) {
267 Ok(edits) => edits,
268 Err(err) => {
269 log::warn!("Ollama zeta: Failed to parse response: {}", err);
270 vec![]
271 }
272 }
273 } else {
274 let completion: Arc<str> = clean_fim_completion(&ollama_response.response).into();
275 if completion.is_empty() {
276 vec![]
277 } else {
278 let cursor_offset = cursor_point.to_offset(&snapshot);
279 let anchor = snapshot.anchor_after(cursor_offset);
280 vec![(anchor..anchor, completion)]
281 }
282 };
283
284 anyhow::Ok(OllamaRequestOutput {
285 created_at: ollama_response.created_at,
286 edits,
287 snapshot,
288 response_received_at,
289 inputs,
290 buffer,
291 buffer_snapshotted_at,
292 })
293 });
294
295 cx.spawn(async move |cx: &mut gpui::AsyncApp| {
296 let output = result.await.context("Ollama edit prediction failed")?;
297 anyhow::Ok(Some(
298 EditPredictionResult::new(
299 EditPredictionId(output.created_at.into()),
300 &output.buffer,
301 &output.snapshot,
302 output.edits.into(),
303 None,
304 output.buffer_snapshotted_at,
305 output.response_received_at,
306 output.inputs,
307 cx,
308 )
309 .await,
310 ))
311 })
312 }
313}
314
315fn is_zeta_model(model: &str) -> bool {
316 model.to_lowercase().contains("zeta")
317}
318
319fn get_zeta_stop_tokens() -> Vec<String> {
320 vec![EDITABLE_REGION_END_MARKER.to_string(), "```".to_string()]
321}
322
323fn format_fim_prompt(model: &str, prefix: &str, suffix: &str) -> String {
324 let model_base = model.split(':').next().unwrap_or(model);
325
326 match model_base {
327 "codellama" | "code-llama" => {
328 format!("<PRE> {prefix} <SUF>{suffix} <MID>")
329 }
330 "starcoder" | "starcoder2" | "starcoderbase" => {
331 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
332 }
333 "deepseek-coder" | "deepseek-coder-v2" => {
334 format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
335 }
336 "qwen2.5-coder" | "qwen-coder" | "qwen" => {
337 format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
338 }
339 "codegemma" => {
340 format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
341 }
342 "codestral" | "mistral" => {
343 format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
344 }
345 "glm" | "glm-4" | "glm-4.5" => {
346 format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
347 }
348 _ => {
349 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
350 }
351 }
352}
353
354fn get_fim_stop_tokens() -> Vec<String> {
355 vec![
356 "<|endoftext|>".to_string(),
357 "<|file_separator|>".to_string(),
358 "<|fim_pad|>".to_string(),
359 "<|fim_prefix|>".to_string(),
360 "<|fim_middle|>".to_string(),
361 "<|fim_suffix|>".to_string(),
362 "<fim_prefix>".to_string(),
363 "<fim_middle>".to_string(),
364 "<fim_suffix>".to_string(),
365 "<PRE>".to_string(),
366 "<SUF>".to_string(),
367 "<MID>".to_string(),
368 "[PREFIX]".to_string(),
369 "[SUFFIX]".to_string(),
370 ]
371}
372
373fn clean_zeta_completion(mut response: &str) -> &str {
374 if let Some(last_newline_ix) = response.rfind('\n') {
375 let last_line = &response[last_newline_ix + 1..];
376 if EDITABLE_REGION_END_MARKER.starts_with(&last_line) {
377 response = &response[..last_newline_ix]
378 }
379 }
380 response
381}
382
383fn clean_fim_completion(response: &str) -> String {
384 let mut result = response.to_string();
385
386 let end_tokens = [
387 "<|endoftext|>",
388 "<|file_separator|>",
389 "<|fim_pad|>",
390 "<|fim_prefix|>",
391 "<|fim_middle|>",
392 "<|fim_suffix|>",
393 "<fim_prefix>",
394 "<fim_middle>",
395 "<fim_suffix>",
396 "<PRE>",
397 "<SUF>",
398 "<MID>",
399 "[PREFIX]",
400 "[SUFFIX]",
401 ];
402
403 for token in &end_tokens {
404 if let Some(pos) = result.find(token) {
405 result.truncate(pos);
406 }
407 }
408
409 result
410}