1use anyhow::Result;
2use edit_prediction::cursor_excerpt;
3use edit_prediction_types::{
4 EditPrediction, EditPredictionDelegate, EditPredictionDiscardReason, EditPredictionIconSet,
5};
6use futures::AsyncReadExt;
7use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
8use http_client::HttpClient;
9use icons::IconName;
10use language::{
11 Anchor, Buffer, BufferSnapshot, EditPreview, language_settings::all_language_settings,
12};
13use language_model::{ApiKeyState, AuthenticateError, EnvVar, env_var};
14use serde::{Deserialize, Serialize};
15
16use std::{
17 ops::Range,
18 sync::Arc,
19 time::{Duration, Instant},
20};
21use text::ToOffset;
22
23pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai";
24pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
25
26static CODESTRAL_API_KEY_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("CODESTRAL_API_KEY");
27
28struct GlobalCodestralApiKey(Entity<ApiKeyState>);
29
30impl Global for GlobalCodestralApiKey {}
31
32pub fn codestral_api_key_state(cx: &mut App) -> Entity<ApiKeyState> {
33 if let Some(global) = cx.try_global::<GlobalCodestralApiKey>() {
34 return global.0.clone();
35 }
36 let entity =
37 cx.new(|cx| ApiKeyState::new(codestral_api_url(cx), CODESTRAL_API_KEY_ENV_VAR.clone()));
38 cx.set_global(GlobalCodestralApiKey(entity.clone()));
39 entity
40}
41
42pub fn codestral_api_key(cx: &App) -> Option<Arc<str>> {
43 let url = codestral_api_url(cx);
44 cx.try_global::<GlobalCodestralApiKey>()?
45 .0
46 .read(cx)
47 .key(&url)
48}
49
50pub fn load_codestral_api_key(cx: &mut App) -> Task<Result<(), AuthenticateError>> {
51 let api_url = codestral_api_url(cx);
52 codestral_api_key_state(cx).update(cx, |key_state, cx| {
53 key_state.load_if_needed(api_url, |s| s, cx)
54 })
55}
56
57pub fn codestral_api_url(cx: &App) -> SharedString {
58 all_language_settings(None, cx)
59 .edit_predictions
60 .codestral
61 .api_url
62 .clone()
63 .unwrap_or_else(|| CODESTRAL_API_URL.to_string())
64 .into()
65}
66
67/// Represents a completion that has been received and processed from Codestral.
68/// This struct maintains the state needed to interpolate the completion as the user types.
69#[derive(Clone)]
70struct CurrentCompletion {
71 /// The buffer snapshot at the time the completion was generated.
72 /// Used to detect changes and interpolate edits.
73 snapshot: BufferSnapshot,
74 /// The edits that should be applied to transform the original text into the predicted text.
75 /// Each edit is a range in the buffer and the text to replace it with.
76 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
77 /// Preview of how the buffer will look after applying the edits.
78 edit_preview: EditPreview,
79}
80
81impl CurrentCompletion {
82 /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
83 /// Returns None if the user's edits conflict with the predicted edits.
84 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
85 edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
86 }
87}
88
89pub struct CodestralEditPredictionDelegate {
90 http_client: Arc<dyn HttpClient>,
91 pending_request: Option<Task<Result<()>>>,
92 current_completion: Option<CurrentCompletion>,
93}
94
95impl CodestralEditPredictionDelegate {
96 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
97 Self {
98 http_client,
99 pending_request: None,
100 current_completion: None,
101 }
102 }
103
104 pub fn ensure_api_key_loaded(cx: &mut App) {
105 load_codestral_api_key(cx).detach();
106 }
107
108 /// Uses Codestral's Fill-in-the-Middle API for code completion.
109 async fn fetch_completion(
110 http_client: Arc<dyn HttpClient>,
111 api_key: &str,
112 prompt: String,
113 suffix: String,
114 model: String,
115 max_tokens: Option<u32>,
116 api_url: String,
117 ) -> Result<String> {
118 let start_time = Instant::now();
119
120 log::debug!(
121 "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
122 model,
123 max_tokens
124 );
125
126 let request = CodestralRequest {
127 model,
128 prompt,
129 suffix: if suffix.is_empty() {
130 None
131 } else {
132 Some(suffix)
133 },
134 max_tokens: max_tokens.or(Some(350)),
135 temperature: Some(0.2),
136 top_p: Some(1.0),
137 stream: Some(false),
138 stop: None,
139 random_seed: None,
140 min_tokens: None,
141 };
142
143 let request_body = serde_json::to_string(&request)?;
144
145 log::debug!("Codestral: Sending FIM request");
146
147 let http_request = http_client::Request::builder()
148 .method(http_client::Method::POST)
149 .uri(format!("{}/v1/fim/completions", api_url))
150 .header("Content-Type", "application/json")
151 .header("Authorization", format!("Bearer {}", api_key))
152 .body(http_client::AsyncBody::from(request_body))?;
153
154 let mut response = http_client.send(http_request).await?;
155 let status = response.status();
156
157 log::debug!("Codestral: Response status: {}", status);
158
159 if !status.is_success() {
160 let mut body = String::new();
161 response.body_mut().read_to_string(&mut body).await?;
162 return Err(anyhow::anyhow!(
163 "Codestral API error: {} - {}",
164 status,
165 body
166 ));
167 }
168
169 let mut body = String::new();
170 response.body_mut().read_to_string(&mut body).await?;
171
172 let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
173
174 let elapsed = start_time.elapsed();
175
176 if let Some(choice) = codestral_response.choices.first() {
177 let completion = &choice.message.content;
178
179 log::debug!(
180 "Codestral: Completion received ({} tokens, {:.2}s)",
181 codestral_response.usage.completion_tokens,
182 elapsed.as_secs_f64()
183 );
184
185 // Return just the completion text for insertion at cursor
186 Ok(completion.clone())
187 } else {
188 log::error!("Codestral: No completion returned in response");
189 Err(anyhow::anyhow!("No completion returned from Codestral"))
190 }
191 }
192}
193
194impl EditPredictionDelegate for CodestralEditPredictionDelegate {
195 fn name() -> &'static str {
196 "codestral"
197 }
198
199 fn display_name() -> &'static str {
200 "Codestral"
201 }
202
203 fn show_predictions_in_menu() -> bool {
204 true
205 }
206
207 fn icons(&self, _cx: &App) -> EditPredictionIconSet {
208 EditPredictionIconSet::new(IconName::AiMistral)
209 }
210
211 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
212 codestral_api_key(cx).is_some()
213 }
214
215 fn is_refreshing(&self, _cx: &App) -> bool {
216 self.pending_request.is_some()
217 }
218
219 fn refresh(
220 &mut self,
221 buffer: Entity<Buffer>,
222 cursor_position: language::Anchor,
223 debounce: bool,
224 cx: &mut Context<Self>,
225 ) {
226 log::debug!("Codestral: Refresh called (debounce: {})", debounce);
227
228 let Some(api_key) = codestral_api_key(cx) else {
229 log::warn!("Codestral: No API key configured, skipping refresh");
230 return;
231 };
232
233 let snapshot = buffer.read(cx).snapshot();
234
235 // Check if current completion is still valid
236 if let Some(current_completion) = self.current_completion.as_ref() {
237 if current_completion.interpolate(&snapshot).is_some() {
238 return;
239 }
240 }
241
242 let http_client = self.http_client.clone();
243
244 // Get settings
245 let settings = all_language_settings(None, cx);
246 let model = settings
247 .edit_predictions
248 .codestral
249 .model
250 .clone()
251 .unwrap_or_else(|| "codestral-latest".to_string());
252 let max_tokens = settings.edit_predictions.codestral.max_tokens;
253 let api_url = codestral_api_url(cx).to_string();
254
255 self.pending_request = Some(cx.spawn(async move |this, cx| {
256 if debounce {
257 log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
258 cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
259 }
260
261 let cursor_offset = cursor_position.to_offset(&snapshot);
262
263 const MAX_EDITABLE_TOKENS: usize = 350;
264 const MAX_CONTEXT_TOKENS: usize = 150;
265
266 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
267 cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
268 let syntax_ranges = cursor_excerpt::compute_syntax_ranges(
269 &snapshot,
270 cursor_offset,
271 &excerpt_offset_range,
272 );
273 let excerpt_text: String = snapshot.text_for_range(excerpt_point_range).collect();
274 let (_, context_range) = zeta_prompt::compute_editable_and_context_ranges(
275 &excerpt_text,
276 cursor_offset_in_excerpt,
277 &syntax_ranges,
278 MAX_EDITABLE_TOKENS,
279 MAX_CONTEXT_TOKENS,
280 );
281 let context_text = &excerpt_text[context_range.clone()];
282 let cursor_within_excerpt = cursor_offset_in_excerpt
283 .saturating_sub(context_range.start)
284 .min(context_text.len());
285 let prompt = context_text[..cursor_within_excerpt].to_string();
286 let suffix = context_text[cursor_within_excerpt..].to_string();
287
288 let completion_text = match Self::fetch_completion(
289 http_client,
290 &api_key,
291 prompt,
292 suffix,
293 model,
294 max_tokens,
295 api_url,
296 )
297 .await
298 {
299 Ok(completion) => completion,
300 Err(e) => {
301 log::error!("Codestral: Failed to fetch completion: {}", e);
302 this.update(cx, |this, cx| {
303 this.pending_request = None;
304 cx.notify();
305 })?;
306 return Err(e);
307 }
308 };
309
310 if completion_text.trim().is_empty() {
311 log::debug!("Codestral: Completion was empty after trimming; ignoring");
312 this.update(cx, |this, cx| {
313 this.pending_request = None;
314 cx.notify();
315 })?;
316 return Ok(());
317 }
318
319 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
320 vec![(cursor_position..cursor_position, completion_text.into())].into();
321 let edit_preview = buffer
322 .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
323 .await;
324
325 this.update(cx, |this, cx| {
326 this.current_completion = Some(CurrentCompletion {
327 snapshot,
328 edits,
329 edit_preview,
330 });
331 this.pending_request = None;
332 cx.notify();
333 })?;
334
335 Ok(())
336 }));
337 }
338
339 fn accept(&mut self, _cx: &mut Context<Self>) {
340 log::debug!("Codestral: Completion accepted");
341 self.pending_request = None;
342 self.current_completion = None;
343 }
344
345 fn discard(&mut self, _reason: EditPredictionDiscardReason, _cx: &mut Context<Self>) {
346 log::debug!("Codestral: Completion discarded");
347 self.pending_request = None;
348 self.current_completion = None;
349 }
350
351 /// Returns the completion suggestion, adjusted or invalidated based on user edits
352 fn suggest(
353 &mut self,
354 buffer: &Entity<Buffer>,
355 _cursor_position: Anchor,
356 cx: &mut Context<Self>,
357 ) -> Option<EditPrediction> {
358 let current_completion = self.current_completion.as_ref()?;
359 let buffer = buffer.read(cx);
360 let edits = current_completion.interpolate(&buffer.snapshot())?;
361 if edits.is_empty() {
362 return None;
363 }
364 Some(EditPrediction::Local {
365 id: None,
366 edits,
367 cursor_position: None,
368 edit_preview: Some(current_completion.edit_preview.clone()),
369 })
370 }
371}
372
373#[derive(Debug, Serialize, Deserialize)]
374pub struct CodestralRequest {
375 pub model: String,
376 pub prompt: String,
377 #[serde(skip_serializing_if = "Option::is_none")]
378 pub suffix: Option<String>,
379 #[serde(skip_serializing_if = "Option::is_none")]
380 pub max_tokens: Option<u32>,
381 #[serde(skip_serializing_if = "Option::is_none")]
382 pub temperature: Option<f32>,
383 #[serde(skip_serializing_if = "Option::is_none")]
384 pub top_p: Option<f32>,
385 #[serde(skip_serializing_if = "Option::is_none")]
386 pub stream: Option<bool>,
387 #[serde(skip_serializing_if = "Option::is_none")]
388 pub stop: Option<Vec<String>>,
389 #[serde(skip_serializing_if = "Option::is_none")]
390 pub random_seed: Option<u32>,
391 #[serde(skip_serializing_if = "Option::is_none")]
392 pub min_tokens: Option<u32>,
393}
394
395#[derive(Debug, Deserialize)]
396pub struct CodestralResponse {
397 pub id: String,
398 pub object: String,
399 pub model: String,
400 pub usage: Usage,
401 pub created: u64,
402 pub choices: Vec<Choice>,
403}
404
405#[derive(Debug, Deserialize)]
406pub struct Usage {
407 pub prompt_tokens: u32,
408 pub completion_tokens: u32,
409 pub total_tokens: u32,
410}
411
412#[derive(Debug, Deserialize)]
413pub struct Choice {
414 pub index: u32,
415 pub message: Message,
416 pub finish_reason: String,
417}
418
419#[derive(Debug, Deserialize)]
420pub struct Message {
421 pub content: String,
422 pub role: String,
423}