codestral.rs

  1use anyhow::Result;
  2use edit_prediction::cursor_excerpt;
  3use edit_prediction_types::{
  4    EditPrediction, EditPredictionDelegate, EditPredictionDiscardReason, EditPredictionIconSet,
  5};
  6use futures::AsyncReadExt;
  7use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
  8use http_client::HttpClient;
  9use icons::IconName;
 10use language::{
 11    Anchor, Buffer, BufferSnapshot, EditPreview, language_settings::all_language_settings,
 12};
 13use language_model::{ApiKeyState, AuthenticateError, EnvVar, env_var};
 14use serde::{Deserialize, Serialize};
 15
 16use std::{
 17    ops::Range,
 18    sync::Arc,
 19    time::{Duration, Instant},
 20};
 21use text::ToOffset;
 22use zed_credentials_provider::global as global_credentials_provider;
 23
 24pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai";
 25pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
 26
 27static CODESTRAL_API_KEY_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("CODESTRAL_API_KEY");
 28
 29struct GlobalCodestralApiKey(Entity<ApiKeyState>);
 30
 31impl Global for GlobalCodestralApiKey {}
 32
 33pub fn codestral_api_key_state(cx: &mut App) -> Entity<ApiKeyState> {
 34    if let Some(global) = cx.try_global::<GlobalCodestralApiKey>() {
 35        return global.0.clone();
 36    }
 37    let entity =
 38        cx.new(|cx| ApiKeyState::new(codestral_api_url(cx), CODESTRAL_API_KEY_ENV_VAR.clone()));
 39    cx.set_global(GlobalCodestralApiKey(entity.clone()));
 40    entity
 41}
 42
 43pub fn codestral_api_key(cx: &App) -> Option<Arc<str>> {
 44    let url = codestral_api_url(cx);
 45    cx.try_global::<GlobalCodestralApiKey>()?
 46        .0
 47        .read(cx)
 48        .key(&url)
 49}
 50
 51pub fn load_codestral_api_key(cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 52    let credentials_provider = global_credentials_provider(cx);
 53    let api_url = codestral_api_url(cx);
 54    codestral_api_key_state(cx).update(cx, |key_state, cx| {
 55        key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
 56    })
 57}
 58
 59pub fn codestral_api_url(cx: &App) -> SharedString {
 60    all_language_settings(None, cx)
 61        .edit_predictions
 62        .codestral
 63        .api_url
 64        .clone()
 65        .unwrap_or_else(|| CODESTRAL_API_URL.to_string())
 66        .into()
 67}
 68
 69/// Represents a completion that has been received and processed from Codestral.
 70/// This struct maintains the state needed to interpolate the completion as the user types.
 71#[derive(Clone)]
 72struct CurrentCompletion {
 73    /// The buffer snapshot at the time the completion was generated.
 74    /// Used to detect changes and interpolate edits.
 75    snapshot: BufferSnapshot,
 76    /// The edits that should be applied to transform the original text into the predicted text.
 77    /// Each edit is a range in the buffer and the text to replace it with.
 78    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 79    /// Preview of how the buffer will look after applying the edits.
 80    edit_preview: EditPreview,
 81}
 82
 83impl CurrentCompletion {
 84    /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
 85    /// Returns None if the user's edits conflict with the predicted edits.
 86    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 87        edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 88    }
 89}
 90
 91pub struct CodestralEditPredictionDelegate {
 92    http_client: Arc<dyn HttpClient>,
 93    pending_request: Option<Task<Result<()>>>,
 94    current_completion: Option<CurrentCompletion>,
 95}
 96
 97impl CodestralEditPredictionDelegate {
 98    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
 99        Self {
100            http_client,
101            pending_request: None,
102            current_completion: None,
103        }
104    }
105
106    pub fn ensure_api_key_loaded(cx: &mut App) {
107        load_codestral_api_key(cx).detach();
108    }
109
110    /// Uses Codestral's Fill-in-the-Middle API for code completion.
111    async fn fetch_completion(
112        http_client: Arc<dyn HttpClient>,
113        api_key: &str,
114        prompt: String,
115        suffix: String,
116        model: String,
117        max_tokens: Option<u32>,
118        api_url: String,
119    ) -> Result<String> {
120        let start_time = Instant::now();
121
122        log::debug!(
123            "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
124            model,
125            max_tokens
126        );
127
128        let request = CodestralRequest {
129            model,
130            prompt,
131            suffix: if suffix.is_empty() {
132                None
133            } else {
134                Some(suffix)
135            },
136            max_tokens: max_tokens.or(Some(350)),
137            temperature: Some(0.2),
138            top_p: Some(1.0),
139            stream: Some(false),
140            stop: None,
141            random_seed: None,
142            min_tokens: None,
143        };
144
145        let request_body = serde_json::to_string(&request)?;
146
147        log::debug!("Codestral: Sending FIM request");
148
149        let http_request = http_client::Request::builder()
150            .method(http_client::Method::POST)
151            .uri(format!("{}/v1/fim/completions", api_url))
152            .header("Content-Type", "application/json")
153            .header("Authorization", format!("Bearer {}", api_key))
154            .body(http_client::AsyncBody::from(request_body))?;
155
156        let mut response = http_client.send(http_request).await?;
157        let status = response.status();
158
159        log::debug!("Codestral: Response status: {}", status);
160
161        if !status.is_success() {
162            let mut body = String::new();
163            response.body_mut().read_to_string(&mut body).await?;
164            return Err(anyhow::anyhow!(
165                "Codestral API error: {} - {}",
166                status,
167                body
168            ));
169        }
170
171        let mut body = String::new();
172        response.body_mut().read_to_string(&mut body).await?;
173
174        let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
175
176        let elapsed = start_time.elapsed();
177
178        if let Some(choice) = codestral_response.choices.first() {
179            let completion = &choice.message.content;
180
181            log::debug!(
182                "Codestral: Completion received ({} tokens, {:.2}s)",
183                codestral_response.usage.completion_tokens,
184                elapsed.as_secs_f64()
185            );
186
187            // Return just the completion text for insertion at cursor
188            Ok(completion.clone())
189        } else {
190            log::error!("Codestral: No completion returned in response");
191            Err(anyhow::anyhow!("No completion returned from Codestral"))
192        }
193    }
194}
195
196impl EditPredictionDelegate for CodestralEditPredictionDelegate {
197    fn name() -> &'static str {
198        "codestral"
199    }
200
201    fn display_name() -> &'static str {
202        "Codestral"
203    }
204
205    fn show_predictions_in_menu() -> bool {
206        true
207    }
208
209    fn icons(&self, _cx: &App) -> EditPredictionIconSet {
210        EditPredictionIconSet::new(IconName::AiMistral)
211    }
212
213    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
214        codestral_api_key(cx).is_some()
215    }
216
217    fn is_refreshing(&self, _cx: &App) -> bool {
218        self.pending_request.is_some()
219    }
220
221    fn refresh(
222        &mut self,
223        buffer: Entity<Buffer>,
224        cursor_position: language::Anchor,
225        debounce: bool,
226        cx: &mut Context<Self>,
227    ) {
228        log::debug!("Codestral: Refresh called (debounce: {})", debounce);
229
230        let Some(api_key) = codestral_api_key(cx) else {
231            log::warn!("Codestral: No API key configured, skipping refresh");
232            return;
233        };
234
235        let snapshot = buffer.read(cx).snapshot();
236
237        // Check if current completion is still valid
238        if let Some(current_completion) = self.current_completion.as_ref() {
239            if current_completion.interpolate(&snapshot).is_some() {
240                return;
241            }
242        }
243
244        let http_client = self.http_client.clone();
245
246        // Get settings
247        let settings = all_language_settings(None, cx);
248        let model = settings
249            .edit_predictions
250            .codestral
251            .model
252            .clone()
253            .unwrap_or_else(|| "codestral-latest".to_string());
254        let max_tokens = settings.edit_predictions.codestral.max_tokens;
255        let api_url = codestral_api_url(cx).to_string();
256
257        self.pending_request = Some(cx.spawn(async move |this, cx| {
258            if debounce {
259                log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
260                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
261            }
262
263            let cursor_offset = cursor_position.to_offset(&snapshot);
264
265            const MAX_EDITABLE_TOKENS: usize = 350;
266            const MAX_CONTEXT_TOKENS: usize = 150;
267
268            let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
269                cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
270            let syntax_ranges = cursor_excerpt::compute_syntax_ranges(
271                &snapshot,
272                cursor_offset,
273                &excerpt_offset_range,
274            );
275            let excerpt_text: String = snapshot.text_for_range(excerpt_point_range).collect();
276            let (_, context_range) = zeta_prompt::compute_editable_and_context_ranges(
277                &excerpt_text,
278                cursor_offset_in_excerpt,
279                &syntax_ranges,
280                MAX_EDITABLE_TOKENS,
281                MAX_CONTEXT_TOKENS,
282            );
283            let context_text = &excerpt_text[context_range.clone()];
284            let cursor_within_excerpt = cursor_offset_in_excerpt
285                .saturating_sub(context_range.start)
286                .min(context_text.len());
287            let prompt = context_text[..cursor_within_excerpt].to_string();
288            let suffix = context_text[cursor_within_excerpt..].to_string();
289
290            let completion_text = match Self::fetch_completion(
291                http_client,
292                &api_key,
293                prompt,
294                suffix,
295                model,
296                max_tokens,
297                api_url,
298            )
299            .await
300            {
301                Ok(completion) => completion,
302                Err(e) => {
303                    log::error!("Codestral: Failed to fetch completion: {}", e);
304                    this.update(cx, |this, cx| {
305                        this.pending_request = None;
306                        cx.notify();
307                    })?;
308                    return Err(e);
309                }
310            };
311
312            if completion_text.trim().is_empty() {
313                log::debug!("Codestral: Completion was empty after trimming; ignoring");
314                this.update(cx, |this, cx| {
315                    this.pending_request = None;
316                    cx.notify();
317                })?;
318                return Ok(());
319            }
320
321            let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
322                vec![(cursor_position..cursor_position, completion_text.into())].into();
323            let edit_preview = buffer
324                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
325                .await;
326
327            this.update(cx, |this, cx| {
328                this.current_completion = Some(CurrentCompletion {
329                    snapshot,
330                    edits,
331                    edit_preview,
332                });
333                this.pending_request = None;
334                cx.notify();
335            })?;
336
337            Ok(())
338        }));
339    }
340
341    fn accept(&mut self, _cx: &mut Context<Self>) {
342        log::debug!("Codestral: Completion accepted");
343        self.pending_request = None;
344        self.current_completion = None;
345    }
346
347    fn discard(&mut self, _reason: EditPredictionDiscardReason, _cx: &mut Context<Self>) {
348        log::debug!("Codestral: Completion discarded");
349        self.pending_request = None;
350        self.current_completion = None;
351    }
352
353    /// Returns the completion suggestion, adjusted or invalidated based on user edits
354    fn suggest(
355        &mut self,
356        buffer: &Entity<Buffer>,
357        _cursor_position: Anchor,
358        cx: &mut Context<Self>,
359    ) -> Option<EditPrediction> {
360        let current_completion = self.current_completion.as_ref()?;
361        let buffer = buffer.read(cx);
362        let edits = current_completion.interpolate(&buffer.snapshot())?;
363        if edits.is_empty() {
364            return None;
365        }
366        Some(EditPrediction::Local {
367            id: None,
368            edits,
369            cursor_position: None,
370            edit_preview: Some(current_completion.edit_preview.clone()),
371        })
372    }
373}
374
375#[derive(Debug, Serialize, Deserialize)]
376pub struct CodestralRequest {
377    pub model: String,
378    pub prompt: String,
379    #[serde(skip_serializing_if = "Option::is_none")]
380    pub suffix: Option<String>,
381    #[serde(skip_serializing_if = "Option::is_none")]
382    pub max_tokens: Option<u32>,
383    #[serde(skip_serializing_if = "Option::is_none")]
384    pub temperature: Option<f32>,
385    #[serde(skip_serializing_if = "Option::is_none")]
386    pub top_p: Option<f32>,
387    #[serde(skip_serializing_if = "Option::is_none")]
388    pub stream: Option<bool>,
389    #[serde(skip_serializing_if = "Option::is_none")]
390    pub stop: Option<Vec<String>>,
391    #[serde(skip_serializing_if = "Option::is_none")]
392    pub random_seed: Option<u32>,
393    #[serde(skip_serializing_if = "Option::is_none")]
394    pub min_tokens: Option<u32>,
395}
396
397#[derive(Debug, Deserialize)]
398pub struct CodestralResponse {
399    pub id: String,
400    pub object: String,
401    pub model: String,
402    pub usage: Usage,
403    pub created: u64,
404    pub choices: Vec<Choice>,
405}
406
407#[derive(Debug, Deserialize)]
408pub struct Usage {
409    pub prompt_tokens: u32,
410    pub completion_tokens: u32,
411    pub total_tokens: u32,
412}
413
414#[derive(Debug, Deserialize)]
415pub struct Choice {
416    pub index: u32,
417    pub message: Message,
418    pub finish_reason: String,
419}
420
421#[derive(Debug, Deserialize)]
422pub struct Message {
423    pub content: String,
424    pub role: String,
425}