codestral.rs

  1use anyhow::Result;
  2use edit_prediction::cursor_excerpt;
  3use edit_prediction_types::{EditPrediction, EditPredictionDelegate};
  4use futures::AsyncReadExt;
  5use gpui::{App, Context, Entity, Task};
  6use http_client::HttpClient;
  7use language::{
  8    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint,
  9};
 10use language_models::MistralLanguageModelProvider;
 11use mistral::CODESTRAL_API_URL;
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    ops::Range,
 15    sync::Arc,
 16    time::{Duration, Instant},
 17};
 18use text::{OffsetRangeExt as _, ToOffset};
 19
 20pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
 21
 22/// Represents a completion that has been received and processed from Codestral.
 23/// This struct maintains the state needed to interpolate the completion as the user types.
 24#[derive(Clone)]
 25struct CurrentCompletion {
 26    /// The buffer snapshot at the time the completion was generated.
 27    /// Used to detect changes and interpolate edits.
 28    snapshot: BufferSnapshot,
 29    /// The edits that should be applied to transform the original text into the predicted text.
 30    /// Each edit is a range in the buffer and the text to replace it with.
 31    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 32    /// Preview of how the buffer will look after applying the edits.
 33    edit_preview: EditPreview,
 34}
 35
 36impl CurrentCompletion {
 37    /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
 38    /// Returns None if the user's edits conflict with the predicted edits.
 39    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 40        edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 41    }
 42}
 43
 44pub struct CodestralEditPredictionDelegate {
 45    http_client: Arc<dyn HttpClient>,
 46    pending_request: Option<Task<Result<()>>>,
 47    current_completion: Option<CurrentCompletion>,
 48}
 49
 50impl CodestralEditPredictionDelegate {
 51    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
 52        Self {
 53            http_client,
 54            pending_request: None,
 55            current_completion: None,
 56        }
 57    }
 58
 59    pub fn has_api_key(cx: &App) -> bool {
 60        Self::api_key(cx).is_some()
 61    }
 62
 63    /// This is so we can immediately show Codestral as a provider users can
 64    /// switch to in the edit prediction menu, if the API has been added
 65    pub fn ensure_api_key_loaded(http_client: Arc<dyn HttpClient>, cx: &mut App) {
 66        MistralLanguageModelProvider::global(http_client, cx)
 67            .load_codestral_api_key(cx)
 68            .detach();
 69    }
 70
 71    fn api_key(cx: &App) -> Option<Arc<str>> {
 72        MistralLanguageModelProvider::try_global(cx)
 73            .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
 74    }
 75
 76    /// Uses Codestral's Fill-in-the-Middle API for code completion.
 77    async fn fetch_completion(
 78        http_client: Arc<dyn HttpClient>,
 79        api_key: &str,
 80        prompt: String,
 81        suffix: String,
 82        model: String,
 83        max_tokens: Option<u32>,
 84        api_url: String,
 85    ) -> Result<String> {
 86        let start_time = Instant::now();
 87
 88        log::debug!(
 89            "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
 90            model,
 91            max_tokens
 92        );
 93
 94        let request = CodestralRequest {
 95            model,
 96            prompt,
 97            suffix: if suffix.is_empty() {
 98                None
 99            } else {
100                Some(suffix)
101            },
102            max_tokens: max_tokens.or(Some(350)),
103            temperature: Some(0.2),
104            top_p: Some(1.0),
105            stream: Some(false),
106            stop: None,
107            random_seed: None,
108            min_tokens: None,
109        };
110
111        let request_body = serde_json::to_string(&request)?;
112
113        log::debug!("Codestral: Sending FIM request");
114
115        let http_request = http_client::Request::builder()
116            .method(http_client::Method::POST)
117            .uri(format!("{}/v1/fim/completions", api_url))
118            .header("Content-Type", "application/json")
119            .header("Authorization", format!("Bearer {}", api_key))
120            .body(http_client::AsyncBody::from(request_body))?;
121
122        let mut response = http_client.send(http_request).await?;
123        let status = response.status();
124
125        log::debug!("Codestral: Response status: {}", status);
126
127        if !status.is_success() {
128            let mut body = String::new();
129            response.body_mut().read_to_string(&mut body).await?;
130            return Err(anyhow::anyhow!(
131                "Codestral API error: {} - {}",
132                status,
133                body
134            ));
135        }
136
137        let mut body = String::new();
138        response.body_mut().read_to_string(&mut body).await?;
139
140        let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
141
142        let elapsed = start_time.elapsed();
143
144        if let Some(choice) = codestral_response.choices.first() {
145            let completion = &choice.message.content;
146
147            log::debug!(
148                "Codestral: Completion received ({} tokens, {:.2}s)",
149                codestral_response.usage.completion_tokens,
150                elapsed.as_secs_f64()
151            );
152
153            // Return just the completion text for insertion at cursor
154            Ok(completion.clone())
155        } else {
156            log::error!("Codestral: No completion returned in response");
157            Err(anyhow::anyhow!("No completion returned from Codestral"))
158        }
159    }
160}
161
162impl EditPredictionDelegate for CodestralEditPredictionDelegate {
163    fn name() -> &'static str {
164        "codestral"
165    }
166
167    fn display_name() -> &'static str {
168        "Codestral"
169    }
170
171    fn show_predictions_in_menu() -> bool {
172        true
173    }
174
175    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
176        Self::api_key(cx).is_some()
177    }
178
179    fn is_refreshing(&self, _cx: &App) -> bool {
180        self.pending_request.is_some()
181    }
182
183    fn refresh(
184        &mut self,
185        buffer: Entity<Buffer>,
186        cursor_position: language::Anchor,
187        debounce: bool,
188        cx: &mut Context<Self>,
189    ) {
190        log::debug!("Codestral: Refresh called (debounce: {})", debounce);
191
192        let Some(api_key) = Self::api_key(cx) else {
193            log::warn!("Codestral: No API key configured, skipping refresh");
194            return;
195        };
196
197        let snapshot = buffer.read(cx).snapshot();
198
199        // Check if current completion is still valid
200        if let Some(current_completion) = self.current_completion.as_ref() {
201            if current_completion.interpolate(&snapshot).is_some() {
202                return;
203            }
204        }
205
206        let http_client = self.http_client.clone();
207
208        // Get settings
209        let settings = all_language_settings(None, cx);
210        let model = settings
211            .edit_predictions
212            .codestral
213            .model
214            .clone()
215            .unwrap_or_else(|| "codestral-latest".to_string());
216        let max_tokens = settings.edit_predictions.codestral.max_tokens;
217        let api_url = settings
218            .edit_predictions
219            .codestral
220            .api_url
221            .clone()
222            .unwrap_or_else(|| CODESTRAL_API_URL.to_string());
223
224        self.pending_request = Some(cx.spawn(async move |this, cx| {
225            if debounce {
226                log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
227                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
228            }
229
230            let cursor_offset = cursor_position.to_offset(&snapshot);
231            let cursor_point = cursor_offset.to_point(&snapshot);
232
233            const MAX_CONTEXT_TOKENS: usize = 150;
234            const MAX_REWRITE_TOKENS: usize = 350;
235
236            let (_, context_range) =
237                cursor_excerpt::editable_and_context_ranges_for_cursor_position(
238                    cursor_point,
239                    &snapshot,
240                    MAX_REWRITE_TOKENS,
241                    MAX_CONTEXT_TOKENS,
242                );
243
244            let context_range = context_range.to_offset(&snapshot);
245            let excerpt_text = snapshot
246                .text_for_range(context_range.clone())
247                .collect::<String>();
248            let cursor_within_excerpt = cursor_offset
249                .saturating_sub(context_range.start)
250                .min(excerpt_text.len());
251            let prompt = excerpt_text[..cursor_within_excerpt].to_string();
252            let suffix = excerpt_text[cursor_within_excerpt..].to_string();
253
254            let completion_text = match Self::fetch_completion(
255                http_client,
256                &api_key,
257                prompt,
258                suffix,
259                model,
260                max_tokens,
261                api_url,
262            )
263            .await
264            {
265                Ok(completion) => completion,
266                Err(e) => {
267                    log::error!("Codestral: Failed to fetch completion: {}", e);
268                    this.update(cx, |this, cx| {
269                        this.pending_request = None;
270                        cx.notify();
271                    })?;
272                    return Err(e);
273                }
274            };
275
276            if completion_text.trim().is_empty() {
277                log::debug!("Codestral: Completion was empty after trimming; ignoring");
278                this.update(cx, |this, cx| {
279                    this.pending_request = None;
280                    cx.notify();
281                })?;
282                return Ok(());
283            }
284
285            let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
286                vec![(cursor_position..cursor_position, completion_text.into())].into();
287            let edit_preview = buffer
288                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
289                .await;
290
291            this.update(cx, |this, cx| {
292                this.current_completion = Some(CurrentCompletion {
293                    snapshot,
294                    edits,
295                    edit_preview,
296                });
297                this.pending_request = None;
298                cx.notify();
299            })?;
300
301            Ok(())
302        }));
303    }
304
305    fn accept(&mut self, _cx: &mut Context<Self>) {
306        log::debug!("Codestral: Completion accepted");
307        self.pending_request = None;
308        self.current_completion = None;
309    }
310
311    fn discard(&mut self, _cx: &mut Context<Self>) {
312        log::debug!("Codestral: Completion discarded");
313        self.pending_request = None;
314        self.current_completion = None;
315    }
316
317    /// Returns the completion suggestion, adjusted or invalidated based on user edits
318    fn suggest(
319        &mut self,
320        buffer: &Entity<Buffer>,
321        _cursor_position: Anchor,
322        cx: &mut Context<Self>,
323    ) -> Option<EditPrediction> {
324        let current_completion = self.current_completion.as_ref()?;
325        let buffer = buffer.read(cx);
326        let edits = current_completion.interpolate(&buffer.snapshot())?;
327        if edits.is_empty() {
328            return None;
329        }
330        Some(EditPrediction::Local {
331            id: None,
332            edits,
333            edit_preview: Some(current_completion.edit_preview.clone()),
334        })
335    }
336}
337
338#[derive(Debug, Serialize, Deserialize)]
339pub struct CodestralRequest {
340    pub model: String,
341    pub prompt: String,
342    #[serde(skip_serializing_if = "Option::is_none")]
343    pub suffix: Option<String>,
344    #[serde(skip_serializing_if = "Option::is_none")]
345    pub max_tokens: Option<u32>,
346    #[serde(skip_serializing_if = "Option::is_none")]
347    pub temperature: Option<f32>,
348    #[serde(skip_serializing_if = "Option::is_none")]
349    pub top_p: Option<f32>,
350    #[serde(skip_serializing_if = "Option::is_none")]
351    pub stream: Option<bool>,
352    #[serde(skip_serializing_if = "Option::is_none")]
353    pub stop: Option<Vec<String>>,
354    #[serde(skip_serializing_if = "Option::is_none")]
355    pub random_seed: Option<u32>,
356    #[serde(skip_serializing_if = "Option::is_none")]
357    pub min_tokens: Option<u32>,
358}
359
360#[derive(Debug, Deserialize)]
361pub struct CodestralResponse {
362    pub id: String,
363    pub object: String,
364    pub model: String,
365    pub usage: Usage,
366    pub created: u64,
367    pub choices: Vec<Choice>,
368}
369
370#[derive(Debug, Deserialize)]
371pub struct Usage {
372    pub prompt_tokens: u32,
373    pub completion_tokens: u32,
374    pub total_tokens: u32,
375}
376
377#[derive(Debug, Deserialize)]
378pub struct Choice {
379    pub index: u32,
380    pub message: Message,
381    pub finish_reason: String,
382}
383
384#[derive(Debug, Deserialize)]
385pub struct Message {
386    pub content: String,
387    pub role: String,
388}