1use anyhow::{Context as _, Result};
2use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
3use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
4use futures::AsyncReadExt;
5use gpui::{App, Context, Entity, Task};
6use http_client::HttpClient;
7use language::{
8 language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint,
9};
10use language_models::MistralLanguageModelProvider;
11use mistral::CODESTRAL_API_URL;
12use serde::{Deserialize, Serialize};
13use std::{
14 ops::Range,
15 sync::Arc,
16 time::{Duration, Instant},
17};
18use text::ToOffset;
19
20pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
21
22const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
23 max_bytes: 1050,
24 min_bytes: 525,
25 target_before_cursor_over_total_bytes: 0.66,
26};
27
28/// Represents a completion that has been received and processed from Codestral.
29/// This struct maintains the state needed to interpolate the completion as the user types.
30#[derive(Clone)]
31struct CurrentCompletion {
32 /// The buffer snapshot at the time the completion was generated.
33 /// Used to detect changes and interpolate edits.
34 snapshot: BufferSnapshot,
35 /// The edits that should be applied to transform the original text into the predicted text.
36 /// Each edit is a range in the buffer and the text to replace it with.
37 edits: Arc<[(Range<Anchor>, String)]>,
38 /// Preview of how the buffer will look after applying the edits.
39 edit_preview: EditPreview,
40}
41
42impl CurrentCompletion {
43 /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
44 /// Returns None if the user's edits conflict with the predicted edits.
45 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
46 edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
47 }
48}
49
50pub struct CodestralCompletionProvider {
51 http_client: Arc<dyn HttpClient>,
52 pending_request: Option<Task<Result<()>>>,
53 current_completion: Option<CurrentCompletion>,
54}
55
56impl CodestralCompletionProvider {
57 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
58 Self {
59 http_client,
60 pending_request: None,
61 current_completion: None,
62 }
63 }
64
65 pub fn has_api_key(cx: &App) -> bool {
66 Self::api_key(cx).is_some()
67 }
68
69 fn api_key(cx: &App) -> Option<Arc<str>> {
70 MistralLanguageModelProvider::try_global(cx)
71 .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
72 }
73
74 /// Uses Codestral's Fill-in-the-Middle API for code completion.
75 async fn fetch_completion(
76 http_client: Arc<dyn HttpClient>,
77 api_key: &str,
78 prompt: String,
79 suffix: String,
80 model: String,
81 max_tokens: Option<u32>,
82 ) -> Result<String> {
83 let start_time = Instant::now();
84
85 log::debug!(
86 "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
87 model,
88 max_tokens
89 );
90
91 let request = CodestralRequest {
92 model,
93 prompt,
94 suffix: if suffix.is_empty() {
95 None
96 } else {
97 Some(suffix)
98 },
99 max_tokens: max_tokens.or(Some(350)),
100 temperature: Some(0.2),
101 top_p: Some(1.0),
102 stream: Some(false),
103 stop: None,
104 random_seed: None,
105 min_tokens: None,
106 };
107
108 let request_body = serde_json::to_string(&request)?;
109
110 log::debug!("Codestral: Sending FIM request");
111
112 let http_request = http_client::Request::builder()
113 .method(http_client::Method::POST)
114 .uri(format!("{}/v1/fim/completions", CODESTRAL_API_URL))
115 .header("Content-Type", "application/json")
116 .header("Authorization", format!("Bearer {}", api_key))
117 .body(http_client::AsyncBody::from(request_body))?;
118
119 let mut response = http_client.send(http_request).await?;
120 let status = response.status();
121
122 log::debug!("Codestral: Response status: {}", status);
123
124 if !status.is_success() {
125 let mut body = String::new();
126 response.body_mut().read_to_string(&mut body).await?;
127 return Err(anyhow::anyhow!(
128 "Codestral API error: {} - {}",
129 status,
130 body
131 ));
132 }
133
134 let mut body = String::new();
135 response.body_mut().read_to_string(&mut body).await?;
136
137 let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
138
139 let elapsed = start_time.elapsed();
140
141 if let Some(choice) = codestral_response.choices.first() {
142 let completion = &choice.message.content;
143
144 log::debug!(
145 "Codestral: Completion received ({} tokens, {:.2}s)",
146 codestral_response.usage.completion_tokens,
147 elapsed.as_secs_f64()
148 );
149
150 // Return just the completion text for insertion at cursor
151 Ok(completion.clone())
152 } else {
153 log::error!("Codestral: No completion returned in response");
154 Err(anyhow::anyhow!("No completion returned from Codestral"))
155 }
156 }
157}
158
159impl EditPredictionProvider for CodestralCompletionProvider {
160 fn name() -> &'static str {
161 "codestral"
162 }
163
164 fn display_name() -> &'static str {
165 "Codestral"
166 }
167
168 fn show_completions_in_menu() -> bool {
169 true
170 }
171
172 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
173 Self::api_key(cx).is_some()
174 }
175
176 fn is_refreshing(&self) -> bool {
177 self.pending_request.is_some()
178 }
179
180 fn refresh(
181 &mut self,
182 buffer: Entity<Buffer>,
183 cursor_position: language::Anchor,
184 debounce: bool,
185 cx: &mut Context<Self>,
186 ) {
187 log::debug!("Codestral: Refresh called (debounce: {})", debounce);
188
189 let Some(api_key) = Self::api_key(cx) else {
190 log::warn!("Codestral: No API key configured, skipping refresh");
191 return;
192 };
193
194 let snapshot = buffer.read(cx).snapshot();
195
196 // Check if current completion is still valid
197 if let Some(current_completion) = self.current_completion.as_ref() {
198 if current_completion.interpolate(&snapshot).is_some() {
199 return;
200 }
201 }
202
203 let http_client = self.http_client.clone();
204
205 // Get settings
206 let settings = all_language_settings(None, cx);
207 let model = settings
208 .edit_predictions
209 .codestral
210 .model
211 .clone()
212 .unwrap_or_else(|| "codestral-latest".to_string());
213 let max_tokens = settings.edit_predictions.codestral.max_tokens;
214
215 self.pending_request = Some(cx.spawn(async move |this, cx| {
216 if debounce {
217 log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
218 smol::Timer::after(DEBOUNCE_TIMEOUT).await;
219 }
220
221 let cursor_offset = cursor_position.to_offset(&snapshot);
222 let cursor_point = cursor_offset.to_point(&snapshot);
223 let excerpt = EditPredictionExcerpt::select_from_buffer(
224 cursor_point,
225 &snapshot,
226 &EXCERPT_OPTIONS,
227 None,
228 )
229 .context("Line containing cursor doesn't fit in excerpt max bytes")?;
230
231 let excerpt_text = excerpt.text(&snapshot);
232 let cursor_within_excerpt = cursor_offset
233 .saturating_sub(excerpt.range.start)
234 .min(excerpt_text.body.len());
235 let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
236 let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
237
238 let completion_text = match Self::fetch_completion(
239 http_client,
240 &api_key,
241 prompt,
242 suffix,
243 model,
244 max_tokens,
245 )
246 .await
247 {
248 Ok(completion) => completion,
249 Err(e) => {
250 log::error!("Codestral: Failed to fetch completion: {}", e);
251 this.update(cx, |this, cx| {
252 this.pending_request = None;
253 cx.notify();
254 })?;
255 return Err(e);
256 }
257 };
258
259 if completion_text.trim().is_empty() {
260 log::debug!("Codestral: Completion was empty after trimming; ignoring");
261 this.update(cx, |this, cx| {
262 this.pending_request = None;
263 cx.notify();
264 })?;
265 return Ok(());
266 }
267
268 let edits: Arc<[(Range<Anchor>, String)]> =
269 vec![(cursor_position..cursor_position, completion_text)].into();
270 let edit_preview = buffer
271 .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
272 .await;
273
274 this.update(cx, |this, cx| {
275 this.current_completion = Some(CurrentCompletion {
276 snapshot,
277 edits,
278 edit_preview,
279 });
280 this.pending_request = None;
281 cx.notify();
282 })?;
283
284 Ok(())
285 }));
286 }
287
288 fn cycle(
289 &mut self,
290 _buffer: Entity<Buffer>,
291 _cursor_position: Anchor,
292 _direction: Direction,
293 _cx: &mut Context<Self>,
294 ) {
295 // Codestral doesn't support multiple completions, so cycling does nothing
296 }
297
298 fn accept(&mut self, _cx: &mut Context<Self>) {
299 log::debug!("Codestral: Completion accepted");
300 self.pending_request = None;
301 self.current_completion = None;
302 }
303
304 fn discard(&mut self, _cx: &mut Context<Self>) {
305 log::debug!("Codestral: Completion discarded");
306 self.pending_request = None;
307 self.current_completion = None;
308 }
309
310 /// Returns the completion suggestion, adjusted or invalidated based on user edits
311 fn suggest(
312 &mut self,
313 buffer: &Entity<Buffer>,
314 _cursor_position: Anchor,
315 cx: &mut Context<Self>,
316 ) -> Option<EditPrediction> {
317 let current_completion = self.current_completion.as_ref()?;
318 let buffer = buffer.read(cx);
319 let edits = current_completion.interpolate(&buffer.snapshot())?;
320 if edits.is_empty() {
321 return None;
322 }
323 Some(EditPrediction::Local {
324 id: None,
325 edits,
326 edit_preview: Some(current_completion.edit_preview.clone()),
327 })
328 }
329}
330
331#[derive(Debug, Serialize, Deserialize)]
332pub struct CodestralRequest {
333 pub model: String,
334 pub prompt: String,
335 #[serde(skip_serializing_if = "Option::is_none")]
336 pub suffix: Option<String>,
337 #[serde(skip_serializing_if = "Option::is_none")]
338 pub max_tokens: Option<u32>,
339 #[serde(skip_serializing_if = "Option::is_none")]
340 pub temperature: Option<f32>,
341 #[serde(skip_serializing_if = "Option::is_none")]
342 pub top_p: Option<f32>,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 pub stream: Option<bool>,
345 #[serde(skip_serializing_if = "Option::is_none")]
346 pub stop: Option<Vec<String>>,
347 #[serde(skip_serializing_if = "Option::is_none")]
348 pub random_seed: Option<u32>,
349 #[serde(skip_serializing_if = "Option::is_none")]
350 pub min_tokens: Option<u32>,
351}
352
353#[derive(Debug, Deserialize)]
354pub struct CodestralResponse {
355 pub id: String,
356 pub object: String,
357 pub model: String,
358 pub usage: Usage,
359 pub created: u64,
360 pub choices: Vec<Choice>,
361}
362
363#[derive(Debug, Deserialize)]
364pub struct Usage {
365 pub prompt_tokens: u32,
366 pub completion_tokens: u32,
367 pub total_tokens: u32,
368}
369
370#[derive(Debug, Deserialize)]
371pub struct Choice {
372 pub index: u32,
373 pub message: Message,
374 pub finish_reason: String,
375}
376
377#[derive(Debug, Deserialize)]
378pub struct Message {
379 pub content: String,
380 pub role: String,
381}