codestral.rs

  1use anyhow::Result;
  2use edit_prediction::cursor_excerpt;
  3use edit_prediction_types::{EditPrediction, EditPredictionDelegate, EditPredictionIconSet};
  4use futures::AsyncReadExt;
  5use gpui::{App, Context, Entity, Task};
  6use http_client::HttpClient;
  7use icons::IconName;
  8use language::{
  9    Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint, language_settings::all_language_settings,
 10};
 11use language_models::MistralLanguageModelProvider;
 12use mistral::CODESTRAL_API_URL;
 13use serde::{Deserialize, Serialize};
 14use std::{
 15    ops::Range,
 16    sync::Arc,
 17    time::{Duration, Instant},
 18};
 19use text::{OffsetRangeExt as _, ToOffset};
 20
 21pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
 22
 23/// Represents a completion that has been received and processed from Codestral.
 24/// This struct maintains the state needed to interpolate the completion as the user types.
 25#[derive(Clone)]
 26struct CurrentCompletion {
 27    /// The buffer snapshot at the time the completion was generated.
 28    /// Used to detect changes and interpolate edits.
 29    snapshot: BufferSnapshot,
 30    /// The edits that should be applied to transform the original text into the predicted text.
 31    /// Each edit is a range in the buffer and the text to replace it with.
 32    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 33    /// Preview of how the buffer will look after applying the edits.
 34    edit_preview: EditPreview,
 35}
 36
 37impl CurrentCompletion {
 38    /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
 39    /// Returns None if the user's edits conflict with the predicted edits.
 40    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 41        edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 42    }
 43}
 44
 45pub struct CodestralEditPredictionDelegate {
 46    http_client: Arc<dyn HttpClient>,
 47    pending_request: Option<Task<Result<()>>>,
 48    current_completion: Option<CurrentCompletion>,
 49}
 50
 51impl CodestralEditPredictionDelegate {
 52    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
 53        Self {
 54            http_client,
 55            pending_request: None,
 56            current_completion: None,
 57        }
 58    }
 59
 60    pub fn has_api_key(cx: &App) -> bool {
 61        Self::api_key(cx).is_some()
 62    }
 63
 64    /// This is so we can immediately show Codestral as a provider users can
 65    /// switch to in the edit prediction menu, if the API has been added
 66    pub fn ensure_api_key_loaded(http_client: Arc<dyn HttpClient>, cx: &mut App) {
 67        MistralLanguageModelProvider::global(http_client, cx)
 68            .load_codestral_api_key(cx)
 69            .detach();
 70    }
 71
 72    fn api_key(cx: &App) -> Option<Arc<str>> {
 73        MistralLanguageModelProvider::try_global(cx)
 74            .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
 75    }
 76
 77    /// Uses Codestral's Fill-in-the-Middle API for code completion.
 78    async fn fetch_completion(
 79        http_client: Arc<dyn HttpClient>,
 80        api_key: &str,
 81        prompt: String,
 82        suffix: String,
 83        model: String,
 84        max_tokens: Option<u32>,
 85        api_url: String,
 86    ) -> Result<String> {
 87        let start_time = Instant::now();
 88
 89        log::debug!(
 90            "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
 91            model,
 92            max_tokens
 93        );
 94
 95        let request = CodestralRequest {
 96            model,
 97            prompt,
 98            suffix: if suffix.is_empty() {
 99                None
100            } else {
101                Some(suffix)
102            },
103            max_tokens: max_tokens.or(Some(350)),
104            temperature: Some(0.2),
105            top_p: Some(1.0),
106            stream: Some(false),
107            stop: None,
108            random_seed: None,
109            min_tokens: None,
110        };
111
112        let request_body = serde_json::to_string(&request)?;
113
114        log::debug!("Codestral: Sending FIM request");
115
116        let http_request = http_client::Request::builder()
117            .method(http_client::Method::POST)
118            .uri(format!("{}/v1/fim/completions", api_url))
119            .header("Content-Type", "application/json")
120            .header("Authorization", format!("Bearer {}", api_key))
121            .body(http_client::AsyncBody::from(request_body))?;
122
123        let mut response = http_client.send(http_request).await?;
124        let status = response.status();
125
126        log::debug!("Codestral: Response status: {}", status);
127
128        if !status.is_success() {
129            let mut body = String::new();
130            response.body_mut().read_to_string(&mut body).await?;
131            return Err(anyhow::anyhow!(
132                "Codestral API error: {} - {}",
133                status,
134                body
135            ));
136        }
137
138        let mut body = String::new();
139        response.body_mut().read_to_string(&mut body).await?;
140
141        let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
142
143        let elapsed = start_time.elapsed();
144
145        if let Some(choice) = codestral_response.choices.first() {
146            let completion = &choice.message.content;
147
148            log::debug!(
149                "Codestral: Completion received ({} tokens, {:.2}s)",
150                codestral_response.usage.completion_tokens,
151                elapsed.as_secs_f64()
152            );
153
154            // Return just the completion text for insertion at cursor
155            Ok(completion.clone())
156        } else {
157            log::error!("Codestral: No completion returned in response");
158            Err(anyhow::anyhow!("No completion returned from Codestral"))
159        }
160    }
161}
162
163impl EditPredictionDelegate for CodestralEditPredictionDelegate {
164    fn name() -> &'static str {
165        "codestral"
166    }
167
168    fn display_name() -> &'static str {
169        "Codestral"
170    }
171
172    fn show_predictions_in_menu() -> bool {
173        true
174    }
175
176    fn icons(&self, _cx: &App) -> EditPredictionIconSet {
177        EditPredictionIconSet::new(IconName::AiMistral)
178    }
179
180    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
181        Self::api_key(cx).is_some()
182    }
183
184    fn is_refreshing(&self, _cx: &App) -> bool {
185        self.pending_request.is_some()
186    }
187
188    fn refresh(
189        &mut self,
190        buffer: Entity<Buffer>,
191        cursor_position: language::Anchor,
192        debounce: bool,
193        cx: &mut Context<Self>,
194    ) {
195        log::debug!("Codestral: Refresh called (debounce: {})", debounce);
196
197        let Some(api_key) = Self::api_key(cx) else {
198            log::warn!("Codestral: No API key configured, skipping refresh");
199            return;
200        };
201
202        let snapshot = buffer.read(cx).snapshot();
203
204        // Check if current completion is still valid
205        if let Some(current_completion) = self.current_completion.as_ref() {
206            if current_completion.interpolate(&snapshot).is_some() {
207                return;
208            }
209        }
210
211        let http_client = self.http_client.clone();
212
213        // Get settings
214        let settings = all_language_settings(None, cx);
215        let model = settings
216            .edit_predictions
217            .codestral
218            .model
219            .clone()
220            .unwrap_or_else(|| "codestral-latest".to_string());
221        let max_tokens = settings.edit_predictions.codestral.max_tokens;
222        let api_url = settings
223            .edit_predictions
224            .codestral
225            .api_url
226            .clone()
227            .unwrap_or_else(|| CODESTRAL_API_URL.to_string());
228
229        self.pending_request = Some(cx.spawn(async move |this, cx| {
230            if debounce {
231                log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
232                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
233            }
234
235            let cursor_offset = cursor_position.to_offset(&snapshot);
236            let cursor_point = cursor_offset.to_point(&snapshot);
237
238            const MAX_CONTEXT_TOKENS: usize = 150;
239            const MAX_REWRITE_TOKENS: usize = 350;
240
241            let (_, context_range) =
242                cursor_excerpt::editable_and_context_ranges_for_cursor_position(
243                    cursor_point,
244                    &snapshot,
245                    MAX_REWRITE_TOKENS,
246                    MAX_CONTEXT_TOKENS,
247                );
248
249            let context_range = context_range.to_offset(&snapshot);
250            let excerpt_text = snapshot
251                .text_for_range(context_range.clone())
252                .collect::<String>();
253            let cursor_within_excerpt = cursor_offset
254                .saturating_sub(context_range.start)
255                .min(excerpt_text.len());
256            let prompt = excerpt_text[..cursor_within_excerpt].to_string();
257            let suffix = excerpt_text[cursor_within_excerpt..].to_string();
258
259            let completion_text = match Self::fetch_completion(
260                http_client,
261                &api_key,
262                prompt,
263                suffix,
264                model,
265                max_tokens,
266                api_url,
267            )
268            .await
269            {
270                Ok(completion) => completion,
271                Err(e) => {
272                    log::error!("Codestral: Failed to fetch completion: {}", e);
273                    this.update(cx, |this, cx| {
274                        this.pending_request = None;
275                        cx.notify();
276                    })?;
277                    return Err(e);
278                }
279            };
280
281            if completion_text.trim().is_empty() {
282                log::debug!("Codestral: Completion was empty after trimming; ignoring");
283                this.update(cx, |this, cx| {
284                    this.pending_request = None;
285                    cx.notify();
286                })?;
287                return Ok(());
288            }
289
290            let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
291                vec![(cursor_position..cursor_position, completion_text.into())].into();
292            let edit_preview = buffer
293                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
294                .await;
295
296            this.update(cx, |this, cx| {
297                this.current_completion = Some(CurrentCompletion {
298                    snapshot,
299                    edits,
300                    edit_preview,
301                });
302                this.pending_request = None;
303                cx.notify();
304            })?;
305
306            Ok(())
307        }));
308    }
309
310    fn accept(&mut self, _cx: &mut Context<Self>) {
311        log::debug!("Codestral: Completion accepted");
312        self.pending_request = None;
313        self.current_completion = None;
314    }
315
316    fn discard(&mut self, _cx: &mut Context<Self>) {
317        log::debug!("Codestral: Completion discarded");
318        self.pending_request = None;
319        self.current_completion = None;
320    }
321
322    /// Returns the completion suggestion, adjusted or invalidated based on user edits
323    fn suggest(
324        &mut self,
325        buffer: &Entity<Buffer>,
326        _cursor_position: Anchor,
327        cx: &mut Context<Self>,
328    ) -> Option<EditPrediction> {
329        let current_completion = self.current_completion.as_ref()?;
330        let buffer = buffer.read(cx);
331        let edits = current_completion.interpolate(&buffer.snapshot())?;
332        if edits.is_empty() {
333            return None;
334        }
335        Some(EditPrediction::Local {
336            id: None,
337            edits,
338            cursor_position: None,
339            edit_preview: Some(current_completion.edit_preview.clone()),
340        })
341    }
342}
343
344#[derive(Debug, Serialize, Deserialize)]
345pub struct CodestralRequest {
346    pub model: String,
347    pub prompt: String,
348    #[serde(skip_serializing_if = "Option::is_none")]
349    pub suffix: Option<String>,
350    #[serde(skip_serializing_if = "Option::is_none")]
351    pub max_tokens: Option<u32>,
352    #[serde(skip_serializing_if = "Option::is_none")]
353    pub temperature: Option<f32>,
354    #[serde(skip_serializing_if = "Option::is_none")]
355    pub top_p: Option<f32>,
356    #[serde(skip_serializing_if = "Option::is_none")]
357    pub stream: Option<bool>,
358    #[serde(skip_serializing_if = "Option::is_none")]
359    pub stop: Option<Vec<String>>,
360    #[serde(skip_serializing_if = "Option::is_none")]
361    pub random_seed: Option<u32>,
362    #[serde(skip_serializing_if = "Option::is_none")]
363    pub min_tokens: Option<u32>,
364}
365
366#[derive(Debug, Deserialize)]
367pub struct CodestralResponse {
368    pub id: String,
369    pub object: String,
370    pub model: String,
371    pub usage: Usage,
372    pub created: u64,
373    pub choices: Vec<Choice>,
374}
375
376#[derive(Debug, Deserialize)]
377pub struct Usage {
378    pub prompt_tokens: u32,
379    pub completion_tokens: u32,
380    pub total_tokens: u32,
381}
382
383#[derive(Debug, Deserialize)]
384pub struct Choice {
385    pub index: u32,
386    pub message: Message,
387    pub finish_reason: String,
388}
389
390#[derive(Debug, Deserialize)]
391pub struct Message {
392    pub content: String,
393    pub role: String,
394}