1use anyhow::{Context as _, Result};
2use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
3use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
4use futures::AsyncReadExt;
5use gpui::{App, Context, Entity, Task};
6use http_client::HttpClient;
7use language::{
8 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint,
9};
10use language_models::MistralLanguageModelProvider;
11use mistral::CODESTRAL_API_URL;
12use serde::{Deserialize, Serialize};
13use std::{
14 ops::Range,
15 sync::Arc,
16 time::{Duration, Instant},
17};
18use text::ToOffset;
19
20pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
21
22const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
23 max_bytes: 1050,
24 min_bytes: 525,
25 target_before_cursor_over_total_bytes: 0.66,
26};
27
28/// Represents a completion that has been received and processed from Codestral.
29/// This struct maintains the state needed to interpolate the completion as the user types.
30#[derive(Clone)]
31struct CurrentCompletion {
32 /// The buffer snapshot at the time the completion was generated.
33 /// Used to detect changes and interpolate edits.
34 snapshot: BufferSnapshot,
35 /// The edits that should be applied to transform the original text into the predicted text.
36 /// Each edit is a range in the buffer and the text to replace it with.
37 edits: Arc<[(Range<Anchor>, String)]>,
38 /// Preview of how the buffer will look after applying the edits.
39 edit_preview: EditPreview,
40}
41
42impl CurrentCompletion {
43 /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
44 /// Returns None if the user's edits conflict with the predicted edits.
45 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
46 edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
47 }
48}
49
50pub struct CodestralCompletionProvider {
51 http_client: Arc<dyn HttpClient>,
52 pending_request: Option<Task<Result<()>>>,
53 current_completion: Option<CurrentCompletion>,
54}
55
56impl CodestralCompletionProvider {
57 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
58 Self {
59 http_client,
60 pending_request: None,
61 current_completion: None,
62 }
63 }
64
65 pub fn has_api_key(cx: &App) -> bool {
66 Self::api_key(cx).is_some()
67 }
68
69 /// This is so we can immediately show Codestral as a provider users can
70 /// switch to in the edit prediction menu, if the API has been added
71 pub fn ensure_api_key_loaded(http_client: Arc<dyn HttpClient>, cx: &mut App) {
72 MistralLanguageModelProvider::global(http_client, cx)
73 .load_codestral_api_key(cx)
74 .detach();
75 }
76
77 fn api_key(cx: &App) -> Option<Arc<str>> {
78 MistralLanguageModelProvider::try_global(cx)
79 .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
80 }
81
82 /// Uses Codestral's Fill-in-the-Middle API for code completion.
83 async fn fetch_completion(
84 http_client: Arc<dyn HttpClient>,
85 api_key: &str,
86 prompt: String,
87 suffix: String,
88 model: String,
89 max_tokens: Option<u32>,
90 api_url: String,
91 ) -> Result<String> {
92 let start_time = Instant::now();
93
94 log::debug!(
95 "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
96 model,
97 max_tokens
98 );
99
100 let request = CodestralRequest {
101 model,
102 prompt,
103 suffix: if suffix.is_empty() {
104 None
105 } else {
106 Some(suffix)
107 },
108 max_tokens: max_tokens.or(Some(350)),
109 temperature: Some(0.2),
110 top_p: Some(1.0),
111 stream: Some(false),
112 stop: None,
113 random_seed: None,
114 min_tokens: None,
115 };
116
117 let request_body = serde_json::to_string(&request)?;
118
119 log::debug!("Codestral: Sending FIM request");
120
121 let http_request = http_client::Request::builder()
122 .method(http_client::Method::POST)
123 .uri(format!("{}/v1/fim/completions", api_url))
124 .header("Content-Type", "application/json")
125 .header("Authorization", format!("Bearer {}", api_key))
126 .body(http_client::AsyncBody::from(request_body))?;
127
128 let mut response = http_client.send(http_request).await?;
129 let status = response.status();
130
131 log::debug!("Codestral: Response status: {}", status);
132
133 if !status.is_success() {
134 let mut body = String::new();
135 response.body_mut().read_to_string(&mut body).await?;
136 return Err(anyhow::anyhow!(
137 "Codestral API error: {} - {}",
138 status,
139 body
140 ));
141 }
142
143 let mut body = String::new();
144 response.body_mut().read_to_string(&mut body).await?;
145
146 let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
147
148 let elapsed = start_time.elapsed();
149
150 if let Some(choice) = codestral_response.choices.first() {
151 let completion = &choice.message.content;
152
153 log::debug!(
154 "Codestral: Completion received ({} tokens, {:.2}s)",
155 codestral_response.usage.completion_tokens,
156 elapsed.as_secs_f64()
157 );
158
159 // Return just the completion text for insertion at cursor
160 Ok(completion.clone())
161 } else {
162 log::error!("Codestral: No completion returned in response");
163 Err(anyhow::anyhow!("No completion returned from Codestral"))
164 }
165 }
166}
167
168impl EditPredictionProvider for CodestralCompletionProvider {
169 fn name() -> &'static str {
170 "codestral"
171 }
172
173 fn display_name() -> &'static str {
174 "Codestral"
175 }
176
177 fn show_completions_in_menu() -> bool {
178 true
179 }
180
181 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
182 Self::api_key(cx).is_some()
183 }
184
185 fn is_refreshing(&self) -> bool {
186 self.pending_request.is_some()
187 }
188
189 fn refresh(
190 &mut self,
191 buffer: Entity<Buffer>,
192 cursor_position: language::Anchor,
193 debounce: bool,
194 cx: &mut Context<Self>,
195 ) {
196 log::debug!("Codestral: Refresh called (debounce: {})", debounce);
197
198 let Some(api_key) = Self::api_key(cx) else {
199 log::warn!("Codestral: No API key configured, skipping refresh");
200 return;
201 };
202
203 let snapshot = buffer.read(cx).snapshot();
204
205 // Check if current completion is still valid
206 if let Some(current_completion) = self.current_completion.as_ref() {
207 if current_completion.interpolate(&snapshot).is_some() {
208 return;
209 }
210 }
211
212 let http_client = self.http_client.clone();
213
214 // Get settings
215 let settings = all_language_settings(None, cx);
216 let model = settings
217 .edit_predictions
218 .codestral
219 .model
220 .clone()
221 .unwrap_or_else(|| "codestral-latest".to_string());
222 let max_tokens = settings.edit_predictions.codestral.max_tokens;
223 let api_url = settings
224 .edit_predictions
225 .codestral
226 .api_url
227 .clone()
228 .unwrap_or_else(|| CODESTRAL_API_URL.to_string());
229
230 self.pending_request = Some(cx.spawn(async move |this, cx| {
231 if debounce {
232 log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
233 smol::Timer::after(DEBOUNCE_TIMEOUT).await;
234 }
235
236 let cursor_offset = cursor_position.to_offset(&snapshot);
237 let cursor_point = cursor_offset.to_point(&snapshot);
238 let excerpt = EditPredictionExcerpt::select_from_buffer(
239 cursor_point,
240 &snapshot,
241 &EXCERPT_OPTIONS,
242 None,
243 )
244 .context("Line containing cursor doesn't fit in excerpt max bytes")?;
245
246 let excerpt_text = excerpt.text(&snapshot);
247 let cursor_within_excerpt = cursor_offset
248 .saturating_sub(excerpt.range.start)
249 .min(excerpt_text.body.len());
250 let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
251 let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
252
253 let completion_text = match Self::fetch_completion(
254 http_client,
255 &api_key,
256 prompt,
257 suffix,
258 model,
259 max_tokens,
260 api_url,
261 )
262 .await
263 {
264 Ok(completion) => completion,
265 Err(e) => {
266 log::error!("Codestral: Failed to fetch completion: {}", e);
267 this.update(cx, |this, cx| {
268 this.pending_request = None;
269 cx.notify();
270 })?;
271 return Err(e);
272 }
273 };
274
275 if completion_text.trim().is_empty() {
276 log::debug!("Codestral: Completion was empty after trimming; ignoring");
277 this.update(cx, |this, cx| {
278 this.pending_request = None;
279 cx.notify();
280 })?;
281 return Ok(());
282 }
283
284 let edits: Arc<[(Range<Anchor>, String)]> =
285 vec![(cursor_position..cursor_position, completion_text)].into();
286 let edit_preview = buffer
287 .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
288 .await;
289
290 this.update(cx, |this, cx| {
291 this.current_completion = Some(CurrentCompletion {
292 snapshot,
293 edits,
294 edit_preview,
295 });
296 this.pending_request = None;
297 cx.notify();
298 })?;
299
300 Ok(())
301 }));
302 }
303
304 fn cycle(
305 &mut self,
306 _buffer: Entity<Buffer>,
307 _cursor_position: Anchor,
308 _direction: Direction,
309 _cx: &mut Context<Self>,
310 ) {
311 // Codestral doesn't support multiple completions, so cycling does nothing
312 }
313
314 fn accept(&mut self, _cx: &mut Context<Self>) {
315 log::debug!("Codestral: Completion accepted");
316 self.pending_request = None;
317 self.current_completion = None;
318 }
319
320 fn discard(&mut self, _cx: &mut Context<Self>) {
321 log::debug!("Codestral: Completion discarded");
322 self.pending_request = None;
323 self.current_completion = None;
324 }
325
326 /// Returns the completion suggestion, adjusted or invalidated based on user edits
327 fn suggest(
328 &mut self,
329 buffer: &Entity<Buffer>,
330 _cursor_position: Anchor,
331 cx: &mut Context<Self>,
332 ) -> Option<EditPrediction> {
333 let current_completion = self.current_completion.as_ref()?;
334 let buffer = buffer.read(cx);
335 let edits = current_completion.interpolate(&buffer.snapshot())?;
336 if edits.is_empty() {
337 return None;
338 }
339 Some(EditPrediction::Local {
340 id: None,
341 edits,
342 edit_preview: Some(current_completion.edit_preview.clone()),
343 })
344 }
345}
346
347#[derive(Debug, Serialize, Deserialize)]
348pub struct CodestralRequest {
349 pub model: String,
350 pub prompt: String,
351 #[serde(skip_serializing_if = "Option::is_none")]
352 pub suffix: Option<String>,
353 #[serde(skip_serializing_if = "Option::is_none")]
354 pub max_tokens: Option<u32>,
355 #[serde(skip_serializing_if = "Option::is_none")]
356 pub temperature: Option<f32>,
357 #[serde(skip_serializing_if = "Option::is_none")]
358 pub top_p: Option<f32>,
359 #[serde(skip_serializing_if = "Option::is_none")]
360 pub stream: Option<bool>,
361 #[serde(skip_serializing_if = "Option::is_none")]
362 pub stop: Option<Vec<String>>,
363 #[serde(skip_serializing_if = "Option::is_none")]
364 pub random_seed: Option<u32>,
365 #[serde(skip_serializing_if = "Option::is_none")]
366 pub min_tokens: Option<u32>,
367}
368
369#[derive(Debug, Deserialize)]
370pub struct CodestralResponse {
371 pub id: String,
372 pub object: String,
373 pub model: String,
374 pub usage: Usage,
375 pub created: u64,
376 pub choices: Vec<Choice>,
377}
378
379#[derive(Debug, Deserialize)]
380pub struct Usage {
381 pub prompt_tokens: u32,
382 pub completion_tokens: u32,
383 pub total_tokens: u32,
384}
385
386#[derive(Debug, Deserialize)]
387pub struct Choice {
388 pub index: u32,
389 pub message: Message,
390 pub finish_reason: String,
391}
392
393#[derive(Debug, Deserialize)]
394pub struct Message {
395 pub content: String,
396 pub role: String,
397}