1use anyhow::Result;
2use edit_prediction::cursor_excerpt;
3use edit_prediction_types::{
4 EditPrediction, EditPredictionDelegate, EditPredictionDiscardReason, EditPredictionIconSet,
5};
6use futures::AsyncReadExt;
7use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
8use http_client::HttpClient;
9use icons::IconName;
10use language::{
11 Anchor, Buffer, BufferSnapshot, EditPreview, language_settings::all_language_settings,
12};
13use language_model::{ApiKeyState, AuthenticateError, EnvVar, env_var};
14use serde::{Deserialize, Serialize};
15
16use std::{
17 ops::Range,
18 sync::Arc,
19 time::{Duration, Instant},
20};
21use text::ToOffset;
22use zed_credentials_provider::global as global_credentials_provider;
23
24pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai";
25pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
26
27static CODESTRAL_API_KEY_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("CODESTRAL_API_KEY");
28
29struct GlobalCodestralApiKey(Entity<ApiKeyState>);
30
31impl Global for GlobalCodestralApiKey {}
32
33pub fn codestral_api_key_state(cx: &mut App) -> Entity<ApiKeyState> {
34 if let Some(global) = cx.try_global::<GlobalCodestralApiKey>() {
35 return global.0.clone();
36 }
37 let entity =
38 cx.new(|cx| ApiKeyState::new(codestral_api_url(cx), CODESTRAL_API_KEY_ENV_VAR.clone()));
39 cx.set_global(GlobalCodestralApiKey(entity.clone()));
40 entity
41}
42
43pub fn codestral_api_key(cx: &App) -> Option<Arc<str>> {
44 let url = codestral_api_url(cx);
45 cx.try_global::<GlobalCodestralApiKey>()?
46 .0
47 .read(cx)
48 .key(&url)
49}
50
51pub fn load_codestral_api_key(cx: &mut App) -> Task<Result<(), AuthenticateError>> {
52 let credentials_provider = global_credentials_provider(cx);
53 let api_url = codestral_api_url(cx);
54 codestral_api_key_state(cx).update(cx, |key_state, cx| {
55 key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
56 })
57}
58
59pub fn codestral_api_url(cx: &App) -> SharedString {
60 all_language_settings(None, cx)
61 .edit_predictions
62 .codestral
63 .api_url
64 .clone()
65 .unwrap_or_else(|| CODESTRAL_API_URL.to_string())
66 .into()
67}
68
69/// Represents a completion that has been received and processed from Codestral.
70/// This struct maintains the state needed to interpolate the completion as the user types.
71#[derive(Clone)]
72struct CurrentCompletion {
73 /// The buffer snapshot at the time the completion was generated.
74 /// Used to detect changes and interpolate edits.
75 snapshot: BufferSnapshot,
76 /// The edits that should be applied to transform the original text into the predicted text.
77 /// Each edit is a range in the buffer and the text to replace it with.
78 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
79 /// Preview of how the buffer will look after applying the edits.
80 edit_preview: EditPreview,
81}
82
83impl CurrentCompletion {
84 /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
85 /// Returns None if the user's edits conflict with the predicted edits.
86 fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
87 edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
88 }
89}
90
91pub struct CodestralEditPredictionDelegate {
92 http_client: Arc<dyn HttpClient>,
93 pending_request: Option<Task<Result<()>>>,
94 current_completion: Option<CurrentCompletion>,
95}
96
97impl CodestralEditPredictionDelegate {
98 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
99 Self {
100 http_client,
101 pending_request: None,
102 current_completion: None,
103 }
104 }
105
106 pub fn ensure_api_key_loaded(cx: &mut App) {
107 load_codestral_api_key(cx).detach();
108 }
109
110 /// Uses Codestral's Fill-in-the-Middle API for code completion.
111 async fn fetch_completion(
112 http_client: Arc<dyn HttpClient>,
113 api_key: &str,
114 prompt: String,
115 suffix: String,
116 model: String,
117 max_tokens: Option<u32>,
118 api_url: String,
119 ) -> Result<String> {
120 let start_time = Instant::now();
121
122 log::debug!(
123 "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
124 model,
125 max_tokens
126 );
127
128 let request = CodestralRequest {
129 model,
130 prompt,
131 suffix: if suffix.is_empty() {
132 None
133 } else {
134 Some(suffix)
135 },
136 max_tokens: max_tokens.or(Some(350)),
137 temperature: Some(0.2),
138 top_p: Some(1.0),
139 stream: Some(false),
140 stop: None,
141 random_seed: None,
142 min_tokens: None,
143 };
144
145 let request_body = serde_json::to_string(&request)?;
146
147 log::debug!("Codestral: Sending FIM request");
148
149 let http_request = http_client::Request::builder()
150 .method(http_client::Method::POST)
151 .uri(format!("{}/v1/fim/completions", api_url))
152 .header("Content-Type", "application/json")
153 .header("Authorization", format!("Bearer {}", api_key))
154 .body(http_client::AsyncBody::from(request_body))?;
155
156 let mut response = http_client.send(http_request).await?;
157 let status = response.status();
158
159 log::debug!("Codestral: Response status: {}", status);
160
161 if !status.is_success() {
162 let mut body = String::new();
163 response.body_mut().read_to_string(&mut body).await?;
164 return Err(anyhow::anyhow!(
165 "Codestral API error: {} - {}",
166 status,
167 body
168 ));
169 }
170
171 let mut body = String::new();
172 response.body_mut().read_to_string(&mut body).await?;
173
174 let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
175
176 let elapsed = start_time.elapsed();
177
178 if let Some(choice) = codestral_response.choices.first() {
179 let completion = &choice.message.content;
180
181 log::debug!(
182 "Codestral: Completion received ({} tokens, {:.2}s)",
183 codestral_response.usage.completion_tokens,
184 elapsed.as_secs_f64()
185 );
186
187 // Return just the completion text for insertion at cursor
188 Ok(completion.clone())
189 } else {
190 log::error!("Codestral: No completion returned in response");
191 Err(anyhow::anyhow!("No completion returned from Codestral"))
192 }
193 }
194}
195
196impl EditPredictionDelegate for CodestralEditPredictionDelegate {
197 fn name() -> &'static str {
198 "codestral"
199 }
200
201 fn display_name() -> &'static str {
202 "Codestral"
203 }
204
205 fn show_predictions_in_menu() -> bool {
206 true
207 }
208
209 fn icons(&self, _cx: &App) -> EditPredictionIconSet {
210 EditPredictionIconSet::new(IconName::AiMistral)
211 }
212
213 fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
214 codestral_api_key(cx).is_some()
215 }
216
217 fn is_refreshing(&self, _cx: &App) -> bool {
218 self.pending_request.is_some()
219 }
220
221 fn refresh(
222 &mut self,
223 buffer: Entity<Buffer>,
224 cursor_position: language::Anchor,
225 debounce: bool,
226 cx: &mut Context<Self>,
227 ) {
228 log::debug!("Codestral: Refresh called (debounce: {})", debounce);
229
230 let Some(api_key) = codestral_api_key(cx) else {
231 log::warn!("Codestral: No API key configured, skipping refresh");
232 return;
233 };
234
235 let snapshot = buffer.read(cx).snapshot();
236
237 // Check if current completion is still valid
238 if let Some(current_completion) = self.current_completion.as_ref() {
239 if current_completion.interpolate(&snapshot).is_some() {
240 return;
241 }
242 }
243
244 let http_client = self.http_client.clone();
245
246 // Get settings
247 let settings = all_language_settings(None, cx);
248 let model = settings
249 .edit_predictions
250 .codestral
251 .model
252 .clone()
253 .unwrap_or_else(|| "codestral-latest".to_string());
254 let max_tokens = settings.edit_predictions.codestral.max_tokens;
255 let api_url = codestral_api_url(cx).to_string();
256
257 self.pending_request = Some(cx.spawn(async move |this, cx| {
258 if debounce {
259 log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
260 cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
261 }
262
263 let cursor_offset = cursor_position.to_offset(&snapshot);
264
265 const MAX_EDITABLE_TOKENS: usize = 350;
266 const MAX_CONTEXT_TOKENS: usize = 150;
267
268 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
269 cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
270 let syntax_ranges = cursor_excerpt::compute_syntax_ranges(
271 &snapshot,
272 cursor_offset,
273 &excerpt_offset_range,
274 );
275 let excerpt_text: String = snapshot.text_for_range(excerpt_point_range).collect();
276 let (_, context_range) = zeta_prompt::compute_editable_and_context_ranges(
277 &excerpt_text,
278 cursor_offset_in_excerpt,
279 &syntax_ranges,
280 MAX_EDITABLE_TOKENS,
281 MAX_CONTEXT_TOKENS,
282 );
283 let context_text = &excerpt_text[context_range.clone()];
284 let cursor_within_excerpt = cursor_offset_in_excerpt
285 .saturating_sub(context_range.start)
286 .min(context_text.len());
287 let prompt = context_text[..cursor_within_excerpt].to_string();
288 let suffix = context_text[cursor_within_excerpt..].to_string();
289
290 let completion_text = match Self::fetch_completion(
291 http_client,
292 &api_key,
293 prompt,
294 suffix,
295 model,
296 max_tokens,
297 api_url,
298 )
299 .await
300 {
301 Ok(completion) => completion,
302 Err(e) => {
303 log::error!("Codestral: Failed to fetch completion: {}", e);
304 this.update(cx, |this, cx| {
305 this.pending_request = None;
306 cx.notify();
307 })?;
308 return Err(e);
309 }
310 };
311
312 if completion_text.trim().is_empty() {
313 log::debug!("Codestral: Completion was empty after trimming; ignoring");
314 this.update(cx, |this, cx| {
315 this.pending_request = None;
316 cx.notify();
317 })?;
318 return Ok(());
319 }
320
321 let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
322 vec![(cursor_position..cursor_position, completion_text.into())].into();
323 let edit_preview = buffer
324 .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
325 .await;
326
327 this.update(cx, |this, cx| {
328 this.current_completion = Some(CurrentCompletion {
329 snapshot,
330 edits,
331 edit_preview,
332 });
333 this.pending_request = None;
334 cx.notify();
335 })?;
336
337 Ok(())
338 }));
339 }
340
341 fn accept(&mut self, _cx: &mut Context<Self>) {
342 log::debug!("Codestral: Completion accepted");
343 self.pending_request = None;
344 self.current_completion = None;
345 }
346
347 fn discard(&mut self, _reason: EditPredictionDiscardReason, _cx: &mut Context<Self>) {
348 log::debug!("Codestral: Completion discarded");
349 self.pending_request = None;
350 self.current_completion = None;
351 }
352
353 /// Returns the completion suggestion, adjusted or invalidated based on user edits
354 fn suggest(
355 &mut self,
356 buffer: &Entity<Buffer>,
357 _cursor_position: Anchor,
358 cx: &mut Context<Self>,
359 ) -> Option<EditPrediction> {
360 let current_completion = self.current_completion.as_ref()?;
361 let buffer = buffer.read(cx);
362 let edits = current_completion.interpolate(&buffer.snapshot())?;
363 if edits.is_empty() {
364 return None;
365 }
366 Some(EditPrediction::Local {
367 id: None,
368 edits,
369 cursor_position: None,
370 edit_preview: Some(current_completion.edit_preview.clone()),
371 })
372 }
373}
374
375#[derive(Debug, Serialize, Deserialize)]
376pub struct CodestralRequest {
377 pub model: String,
378 pub prompt: String,
379 #[serde(skip_serializing_if = "Option::is_none")]
380 pub suffix: Option<String>,
381 #[serde(skip_serializing_if = "Option::is_none")]
382 pub max_tokens: Option<u32>,
383 #[serde(skip_serializing_if = "Option::is_none")]
384 pub temperature: Option<f32>,
385 #[serde(skip_serializing_if = "Option::is_none")]
386 pub top_p: Option<f32>,
387 #[serde(skip_serializing_if = "Option::is_none")]
388 pub stream: Option<bool>,
389 #[serde(skip_serializing_if = "Option::is_none")]
390 pub stop: Option<Vec<String>>,
391 #[serde(skip_serializing_if = "Option::is_none")]
392 pub random_seed: Option<u32>,
393 #[serde(skip_serializing_if = "Option::is_none")]
394 pub min_tokens: Option<u32>,
395}
396
397#[derive(Debug, Deserialize)]
398pub struct CodestralResponse {
399 pub id: String,
400 pub object: String,
401 pub model: String,
402 pub usage: Usage,
403 pub created: u64,
404 pub choices: Vec<Choice>,
405}
406
407#[derive(Debug, Deserialize)]
408pub struct Usage {
409 pub prompt_tokens: u32,
410 pub completion_tokens: u32,
411 pub total_tokens: u32,
412}
413
414#[derive(Debug, Deserialize)]
415pub struct Choice {
416 pub index: u32,
417 pub message: Message,
418 pub finish_reason: String,
419}
420
421#[derive(Debug, Deserialize)]
422pub struct Message {
423 pub content: String,
424 pub role: String,
425}