codestral.rs

  1use anyhow::Result;
  2use edit_prediction::cursor_excerpt;
  3use edit_prediction_types::{
  4    EditPrediction, EditPredictionDelegate, EditPredictionDiscardReason, EditPredictionIconSet,
  5};
  6use futures::AsyncReadExt;
  7use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
  8use http_client::HttpClient;
  9use icons::IconName;
 10use language::{
 11    Anchor, Buffer, BufferSnapshot, EditPreview, language_settings::all_language_settings,
 12};
 13use language_model::{ApiKeyState, AuthenticateError, EnvVar, env_var};
 14use serde::{Deserialize, Serialize};
 15
 16use std::{
 17    ops::Range,
 18    sync::Arc,
 19    time::{Duration, Instant},
 20};
 21use text::ToOffset;
 22
 23pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai";
 24pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
 25
 26static CODESTRAL_API_KEY_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("CODESTRAL_API_KEY");
 27
 28struct GlobalCodestralApiKey(Entity<ApiKeyState>);
 29
 30impl Global for GlobalCodestralApiKey {}
 31
 32pub fn codestral_api_key_state(cx: &mut App) -> Entity<ApiKeyState> {
 33    if let Some(global) = cx.try_global::<GlobalCodestralApiKey>() {
 34        return global.0.clone();
 35    }
 36    let entity =
 37        cx.new(|cx| ApiKeyState::new(codestral_api_url(cx), CODESTRAL_API_KEY_ENV_VAR.clone()));
 38    cx.set_global(GlobalCodestralApiKey(entity.clone()));
 39    entity
 40}
 41
 42pub fn codestral_api_key(cx: &App) -> Option<Arc<str>> {
 43    let url = codestral_api_url(cx);
 44    cx.try_global::<GlobalCodestralApiKey>()?
 45        .0
 46        .read(cx)
 47        .key(&url)
 48}
 49
 50pub fn load_codestral_api_key(cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 51    let api_url = codestral_api_url(cx);
 52    codestral_api_key_state(cx).update(cx, |key_state, cx| {
 53        key_state.load_if_needed(api_url, |s| s, cx)
 54    })
 55}
 56
 57pub fn codestral_api_url(cx: &App) -> SharedString {
 58    all_language_settings(None, cx)
 59        .edit_predictions
 60        .codestral
 61        .api_url
 62        .clone()
 63        .unwrap_or_else(|| CODESTRAL_API_URL.to_string())
 64        .into()
 65}
 66
 67/// Represents a completion that has been received and processed from Codestral.
 68/// This struct maintains the state needed to interpolate the completion as the user types.
 69#[derive(Clone)]
 70struct CurrentCompletion {
 71    /// The buffer snapshot at the time the completion was generated.
 72    /// Used to detect changes and interpolate edits.
 73    snapshot: BufferSnapshot,
 74    /// The edits that should be applied to transform the original text into the predicted text.
 75    /// Each edit is a range in the buffer and the text to replace it with.
 76    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 77    /// Preview of how the buffer will look after applying the edits.
 78    edit_preview: EditPreview,
 79}
 80
 81impl CurrentCompletion {
 82    /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
 83    /// Returns None if the user's edits conflict with the predicted edits.
 84    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 85        edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 86    }
 87}
 88
 89pub struct CodestralEditPredictionDelegate {
 90    http_client: Arc<dyn HttpClient>,
 91    pending_request: Option<Task<Result<()>>>,
 92    current_completion: Option<CurrentCompletion>,
 93}
 94
 95impl CodestralEditPredictionDelegate {
 96    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
 97        Self {
 98            http_client,
 99            pending_request: None,
100            current_completion: None,
101        }
102    }
103
104    pub fn ensure_api_key_loaded(cx: &mut App) {
105        load_codestral_api_key(cx).detach();
106    }
107
108    /// Uses Codestral's Fill-in-the-Middle API for code completion.
109    async fn fetch_completion(
110        http_client: Arc<dyn HttpClient>,
111        api_key: &str,
112        prompt: String,
113        suffix: String,
114        model: String,
115        max_tokens: Option<u32>,
116        api_url: String,
117    ) -> Result<String> {
118        let start_time = Instant::now();
119
120        log::debug!(
121            "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
122            model,
123            max_tokens
124        );
125
126        let request = CodestralRequest {
127            model,
128            prompt,
129            suffix: if suffix.is_empty() {
130                None
131            } else {
132                Some(suffix)
133            },
134            max_tokens: max_tokens.or(Some(350)),
135            temperature: Some(0.2),
136            top_p: Some(1.0),
137            stream: Some(false),
138            stop: None,
139            random_seed: None,
140            min_tokens: None,
141        };
142
143        let request_body = serde_json::to_string(&request)?;
144
145        log::debug!("Codestral: Sending FIM request");
146
147        let http_request = http_client::Request::builder()
148            .method(http_client::Method::POST)
149            .uri(format!("{}/v1/fim/completions", api_url))
150            .header("Content-Type", "application/json")
151            .header("Authorization", format!("Bearer {}", api_key))
152            .body(http_client::AsyncBody::from(request_body))?;
153
154        let mut response = http_client.send(http_request).await?;
155        let status = response.status();
156
157        log::debug!("Codestral: Response status: {}", status);
158
159        if !status.is_success() {
160            let mut body = String::new();
161            response.body_mut().read_to_string(&mut body).await?;
162            return Err(anyhow::anyhow!(
163                "Codestral API error: {} - {}",
164                status,
165                body
166            ));
167        }
168
169        let mut body = String::new();
170        response.body_mut().read_to_string(&mut body).await?;
171
172        let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
173
174        let elapsed = start_time.elapsed();
175
176        if let Some(choice) = codestral_response.choices.first() {
177            let completion = &choice.message.content;
178
179            log::debug!(
180                "Codestral: Completion received ({} tokens, {:.2}s)",
181                codestral_response.usage.completion_tokens,
182                elapsed.as_secs_f64()
183            );
184
185            // Return just the completion text for insertion at cursor
186            Ok(completion.clone())
187        } else {
188            log::error!("Codestral: No completion returned in response");
189            Err(anyhow::anyhow!("No completion returned from Codestral"))
190        }
191    }
192}
193
194impl EditPredictionDelegate for CodestralEditPredictionDelegate {
195    fn name() -> &'static str {
196        "codestral"
197    }
198
199    fn display_name() -> &'static str {
200        "Codestral"
201    }
202
203    fn show_predictions_in_menu() -> bool {
204        true
205    }
206
207    fn icons(&self, _cx: &App) -> EditPredictionIconSet {
208        EditPredictionIconSet::new(IconName::AiMistral)
209    }
210
211    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
212        codestral_api_key(cx).is_some()
213    }
214
215    fn is_refreshing(&self, _cx: &App) -> bool {
216        self.pending_request.is_some()
217    }
218
219    fn refresh(
220        &mut self,
221        buffer: Entity<Buffer>,
222        cursor_position: language::Anchor,
223        debounce: bool,
224        cx: &mut Context<Self>,
225    ) {
226        log::debug!("Codestral: Refresh called (debounce: {})", debounce);
227
228        let Some(api_key) = codestral_api_key(cx) else {
229            log::warn!("Codestral: No API key configured, skipping refresh");
230            return;
231        };
232
233        let snapshot = buffer.read(cx).snapshot();
234
235        // Check if current completion is still valid
236        if let Some(current_completion) = self.current_completion.as_ref() {
237            if current_completion.interpolate(&snapshot).is_some() {
238                return;
239            }
240        }
241
242        let http_client = self.http_client.clone();
243
244        // Get settings
245        let settings = all_language_settings(None, cx);
246        let model = settings
247            .edit_predictions
248            .codestral
249            .model
250            .clone()
251            .unwrap_or_else(|| "codestral-latest".to_string());
252        let max_tokens = settings.edit_predictions.codestral.max_tokens;
253        let api_url = codestral_api_url(cx).to_string();
254
255        self.pending_request = Some(cx.spawn(async move |this, cx| {
256            if debounce {
257                log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
258                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
259            }
260
261            let cursor_offset = cursor_position.to_offset(&snapshot);
262
263            const MAX_EDITABLE_TOKENS: usize = 350;
264            const MAX_CONTEXT_TOKENS: usize = 150;
265
266            let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
267                cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
268            let syntax_ranges = cursor_excerpt::compute_syntax_ranges(
269                &snapshot,
270                cursor_offset,
271                &excerpt_offset_range,
272            );
273            let excerpt_text: String = snapshot.text_for_range(excerpt_point_range).collect();
274            let (_, context_range) = zeta_prompt::compute_editable_and_context_ranges(
275                &excerpt_text,
276                cursor_offset_in_excerpt,
277                &syntax_ranges,
278                MAX_EDITABLE_TOKENS,
279                MAX_CONTEXT_TOKENS,
280            );
281            let context_text = &excerpt_text[context_range.clone()];
282            let cursor_within_excerpt = cursor_offset_in_excerpt
283                .saturating_sub(context_range.start)
284                .min(context_text.len());
285            let prompt = context_text[..cursor_within_excerpt].to_string();
286            let suffix = context_text[cursor_within_excerpt..].to_string();
287
288            let completion_text = match Self::fetch_completion(
289                http_client,
290                &api_key,
291                prompt,
292                suffix,
293                model,
294                max_tokens,
295                api_url,
296            )
297            .await
298            {
299                Ok(completion) => completion,
300                Err(e) => {
301                    log::error!("Codestral: Failed to fetch completion: {}", e);
302                    this.update(cx, |this, cx| {
303                        this.pending_request = None;
304                        cx.notify();
305                    })?;
306                    return Err(e);
307                }
308            };
309
310            if completion_text.trim().is_empty() {
311                log::debug!("Codestral: Completion was empty after trimming; ignoring");
312                this.update(cx, |this, cx| {
313                    this.pending_request = None;
314                    cx.notify();
315                })?;
316                return Ok(());
317            }
318
319            let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
320                vec![(cursor_position..cursor_position, completion_text.into())].into();
321            let edit_preview = buffer
322                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
323                .await;
324
325            this.update(cx, |this, cx| {
326                this.current_completion = Some(CurrentCompletion {
327                    snapshot,
328                    edits,
329                    edit_preview,
330                });
331                this.pending_request = None;
332                cx.notify();
333            })?;
334
335            Ok(())
336        }));
337    }
338
339    fn accept(&mut self, _cx: &mut Context<Self>) {
340        log::debug!("Codestral: Completion accepted");
341        self.pending_request = None;
342        self.current_completion = None;
343    }
344
345    fn discard(&mut self, _reason: EditPredictionDiscardReason, _cx: &mut Context<Self>) {
346        log::debug!("Codestral: Completion discarded");
347        self.pending_request = None;
348        self.current_completion = None;
349    }
350
351    /// Returns the completion suggestion, adjusted or invalidated based on user edits
352    fn suggest(
353        &mut self,
354        buffer: &Entity<Buffer>,
355        _cursor_position: Anchor,
356        cx: &mut Context<Self>,
357    ) -> Option<EditPrediction> {
358        let current_completion = self.current_completion.as_ref()?;
359        let buffer = buffer.read(cx);
360        let edits = current_completion.interpolate(&buffer.snapshot())?;
361        if edits.is_empty() {
362            return None;
363        }
364        Some(EditPrediction::Local {
365            id: None,
366            edits,
367            cursor_position: None,
368            edit_preview: Some(current_completion.edit_preview.clone()),
369        })
370    }
371}
372
373#[derive(Debug, Serialize, Deserialize)]
374pub struct CodestralRequest {
375    pub model: String,
376    pub prompt: String,
377    #[serde(skip_serializing_if = "Option::is_none")]
378    pub suffix: Option<String>,
379    #[serde(skip_serializing_if = "Option::is_none")]
380    pub max_tokens: Option<u32>,
381    #[serde(skip_serializing_if = "Option::is_none")]
382    pub temperature: Option<f32>,
383    #[serde(skip_serializing_if = "Option::is_none")]
384    pub top_p: Option<f32>,
385    #[serde(skip_serializing_if = "Option::is_none")]
386    pub stream: Option<bool>,
387    #[serde(skip_serializing_if = "Option::is_none")]
388    pub stop: Option<Vec<String>>,
389    #[serde(skip_serializing_if = "Option::is_none")]
390    pub random_seed: Option<u32>,
391    #[serde(skip_serializing_if = "Option::is_none")]
392    pub min_tokens: Option<u32>,
393}
394
395#[derive(Debug, Deserialize)]
396pub struct CodestralResponse {
397    pub id: String,
398    pub object: String,
399    pub model: String,
400    pub usage: Usage,
401    pub created: u64,
402    pub choices: Vec<Choice>,
403}
404
405#[derive(Debug, Deserialize)]
406pub struct Usage {
407    pub prompt_tokens: u32,
408    pub completion_tokens: u32,
409    pub total_tokens: u32,
410}
411
412#[derive(Debug, Deserialize)]
413pub struct Choice {
414    pub index: u32,
415    pub message: Message,
416    pub finish_reason: String,
417}
418
419#[derive(Debug, Deserialize)]
420pub struct Message {
421    pub content: String,
422    pub role: String,
423}