codestral.rs

  1use anyhow::Result;
  2use edit_prediction::cursor_excerpt;
  3use edit_prediction_types::{
  4    EditPrediction, EditPredictionDelegate, EditPredictionDismissReason, EditPredictionIconSet,
  5};
  6use futures::AsyncReadExt;
  7use gpui::{App, Context, Entity, Task};
  8use http_client::HttpClient;
  9use icons::IconName;
 10use language::{
 11    Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint, language_settings::all_language_settings,
 12};
 13use language_models::MistralLanguageModelProvider;
 14use mistral::CODESTRAL_API_URL;
 15use serde::{Deserialize, Serialize};
 16use std::{
 17    ops::Range,
 18    sync::Arc,
 19    time::{Duration, Instant},
 20};
 21use text::{OffsetRangeExt as _, ToOffset};
 22
 23pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
 24
 25/// Represents a completion that has been received and processed from Codestral.
 26/// This struct maintains the state needed to interpolate the completion as the user types.
 27#[derive(Clone)]
 28struct CurrentCompletion {
 29    /// The buffer snapshot at the time the completion was generated.
 30    /// Used to detect changes and interpolate edits.
 31    snapshot: BufferSnapshot,
 32    /// The edits that should be applied to transform the original text into the predicted text.
 33    /// Each edit is a range in the buffer and the text to replace it with.
 34    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 35    /// Preview of how the buffer will look after applying the edits.
 36    edit_preview: EditPreview,
 37}
 38
 39impl CurrentCompletion {
 40    /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
 41    /// Returns None if the user's edits conflict with the predicted edits.
 42    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 43        edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 44    }
 45}
 46
 47pub struct CodestralEditPredictionDelegate {
 48    http_client: Arc<dyn HttpClient>,
 49    pending_request: Option<Task<Result<()>>>,
 50    current_completion: Option<CurrentCompletion>,
 51}
 52
 53impl CodestralEditPredictionDelegate {
 54    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
 55        Self {
 56            http_client,
 57            pending_request: None,
 58            current_completion: None,
 59        }
 60    }
 61
 62    pub fn has_api_key(cx: &App) -> bool {
 63        Self::api_key(cx).is_some()
 64    }
 65
 66    /// This is so we can immediately show Codestral as a provider users can
 67    /// switch to in the edit prediction menu, if the API has been added
 68    pub fn ensure_api_key_loaded(http_client: Arc<dyn HttpClient>, cx: &mut App) {
 69        MistralLanguageModelProvider::global(http_client, cx)
 70            .load_codestral_api_key(cx)
 71            .detach();
 72    }
 73
 74    fn api_key(cx: &App) -> Option<Arc<str>> {
 75        MistralLanguageModelProvider::try_global(cx)
 76            .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
 77    }
 78
 79    /// Uses Codestral's Fill-in-the-Middle API for code completion.
 80    async fn fetch_completion(
 81        http_client: Arc<dyn HttpClient>,
 82        api_key: &str,
 83        prompt: String,
 84        suffix: String,
 85        model: String,
 86        max_tokens: Option<u32>,
 87        api_url: String,
 88    ) -> Result<String> {
 89        let start_time = Instant::now();
 90
 91        log::debug!(
 92            "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
 93            model,
 94            max_tokens
 95        );
 96
 97        let request = CodestralRequest {
 98            model,
 99            prompt,
100            suffix: if suffix.is_empty() {
101                None
102            } else {
103                Some(suffix)
104            },
105            max_tokens: max_tokens.or(Some(350)),
106            temperature: Some(0.2),
107            top_p: Some(1.0),
108            stream: Some(false),
109            stop: None,
110            random_seed: None,
111            min_tokens: None,
112        };
113
114        let request_body = serde_json::to_string(&request)?;
115
116        log::debug!("Codestral: Sending FIM request");
117
118        let http_request = http_client::Request::builder()
119            .method(http_client::Method::POST)
120            .uri(format!("{}/v1/fim/completions", api_url))
121            .header("Content-Type", "application/json")
122            .header("Authorization", format!("Bearer {}", api_key))
123            .body(http_client::AsyncBody::from(request_body))?;
124
125        let mut response = http_client.send(http_request).await?;
126        let status = response.status();
127
128        log::debug!("Codestral: Response status: {}", status);
129
130        if !status.is_success() {
131            let mut body = String::new();
132            response.body_mut().read_to_string(&mut body).await?;
133            return Err(anyhow::anyhow!(
134                "Codestral API error: {} - {}",
135                status,
136                body
137            ));
138        }
139
140        let mut body = String::new();
141        response.body_mut().read_to_string(&mut body).await?;
142
143        let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
144
145        let elapsed = start_time.elapsed();
146
147        if let Some(choice) = codestral_response.choices.first() {
148            let completion = &choice.message.content;
149
150            log::debug!(
151                "Codestral: Completion received ({} tokens, {:.2}s)",
152                codestral_response.usage.completion_tokens,
153                elapsed.as_secs_f64()
154            );
155
156            // Return just the completion text for insertion at cursor
157            Ok(completion.clone())
158        } else {
159            log::error!("Codestral: No completion returned in response");
160            Err(anyhow::anyhow!("No completion returned from Codestral"))
161        }
162    }
163}
164
165impl EditPredictionDelegate for CodestralEditPredictionDelegate {
166    fn name() -> &'static str {
167        "codestral"
168    }
169
170    fn display_name() -> &'static str {
171        "Codestral"
172    }
173
174    fn show_predictions_in_menu() -> bool {
175        true
176    }
177
178    fn icons(&self, _cx: &App) -> EditPredictionIconSet {
179        EditPredictionIconSet::new(IconName::AiMistral)
180    }
181
182    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
183        Self::api_key(cx).is_some()
184    }
185
186    fn is_refreshing(&self, _cx: &App) -> bool {
187        self.pending_request.is_some()
188    }
189
190    fn refresh(
191        &mut self,
192        buffer: Entity<Buffer>,
193        cursor_position: language::Anchor,
194        debounce: bool,
195        cx: &mut Context<Self>,
196    ) {
197        log::debug!("Codestral: Refresh called (debounce: {})", debounce);
198
199        let Some(api_key) = Self::api_key(cx) else {
200            log::warn!("Codestral: No API key configured, skipping refresh");
201            return;
202        };
203
204        let snapshot = buffer.read(cx).snapshot();
205
206        // Check if current completion is still valid
207        if let Some(current_completion) = self.current_completion.as_ref() {
208            if current_completion.interpolate(&snapshot).is_some() {
209                return;
210            }
211        }
212
213        let http_client = self.http_client.clone();
214
215        // Get settings
216        let settings = all_language_settings(None, cx);
217        let model = settings
218            .edit_predictions
219            .codestral
220            .model
221            .clone()
222            .unwrap_or_else(|| "codestral-latest".to_string());
223        let max_tokens = settings.edit_predictions.codestral.max_tokens;
224        let api_url = settings
225            .edit_predictions
226            .codestral
227            .api_url
228            .clone()
229            .unwrap_or_else(|| CODESTRAL_API_URL.to_string());
230
231        self.pending_request = Some(cx.spawn(async move |this, cx| {
232            if debounce {
233                log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
234                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
235            }
236
237            let cursor_offset = cursor_position.to_offset(&snapshot);
238            let cursor_point = cursor_offset.to_point(&snapshot);
239
240            const MAX_CONTEXT_TOKENS: usize = 150;
241            const MAX_REWRITE_TOKENS: usize = 350;
242
243            let (_, context_range) =
244                cursor_excerpt::editable_and_context_ranges_for_cursor_position(
245                    cursor_point,
246                    &snapshot,
247                    MAX_REWRITE_TOKENS,
248                    MAX_CONTEXT_TOKENS,
249                );
250
251            let context_range = context_range.to_offset(&snapshot);
252            let excerpt_text = snapshot
253                .text_for_range(context_range.clone())
254                .collect::<String>();
255            let cursor_within_excerpt = cursor_offset
256                .saturating_sub(context_range.start)
257                .min(excerpt_text.len());
258            let prompt = excerpt_text[..cursor_within_excerpt].to_string();
259            let suffix = excerpt_text[cursor_within_excerpt..].to_string();
260
261            let completion_text = match Self::fetch_completion(
262                http_client,
263                &api_key,
264                prompt,
265                suffix,
266                model,
267                max_tokens,
268                api_url,
269            )
270            .await
271            {
272                Ok(completion) => completion,
273                Err(e) => {
274                    log::error!("Codestral: Failed to fetch completion: {}", e);
275                    this.update(cx, |this, cx| {
276                        this.pending_request = None;
277                        cx.notify();
278                    })?;
279                    return Err(e);
280                }
281            };
282
283            if completion_text.trim().is_empty() {
284                log::debug!("Codestral: Completion was empty after trimming; ignoring");
285                this.update(cx, |this, cx| {
286                    this.pending_request = None;
287                    cx.notify();
288                })?;
289                return Ok(());
290            }
291
292            let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
293                vec![(cursor_position..cursor_position, completion_text.into())].into();
294            let edit_preview = buffer
295                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
296                .await;
297
298            this.update(cx, |this, cx| {
299                this.current_completion = Some(CurrentCompletion {
300                    snapshot,
301                    edits,
302                    edit_preview,
303                });
304                this.pending_request = None;
305                cx.notify();
306            })?;
307
308            Ok(())
309        }));
310    }
311
312    fn accept(&mut self, _cx: &mut Context<Self>) {
313        log::debug!("Codestral: Completion accepted");
314        self.pending_request = None;
315        self.current_completion = None;
316    }
317
318    fn discard(&mut self, _reason: EditPredictionDismissReason, _cx: &mut Context<Self>) {
319        log::debug!("Codestral: Completion discarded");
320        self.pending_request = None;
321        self.current_completion = None;
322    }
323
324    /// Returns the completion suggestion, adjusted or invalidated based on user edits
325    fn suggest(
326        &mut self,
327        buffer: &Entity<Buffer>,
328        _cursor_position: Anchor,
329        cx: &mut Context<Self>,
330    ) -> Option<EditPrediction> {
331        let current_completion = self.current_completion.as_ref()?;
332        let buffer = buffer.read(cx);
333        let edits = current_completion.interpolate(&buffer.snapshot())?;
334        if edits.is_empty() {
335            return None;
336        }
337        Some(EditPrediction::Local {
338            id: None,
339            edits,
340            cursor_position: None,
341            edit_preview: Some(current_completion.edit_preview.clone()),
342        })
343    }
344}
345
346#[derive(Debug, Serialize, Deserialize)]
347pub struct CodestralRequest {
348    pub model: String,
349    pub prompt: String,
350    #[serde(skip_serializing_if = "Option::is_none")]
351    pub suffix: Option<String>,
352    #[serde(skip_serializing_if = "Option::is_none")]
353    pub max_tokens: Option<u32>,
354    #[serde(skip_serializing_if = "Option::is_none")]
355    pub temperature: Option<f32>,
356    #[serde(skip_serializing_if = "Option::is_none")]
357    pub top_p: Option<f32>,
358    #[serde(skip_serializing_if = "Option::is_none")]
359    pub stream: Option<bool>,
360    #[serde(skip_serializing_if = "Option::is_none")]
361    pub stop: Option<Vec<String>>,
362    #[serde(skip_serializing_if = "Option::is_none")]
363    pub random_seed: Option<u32>,
364    #[serde(skip_serializing_if = "Option::is_none")]
365    pub min_tokens: Option<u32>,
366}
367
368#[derive(Debug, Deserialize)]
369pub struct CodestralResponse {
370    pub id: String,
371    pub object: String,
372    pub model: String,
373    pub usage: Usage,
374    pub created: u64,
375    pub choices: Vec<Choice>,
376}
377
378#[derive(Debug, Deserialize)]
379pub struct Usage {
380    pub prompt_tokens: u32,
381    pub completion_tokens: u32,
382    pub total_tokens: u32,
383}
384
385#[derive(Debug, Deserialize)]
386pub struct Choice {
387    pub index: u32,
388    pub message: Message,
389    pub finish_reason: String,
390}
391
392#[derive(Debug, Deserialize)]
393pub struct Message {
394    pub content: String,
395    pub role: String,
396}