1use anyhow::{Context as _, Result};
2use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
3use edit_prediction_types::{EditPrediction, EditPredictionDelegate};
4use futures::AsyncReadExt;
5use gpui::{App, Context, Entity, Task};
6use http_client::HttpClient;
7use language::{
8 Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint, language_settings::all_language_settings,
9};
10use language_model::{LanguageModelProviderId, LanguageModelRegistry};
11use serde::{Deserialize, Serialize};
12use std::{
13 ops::Range,
14 sync::Arc,
15 time::{Duration, Instant},
16};
17use text::ToOffset;
18
19use crate::{OLLAMA_API_URL, get_models};
20
21pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
22
23const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
24 max_bytes: 1050,
25 min_bytes: 525,
26 target_before_cursor_over_total_bytes: 0.66,
27};
28
29pub const RECOMMENDED_EDIT_PREDICTION_MODELS: [&str; 4] = [
30 "qwen2.5-coder:3b-base",
31 "qwen2.5-coder:3b",
32 "qwen2.5-coder:7b-base",
33 "qwen2.5-coder:7b",
34];
35
36#[derive(Clone)]
37struct CurrentCompletion {
38 snapshot: BufferSnapshot,
39 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
40 edit_preview: EditPreview,
41}
42
43impl CurrentCompletion {
44 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
45 edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
46 }
47}
48
49pub struct OllamaEditPredictionDelegate {
50 http_client: Arc<dyn HttpClient>,
51 pending_request: Option<Task<Result<()>>>,
52 current_completion: Option<CurrentCompletion>,
53}
54
55impl OllamaEditPredictionDelegate {
56 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
57 Self {
58 http_client,
59 pending_request: None,
60 current_completion: None,
61 }
62 }
63
64 pub fn is_available(cx: &App) -> bool {
65 let ollama_provider_id = LanguageModelProviderId::new("ollama");
66 LanguageModelRegistry::read_global(cx)
67 .provider(&ollama_provider_id)
68 .is_some_and(|provider| provider.is_authenticated(cx))
69 }
70
71 async fn fetch_completion(
72 http_client: Arc<dyn HttpClient>,
73 prompt: String,
74 suffix: String,
75 model: String,
76 api_url: String,
77 ) -> Result<String> {
78 let start_time = Instant::now();
79
80 log::debug!("Ollama: Requesting completion (model: {})", model);
81
82 let fim_prompt = format_fim_prompt(&model, &prompt, &suffix);
83
84 let request = OllamaGenerateRequest {
85 model,
86 prompt: fim_prompt,
87 raw: true,
88 stream: false,
89 options: Some(OllamaGenerateOptions {
90 num_predict: Some(64),
91 temperature: Some(0.2),
92 stop: Some(get_stop_tokens()),
93 }),
94 };
95
96 let request_body = serde_json::to_string(&request)?;
97
98 log::debug!("Ollama: Sending FIM request");
99
100 let http_request = http_client::Request::builder()
101 .method(http_client::Method::POST)
102 .uri(format!("{}/api/generate", api_url))
103 .header("Content-Type", "application/json")
104 .body(http_client::AsyncBody::from(request_body))?;
105
106 let mut response = http_client.send(http_request).await?;
107 let status = response.status();
108
109 log::debug!("Ollama: Response status: {}", status);
110
111 if !status.is_success() {
112 let mut body = String::new();
113 response.body_mut().read_to_string(&mut body).await?;
114 return Err(anyhow::anyhow!("Ollama API error: {} - {}", status, body));
115 }
116
117 let mut body = String::new();
118 response.body_mut().read_to_string(&mut body).await?;
119
120 let ollama_response: OllamaGenerateResponse =
121 serde_json::from_str(&body).context("Failed to parse Ollama response")?;
122
123 let elapsed = start_time.elapsed();
124
125 log::debug!(
126 "Ollama: Completion received ({:.2}s)",
127 elapsed.as_secs_f64()
128 );
129
130 let completion = clean_completion(&ollama_response.response);
131 Ok(completion)
132 }
133}
134
135impl EditPredictionDelegate for OllamaEditPredictionDelegate {
136 fn name() -> &'static str {
137 "ollama"
138 }
139
140 fn display_name() -> &'static str {
141 "Ollama"
142 }
143
144 fn show_predictions_in_menu() -> bool {
145 true
146 }
147
148 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
149 Self::is_available(cx)
150 }
151
152 fn is_refreshing(&self, _cx: &App) -> bool {
153 self.pending_request.is_some()
154 }
155
156 fn refresh(
157 &mut self,
158 buffer: Entity<Buffer>,
159 cursor_position: Anchor,
160 debounce: bool,
161 cx: &mut Context<Self>,
162 ) {
163 log::debug!("Ollama: Refresh called (debounce: {})", debounce);
164
165 let snapshot = buffer.read(cx).snapshot();
166
167 if let Some(current_completion) = self.current_completion.as_ref() {
168 if current_completion.interpolate(&snapshot).is_some() {
169 return;
170 }
171 }
172
173 let http_client = self.http_client.clone();
174
175 let settings = all_language_settings(None, cx);
176 let configured_model = settings.edit_predictions.ollama.model.clone();
177 let api_url = settings
178 .edit_predictions
179 .ollama
180 .api_url
181 .clone()
182 .unwrap_or_else(|| OLLAMA_API_URL.to_string());
183
184 self.pending_request = Some(cx.spawn(async move |this, cx| {
185 if debounce {
186 log::debug!("Ollama: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
187 cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
188 }
189
190 let model = if let Some(model) = configured_model
191 .as_deref()
192 .map(str::trim)
193 .filter(|model| !model.is_empty())
194 {
195 model.to_string()
196 } else {
197 let local_models = get_models(http_client.as_ref(), &api_url, None).await?;
198 let available_model_names = local_models.iter().map(|model| model.name.as_str());
199
200 match pick_recommended_edit_prediction_model(available_model_names) {
201 Some(recommended) => recommended.to_string(),
202 None => {
203 log::debug!(
204 "Ollama: No model configured and no recommended local model found; skipping edit prediction"
205 );
206 this.update(cx, |this, cx| {
207 this.pending_request = None;
208 cx.notify();
209 })?;
210 return Ok(());
211 }
212 }
213 };
214
215 let cursor_offset = cursor_position.to_offset(&snapshot);
216 let cursor_point = cursor_offset.to_point(&snapshot);
217 let excerpt = EditPredictionExcerpt::select_from_buffer(
218 cursor_point,
219 &snapshot,
220 &EXCERPT_OPTIONS,
221 )
222 .context("Line containing cursor doesn't fit in excerpt max bytes")?;
223
224 let excerpt_text = excerpt.text(&snapshot);
225 let cursor_within_excerpt = cursor_offset
226 .saturating_sub(excerpt.range.start)
227 .min(excerpt_text.body.len());
228 let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
229 let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
230
231 let completion_text =
232 match Self::fetch_completion(http_client, prompt, suffix, model, api_url).await {
233 Ok(completion) => completion,
234 Err(e) => {
235 log::error!("Ollama: Failed to fetch completion: {}", e);
236 this.update(cx, |this, cx| {
237 this.pending_request = None;
238 cx.notify();
239 })?;
240 return Err(e);
241 }
242 };
243
244 if completion_text.trim().is_empty() {
245 log::debug!("Ollama: Completion was empty after trimming; ignoring");
246 this.update(cx, |this, cx| {
247 this.pending_request = None;
248 cx.notify();
249 })?;
250 return Ok(());
251 }
252
253 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = buffer.read_with(cx, |buffer, _cx| {
254 // Clamp the requested offset to the current buffer snapshot length.
255 //
256 // `anchor_after` ultimately asserts that the offset is within the rope bounds
257 // (in debug builds), and our `cursor_position` may be stale vs. the snapshot
258 // we used to compute `cursor_offset`.
259 let snapshot = buffer.snapshot();
260 let clamped_cursor_offset = cursor_offset.min(snapshot.len());
261
262 // Use anchor_after (Right bias) so the cursor stays before the completion text,
263 // not at the end of it. This matches how Copilot handles edit predictions.
264 let position = buffer.anchor_after(clamped_cursor_offset);
265 vec![(position..position, completion_text.into())].into()
266 })?;
267 let edit_preview = buffer
268 .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
269 .await;
270
271 this.update(cx, |this, cx| {
272 this.current_completion = Some(CurrentCompletion {
273 snapshot,
274 edits,
275 edit_preview,
276 });
277 this.pending_request = None;
278 cx.notify();
279 })?;
280
281 Ok(())
282 }));
283 }
284
285 fn accept(&mut self, _cx: &mut Context<Self>) {
286 log::debug!("Ollama: Completion accepted");
287 self.pending_request = None;
288 self.current_completion = None;
289 }
290
291 fn discard(&mut self, _cx: &mut Context<Self>) {
292 log::debug!("Ollama: Completion discarded");
293 self.pending_request = None;
294 self.current_completion = None;
295 }
296
297 fn suggest(
298 &mut self,
299 buffer: &Entity<Buffer>,
300 _cursor_position: Anchor,
301 cx: &mut Context<Self>,
302 ) -> Option<EditPrediction> {
303 let current_completion = self.current_completion.as_ref()?;
304 let buffer = buffer.read(cx);
305 let edits = current_completion.interpolate(&buffer.snapshot())?;
306 if edits.is_empty() {
307 return None;
308 }
309 Some(EditPrediction::Local {
310 id: None,
311 edits,
312 edit_preview: Some(current_completion.edit_preview.clone()),
313 })
314 }
315}
316
317fn format_fim_prompt(model: &str, prefix: &str, suffix: &str) -> String {
318 let model_base = model.split(':').next().unwrap_or(model);
319
320 match model_base {
321 "codellama" | "code-llama" => {
322 format!("<PRE> {prefix} <SUF>{suffix} <MID>")
323 }
324 "starcoder" | "starcoder2" | "starcoderbase" => {
325 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
326 }
327 "deepseek-coder" | "deepseek-coder-v2" => {
328 // DeepSeek uses special Unicode characters for FIM tokens
329 format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
330 }
331 "qwen2.5-coder" | "qwen-coder" | "qwen" => {
332 format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
333 }
334 "codegemma" => {
335 format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
336 }
337 "codestral" | "mistral" => {
338 format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
339 }
340 "glm" | "glm-4" | "glm-4.5" => {
341 format!("<|code_prefix|>{prefix}<|code_suffix|>{suffix}<|code_middle|>")
342 }
343 _ => {
344 format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
345 }
346 }
347}
348
349fn get_stop_tokens() -> Vec<String> {
350 vec![
351 "<|endoftext|>".to_string(),
352 "<|file_separator|>".to_string(),
353 "<|fim_pad|>".to_string(),
354 "<|fim_prefix|>".to_string(),
355 "<|fim_middle|>".to_string(),
356 "<|fim_suffix|>".to_string(),
357 "<fim_prefix>".to_string(),
358 "<fim_middle>".to_string(),
359 "<fim_suffix>".to_string(),
360 "<PRE>".to_string(),
361 "<SUF>".to_string(),
362 "<MID>".to_string(),
363 "[PREFIX]".to_string(),
364 "[SUFFIX]".to_string(),
365 ]
366}
367
368fn clean_completion(response: &str) -> String {
369 let mut result = response.to_string();
370
371 let end_tokens = [
372 "<|endoftext|>",
373 "<|file_separator|>",
374 "<|fim_pad|>",
375 "<|fim_prefix|>",
376 "<|fim_middle|>",
377 "<|fim_suffix|>",
378 "<fim_prefix>",
379 "<fim_middle>",
380 "<fim_suffix>",
381 "<PRE>",
382 "<SUF>",
383 "<MID>",
384 "[PREFIX]",
385 "[SUFFIX]",
386 ];
387
388 for token in &end_tokens {
389 if let Some(pos) = result.find(token) {
390 result.truncate(pos);
391 }
392 }
393
394 result
395}
396
397#[derive(Debug, Serialize)]
398struct OllamaGenerateRequest {
399 model: String,
400 prompt: String,
401 raw: bool,
402 stream: bool,
403 #[serde(skip_serializing_if = "Option::is_none")]
404 options: Option<OllamaGenerateOptions>,
405}
406
407#[derive(Debug, Serialize)]
408struct OllamaGenerateOptions {
409 #[serde(skip_serializing_if = "Option::is_none")]
410 num_predict: Option<u32>,
411 #[serde(skip_serializing_if = "Option::is_none")]
412 temperature: Option<f32>,
413 #[serde(skip_serializing_if = "Option::is_none")]
414 stop: Option<Vec<String>>,
415}
416
417#[derive(Debug, Deserialize)]
418struct OllamaGenerateResponse {
419 response: String,
420}
421pub fn pick_recommended_edit_prediction_model<'a>(
422 available_models: impl IntoIterator<Item = &'a str>,
423) -> Option<&'static str> {
424 let available: std::collections::HashSet<&str> = available_models.into_iter().collect();
425
426 RECOMMENDED_EDIT_PREDICTION_MODELS
427 .into_iter()
428 .find(|recommended| available.contains(recommended))
429}