1use anyhow::Result;
2use edit_prediction::cursor_excerpt;
3use edit_prediction_types::{EditPrediction, EditPredictionDelegate};
4use futures::AsyncReadExt;
5use gpui::{App, Context, Entity, Task};
6use http_client::HttpClient;
7use language::{
8 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint,
9};
10use language_models::MistralLanguageModelProvider;
11use mistral::CODESTRAL_API_URL;
12use serde::{Deserialize, Serialize};
13use std::{
14 ops::Range,
15 sync::Arc,
16 time::{Duration, Instant},
17};
18use text::{OffsetRangeExt as _, ToOffset};
19
20pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
21
22/// Represents a completion that has been received and processed from Codestral.
23/// This struct maintains the state needed to interpolate the completion as the user types.
24#[derive(Clone)]
25struct CurrentCompletion {
26 /// The buffer snapshot at the time the completion was generated.
27 /// Used to detect changes and interpolate edits.
28 snapshot: BufferSnapshot,
29 /// The edits that should be applied to transform the original text into the predicted text.
30 /// Each edit is a range in the buffer and the text to replace it with.
31 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
32 /// Preview of how the buffer will look after applying the edits.
33 edit_preview: EditPreview,
34}
35
36impl CurrentCompletion {
37 /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
38 /// Returns None if the user's edits conflict with the predicted edits.
39 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
40 edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
41 }
42}
43
44pub struct CodestralEditPredictionDelegate {
45 http_client: Arc<dyn HttpClient>,
46 pending_request: Option<Task<Result<()>>>,
47 current_completion: Option<CurrentCompletion>,
48}
49
50impl CodestralEditPredictionDelegate {
51 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
52 Self {
53 http_client,
54 pending_request: None,
55 current_completion: None,
56 }
57 }
58
59 pub fn has_api_key(cx: &App) -> bool {
60 Self::api_key(cx).is_some()
61 }
62
63 /// This is so we can immediately show Codestral as a provider users can
64 /// switch to in the edit prediction menu, if the API has been added
65 pub fn ensure_api_key_loaded(http_client: Arc<dyn HttpClient>, cx: &mut App) {
66 MistralLanguageModelProvider::global(http_client, cx)
67 .load_codestral_api_key(cx)
68 .detach();
69 }
70
71 fn api_key(cx: &App) -> Option<Arc<str>> {
72 MistralLanguageModelProvider::try_global(cx)
73 .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
74 }
75
76 /// Uses Codestral's Fill-in-the-Middle API for code completion.
77 async fn fetch_completion(
78 http_client: Arc<dyn HttpClient>,
79 api_key: &str,
80 prompt: String,
81 suffix: String,
82 model: String,
83 max_tokens: Option<u32>,
84 api_url: String,
85 ) -> Result<String> {
86 let start_time = Instant::now();
87
88 log::debug!(
89 "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
90 model,
91 max_tokens
92 );
93
94 let request = CodestralRequest {
95 model,
96 prompt,
97 suffix: if suffix.is_empty() {
98 None
99 } else {
100 Some(suffix)
101 },
102 max_tokens: max_tokens.or(Some(350)),
103 temperature: Some(0.2),
104 top_p: Some(1.0),
105 stream: Some(false),
106 stop: None,
107 random_seed: None,
108 min_tokens: None,
109 };
110
111 let request_body = serde_json::to_string(&request)?;
112
113 log::debug!("Codestral: Sending FIM request");
114
115 let http_request = http_client::Request::builder()
116 .method(http_client::Method::POST)
117 .uri(format!("{}/v1/fim/completions", api_url))
118 .header("Content-Type", "application/json")
119 .header("Authorization", format!("Bearer {}", api_key))
120 .body(http_client::AsyncBody::from(request_body))?;
121
122 let mut response = http_client.send(http_request).await?;
123 let status = response.status();
124
125 log::debug!("Codestral: Response status: {}", status);
126
127 if !status.is_success() {
128 let mut body = String::new();
129 response.body_mut().read_to_string(&mut body).await?;
130 return Err(anyhow::anyhow!(
131 "Codestral API error: {} - {}",
132 status,
133 body
134 ));
135 }
136
137 let mut body = String::new();
138 response.body_mut().read_to_string(&mut body).await?;
139
140 let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
141
142 let elapsed = start_time.elapsed();
143
144 if let Some(choice) = codestral_response.choices.first() {
145 let completion = &choice.message.content;
146
147 log::debug!(
148 "Codestral: Completion received ({} tokens, {:.2}s)",
149 codestral_response.usage.completion_tokens,
150 elapsed.as_secs_f64()
151 );
152
153 // Return just the completion text for insertion at cursor
154 Ok(completion.clone())
155 } else {
156 log::error!("Codestral: No completion returned in response");
157 Err(anyhow::anyhow!("No completion returned from Codestral"))
158 }
159 }
160}
161
162impl EditPredictionDelegate for CodestralEditPredictionDelegate {
163 fn name() -> &'static str {
164 "codestral"
165 }
166
167 fn display_name() -> &'static str {
168 "Codestral"
169 }
170
171 fn show_predictions_in_menu() -> bool {
172 true
173 }
174
175 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
176 Self::api_key(cx).is_some()
177 }
178
179 fn is_refreshing(&self, _cx: &App) -> bool {
180 self.pending_request.is_some()
181 }
182
183 fn refresh(
184 &mut self,
185 buffer: Entity<Buffer>,
186 cursor_position: language::Anchor,
187 debounce: bool,
188 cx: &mut Context<Self>,
189 ) {
190 log::debug!("Codestral: Refresh called (debounce: {})", debounce);
191
192 let Some(api_key) = Self::api_key(cx) else {
193 log::warn!("Codestral: No API key configured, skipping refresh");
194 return;
195 };
196
197 let snapshot = buffer.read(cx).snapshot();
198
199 // Check if current completion is still valid
200 if let Some(current_completion) = self.current_completion.as_ref() {
201 if current_completion.interpolate(&snapshot).is_some() {
202 return;
203 }
204 }
205
206 let http_client = self.http_client.clone();
207
208 // Get settings
209 let settings = all_language_settings(None, cx);
210 let model = settings
211 .edit_predictions
212 .codestral
213 .model
214 .clone()
215 .unwrap_or_else(|| "codestral-latest".to_string());
216 let max_tokens = settings.edit_predictions.codestral.max_tokens;
217 let api_url = settings
218 .edit_predictions
219 .codestral
220 .api_url
221 .clone()
222 .unwrap_or_else(|| CODESTRAL_API_URL.to_string());
223
224 self.pending_request = Some(cx.spawn(async move |this, cx| {
225 if debounce {
226 log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
227 cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
228 }
229
230 let cursor_offset = cursor_position.to_offset(&snapshot);
231 let cursor_point = cursor_offset.to_point(&snapshot);
232
233 const MAX_CONTEXT_TOKENS: usize = 150;
234 const MAX_REWRITE_TOKENS: usize = 350;
235
236 let (_, context_range) =
237 cursor_excerpt::editable_and_context_ranges_for_cursor_position(
238 cursor_point,
239 &snapshot,
240 MAX_REWRITE_TOKENS,
241 MAX_CONTEXT_TOKENS,
242 );
243
244 let context_range = context_range.to_offset(&snapshot);
245 let excerpt_text = snapshot
246 .text_for_range(context_range.clone())
247 .collect::<String>();
248 let cursor_within_excerpt = cursor_offset
249 .saturating_sub(context_range.start)
250 .min(excerpt_text.len());
251 let prompt = excerpt_text[..cursor_within_excerpt].to_string();
252 let suffix = excerpt_text[cursor_within_excerpt..].to_string();
253
254 let completion_text = match Self::fetch_completion(
255 http_client,
256 &api_key,
257 prompt,
258 suffix,
259 model,
260 max_tokens,
261 api_url,
262 )
263 .await
264 {
265 Ok(completion) => completion,
266 Err(e) => {
267 log::error!("Codestral: Failed to fetch completion: {}", e);
268 this.update(cx, |this, cx| {
269 this.pending_request = None;
270 cx.notify();
271 })?;
272 return Err(e);
273 }
274 };
275
276 if completion_text.trim().is_empty() {
277 log::debug!("Codestral: Completion was empty after trimming; ignoring");
278 this.update(cx, |this, cx| {
279 this.pending_request = None;
280 cx.notify();
281 })?;
282 return Ok(());
283 }
284
285 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
286 vec![(cursor_position..cursor_position, completion_text.into())].into();
287 let edit_preview = buffer
288 .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
289 .await;
290
291 this.update(cx, |this, cx| {
292 this.current_completion = Some(CurrentCompletion {
293 snapshot,
294 edits,
295 edit_preview,
296 });
297 this.pending_request = None;
298 cx.notify();
299 })?;
300
301 Ok(())
302 }));
303 }
304
305 fn accept(&mut self, _cx: &mut Context<Self>) {
306 log::debug!("Codestral: Completion accepted");
307 self.pending_request = None;
308 self.current_completion = None;
309 }
310
311 fn discard(&mut self, _cx: &mut Context<Self>) {
312 log::debug!("Codestral: Completion discarded");
313 self.pending_request = None;
314 self.current_completion = None;
315 }
316
317 /// Returns the completion suggestion, adjusted or invalidated based on user edits
318 fn suggest(
319 &mut self,
320 buffer: &Entity<Buffer>,
321 _cursor_position: Anchor,
322 cx: &mut Context<Self>,
323 ) -> Option<EditPrediction> {
324 let current_completion = self.current_completion.as_ref()?;
325 let buffer = buffer.read(cx);
326 let edits = current_completion.interpolate(&buffer.snapshot())?;
327 if edits.is_empty() {
328 return None;
329 }
330 Some(EditPrediction::Local {
331 id: None,
332 edits,
333 edit_preview: Some(current_completion.edit_preview.clone()),
334 })
335 }
336}
337
338#[derive(Debug, Serialize, Deserialize)]
339pub struct CodestralRequest {
340 pub model: String,
341 pub prompt: String,
342 #[serde(skip_serializing_if = "Option::is_none")]
343 pub suffix: Option<String>,
344 #[serde(skip_serializing_if = "Option::is_none")]
345 pub max_tokens: Option<u32>,
346 #[serde(skip_serializing_if = "Option::is_none")]
347 pub temperature: Option<f32>,
348 #[serde(skip_serializing_if = "Option::is_none")]
349 pub top_p: Option<f32>,
350 #[serde(skip_serializing_if = "Option::is_none")]
351 pub stream: Option<bool>,
352 #[serde(skip_serializing_if = "Option::is_none")]
353 pub stop: Option<Vec<String>>,
354 #[serde(skip_serializing_if = "Option::is_none")]
355 pub random_seed: Option<u32>,
356 #[serde(skip_serializing_if = "Option::is_none")]
357 pub min_tokens: Option<u32>,
358}
359
360#[derive(Debug, Deserialize)]
361pub struct CodestralResponse {
362 pub id: String,
363 pub object: String,
364 pub model: String,
365 pub usage: Usage,
366 pub created: u64,
367 pub choices: Vec<Choice>,
368}
369
370#[derive(Debug, Deserialize)]
371pub struct Usage {
372 pub prompt_tokens: u32,
373 pub completion_tokens: u32,
374 pub total_tokens: u32,
375}
376
377#[derive(Debug, Deserialize)]
378pub struct Choice {
379 pub index: u32,
380 pub message: Message,
381 pub finish_reason: String,
382}
383
384#[derive(Debug, Deserialize)]
385pub struct Message {
386 pub content: String,
387 pub role: String,
388}