1use anyhow::Result;
2use edit_prediction::cursor_excerpt;
3use edit_prediction_types::{EditPrediction, EditPredictionDelegate, EditPredictionIconSet};
4use futures::AsyncReadExt;
5use gpui::{App, Context, Entity, Task};
6use http_client::HttpClient;
7use icons::IconName;
8use language::{
9 Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint, language_settings::all_language_settings,
10};
11use language_models::MistralLanguageModelProvider;
12use mistral::CODESTRAL_API_URL;
13use serde::{Deserialize, Serialize};
14use std::{
15 ops::Range,
16 sync::Arc,
17 time::{Duration, Instant},
18};
19use text::{OffsetRangeExt as _, ToOffset};
20
21pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
22
23/// Represents a completion that has been received and processed from Codestral.
24/// This struct maintains the state needed to interpolate the completion as the user types.
25#[derive(Clone)]
26struct CurrentCompletion {
27 /// The buffer snapshot at the time the completion was generated.
28 /// Used to detect changes and interpolate edits.
29 snapshot: BufferSnapshot,
30 /// The edits that should be applied to transform the original text into the predicted text.
31 /// Each edit is a range in the buffer and the text to replace it with.
32 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
33 /// Preview of how the buffer will look after applying the edits.
34 edit_preview: EditPreview,
35}
36
37impl CurrentCompletion {
38 /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
39 /// Returns None if the user's edits conflict with the predicted edits.
40 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
41 edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
42 }
43}
44
45pub struct CodestralEditPredictionDelegate {
46 http_client: Arc<dyn HttpClient>,
47 pending_request: Option<Task<Result<()>>>,
48 current_completion: Option<CurrentCompletion>,
49}
50
51impl CodestralEditPredictionDelegate {
52 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
53 Self {
54 http_client,
55 pending_request: None,
56 current_completion: None,
57 }
58 }
59
60 pub fn has_api_key(cx: &App) -> bool {
61 Self::api_key(cx).is_some()
62 }
63
64 /// This is so we can immediately show Codestral as a provider users can
65 /// switch to in the edit prediction menu, if the API has been added
66 pub fn ensure_api_key_loaded(http_client: Arc<dyn HttpClient>, cx: &mut App) {
67 MistralLanguageModelProvider::global(http_client, cx)
68 .load_codestral_api_key(cx)
69 .detach();
70 }
71
72 fn api_key(cx: &App) -> Option<Arc<str>> {
73 MistralLanguageModelProvider::try_global(cx)
74 .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
75 }
76
77 /// Uses Codestral's Fill-in-the-Middle API for code completion.
78 async fn fetch_completion(
79 http_client: Arc<dyn HttpClient>,
80 api_key: &str,
81 prompt: String,
82 suffix: String,
83 model: String,
84 max_tokens: Option<u32>,
85 api_url: String,
86 ) -> Result<String> {
87 let start_time = Instant::now();
88
89 log::debug!(
90 "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
91 model,
92 max_tokens
93 );
94
95 let request = CodestralRequest {
96 model,
97 prompt,
98 suffix: if suffix.is_empty() {
99 None
100 } else {
101 Some(suffix)
102 },
103 max_tokens: max_tokens.or(Some(350)),
104 temperature: Some(0.2),
105 top_p: Some(1.0),
106 stream: Some(false),
107 stop: None,
108 random_seed: None,
109 min_tokens: None,
110 };
111
112 let request_body = serde_json::to_string(&request)?;
113
114 log::debug!("Codestral: Sending FIM request");
115
116 let http_request = http_client::Request::builder()
117 .method(http_client::Method::POST)
118 .uri(format!("{}/v1/fim/completions", api_url))
119 .header("Content-Type", "application/json")
120 .header("Authorization", format!("Bearer {}", api_key))
121 .body(http_client::AsyncBody::from(request_body))?;
122
123 let mut response = http_client.send(http_request).await?;
124 let status = response.status();
125
126 log::debug!("Codestral: Response status: {}", status);
127
128 if !status.is_success() {
129 let mut body = String::new();
130 response.body_mut().read_to_string(&mut body).await?;
131 return Err(anyhow::anyhow!(
132 "Codestral API error: {} - {}",
133 status,
134 body
135 ));
136 }
137
138 let mut body = String::new();
139 response.body_mut().read_to_string(&mut body).await?;
140
141 let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
142
143 let elapsed = start_time.elapsed();
144
145 if let Some(choice) = codestral_response.choices.first() {
146 let completion = &choice.message.content;
147
148 log::debug!(
149 "Codestral: Completion received ({} tokens, {:.2}s)",
150 codestral_response.usage.completion_tokens,
151 elapsed.as_secs_f64()
152 );
153
154 // Return just the completion text for insertion at cursor
155 Ok(completion.clone())
156 } else {
157 log::error!("Codestral: No completion returned in response");
158 Err(anyhow::anyhow!("No completion returned from Codestral"))
159 }
160 }
161}
162
163impl EditPredictionDelegate for CodestralEditPredictionDelegate {
164 fn name() -> &'static str {
165 "codestral"
166 }
167
168 fn display_name() -> &'static str {
169 "Codestral"
170 }
171
172 fn show_predictions_in_menu() -> bool {
173 true
174 }
175
176 fn icons(&self, _cx: &App) -> EditPredictionIconSet {
177 EditPredictionIconSet::new(IconName::AiMistral)
178 }
179
180 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
181 Self::api_key(cx).is_some()
182 }
183
184 fn is_refreshing(&self, _cx: &App) -> bool {
185 self.pending_request.is_some()
186 }
187
188 fn refresh(
189 &mut self,
190 buffer: Entity<Buffer>,
191 cursor_position: language::Anchor,
192 debounce: bool,
193 cx: &mut Context<Self>,
194 ) {
195 log::debug!("Codestral: Refresh called (debounce: {})", debounce);
196
197 let Some(api_key) = Self::api_key(cx) else {
198 log::warn!("Codestral: No API key configured, skipping refresh");
199 return;
200 };
201
202 let snapshot = buffer.read(cx).snapshot();
203
204 // Check if current completion is still valid
205 if let Some(current_completion) = self.current_completion.as_ref() {
206 if current_completion.interpolate(&snapshot).is_some() {
207 return;
208 }
209 }
210
211 let http_client = self.http_client.clone();
212
213 // Get settings
214 let settings = all_language_settings(None, cx);
215 let model = settings
216 .edit_predictions
217 .codestral
218 .model
219 .clone()
220 .unwrap_or_else(|| "codestral-latest".to_string());
221 let max_tokens = settings.edit_predictions.codestral.max_tokens;
222 let api_url = settings
223 .edit_predictions
224 .codestral
225 .api_url
226 .clone()
227 .unwrap_or_else(|| CODESTRAL_API_URL.to_string());
228
229 self.pending_request = Some(cx.spawn(async move |this, cx| {
230 if debounce {
231 log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
232 cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
233 }
234
235 let cursor_offset = cursor_position.to_offset(&snapshot);
236 let cursor_point = cursor_offset.to_point(&snapshot);
237
238 const MAX_CONTEXT_TOKENS: usize = 150;
239 const MAX_REWRITE_TOKENS: usize = 350;
240
241 let (_, context_range) =
242 cursor_excerpt::editable_and_context_ranges_for_cursor_position(
243 cursor_point,
244 &snapshot,
245 MAX_REWRITE_TOKENS,
246 MAX_CONTEXT_TOKENS,
247 );
248
249 let context_range = context_range.to_offset(&snapshot);
250 let excerpt_text = snapshot
251 .text_for_range(context_range.clone())
252 .collect::<String>();
253 let cursor_within_excerpt = cursor_offset
254 .saturating_sub(context_range.start)
255 .min(excerpt_text.len());
256 let prompt = excerpt_text[..cursor_within_excerpt].to_string();
257 let suffix = excerpt_text[cursor_within_excerpt..].to_string();
258
259 let completion_text = match Self::fetch_completion(
260 http_client,
261 &api_key,
262 prompt,
263 suffix,
264 model,
265 max_tokens,
266 api_url,
267 )
268 .await
269 {
270 Ok(completion) => completion,
271 Err(e) => {
272 log::error!("Codestral: Failed to fetch completion: {}", e);
273 this.update(cx, |this, cx| {
274 this.pending_request = None;
275 cx.notify();
276 })?;
277 return Err(e);
278 }
279 };
280
281 if completion_text.trim().is_empty() {
282 log::debug!("Codestral: Completion was empty after trimming; ignoring");
283 this.update(cx, |this, cx| {
284 this.pending_request = None;
285 cx.notify();
286 })?;
287 return Ok(());
288 }
289
290 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
291 vec![(cursor_position..cursor_position, completion_text.into())].into();
292 let edit_preview = buffer
293 .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
294 .await;
295
296 this.update(cx, |this, cx| {
297 this.current_completion = Some(CurrentCompletion {
298 snapshot,
299 edits,
300 edit_preview,
301 });
302 this.pending_request = None;
303 cx.notify();
304 })?;
305
306 Ok(())
307 }));
308 }
309
310 fn accept(&mut self, _cx: &mut Context<Self>) {
311 log::debug!("Codestral: Completion accepted");
312 self.pending_request = None;
313 self.current_completion = None;
314 }
315
316 fn discard(&mut self, _cx: &mut Context<Self>) {
317 log::debug!("Codestral: Completion discarded");
318 self.pending_request = None;
319 self.current_completion = None;
320 }
321
322 /// Returns the completion suggestion, adjusted or invalidated based on user edits
323 fn suggest(
324 &mut self,
325 buffer: &Entity<Buffer>,
326 _cursor_position: Anchor,
327 cx: &mut Context<Self>,
328 ) -> Option<EditPrediction> {
329 let current_completion = self.current_completion.as_ref()?;
330 let buffer = buffer.read(cx);
331 let edits = current_completion.interpolate(&buffer.snapshot())?;
332 if edits.is_empty() {
333 return None;
334 }
335 Some(EditPrediction::Local {
336 id: None,
337 edits,
338 cursor_position: None,
339 edit_preview: Some(current_completion.edit_preview.clone()),
340 })
341 }
342}
343
344#[derive(Debug, Serialize, Deserialize)]
345pub struct CodestralRequest {
346 pub model: String,
347 pub prompt: String,
348 #[serde(skip_serializing_if = "Option::is_none")]
349 pub suffix: Option<String>,
350 #[serde(skip_serializing_if = "Option::is_none")]
351 pub max_tokens: Option<u32>,
352 #[serde(skip_serializing_if = "Option::is_none")]
353 pub temperature: Option<f32>,
354 #[serde(skip_serializing_if = "Option::is_none")]
355 pub top_p: Option<f32>,
356 #[serde(skip_serializing_if = "Option::is_none")]
357 pub stream: Option<bool>,
358 #[serde(skip_serializing_if = "Option::is_none")]
359 pub stop: Option<Vec<String>>,
360 #[serde(skip_serializing_if = "Option::is_none")]
361 pub random_seed: Option<u32>,
362 #[serde(skip_serializing_if = "Option::is_none")]
363 pub min_tokens: Option<u32>,
364}
365
366#[derive(Debug, Deserialize)]
367pub struct CodestralResponse {
368 pub id: String,
369 pub object: String,
370 pub model: String,
371 pub usage: Usage,
372 pub created: u64,
373 pub choices: Vec<Choice>,
374}
375
376#[derive(Debug, Deserialize)]
377pub struct Usage {
378 pub prompt_tokens: u32,
379 pub completion_tokens: u32,
380 pub total_tokens: u32,
381}
382
383#[derive(Debug, Deserialize)]
384pub struct Choice {
385 pub index: u32,
386 pub message: Message,
387 pub finish_reason: String,
388}
389
390#[derive(Debug, Deserialize)]
391pub struct Message {
392 pub content: String,
393 pub role: String,
394}