1use anyhow::Result;
2use edit_prediction::cursor_excerpt;
3use edit_prediction_types::{
4 EditPrediction, EditPredictionDelegate, EditPredictionDismissReason, EditPredictionIconSet,
5};
6use futures::AsyncReadExt;
7use gpui::{App, Context, Entity, Task};
8use http_client::HttpClient;
9use icons::IconName;
10use language::{
11 Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint, language_settings::all_language_settings,
12};
13use language_models::MistralLanguageModelProvider;
14use mistral::CODESTRAL_API_URL;
15use serde::{Deserialize, Serialize};
16use std::{
17 ops::Range,
18 sync::Arc,
19 time::{Duration, Instant},
20};
21use text::{OffsetRangeExt as _, ToOffset};
22
23pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
24
25/// Represents a completion that has been received and processed from Codestral.
26/// This struct maintains the state needed to interpolate the completion as the user types.
27#[derive(Clone)]
28struct CurrentCompletion {
29 /// The buffer snapshot at the time the completion was generated.
30 /// Used to detect changes and interpolate edits.
31 snapshot: BufferSnapshot,
32 /// The edits that should be applied to transform the original text into the predicted text.
33 /// Each edit is a range in the buffer and the text to replace it with.
34 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
35 /// Preview of how the buffer will look after applying the edits.
36 edit_preview: EditPreview,
37}
38
39impl CurrentCompletion {
40 /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
41 /// Returns None if the user's edits conflict with the predicted edits.
42 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
43 edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
44 }
45}
46
47pub struct CodestralEditPredictionDelegate {
48 http_client: Arc<dyn HttpClient>,
49 pending_request: Option<Task<Result<()>>>,
50 current_completion: Option<CurrentCompletion>,
51}
52
53impl CodestralEditPredictionDelegate {
54 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
55 Self {
56 http_client,
57 pending_request: None,
58 current_completion: None,
59 }
60 }
61
62 pub fn has_api_key(cx: &App) -> bool {
63 Self::api_key(cx).is_some()
64 }
65
66 /// This is so we can immediately show Codestral as a provider users can
67 /// switch to in the edit prediction menu, if the API has been added
68 pub fn ensure_api_key_loaded(http_client: Arc<dyn HttpClient>, cx: &mut App) {
69 MistralLanguageModelProvider::global(http_client, cx)
70 .load_codestral_api_key(cx)
71 .detach();
72 }
73
74 fn api_key(cx: &App) -> Option<Arc<str>> {
75 MistralLanguageModelProvider::try_global(cx)
76 .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
77 }
78
79 /// Uses Codestral's Fill-in-the-Middle API for code completion.
80 async fn fetch_completion(
81 http_client: Arc<dyn HttpClient>,
82 api_key: &str,
83 prompt: String,
84 suffix: String,
85 model: String,
86 max_tokens: Option<u32>,
87 api_url: String,
88 ) -> Result<String> {
89 let start_time = Instant::now();
90
91 log::debug!(
92 "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
93 model,
94 max_tokens
95 );
96
97 let request = CodestralRequest {
98 model,
99 prompt,
100 suffix: if suffix.is_empty() {
101 None
102 } else {
103 Some(suffix)
104 },
105 max_tokens: max_tokens.or(Some(350)),
106 temperature: Some(0.2),
107 top_p: Some(1.0),
108 stream: Some(false),
109 stop: None,
110 random_seed: None,
111 min_tokens: None,
112 };
113
114 let request_body = serde_json::to_string(&request)?;
115
116 log::debug!("Codestral: Sending FIM request");
117
118 let http_request = http_client::Request::builder()
119 .method(http_client::Method::POST)
120 .uri(format!("{}/v1/fim/completions", api_url))
121 .header("Content-Type", "application/json")
122 .header("Authorization", format!("Bearer {}", api_key))
123 .body(http_client::AsyncBody::from(request_body))?;
124
125 let mut response = http_client.send(http_request).await?;
126 let status = response.status();
127
128 log::debug!("Codestral: Response status: {}", status);
129
130 if !status.is_success() {
131 let mut body = String::new();
132 response.body_mut().read_to_string(&mut body).await?;
133 return Err(anyhow::anyhow!(
134 "Codestral API error: {} - {}",
135 status,
136 body
137 ));
138 }
139
140 let mut body = String::new();
141 response.body_mut().read_to_string(&mut body).await?;
142
143 let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
144
145 let elapsed = start_time.elapsed();
146
147 if let Some(choice) = codestral_response.choices.first() {
148 let completion = &choice.message.content;
149
150 log::debug!(
151 "Codestral: Completion received ({} tokens, {:.2}s)",
152 codestral_response.usage.completion_tokens,
153 elapsed.as_secs_f64()
154 );
155
156 // Return just the completion text for insertion at cursor
157 Ok(completion.clone())
158 } else {
159 log::error!("Codestral: No completion returned in response");
160 Err(anyhow::anyhow!("No completion returned from Codestral"))
161 }
162 }
163}
164
165impl EditPredictionDelegate for CodestralEditPredictionDelegate {
166 fn name() -> &'static str {
167 "codestral"
168 }
169
170 fn display_name() -> &'static str {
171 "Codestral"
172 }
173
174 fn show_predictions_in_menu() -> bool {
175 true
176 }
177
178 fn icons(&self, _cx: &App) -> EditPredictionIconSet {
179 EditPredictionIconSet::new(IconName::AiMistral)
180 }
181
182 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
183 Self::api_key(cx).is_some()
184 }
185
186 fn is_refreshing(&self, _cx: &App) -> bool {
187 self.pending_request.is_some()
188 }
189
190 fn refresh(
191 &mut self,
192 buffer: Entity<Buffer>,
193 cursor_position: language::Anchor,
194 debounce: bool,
195 cx: &mut Context<Self>,
196 ) {
197 log::debug!("Codestral: Refresh called (debounce: {})", debounce);
198
199 let Some(api_key) = Self::api_key(cx) else {
200 log::warn!("Codestral: No API key configured, skipping refresh");
201 return;
202 };
203
204 let snapshot = buffer.read(cx).snapshot();
205
206 // Check if current completion is still valid
207 if let Some(current_completion) = self.current_completion.as_ref() {
208 if current_completion.interpolate(&snapshot).is_some() {
209 return;
210 }
211 }
212
213 let http_client = self.http_client.clone();
214
215 // Get settings
216 let settings = all_language_settings(None, cx);
217 let model = settings
218 .edit_predictions
219 .codestral
220 .model
221 .clone()
222 .unwrap_or_else(|| "codestral-latest".to_string());
223 let max_tokens = settings.edit_predictions.codestral.max_tokens;
224 let api_url = settings
225 .edit_predictions
226 .codestral
227 .api_url
228 .clone()
229 .unwrap_or_else(|| CODESTRAL_API_URL.to_string());
230
231 self.pending_request = Some(cx.spawn(async move |this, cx| {
232 if debounce {
233 log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
234 cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
235 }
236
237 let cursor_offset = cursor_position.to_offset(&snapshot);
238 let cursor_point = cursor_offset.to_point(&snapshot);
239
240 const MAX_CONTEXT_TOKENS: usize = 150;
241 const MAX_REWRITE_TOKENS: usize = 350;
242
243 let (_, context_range) =
244 cursor_excerpt::editable_and_context_ranges_for_cursor_position(
245 cursor_point,
246 &snapshot,
247 MAX_REWRITE_TOKENS,
248 MAX_CONTEXT_TOKENS,
249 );
250
251 let context_range = context_range.to_offset(&snapshot);
252 let excerpt_text = snapshot
253 .text_for_range(context_range.clone())
254 .collect::<String>();
255 let cursor_within_excerpt = cursor_offset
256 .saturating_sub(context_range.start)
257 .min(excerpt_text.len());
258 let prompt = excerpt_text[..cursor_within_excerpt].to_string();
259 let suffix = excerpt_text[cursor_within_excerpt..].to_string();
260
261 let completion_text = match Self::fetch_completion(
262 http_client,
263 &api_key,
264 prompt,
265 suffix,
266 model,
267 max_tokens,
268 api_url,
269 )
270 .await
271 {
272 Ok(completion) => completion,
273 Err(e) => {
274 log::error!("Codestral: Failed to fetch completion: {}", e);
275 this.update(cx, |this, cx| {
276 this.pending_request = None;
277 cx.notify();
278 })?;
279 return Err(e);
280 }
281 };
282
283 if completion_text.trim().is_empty() {
284 log::debug!("Codestral: Completion was empty after trimming; ignoring");
285 this.update(cx, |this, cx| {
286 this.pending_request = None;
287 cx.notify();
288 })?;
289 return Ok(());
290 }
291
292 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
293 vec![(cursor_position..cursor_position, completion_text.into())].into();
294 let edit_preview = buffer
295 .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
296 .await;
297
298 this.update(cx, |this, cx| {
299 this.current_completion = Some(CurrentCompletion {
300 snapshot,
301 edits,
302 edit_preview,
303 });
304 this.pending_request = None;
305 cx.notify();
306 })?;
307
308 Ok(())
309 }));
310 }
311
312 fn accept(&mut self, _cx: &mut Context<Self>) {
313 log::debug!("Codestral: Completion accepted");
314 self.pending_request = None;
315 self.current_completion = None;
316 }
317
318 fn discard(&mut self, _reason: EditPredictionDismissReason, _cx: &mut Context<Self>) {
319 log::debug!("Codestral: Completion discarded");
320 self.pending_request = None;
321 self.current_completion = None;
322 }
323
324 /// Returns the completion suggestion, adjusted or invalidated based on user edits
325 fn suggest(
326 &mut self,
327 buffer: &Entity<Buffer>,
328 _cursor_position: Anchor,
329 cx: &mut Context<Self>,
330 ) -> Option<EditPrediction> {
331 let current_completion = self.current_completion.as_ref()?;
332 let buffer = buffer.read(cx);
333 let edits = current_completion.interpolate(&buffer.snapshot())?;
334 if edits.is_empty() {
335 return None;
336 }
337 Some(EditPrediction::Local {
338 id: None,
339 edits,
340 cursor_position: None,
341 edit_preview: Some(current_completion.edit_preview.clone()),
342 })
343 }
344}
345
346#[derive(Debug, Serialize, Deserialize)]
347pub struct CodestralRequest {
348 pub model: String,
349 pub prompt: String,
350 #[serde(skip_serializing_if = "Option::is_none")]
351 pub suffix: Option<String>,
352 #[serde(skip_serializing_if = "Option::is_none")]
353 pub max_tokens: Option<u32>,
354 #[serde(skip_serializing_if = "Option::is_none")]
355 pub temperature: Option<f32>,
356 #[serde(skip_serializing_if = "Option::is_none")]
357 pub top_p: Option<f32>,
358 #[serde(skip_serializing_if = "Option::is_none")]
359 pub stream: Option<bool>,
360 #[serde(skip_serializing_if = "Option::is_none")]
361 pub stop: Option<Vec<String>>,
362 #[serde(skip_serializing_if = "Option::is_none")]
363 pub random_seed: Option<u32>,
364 #[serde(skip_serializing_if = "Option::is_none")]
365 pub min_tokens: Option<u32>,
366}
367
368#[derive(Debug, Deserialize)]
369pub struct CodestralResponse {
370 pub id: String,
371 pub object: String,
372 pub model: String,
373 pub usage: Usage,
374 pub created: u64,
375 pub choices: Vec<Choice>,
376}
377
378#[derive(Debug, Deserialize)]
379pub struct Usage {
380 pub prompt_tokens: u32,
381 pub completion_tokens: u32,
382 pub total_tokens: u32,
383}
384
385#[derive(Debug, Deserialize)]
386pub struct Choice {
387 pub index: u32,
388 pub message: Message,
389 pub finish_reason: String,
390}
391
392#[derive(Debug, Deserialize)]
393pub struct Message {
394 pub content: String,
395 pub role: String,
396}