codestral.rs

  1use anyhow::Result;
  2use edit_prediction::cursor_excerpt;
  3use edit_prediction_types::{
  4    EditPrediction, EditPredictionDelegate, EditPredictionDiscardReason, EditPredictionIconSet,
  5};
  6use futures::AsyncReadExt;
  7use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
  8use http_client::HttpClient;
  9use icons::IconName;
 10use language::{
 11    Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint, language_settings::all_language_settings,
 12};
 13use language_model::{ApiKeyState, AuthenticateError, EnvVar, env_var};
 14use serde::{Deserialize, Serialize};
 15
 16use std::{
 17    ops::Range,
 18    sync::Arc,
 19    time::{Duration, Instant},
 20};
 21use text::{OffsetRangeExt as _, ToOffset};
 22
 23pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai";
 24pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
 25
 26static CODESTRAL_API_KEY_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("CODESTRAL_API_KEY");
 27
 28struct GlobalCodestralApiKey(Entity<ApiKeyState>);
 29
 30impl Global for GlobalCodestralApiKey {}
 31
 32pub fn codestral_api_key_state(cx: &mut App) -> Entity<ApiKeyState> {
 33    if let Some(global) = cx.try_global::<GlobalCodestralApiKey>() {
 34        return global.0.clone();
 35    }
 36    let entity =
 37        cx.new(|cx| ApiKeyState::new(codestral_api_url(cx), CODESTRAL_API_KEY_ENV_VAR.clone()));
 38    cx.set_global(GlobalCodestralApiKey(entity.clone()));
 39    entity
 40}
 41
 42pub fn codestral_api_key(cx: &App) -> Option<Arc<str>> {
 43    let url = codestral_api_url(cx);
 44    cx.try_global::<GlobalCodestralApiKey>()?
 45        .0
 46        .read(cx)
 47        .key(&url)
 48}
 49
 50pub fn load_codestral_api_key(cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 51    let api_url = codestral_api_url(cx);
 52    codestral_api_key_state(cx).update(cx, |key_state, cx| {
 53        key_state.load_if_needed(api_url, |s| s, cx)
 54    })
 55}
 56
 57pub fn codestral_api_url(cx: &App) -> SharedString {
 58    all_language_settings(None, cx)
 59        .edit_predictions
 60        .codestral
 61        .api_url
 62        .clone()
 63        .unwrap_or_else(|| CODESTRAL_API_URL.to_string())
 64        .into()
 65}
 66
 67/// Represents a completion that has been received and processed from Codestral.
 68/// This struct maintains the state needed to interpolate the completion as the user types.
 69#[derive(Clone)]
 70struct CurrentCompletion {
 71    /// The buffer snapshot at the time the completion was generated.
 72    /// Used to detect changes and interpolate edits.
 73    snapshot: BufferSnapshot,
 74    /// The edits that should be applied to transform the original text into the predicted text.
 75    /// Each edit is a range in the buffer and the text to replace it with.
 76    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 77    /// Preview of how the buffer will look after applying the edits.
 78    edit_preview: EditPreview,
 79}
 80
 81impl CurrentCompletion {
 82    /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
 83    /// Returns None if the user's edits conflict with the predicted edits.
 84    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 85        edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 86    }
 87}
 88
 89pub struct CodestralEditPredictionDelegate {
 90    http_client: Arc<dyn HttpClient>,
 91    pending_request: Option<Task<Result<()>>>,
 92    current_completion: Option<CurrentCompletion>,
 93}
 94
 95impl CodestralEditPredictionDelegate {
 96    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
 97        Self {
 98            http_client,
 99            pending_request: None,
100            current_completion: None,
101        }
102    }
103
104    pub fn ensure_api_key_loaded(cx: &mut App) {
105        load_codestral_api_key(cx).detach();
106    }
107
108    /// Uses Codestral's Fill-in-the-Middle API for code completion.
109    async fn fetch_completion(
110        http_client: Arc<dyn HttpClient>,
111        api_key: &str,
112        prompt: String,
113        suffix: String,
114        model: String,
115        max_tokens: Option<u32>,
116        api_url: String,
117    ) -> Result<String> {
118        let start_time = Instant::now();
119
120        log::debug!(
121            "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
122            model,
123            max_tokens
124        );
125
126        let request = CodestralRequest {
127            model,
128            prompt,
129            suffix: if suffix.is_empty() {
130                None
131            } else {
132                Some(suffix)
133            },
134            max_tokens: max_tokens.or(Some(350)),
135            temperature: Some(0.2),
136            top_p: Some(1.0),
137            stream: Some(false),
138            stop: None,
139            random_seed: None,
140            min_tokens: None,
141        };
142
143        let request_body = serde_json::to_string(&request)?;
144
145        log::debug!("Codestral: Sending FIM request");
146
147        let http_request = http_client::Request::builder()
148            .method(http_client::Method::POST)
149            .uri(format!("{}/v1/fim/completions", api_url))
150            .header("Content-Type", "application/json")
151            .header("Authorization", format!("Bearer {}", api_key))
152            .body(http_client::AsyncBody::from(request_body))?;
153
154        let mut response = http_client.send(http_request).await?;
155        let status = response.status();
156
157        log::debug!("Codestral: Response status: {}", status);
158
159        if !status.is_success() {
160            let mut body = String::new();
161            response.body_mut().read_to_string(&mut body).await?;
162            return Err(anyhow::anyhow!(
163                "Codestral API error: {} - {}",
164                status,
165                body
166            ));
167        }
168
169        let mut body = String::new();
170        response.body_mut().read_to_string(&mut body).await?;
171
172        let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
173
174        let elapsed = start_time.elapsed();
175
176        if let Some(choice) = codestral_response.choices.first() {
177            let completion = &choice.message.content;
178
179            log::debug!(
180                "Codestral: Completion received ({} tokens, {:.2}s)",
181                codestral_response.usage.completion_tokens,
182                elapsed.as_secs_f64()
183            );
184
185            // Return just the completion text for insertion at cursor
186            Ok(completion.clone())
187        } else {
188            log::error!("Codestral: No completion returned in response");
189            Err(anyhow::anyhow!("No completion returned from Codestral"))
190        }
191    }
192}
193
194impl EditPredictionDelegate for CodestralEditPredictionDelegate {
195    fn name() -> &'static str {
196        "codestral"
197    }
198
199    fn display_name() -> &'static str {
200        "Codestral"
201    }
202
203    fn show_predictions_in_menu() -> bool {
204        true
205    }
206
207    fn icons(&self, _cx: &App) -> EditPredictionIconSet {
208        EditPredictionIconSet::new(IconName::AiMistral)
209    }
210
211    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
212        codestral_api_key(cx).is_some()
213    }
214
215    fn is_refreshing(&self, _cx: &App) -> bool {
216        self.pending_request.is_some()
217    }
218
219    fn refresh(
220        &mut self,
221        buffer: Entity<Buffer>,
222        cursor_position: language::Anchor,
223        debounce: bool,
224        cx: &mut Context<Self>,
225    ) {
226        log::debug!("Codestral: Refresh called (debounce: {})", debounce);
227
228        let Some(api_key) = codestral_api_key(cx) else {
229            log::warn!("Codestral: No API key configured, skipping refresh");
230            return;
231        };
232
233        let snapshot = buffer.read(cx).snapshot();
234
235        // Check if current completion is still valid
236        if let Some(current_completion) = self.current_completion.as_ref() {
237            if current_completion.interpolate(&snapshot).is_some() {
238                return;
239            }
240        }
241
242        let http_client = self.http_client.clone();
243
244        // Get settings
245        let settings = all_language_settings(None, cx);
246        let model = settings
247            .edit_predictions
248            .codestral
249            .model
250            .clone()
251            .unwrap_or_else(|| "codestral-latest".to_string());
252        let max_tokens = settings.edit_predictions.codestral.max_tokens;
253        let api_url = codestral_api_url(cx).to_string();
254
255        self.pending_request = Some(cx.spawn(async move |this, cx| {
256            if debounce {
257                log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
258                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
259            }
260
261            let cursor_offset = cursor_position.to_offset(&snapshot);
262            let cursor_point = cursor_offset.to_point(&snapshot);
263
264            const MAX_CONTEXT_TOKENS: usize = 150;
265            const MAX_REWRITE_TOKENS: usize = 350;
266
267            let (_, context_range) =
268                cursor_excerpt::editable_and_context_ranges_for_cursor_position(
269                    cursor_point,
270                    &snapshot,
271                    MAX_REWRITE_TOKENS,
272                    MAX_CONTEXT_TOKENS,
273                );
274
275            let context_range = context_range.to_offset(&snapshot);
276            let excerpt_text = snapshot
277                .text_for_range(context_range.clone())
278                .collect::<String>();
279            let cursor_within_excerpt = cursor_offset
280                .saturating_sub(context_range.start)
281                .min(excerpt_text.len());
282            let prompt = excerpt_text[..cursor_within_excerpt].to_string();
283            let suffix = excerpt_text[cursor_within_excerpt..].to_string();
284
285            let completion_text = match Self::fetch_completion(
286                http_client,
287                &api_key,
288                prompt,
289                suffix,
290                model,
291                max_tokens,
292                api_url,
293            )
294            .await
295            {
296                Ok(completion) => completion,
297                Err(e) => {
298                    log::error!("Codestral: Failed to fetch completion: {}", e);
299                    this.update(cx, |this, cx| {
300                        this.pending_request = None;
301                        cx.notify();
302                    })?;
303                    return Err(e);
304                }
305            };
306
307            if completion_text.trim().is_empty() {
308                log::debug!("Codestral: Completion was empty after trimming; ignoring");
309                this.update(cx, |this, cx| {
310                    this.pending_request = None;
311                    cx.notify();
312                })?;
313                return Ok(());
314            }
315
316            let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
317                vec![(cursor_position..cursor_position, completion_text.into())].into();
318            let edit_preview = buffer
319                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
320                .await;
321
322            this.update(cx, |this, cx| {
323                this.current_completion = Some(CurrentCompletion {
324                    snapshot,
325                    edits,
326                    edit_preview,
327                });
328                this.pending_request = None;
329                cx.notify();
330            })?;
331
332            Ok(())
333        }));
334    }
335
336    fn accept(&mut self, _cx: &mut Context<Self>) {
337        log::debug!("Codestral: Completion accepted");
338        self.pending_request = None;
339        self.current_completion = None;
340    }
341
342    fn discard(&mut self, _reason: EditPredictionDiscardReason, _cx: &mut Context<Self>) {
343        log::debug!("Codestral: Completion discarded");
344        self.pending_request = None;
345        self.current_completion = None;
346    }
347
348    /// Returns the completion suggestion, adjusted or invalidated based on user edits
349    fn suggest(
350        &mut self,
351        buffer: &Entity<Buffer>,
352        _cursor_position: Anchor,
353        cx: &mut Context<Self>,
354    ) -> Option<EditPrediction> {
355        let current_completion = self.current_completion.as_ref()?;
356        let buffer = buffer.read(cx);
357        let edits = current_completion.interpolate(&buffer.snapshot())?;
358        if edits.is_empty() {
359            return None;
360        }
361        Some(EditPrediction::Local {
362            id: None,
363            edits,
364            cursor_position: None,
365            edit_preview: Some(current_completion.edit_preview.clone()),
366        })
367    }
368}
369
370#[derive(Debug, Serialize, Deserialize)]
371pub struct CodestralRequest {
372    pub model: String,
373    pub prompt: String,
374    #[serde(skip_serializing_if = "Option::is_none")]
375    pub suffix: Option<String>,
376    #[serde(skip_serializing_if = "Option::is_none")]
377    pub max_tokens: Option<u32>,
378    #[serde(skip_serializing_if = "Option::is_none")]
379    pub temperature: Option<f32>,
380    #[serde(skip_serializing_if = "Option::is_none")]
381    pub top_p: Option<f32>,
382    #[serde(skip_serializing_if = "Option::is_none")]
383    pub stream: Option<bool>,
384    #[serde(skip_serializing_if = "Option::is_none")]
385    pub stop: Option<Vec<String>>,
386    #[serde(skip_serializing_if = "Option::is_none")]
387    pub random_seed: Option<u32>,
388    #[serde(skip_serializing_if = "Option::is_none")]
389    pub min_tokens: Option<u32>,
390}
391
392#[derive(Debug, Deserialize)]
393pub struct CodestralResponse {
394    pub id: String,
395    pub object: String,
396    pub model: String,
397    pub usage: Usage,
398    pub created: u64,
399    pub choices: Vec<Choice>,
400}
401
402#[derive(Debug, Deserialize)]
403pub struct Usage {
404    pub prompt_tokens: u32,
405    pub completion_tokens: u32,
406    pub total_tokens: u32,
407}
408
409#[derive(Debug, Deserialize)]
410pub struct Choice {
411    pub index: u32,
412    pub message: Message,
413    pub finish_reason: String,
414}
415
416#[derive(Debug, Deserialize)]
417pub struct Message {
418    pub content: String,
419    pub role: String,
420}