1use anyhow::{Context as _, Result};
2use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
3use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
4use futures::AsyncReadExt;
5use gpui::{App, Context, Entity, Task};
6use http_client::HttpClient;
7use language::{
8 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint,
9};
10use language_models::MistralLanguageModelProvider;
11use mistral::CODESTRAL_API_URL;
12use serde::{Deserialize, Serialize};
13use std::{
14 ops::Range,
15 sync::Arc,
16 time::{Duration, Instant},
17};
18use text::ToOffset;
19
20pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
21
22const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
23 max_bytes: 1050,
24 min_bytes: 525,
25 target_before_cursor_over_total_bytes: 0.66,
26};
27
28/// Represents a completion that has been received and processed from Codestral.
29/// This struct maintains the state needed to interpolate the completion as the user types.
30#[derive(Clone)]
31struct CurrentCompletion {
32 /// The buffer snapshot at the time the completion was generated.
33 /// Used to detect changes and interpolate edits.
34 snapshot: BufferSnapshot,
35 /// The edits that should be applied to transform the original text into the predicted text.
36 /// Each edit is a range in the buffer and the text to replace it with.
37 edits: Arc<[(Range<Anchor>, String)]>,
38 /// Preview of how the buffer will look after applying the edits.
39 edit_preview: EditPreview,
40}
41
42impl CurrentCompletion {
43 /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
44 /// Returns None if the user's edits conflict with the predicted edits.
45 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
46 edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
47 }
48}
49
50pub struct CodestralCompletionProvider {
51 http_client: Arc<dyn HttpClient>,
52 pending_request: Option<Task<Result<()>>>,
53 current_completion: Option<CurrentCompletion>,
54}
55
56impl CodestralCompletionProvider {
57 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
58 Self {
59 http_client,
60 pending_request: None,
61 current_completion: None,
62 }
63 }
64
65 pub fn has_api_key(cx: &App) -> bool {
66 Self::api_key(cx).is_some()
67 }
68
69 fn api_key(cx: &App) -> Option<Arc<str>> {
70 MistralLanguageModelProvider::try_global(cx)
71 .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
72 }
73
74 /// Uses Codestral's Fill-in-the-Middle API for code completion.
75 async fn fetch_completion(
76 http_client: Arc<dyn HttpClient>,
77 api_key: &str,
78 prompt: String,
79 suffix: String,
80 model: String,
81 max_tokens: Option<u32>,
82 api_url: String,
83 ) -> Result<String> {
84 let start_time = Instant::now();
85
86 log::debug!(
87 "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
88 model,
89 max_tokens
90 );
91
92 let request = CodestralRequest {
93 model,
94 prompt,
95 suffix: if suffix.is_empty() {
96 None
97 } else {
98 Some(suffix)
99 },
100 max_tokens: max_tokens.or(Some(350)),
101 temperature: Some(0.2),
102 top_p: Some(1.0),
103 stream: Some(false),
104 stop: None,
105 random_seed: None,
106 min_tokens: None,
107 };
108
109 let request_body = serde_json::to_string(&request)?;
110
111 log::debug!("Codestral: Sending FIM request");
112
113 let http_request = http_client::Request::builder()
114 .method(http_client::Method::POST)
115 .uri(format!("{}/v1/fim/completions", api_url))
116 .header("Content-Type", "application/json")
117 .header("Authorization", format!("Bearer {}", api_key))
118 .body(http_client::AsyncBody::from(request_body))?;
119
120 let mut response = http_client.send(http_request).await?;
121 let status = response.status();
122
123 log::debug!("Codestral: Response status: {}", status);
124
125 if !status.is_success() {
126 let mut body = String::new();
127 response.body_mut().read_to_string(&mut body).await?;
128 return Err(anyhow::anyhow!(
129 "Codestral API error: {} - {}",
130 status,
131 body
132 ));
133 }
134
135 let mut body = String::new();
136 response.body_mut().read_to_string(&mut body).await?;
137
138 let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
139
140 let elapsed = start_time.elapsed();
141
142 if let Some(choice) = codestral_response.choices.first() {
143 let completion = &choice.message.content;
144
145 log::debug!(
146 "Codestral: Completion received ({} tokens, {:.2}s)",
147 codestral_response.usage.completion_tokens,
148 elapsed.as_secs_f64()
149 );
150
151 // Return just the completion text for insertion at cursor
152 Ok(completion.clone())
153 } else {
154 log::error!("Codestral: No completion returned in response");
155 Err(anyhow::anyhow!("No completion returned from Codestral"))
156 }
157 }
158}
159
160impl EditPredictionProvider for CodestralCompletionProvider {
161 fn name() -> &'static str {
162 "codestral"
163 }
164
165 fn display_name() -> &'static str {
166 "Codestral"
167 }
168
169 fn show_completions_in_menu() -> bool {
170 true
171 }
172
173 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
174 Self::api_key(cx).is_some()
175 }
176
177 fn is_refreshing(&self) -> bool {
178 self.pending_request.is_some()
179 }
180
181 fn refresh(
182 &mut self,
183 buffer: Entity<Buffer>,
184 cursor_position: language::Anchor,
185 debounce: bool,
186 cx: &mut Context<Self>,
187 ) {
188 log::debug!("Codestral: Refresh called (debounce: {})", debounce);
189
190 let Some(api_key) = Self::api_key(cx) else {
191 log::warn!("Codestral: No API key configured, skipping refresh");
192 return;
193 };
194
195 let snapshot = buffer.read(cx).snapshot();
196
197 // Check if current completion is still valid
198 if let Some(current_completion) = self.current_completion.as_ref() {
199 if current_completion.interpolate(&snapshot).is_some() {
200 return;
201 }
202 }
203
204 let http_client = self.http_client.clone();
205
206 // Get settings
207 let settings = all_language_settings(None, cx);
208 let model = settings
209 .edit_predictions
210 .codestral
211 .model
212 .clone()
213 .unwrap_or_else(|| "codestral-latest".to_string());
214 let max_tokens = settings.edit_predictions.codestral.max_tokens;
215 let api_url = settings
216 .edit_predictions
217 .codestral
218 .api_url
219 .clone()
220 .unwrap_or_else(|| CODESTRAL_API_URL.to_string());
221
222 self.pending_request = Some(cx.spawn(async move |this, cx| {
223 if debounce {
224 log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
225 smol::Timer::after(DEBOUNCE_TIMEOUT).await;
226 }
227
228 let cursor_offset = cursor_position.to_offset(&snapshot);
229 let cursor_point = cursor_offset.to_point(&snapshot);
230 let excerpt = EditPredictionExcerpt::select_from_buffer(
231 cursor_point,
232 &snapshot,
233 &EXCERPT_OPTIONS,
234 None,
235 )
236 .context("Line containing cursor doesn't fit in excerpt max bytes")?;
237
238 let excerpt_text = excerpt.text(&snapshot);
239 let cursor_within_excerpt = cursor_offset
240 .saturating_sub(excerpt.range.start)
241 .min(excerpt_text.body.len());
242 let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
243 let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
244
245 let completion_text = match Self::fetch_completion(
246 http_client,
247 &api_key,
248 prompt,
249 suffix,
250 model,
251 max_tokens,
252 api_url,
253 )
254 .await
255 {
256 Ok(completion) => completion,
257 Err(e) => {
258 log::error!("Codestral: Failed to fetch completion: {}", e);
259 this.update(cx, |this, cx| {
260 this.pending_request = None;
261 cx.notify();
262 })?;
263 return Err(e);
264 }
265 };
266
267 if completion_text.trim().is_empty() {
268 log::debug!("Codestral: Completion was empty after trimming; ignoring");
269 this.update(cx, |this, cx| {
270 this.pending_request = None;
271 cx.notify();
272 })?;
273 return Ok(());
274 }
275
276 let edits: Arc<[(Range<Anchor>, String)]> =
277 vec![(cursor_position..cursor_position, completion_text)].into();
278 let edit_preview = buffer
279 .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
280 .await;
281
282 this.update(cx, |this, cx| {
283 this.current_completion = Some(CurrentCompletion {
284 snapshot,
285 edits,
286 edit_preview,
287 });
288 this.pending_request = None;
289 cx.notify();
290 })?;
291
292 Ok(())
293 }));
294 }
295
296 fn cycle(
297 &mut self,
298 _buffer: Entity<Buffer>,
299 _cursor_position: Anchor,
300 _direction: Direction,
301 _cx: &mut Context<Self>,
302 ) {
303 // Codestral doesn't support multiple completions, so cycling does nothing
304 }
305
306 fn accept(&mut self, _cx: &mut Context<Self>) {
307 log::debug!("Codestral: Completion accepted");
308 self.pending_request = None;
309 self.current_completion = None;
310 }
311
312 fn discard(&mut self, _cx: &mut Context<Self>) {
313 log::debug!("Codestral: Completion discarded");
314 self.pending_request = None;
315 self.current_completion = None;
316 }
317
318 /// Returns the completion suggestion, adjusted or invalidated based on user edits
319 fn suggest(
320 &mut self,
321 buffer: &Entity<Buffer>,
322 _cursor_position: Anchor,
323 cx: &mut Context<Self>,
324 ) -> Option<EditPrediction> {
325 let current_completion = self.current_completion.as_ref()?;
326 let buffer = buffer.read(cx);
327 let edits = current_completion.interpolate(&buffer.snapshot())?;
328 if edits.is_empty() {
329 return None;
330 }
331 Some(EditPrediction::Local {
332 id: None,
333 edits,
334 edit_preview: Some(current_completion.edit_preview.clone()),
335 })
336 }
337}
338
339#[derive(Debug, Serialize, Deserialize)]
340pub struct CodestralRequest {
341 pub model: String,
342 pub prompt: String,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 pub suffix: Option<String>,
345 #[serde(skip_serializing_if = "Option::is_none")]
346 pub max_tokens: Option<u32>,
347 #[serde(skip_serializing_if = "Option::is_none")]
348 pub temperature: Option<f32>,
349 #[serde(skip_serializing_if = "Option::is_none")]
350 pub top_p: Option<f32>,
351 #[serde(skip_serializing_if = "Option::is_none")]
352 pub stream: Option<bool>,
353 #[serde(skip_serializing_if = "Option::is_none")]
354 pub stop: Option<Vec<String>>,
355 #[serde(skip_serializing_if = "Option::is_none")]
356 pub random_seed: Option<u32>,
357 #[serde(skip_serializing_if = "Option::is_none")]
358 pub min_tokens: Option<u32>,
359}
360
361#[derive(Debug, Deserialize)]
362pub struct CodestralResponse {
363 pub id: String,
364 pub object: String,
365 pub model: String,
366 pub usage: Usage,
367 pub created: u64,
368 pub choices: Vec<Choice>,
369}
370
371#[derive(Debug, Deserialize)]
372pub struct Usage {
373 pub prompt_tokens: u32,
374 pub completion_tokens: u32,
375 pub total_tokens: u32,
376}
377
378#[derive(Debug, Deserialize)]
379pub struct Choice {
380 pub index: u32,
381 pub message: Message,
382 pub finish_reason: String,
383}
384
385#[derive(Debug, Deserialize)]
386pub struct Message {
387 pub content: String,
388 pub role: String,
389}