1use anyhow::Result;
2use edit_prediction::cursor_excerpt;
3use edit_prediction_types::{
4 EditPrediction, EditPredictionDelegate, EditPredictionDiscardReason, EditPredictionIconSet,
5};
6use futures::AsyncReadExt;
7use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
8use http_client::HttpClient;
9use icons::IconName;
10use language::{
11 Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint, language_settings::all_language_settings,
12};
13use language_model::{ApiKeyState, AuthenticateError, EnvVar, env_var};
14use serde::{Deserialize, Serialize};
15
16use std::{
17 ops::Range,
18 sync::Arc,
19 time::{Duration, Instant},
20};
21use text::{OffsetRangeExt as _, ToOffset};
22
23pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai";
24pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
25
26static CODESTRAL_API_KEY_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("CODESTRAL_API_KEY");
27
28struct GlobalCodestralApiKey(Entity<ApiKeyState>);
29
30impl Global for GlobalCodestralApiKey {}
31
32pub fn codestral_api_key_state(cx: &mut App) -> Entity<ApiKeyState> {
33 if let Some(global) = cx.try_global::<GlobalCodestralApiKey>() {
34 return global.0.clone();
35 }
36 let entity =
37 cx.new(|cx| ApiKeyState::new(codestral_api_url(cx), CODESTRAL_API_KEY_ENV_VAR.clone()));
38 cx.set_global(GlobalCodestralApiKey(entity.clone()));
39 entity
40}
41
42pub fn codestral_api_key(cx: &App) -> Option<Arc<str>> {
43 let url = codestral_api_url(cx);
44 cx.try_global::<GlobalCodestralApiKey>()?
45 .0
46 .read(cx)
47 .key(&url)
48}
49
50pub fn load_codestral_api_key(cx: &mut App) -> Task<Result<(), AuthenticateError>> {
51 let api_url = codestral_api_url(cx);
52 codestral_api_key_state(cx).update(cx, |key_state, cx| {
53 key_state.load_if_needed(api_url, |s| s, cx)
54 })
55}
56
57pub fn codestral_api_url(cx: &App) -> SharedString {
58 all_language_settings(None, cx)
59 .edit_predictions
60 .codestral
61 .api_url
62 .clone()
63 .unwrap_or_else(|| CODESTRAL_API_URL.to_string())
64 .into()
65}
66
67/// Represents a completion that has been received and processed from Codestral.
68/// This struct maintains the state needed to interpolate the completion as the user types.
69#[derive(Clone)]
70struct CurrentCompletion {
71 /// The buffer snapshot at the time the completion was generated.
72 /// Used to detect changes and interpolate edits.
73 snapshot: BufferSnapshot,
74 /// The edits that should be applied to transform the original text into the predicted text.
75 /// Each edit is a range in the buffer and the text to replace it with.
76 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
77 /// Preview of how the buffer will look after applying the edits.
78 edit_preview: EditPreview,
79}
80
81impl CurrentCompletion {
82 /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
83 /// Returns None if the user's edits conflict with the predicted edits.
84 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
85 edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
86 }
87}
88
89pub struct CodestralEditPredictionDelegate {
90 http_client: Arc<dyn HttpClient>,
91 pending_request: Option<Task<Result<()>>>,
92 current_completion: Option<CurrentCompletion>,
93}
94
95impl CodestralEditPredictionDelegate {
96 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
97 Self {
98 http_client,
99 pending_request: None,
100 current_completion: None,
101 }
102 }
103
104 pub fn ensure_api_key_loaded(cx: &mut App) {
105 load_codestral_api_key(cx).detach();
106 }
107
108 /// Uses Codestral's Fill-in-the-Middle API for code completion.
109 async fn fetch_completion(
110 http_client: Arc<dyn HttpClient>,
111 api_key: &str,
112 prompt: String,
113 suffix: String,
114 model: String,
115 max_tokens: Option<u32>,
116 api_url: String,
117 ) -> Result<String> {
118 let start_time = Instant::now();
119
120 log::debug!(
121 "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
122 model,
123 max_tokens
124 );
125
126 let request = CodestralRequest {
127 model,
128 prompt,
129 suffix: if suffix.is_empty() {
130 None
131 } else {
132 Some(suffix)
133 },
134 max_tokens: max_tokens.or(Some(350)),
135 temperature: Some(0.2),
136 top_p: Some(1.0),
137 stream: Some(false),
138 stop: None,
139 random_seed: None,
140 min_tokens: None,
141 };
142
143 let request_body = serde_json::to_string(&request)?;
144
145 log::debug!("Codestral: Sending FIM request");
146
147 let http_request = http_client::Request::builder()
148 .method(http_client::Method::POST)
149 .uri(format!("{}/v1/fim/completions", api_url))
150 .header("Content-Type", "application/json")
151 .header("Authorization", format!("Bearer {}", api_key))
152 .body(http_client::AsyncBody::from(request_body))?;
153
154 let mut response = http_client.send(http_request).await?;
155 let status = response.status();
156
157 log::debug!("Codestral: Response status: {}", status);
158
159 if !status.is_success() {
160 let mut body = String::new();
161 response.body_mut().read_to_string(&mut body).await?;
162 return Err(anyhow::anyhow!(
163 "Codestral API error: {} - {}",
164 status,
165 body
166 ));
167 }
168
169 let mut body = String::new();
170 response.body_mut().read_to_string(&mut body).await?;
171
172 let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
173
174 let elapsed = start_time.elapsed();
175
176 if let Some(choice) = codestral_response.choices.first() {
177 let completion = &choice.message.content;
178
179 log::debug!(
180 "Codestral: Completion received ({} tokens, {:.2}s)",
181 codestral_response.usage.completion_tokens,
182 elapsed.as_secs_f64()
183 );
184
185 // Return just the completion text for insertion at cursor
186 Ok(completion.clone())
187 } else {
188 log::error!("Codestral: No completion returned in response");
189 Err(anyhow::anyhow!("No completion returned from Codestral"))
190 }
191 }
192}
193
194impl EditPredictionDelegate for CodestralEditPredictionDelegate {
195 fn name() -> &'static str {
196 "codestral"
197 }
198
199 fn display_name() -> &'static str {
200 "Codestral"
201 }
202
203 fn show_predictions_in_menu() -> bool {
204 true
205 }
206
207 fn icons(&self, _cx: &App) -> EditPredictionIconSet {
208 EditPredictionIconSet::new(IconName::AiMistral)
209 }
210
211 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
212 codestral_api_key(cx).is_some()
213 }
214
215 fn is_refreshing(&self, _cx: &App) -> bool {
216 self.pending_request.is_some()
217 }
218
219 fn refresh(
220 &mut self,
221 buffer: Entity<Buffer>,
222 cursor_position: language::Anchor,
223 debounce: bool,
224 cx: &mut Context<Self>,
225 ) {
226 log::debug!("Codestral: Refresh called (debounce: {})", debounce);
227
228 let Some(api_key) = codestral_api_key(cx) else {
229 log::warn!("Codestral: No API key configured, skipping refresh");
230 return;
231 };
232
233 let snapshot = buffer.read(cx).snapshot();
234
235 // Check if current completion is still valid
236 if let Some(current_completion) = self.current_completion.as_ref() {
237 if current_completion.interpolate(&snapshot).is_some() {
238 return;
239 }
240 }
241
242 let http_client = self.http_client.clone();
243
244 // Get settings
245 let settings = all_language_settings(None, cx);
246 let model = settings
247 .edit_predictions
248 .codestral
249 .model
250 .clone()
251 .unwrap_or_else(|| "codestral-latest".to_string());
252 let max_tokens = settings.edit_predictions.codestral.max_tokens;
253 let api_url = codestral_api_url(cx).to_string();
254
255 self.pending_request = Some(cx.spawn(async move |this, cx| {
256 if debounce {
257 log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
258 cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
259 }
260
261 let cursor_offset = cursor_position.to_offset(&snapshot);
262 let cursor_point = cursor_offset.to_point(&snapshot);
263
264 const MAX_CONTEXT_TOKENS: usize = 150;
265 const MAX_REWRITE_TOKENS: usize = 350;
266
267 let (_, context_range) =
268 cursor_excerpt::editable_and_context_ranges_for_cursor_position(
269 cursor_point,
270 &snapshot,
271 MAX_REWRITE_TOKENS,
272 MAX_CONTEXT_TOKENS,
273 );
274
275 let context_range = context_range.to_offset(&snapshot);
276 let excerpt_text = snapshot
277 .text_for_range(context_range.clone())
278 .collect::<String>();
279 let cursor_within_excerpt = cursor_offset
280 .saturating_sub(context_range.start)
281 .min(excerpt_text.len());
282 let prompt = excerpt_text[..cursor_within_excerpt].to_string();
283 let suffix = excerpt_text[cursor_within_excerpt..].to_string();
284
285 let completion_text = match Self::fetch_completion(
286 http_client,
287 &api_key,
288 prompt,
289 suffix,
290 model,
291 max_tokens,
292 api_url,
293 )
294 .await
295 {
296 Ok(completion) => completion,
297 Err(e) => {
298 log::error!("Codestral: Failed to fetch completion: {}", e);
299 this.update(cx, |this, cx| {
300 this.pending_request = None;
301 cx.notify();
302 })?;
303 return Err(e);
304 }
305 };
306
307 if completion_text.trim().is_empty() {
308 log::debug!("Codestral: Completion was empty after trimming; ignoring");
309 this.update(cx, |this, cx| {
310 this.pending_request = None;
311 cx.notify();
312 })?;
313 return Ok(());
314 }
315
316 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
317 vec![(cursor_position..cursor_position, completion_text.into())].into();
318 let edit_preview = buffer
319 .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
320 .await;
321
322 this.update(cx, |this, cx| {
323 this.current_completion = Some(CurrentCompletion {
324 snapshot,
325 edits,
326 edit_preview,
327 });
328 this.pending_request = None;
329 cx.notify();
330 })?;
331
332 Ok(())
333 }));
334 }
335
336 fn accept(&mut self, _cx: &mut Context<Self>) {
337 log::debug!("Codestral: Completion accepted");
338 self.pending_request = None;
339 self.current_completion = None;
340 }
341
342 fn discard(&mut self, _reason: EditPredictionDiscardReason, _cx: &mut Context<Self>) {
343 log::debug!("Codestral: Completion discarded");
344 self.pending_request = None;
345 self.current_completion = None;
346 }
347
348 /// Returns the completion suggestion, adjusted or invalidated based on user edits
349 fn suggest(
350 &mut self,
351 buffer: &Entity<Buffer>,
352 _cursor_position: Anchor,
353 cx: &mut Context<Self>,
354 ) -> Option<EditPrediction> {
355 let current_completion = self.current_completion.as_ref()?;
356 let buffer = buffer.read(cx);
357 let edits = current_completion.interpolate(&buffer.snapshot())?;
358 if edits.is_empty() {
359 return None;
360 }
361 Some(EditPrediction::Local {
362 id: None,
363 edits,
364 cursor_position: None,
365 edit_preview: Some(current_completion.edit_preview.clone()),
366 })
367 }
368}
369
370#[derive(Debug, Serialize, Deserialize)]
371pub struct CodestralRequest {
372 pub model: String,
373 pub prompt: String,
374 #[serde(skip_serializing_if = "Option::is_none")]
375 pub suffix: Option<String>,
376 #[serde(skip_serializing_if = "Option::is_none")]
377 pub max_tokens: Option<u32>,
378 #[serde(skip_serializing_if = "Option::is_none")]
379 pub temperature: Option<f32>,
380 #[serde(skip_serializing_if = "Option::is_none")]
381 pub top_p: Option<f32>,
382 #[serde(skip_serializing_if = "Option::is_none")]
383 pub stream: Option<bool>,
384 #[serde(skip_serializing_if = "Option::is_none")]
385 pub stop: Option<Vec<String>>,
386 #[serde(skip_serializing_if = "Option::is_none")]
387 pub random_seed: Option<u32>,
388 #[serde(skip_serializing_if = "Option::is_none")]
389 pub min_tokens: Option<u32>,
390}
391
392#[derive(Debug, Deserialize)]
393pub struct CodestralResponse {
394 pub id: String,
395 pub object: String,
396 pub model: String,
397 pub usage: Usage,
398 pub created: u64,
399 pub choices: Vec<Choice>,
400}
401
402#[derive(Debug, Deserialize)]
403pub struct Usage {
404 pub prompt_tokens: u32,
405 pub completion_tokens: u32,
406 pub total_tokens: u32,
407}
408
409#[derive(Debug, Deserialize)]
410pub struct Choice {
411 pub index: u32,
412 pub message: Message,
413 pub finish_reason: String,
414}
415
416#[derive(Debug, Deserialize)]
417pub struct Message {
418 pub content: String,
419 pub role: String,
420}