1use anyhow::Result;
2use edit_prediction::cursor_excerpt;
3use edit_prediction_types::{
4 EditPrediction, EditPredictionDelegate, EditPredictionDiscardReason, EditPredictionIconSet,
5};
6use futures::AsyncReadExt;
7use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
8use http_client::HttpClient;
9use icons::IconName;
10use language::{
11 Anchor, Buffer, BufferSnapshot, EditPreview, language_settings::all_language_settings,
12};
13use language_model::{ApiKeyState, AuthenticateError, EnvVar, env_var};
14use serde::{Deserialize, Serialize};
15
16use std::{
17 ops::Range,
18 sync::Arc,
19 time::{Duration, Instant},
20};
21use text::ToOffset;
22
23pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai";
24pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
25
26static CODESTRAL_API_KEY_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("CODESTRAL_API_KEY");
27
28struct GlobalCodestralApiKey(Entity<ApiKeyState>);
29
30impl Global for GlobalCodestralApiKey {}
31
32pub fn codestral_api_key_state(cx: &mut App) -> Entity<ApiKeyState> {
33 if let Some(global) = cx.try_global::<GlobalCodestralApiKey>() {
34 return global.0.clone();
35 }
36 let entity =
37 cx.new(|cx| ApiKeyState::new(codestral_api_url(cx), CODESTRAL_API_KEY_ENV_VAR.clone()));
38 cx.set_global(GlobalCodestralApiKey(entity.clone()));
39 entity
40}
41
42pub fn codestral_api_key(cx: &App) -> Option<Arc<str>> {
43 let url = codestral_api_url(cx);
44 cx.try_global::<GlobalCodestralApiKey>()?
45 .0
46 .read(cx)
47 .key(&url)
48}
49
50pub fn load_codestral_api_key(cx: &mut App) -> Task<Result<(), AuthenticateError>> {
51 let credentials_provider = zed_credentials_provider::global(cx);
52 let api_url = codestral_api_url(cx);
53 codestral_api_key_state(cx).update(cx, |key_state, cx| {
54 key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
55 })
56}
57
58pub fn codestral_api_url(cx: &App) -> SharedString {
59 all_language_settings(None, cx)
60 .edit_predictions
61 .codestral
62 .api_url
63 .clone()
64 .unwrap_or_else(|| CODESTRAL_API_URL.to_string())
65 .into()
66}
67
68/// Represents a completion that has been received and processed from Codestral.
69/// This struct maintains the state needed to interpolate the completion as the user types.
70#[derive(Clone)]
71struct CurrentCompletion {
72 /// The buffer snapshot at the time the completion was generated.
73 /// Used to detect changes and interpolate edits.
74 snapshot: BufferSnapshot,
75 /// The edits that should be applied to transform the original text into the predicted text.
76 /// Each edit is a range in the buffer and the text to replace it with.
77 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
78 /// Preview of how the buffer will look after applying the edits.
79 edit_preview: EditPreview,
80}
81
82impl CurrentCompletion {
83 /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
84 /// Returns None if the user's edits conflict with the predicted edits.
85 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
86 edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
87 }
88}
89
90pub struct CodestralEditPredictionDelegate {
91 http_client: Arc<dyn HttpClient>,
92 pending_request: Option<Task<Result<()>>>,
93 current_completion: Option<CurrentCompletion>,
94}
95
96impl CodestralEditPredictionDelegate {
97 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
98 Self {
99 http_client,
100 pending_request: None,
101 current_completion: None,
102 }
103 }
104
105 pub fn ensure_api_key_loaded(cx: &mut App) {
106 load_codestral_api_key(cx).detach();
107 }
108
109 /// Uses Codestral's Fill-in-the-Middle API for code completion.
110 async fn fetch_completion(
111 http_client: Arc<dyn HttpClient>,
112 api_key: &str,
113 prompt: String,
114 suffix: String,
115 model: String,
116 max_tokens: Option<u32>,
117 api_url: String,
118 ) -> Result<String> {
119 let start_time = Instant::now();
120
121 log::debug!(
122 "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
123 model,
124 max_tokens
125 );
126
127 let request = CodestralRequest {
128 model,
129 prompt,
130 suffix: if suffix.is_empty() {
131 None
132 } else {
133 Some(suffix)
134 },
135 max_tokens: max_tokens.or(Some(350)),
136 temperature: Some(0.2),
137 top_p: Some(1.0),
138 stream: Some(false),
139 stop: None,
140 random_seed: None,
141 min_tokens: None,
142 };
143
144 let request_body = serde_json::to_string(&request)?;
145
146 log::debug!("Codestral: Sending FIM request");
147
148 let http_request = http_client::Request::builder()
149 .method(http_client::Method::POST)
150 .uri(format!("{}/v1/fim/completions", api_url))
151 .header("Content-Type", "application/json")
152 .header("Authorization", format!("Bearer {}", api_key))
153 .body(http_client::AsyncBody::from(request_body))?;
154
155 let mut response = http_client.send(http_request).await?;
156 let status = response.status();
157
158 log::debug!("Codestral: Response status: {}", status);
159
160 if !status.is_success() {
161 let mut body = String::new();
162 response.body_mut().read_to_string(&mut body).await?;
163 return Err(anyhow::anyhow!(
164 "Codestral API error: {} - {}",
165 status,
166 body
167 ));
168 }
169
170 let mut body = String::new();
171 response.body_mut().read_to_string(&mut body).await?;
172
173 let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
174
175 let elapsed = start_time.elapsed();
176
177 if let Some(choice) = codestral_response.choices.first() {
178 let completion = &choice.message.content;
179
180 log::debug!(
181 "Codestral: Completion received ({} tokens, {:.2}s)",
182 codestral_response.usage.completion_tokens,
183 elapsed.as_secs_f64()
184 );
185
186 // Return just the completion text for insertion at cursor
187 Ok(completion.clone())
188 } else {
189 log::error!("Codestral: No completion returned in response");
190 Err(anyhow::anyhow!("No completion returned from Codestral"))
191 }
192 }
193}
194
195impl EditPredictionDelegate for CodestralEditPredictionDelegate {
196 fn name() -> &'static str {
197 "codestral"
198 }
199
200 fn display_name() -> &'static str {
201 "Codestral"
202 }
203
204 fn show_predictions_in_menu() -> bool {
205 true
206 }
207
208 fn icons(&self, _cx: &App) -> EditPredictionIconSet {
209 EditPredictionIconSet::new(IconName::AiMistral)
210 }
211
212 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
213 codestral_api_key(cx).is_some()
214 }
215
216 fn is_refreshing(&self, _cx: &App) -> bool {
217 self.pending_request.is_some()
218 }
219
220 fn refresh(
221 &mut self,
222 buffer: Entity<Buffer>,
223 cursor_position: language::Anchor,
224 debounce: bool,
225 cx: &mut Context<Self>,
226 ) {
227 log::debug!("Codestral: Refresh called (debounce: {})", debounce);
228
229 let Some(api_key) = codestral_api_key(cx) else {
230 log::warn!("Codestral: No API key configured, skipping refresh");
231 return;
232 };
233
234 let snapshot = buffer.read(cx).snapshot();
235
236 // Check if current completion is still valid
237 if let Some(current_completion) = self.current_completion.as_ref() {
238 if current_completion.interpolate(&snapshot).is_some() {
239 return;
240 }
241 }
242
243 let http_client = self.http_client.clone();
244
245 // Get settings
246 let settings = all_language_settings(None, cx);
247 let model = settings
248 .edit_predictions
249 .codestral
250 .model
251 .clone()
252 .unwrap_or_else(|| "codestral-latest".to_string());
253 let max_tokens = settings.edit_predictions.codestral.max_tokens;
254 let api_url = codestral_api_url(cx).to_string();
255
256 self.pending_request = Some(cx.spawn(async move |this, cx| {
257 if debounce {
258 log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
259 cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
260 }
261
262 let cursor_offset = cursor_position.to_offset(&snapshot);
263
264 const MAX_EDITABLE_TOKENS: usize = 350;
265 const MAX_CONTEXT_TOKENS: usize = 150;
266
267 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
268 cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
269 let syntax_ranges = cursor_excerpt::compute_syntax_ranges(
270 &snapshot,
271 cursor_offset,
272 &excerpt_offset_range,
273 );
274 let excerpt_text: String = snapshot.text_for_range(excerpt_point_range).collect();
275 let (_, context_range) = zeta_prompt::compute_editable_and_context_ranges(
276 &excerpt_text,
277 cursor_offset_in_excerpt,
278 &syntax_ranges,
279 MAX_EDITABLE_TOKENS,
280 MAX_CONTEXT_TOKENS,
281 );
282 let context_text = &excerpt_text[context_range.clone()];
283 let cursor_within_excerpt = cursor_offset_in_excerpt
284 .saturating_sub(context_range.start)
285 .min(context_text.len());
286 let prompt = context_text[..cursor_within_excerpt].to_string();
287 let suffix = context_text[cursor_within_excerpt..].to_string();
288
289 let completion_text = match Self::fetch_completion(
290 http_client,
291 &api_key,
292 prompt,
293 suffix,
294 model,
295 max_tokens,
296 api_url,
297 )
298 .await
299 {
300 Ok(completion) => completion,
301 Err(e) => {
302 log::error!("Codestral: Failed to fetch completion: {}", e);
303 this.update(cx, |this, cx| {
304 this.pending_request = None;
305 cx.notify();
306 })?;
307 return Err(e);
308 }
309 };
310
311 if completion_text.trim().is_empty() {
312 log::debug!("Codestral: Completion was empty after trimming; ignoring");
313 this.update(cx, |this, cx| {
314 this.pending_request = None;
315 cx.notify();
316 })?;
317 return Ok(());
318 }
319
320 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
321 vec![(cursor_position..cursor_position, completion_text.into())].into();
322 let edit_preview = buffer
323 .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
324 .await;
325
326 this.update(cx, |this, cx| {
327 this.current_completion = Some(CurrentCompletion {
328 snapshot,
329 edits,
330 edit_preview,
331 });
332 this.pending_request = None;
333 cx.notify();
334 })?;
335
336 Ok(())
337 }));
338 }
339
340 fn accept(&mut self, _cx: &mut Context<Self>) {
341 log::debug!("Codestral: Completion accepted");
342 self.pending_request = None;
343 self.current_completion = None;
344 }
345
346 fn discard(&mut self, _reason: EditPredictionDiscardReason, _cx: &mut Context<Self>) {
347 log::debug!("Codestral: Completion discarded");
348 self.pending_request = None;
349 self.current_completion = None;
350 }
351
352 /// Returns the completion suggestion, adjusted or invalidated based on user edits
353 fn suggest(
354 &mut self,
355 buffer: &Entity<Buffer>,
356 _cursor_position: Anchor,
357 cx: &mut Context<Self>,
358 ) -> Option<EditPrediction> {
359 let current_completion = self.current_completion.as_ref()?;
360 let buffer = buffer.read(cx);
361 let edits = current_completion.interpolate(&buffer.snapshot())?;
362 if edits.is_empty() {
363 return None;
364 }
365 Some(EditPrediction::Local {
366 id: None,
367 edits,
368 cursor_position: None,
369 edit_preview: Some(current_completion.edit_preview.clone()),
370 })
371 }
372}
373
374#[derive(Debug, Serialize, Deserialize)]
375pub struct CodestralRequest {
376 pub model: String,
377 pub prompt: String,
378 #[serde(skip_serializing_if = "Option::is_none")]
379 pub suffix: Option<String>,
380 #[serde(skip_serializing_if = "Option::is_none")]
381 pub max_tokens: Option<u32>,
382 #[serde(skip_serializing_if = "Option::is_none")]
383 pub temperature: Option<f32>,
384 #[serde(skip_serializing_if = "Option::is_none")]
385 pub top_p: Option<f32>,
386 #[serde(skip_serializing_if = "Option::is_none")]
387 pub stream: Option<bool>,
388 #[serde(skip_serializing_if = "Option::is_none")]
389 pub stop: Option<Vec<String>>,
390 #[serde(skip_serializing_if = "Option::is_none")]
391 pub random_seed: Option<u32>,
392 #[serde(skip_serializing_if = "Option::is_none")]
393 pub min_tokens: Option<u32>,
394}
395
396#[derive(Debug, Deserialize)]
397pub struct CodestralResponse {
398 pub id: String,
399 pub object: String,
400 pub model: String,
401 pub usage: Usage,
402 pub created: u64,
403 pub choices: Vec<Choice>,
404}
405
406#[derive(Debug, Deserialize)]
407pub struct Usage {
408 pub prompt_tokens: u32,
409 pub completion_tokens: u32,
410 pub total_tokens: u32,
411}
412
413#[derive(Debug, Deserialize)]
414pub struct Choice {
415 pub index: u32,
416 pub message: Message,
417 pub finish_reason: String,
418}
419
420#[derive(Debug, Deserialize)]
421pub struct Message {
422 pub content: String,
423 pub role: String,
424}