1use anyhow::{Context as _, Result};
2use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
3use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate};
4use futures::AsyncReadExt;
5use gpui::{App, Context, Entity, Task};
6use http_client::HttpClient;
7use language::{
8 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint,
9};
10use language_models::MistralLanguageModelProvider;
11use mistral::CODESTRAL_API_URL;
12use serde::{Deserialize, Serialize};
13use std::{
14 ops::Range,
15 sync::Arc,
16 time::{Duration, Instant},
17};
18use text::ToOffset;
19
20pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
21
22const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
23 max_bytes: 1050,
24 min_bytes: 525,
25 target_before_cursor_over_total_bytes: 0.66,
26};
27
28/// Represents a completion that has been received and processed from Codestral.
29/// This struct maintains the state needed to interpolate the completion as the user types.
30#[derive(Clone)]
31struct CurrentCompletion {
32 /// The buffer snapshot at the time the completion was generated.
33 /// Used to detect changes and interpolate edits.
34 snapshot: BufferSnapshot,
35 /// The edits that should be applied to transform the original text into the predicted text.
36 /// Each edit is a range in the buffer and the text to replace it with.
37 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
38 /// Preview of how the buffer will look after applying the edits.
39 edit_preview: EditPreview,
40}
41
42impl CurrentCompletion {
43 /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
44 /// Returns None if the user's edits conflict with the predicted edits.
45 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
46 edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
47 }
48}
49
50pub struct CodestralEditPredictionDelegate {
51 http_client: Arc<dyn HttpClient>,
52 pending_request: Option<Task<Result<()>>>,
53 current_completion: Option<CurrentCompletion>,
54}
55
56impl CodestralEditPredictionDelegate {
57 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
58 Self {
59 http_client,
60 pending_request: None,
61 current_completion: None,
62 }
63 }
64
65 pub fn has_api_key(cx: &App) -> bool {
66 Self::api_key(cx).is_some()
67 }
68
69 /// This is so we can immediately show Codestral as a provider users can
70 /// switch to in the edit prediction menu, if the API has been added
71 pub fn ensure_api_key_loaded(http_client: Arc<dyn HttpClient>, cx: &mut App) {
72 MistralLanguageModelProvider::global(http_client, cx)
73 .load_codestral_api_key(cx)
74 .detach();
75 }
76
77 fn api_key(cx: &App) -> Option<Arc<str>> {
78 MistralLanguageModelProvider::try_global(cx)
79 .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
80 }
81
82 /// Uses Codestral's Fill-in-the-Middle API for code completion.
83 async fn fetch_completion(
84 http_client: Arc<dyn HttpClient>,
85 api_key: &str,
86 prompt: String,
87 suffix: String,
88 model: String,
89 max_tokens: Option<u32>,
90 api_url: String,
91 ) -> Result<String> {
92 let start_time = Instant::now();
93
94 log::debug!(
95 "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
96 model,
97 max_tokens
98 );
99
100 let request = CodestralRequest {
101 model,
102 prompt,
103 suffix: if suffix.is_empty() {
104 None
105 } else {
106 Some(suffix)
107 },
108 max_tokens: max_tokens.or(Some(350)),
109 temperature: Some(0.2),
110 top_p: Some(1.0),
111 stream: Some(false),
112 stop: None,
113 random_seed: None,
114 min_tokens: None,
115 };
116
117 let request_body = serde_json::to_string(&request)?;
118
119 log::debug!("Codestral: Sending FIM request");
120
121 let http_request = http_client::Request::builder()
122 .method(http_client::Method::POST)
123 .uri(format!("{}/v1/fim/completions", api_url))
124 .header("Content-Type", "application/json")
125 .header("Authorization", format!("Bearer {}", api_key))
126 .body(http_client::AsyncBody::from(request_body))?;
127
128 let mut response = http_client.send(http_request).await?;
129 let status = response.status();
130
131 log::debug!("Codestral: Response status: {}", status);
132
133 if !status.is_success() {
134 let mut body = String::new();
135 response.body_mut().read_to_string(&mut body).await?;
136 return Err(anyhow::anyhow!(
137 "Codestral API error: {} - {}",
138 status,
139 body
140 ));
141 }
142
143 let mut body = String::new();
144 response.body_mut().read_to_string(&mut body).await?;
145
146 let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
147
148 let elapsed = start_time.elapsed();
149
150 if let Some(choice) = codestral_response.choices.first() {
151 let completion = &choice.message.content;
152
153 log::debug!(
154 "Codestral: Completion received ({} tokens, {:.2}s)",
155 codestral_response.usage.completion_tokens,
156 elapsed.as_secs_f64()
157 );
158
159 // Return just the completion text for insertion at cursor
160 Ok(completion.clone())
161 } else {
162 log::error!("Codestral: No completion returned in response");
163 Err(anyhow::anyhow!("No completion returned from Codestral"))
164 }
165 }
166}
167
168impl EditPredictionDelegate for CodestralEditPredictionDelegate {
169 fn name() -> &'static str {
170 "codestral"
171 }
172
173 fn display_name() -> &'static str {
174 "Codestral"
175 }
176
177 fn show_predictions_in_menu() -> bool {
178 true
179 }
180
181 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
182 Self::api_key(cx).is_some()
183 }
184
185 fn is_refreshing(&self, _cx: &App) -> bool {
186 self.pending_request.is_some()
187 }
188
189 fn refresh(
190 &mut self,
191 buffer: Entity<Buffer>,
192 cursor_position: language::Anchor,
193 debounce: bool,
194 cx: &mut Context<Self>,
195 ) {
196 log::debug!("Codestral: Refresh called (debounce: {})", debounce);
197
198 let Some(api_key) = Self::api_key(cx) else {
199 log::warn!("Codestral: No API key configured, skipping refresh");
200 return;
201 };
202
203 let snapshot = buffer.read(cx).snapshot();
204
205 // Check if current completion is still valid
206 if let Some(current_completion) = self.current_completion.as_ref() {
207 if current_completion.interpolate(&snapshot).is_some() {
208 return;
209 }
210 }
211
212 let http_client = self.http_client.clone();
213
214 // Get settings
215 let settings = all_language_settings(None, cx);
216 let model = settings
217 .edit_predictions
218 .codestral
219 .model
220 .clone()
221 .unwrap_or_else(|| "codestral-latest".to_string());
222 let max_tokens = settings.edit_predictions.codestral.max_tokens;
223 let api_url = settings
224 .edit_predictions
225 .codestral
226 .api_url
227 .clone()
228 .unwrap_or_else(|| CODESTRAL_API_URL.to_string());
229
230 self.pending_request = Some(cx.spawn(async move |this, cx| {
231 if debounce {
232 log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
233 smol::Timer::after(DEBOUNCE_TIMEOUT).await;
234 }
235
236 let cursor_offset = cursor_position.to_offset(&snapshot);
237 let cursor_point = cursor_offset.to_point(&snapshot);
238 let excerpt = EditPredictionExcerpt::select_from_buffer(
239 cursor_point,
240 &snapshot,
241 &EXCERPT_OPTIONS,
242 )
243 .context("Line containing cursor doesn't fit in excerpt max bytes")?;
244
245 let excerpt_text = excerpt.text(&snapshot);
246 let cursor_within_excerpt = cursor_offset
247 .saturating_sub(excerpt.range.start)
248 .min(excerpt_text.body.len());
249 let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
250 let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
251
252 let completion_text = match Self::fetch_completion(
253 http_client,
254 &api_key,
255 prompt,
256 suffix,
257 model,
258 max_tokens,
259 api_url,
260 )
261 .await
262 {
263 Ok(completion) => completion,
264 Err(e) => {
265 log::error!("Codestral: Failed to fetch completion: {}", e);
266 this.update(cx, |this, cx| {
267 this.pending_request = None;
268 cx.notify();
269 })?;
270 return Err(e);
271 }
272 };
273
274 if completion_text.trim().is_empty() {
275 log::debug!("Codestral: Completion was empty after trimming; ignoring");
276 this.update(cx, |this, cx| {
277 this.pending_request = None;
278 cx.notify();
279 })?;
280 return Ok(());
281 }
282
283 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
284 vec![(cursor_position..cursor_position, completion_text.into())].into();
285 let edit_preview = buffer
286 .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
287 .await;
288
289 this.update(cx, |this, cx| {
290 this.current_completion = Some(CurrentCompletion {
291 snapshot,
292 edits,
293 edit_preview,
294 });
295 this.pending_request = None;
296 cx.notify();
297 })?;
298
299 Ok(())
300 }));
301 }
302
303 fn cycle(
304 &mut self,
305 _buffer: Entity<Buffer>,
306 _cursor_position: Anchor,
307 _direction: Direction,
308 _cx: &mut Context<Self>,
309 ) {
310 // Codestral doesn't support multiple completions, so cycling does nothing
311 }
312
313 fn accept(&mut self, _cx: &mut Context<Self>) {
314 log::debug!("Codestral: Completion accepted");
315 self.pending_request = None;
316 self.current_completion = None;
317 }
318
319 fn discard(&mut self, _cx: &mut Context<Self>) {
320 log::debug!("Codestral: Completion discarded");
321 self.pending_request = None;
322 self.current_completion = None;
323 }
324
325 /// Returns the completion suggestion, adjusted or invalidated based on user edits
326 fn suggest(
327 &mut self,
328 buffer: &Entity<Buffer>,
329 _cursor_position: Anchor,
330 cx: &mut Context<Self>,
331 ) -> Option<EditPrediction> {
332 let current_completion = self.current_completion.as_ref()?;
333 let buffer = buffer.read(cx);
334 let edits = current_completion.interpolate(&buffer.snapshot())?;
335 if edits.is_empty() {
336 return None;
337 }
338 Some(EditPrediction::Local {
339 id: None,
340 edits,
341 edit_preview: Some(current_completion.edit_preview.clone()),
342 })
343 }
344}
345
346#[derive(Debug, Serialize, Deserialize)]
347pub struct CodestralRequest {
348 pub model: String,
349 pub prompt: String,
350 #[serde(skip_serializing_if = "Option::is_none")]
351 pub suffix: Option<String>,
352 #[serde(skip_serializing_if = "Option::is_none")]
353 pub max_tokens: Option<u32>,
354 #[serde(skip_serializing_if = "Option::is_none")]
355 pub temperature: Option<f32>,
356 #[serde(skip_serializing_if = "Option::is_none")]
357 pub top_p: Option<f32>,
358 #[serde(skip_serializing_if = "Option::is_none")]
359 pub stream: Option<bool>,
360 #[serde(skip_serializing_if = "Option::is_none")]
361 pub stop: Option<Vec<String>>,
362 #[serde(skip_serializing_if = "Option::is_none")]
363 pub random_seed: Option<u32>,
364 #[serde(skip_serializing_if = "Option::is_none")]
365 pub min_tokens: Option<u32>,
366}
367
368#[derive(Debug, Deserialize)]
369pub struct CodestralResponse {
370 pub id: String,
371 pub object: String,
372 pub model: String,
373 pub usage: Usage,
374 pub created: u64,
375 pub choices: Vec<Choice>,
376}
377
378#[derive(Debug, Deserialize)]
379pub struct Usage {
380 pub prompt_tokens: u32,
381 pub completion_tokens: u32,
382 pub total_tokens: u32,
383}
384
385#[derive(Debug, Deserialize)]
386pub struct Choice {
387 pub index: u32,
388 pub message: Message,
389 pub finish_reason: String,
390}
391
392#[derive(Debug, Deserialize)]
393pub struct Message {
394 pub content: String,
395 pub role: String,
396}