codestral.rs

  1use anyhow::{Context as _, Result};
  2use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
  3use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
  4use futures::AsyncReadExt;
  5use gpui::{App, Context, Entity, Task};
  6use http_client::HttpClient;
  7use language::{
  8    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint,
  9};
 10use language_models::MistralLanguageModelProvider;
 11use mistral::CODESTRAL_API_URL;
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    ops::Range,
 15    sync::Arc,
 16    time::{Duration, Instant},
 17};
 18use text::ToOffset;
 19
 20pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
 21
 22const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 23    max_bytes: 1050,
 24    min_bytes: 525,
 25    target_before_cursor_over_total_bytes: 0.66,
 26};
 27
 28/// Represents a completion that has been received and processed from Codestral.
 29/// This struct maintains the state needed to interpolate the completion as the user types.
 30#[derive(Clone)]
 31struct CurrentCompletion {
 32    /// The buffer snapshot at the time the completion was generated.
 33    /// Used to detect changes and interpolate edits.
 34    snapshot: BufferSnapshot,
 35    /// The edits that should be applied to transform the original text into the predicted text.
 36    /// Each edit is a range in the buffer and the text to replace it with.
 37    edits: Arc<[(Range<Anchor>, String)]>,
 38    /// Preview of how the buffer will look after applying the edits.
 39    edit_preview: EditPreview,
 40}
 41
 42impl CurrentCompletion {
 43    /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
 44    /// Returns None if the user's edits conflict with the predicted edits.
 45    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 46        edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 47    }
 48}
 49
 50pub struct CodestralCompletionProvider {
 51    http_client: Arc<dyn HttpClient>,
 52    pending_request: Option<Task<Result<()>>>,
 53    current_completion: Option<CurrentCompletion>,
 54}
 55
 56impl CodestralCompletionProvider {
 57    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
 58        Self {
 59            http_client,
 60            pending_request: None,
 61            current_completion: None,
 62        }
 63    }
 64
 65    pub fn has_api_key(cx: &App) -> bool {
 66        Self::api_key(cx).is_some()
 67    }
 68
 69    fn api_key(cx: &App) -> Option<Arc<str>> {
 70        MistralLanguageModelProvider::try_global(cx)
 71            .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
 72    }
 73
 74    /// Uses Codestral's Fill-in-the-Middle API for code completion.
 75    async fn fetch_completion(
 76        http_client: Arc<dyn HttpClient>,
 77        api_key: &str,
 78        prompt: String,
 79        suffix: String,
 80        model: String,
 81        max_tokens: Option<u32>,
 82    ) -> Result<String> {
 83        let start_time = Instant::now();
 84
 85        log::debug!(
 86            "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
 87            model,
 88            max_tokens
 89        );
 90
 91        let request = CodestralRequest {
 92            model,
 93            prompt,
 94            suffix: if suffix.is_empty() {
 95                None
 96            } else {
 97                Some(suffix)
 98            },
 99            max_tokens: max_tokens.or(Some(350)),
100            temperature: Some(0.2),
101            top_p: Some(1.0),
102            stream: Some(false),
103            stop: None,
104            random_seed: None,
105            min_tokens: None,
106        };
107
108        let request_body = serde_json::to_string(&request)?;
109
110        log::debug!("Codestral: Sending FIM request");
111
112        let http_request = http_client::Request::builder()
113            .method(http_client::Method::POST)
114            .uri(format!("{}/v1/fim/completions", CODESTRAL_API_URL))
115            .header("Content-Type", "application/json")
116            .header("Authorization", format!("Bearer {}", api_key))
117            .body(http_client::AsyncBody::from(request_body))?;
118
119        let mut response = http_client.send(http_request).await?;
120        let status = response.status();
121
122        log::debug!("Codestral: Response status: {}", status);
123
124        if !status.is_success() {
125            let mut body = String::new();
126            response.body_mut().read_to_string(&mut body).await?;
127            return Err(anyhow::anyhow!(
128                "Codestral API error: {} - {}",
129                status,
130                body
131            ));
132        }
133
134        let mut body = String::new();
135        response.body_mut().read_to_string(&mut body).await?;
136
137        let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
138
139        let elapsed = start_time.elapsed();
140
141        if let Some(choice) = codestral_response.choices.first() {
142            let completion = &choice.message.content;
143
144            log::debug!(
145                "Codestral: Completion received ({} tokens, {:.2}s)",
146                codestral_response.usage.completion_tokens,
147                elapsed.as_secs_f64()
148            );
149
150            // Return just the completion text for insertion at cursor
151            Ok(completion.clone())
152        } else {
153            log::error!("Codestral: No completion returned in response");
154            Err(anyhow::anyhow!("No completion returned from Codestral"))
155        }
156    }
157}
158
159impl EditPredictionProvider for CodestralCompletionProvider {
160    fn name() -> &'static str {
161        "codestral"
162    }
163
164    fn display_name() -> &'static str {
165        "Codestral"
166    }
167
168    fn show_completions_in_menu() -> bool {
169        true
170    }
171
172    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
173        Self::api_key(cx).is_some()
174    }
175
176    fn is_refreshing(&self) -> bool {
177        self.pending_request.is_some()
178    }
179
180    fn refresh(
181        &mut self,
182        buffer: Entity<Buffer>,
183        cursor_position: language::Anchor,
184        debounce: bool,
185        cx: &mut Context<Self>,
186    ) {
187        log::debug!("Codestral: Refresh called (debounce: {})", debounce);
188
189        let Some(api_key) = Self::api_key(cx) else {
190            log::warn!("Codestral: No API key configured, skipping refresh");
191            return;
192        };
193
194        let snapshot = buffer.read(cx).snapshot();
195
196        // Check if current completion is still valid
197        if let Some(current_completion) = self.current_completion.as_ref() {
198            if current_completion.interpolate(&snapshot).is_some() {
199                return;
200            }
201        }
202
203        let http_client = self.http_client.clone();
204
205        // Get settings
206        let settings = all_language_settings(None, cx);
207        let model = settings
208            .edit_predictions
209            .codestral
210            .model
211            .clone()
212            .unwrap_or_else(|| "codestral-latest".to_string());
213        let max_tokens = settings.edit_predictions.codestral.max_tokens;
214
215        self.pending_request = Some(cx.spawn(async move |this, cx| {
216            if debounce {
217                log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
218                smol::Timer::after(DEBOUNCE_TIMEOUT).await;
219            }
220
221            let cursor_offset = cursor_position.to_offset(&snapshot);
222            let cursor_point = cursor_offset.to_point(&snapshot);
223            let excerpt = EditPredictionExcerpt::select_from_buffer(
224                cursor_point,
225                &snapshot,
226                &EXCERPT_OPTIONS,
227                None,
228            )
229            .context("Line containing cursor doesn't fit in excerpt max bytes")?;
230
231            let excerpt_text = excerpt.text(&snapshot);
232            let cursor_within_excerpt = cursor_offset
233                .saturating_sub(excerpt.range.start)
234                .min(excerpt_text.body.len());
235            let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
236            let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
237
238            let completion_text = match Self::fetch_completion(
239                http_client,
240                &api_key,
241                prompt,
242                suffix,
243                model,
244                max_tokens,
245            )
246            .await
247            {
248                Ok(completion) => completion,
249                Err(e) => {
250                    log::error!("Codestral: Failed to fetch completion: {}", e);
251                    this.update(cx, |this, cx| {
252                        this.pending_request = None;
253                        cx.notify();
254                    })?;
255                    return Err(e);
256                }
257            };
258
259            if completion_text.trim().is_empty() {
260                log::debug!("Codestral: Completion was empty after trimming; ignoring");
261                this.update(cx, |this, cx| {
262                    this.pending_request = None;
263                    cx.notify();
264                })?;
265                return Ok(());
266            }
267
268            let edits: Arc<[(Range<Anchor>, String)]> =
269                vec![(cursor_position..cursor_position, completion_text)].into();
270            let edit_preview = buffer
271                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
272                .await;
273
274            this.update(cx, |this, cx| {
275                this.current_completion = Some(CurrentCompletion {
276                    snapshot,
277                    edits,
278                    edit_preview,
279                });
280                this.pending_request = None;
281                cx.notify();
282            })?;
283
284            Ok(())
285        }));
286    }
287
288    fn cycle(
289        &mut self,
290        _buffer: Entity<Buffer>,
291        _cursor_position: Anchor,
292        _direction: Direction,
293        _cx: &mut Context<Self>,
294    ) {
295        // Codestral doesn't support multiple completions, so cycling does nothing
296    }
297
298    fn accept(&mut self, _cx: &mut Context<Self>) {
299        log::debug!("Codestral: Completion accepted");
300        self.pending_request = None;
301        self.current_completion = None;
302    }
303
304    fn discard(&mut self, _cx: &mut Context<Self>) {
305        log::debug!("Codestral: Completion discarded");
306        self.pending_request = None;
307        self.current_completion = None;
308    }
309
310    /// Returns the completion suggestion, adjusted or invalidated based on user edits
311    fn suggest(
312        &mut self,
313        buffer: &Entity<Buffer>,
314        _cursor_position: Anchor,
315        cx: &mut Context<Self>,
316    ) -> Option<EditPrediction> {
317        let current_completion = self.current_completion.as_ref()?;
318        let buffer = buffer.read(cx);
319        let edits = current_completion.interpolate(&buffer.snapshot())?;
320        if edits.is_empty() {
321            return None;
322        }
323        Some(EditPrediction::Local {
324            id: None,
325            edits,
326            edit_preview: Some(current_completion.edit_preview.clone()),
327        })
328    }
329}
330
331#[derive(Debug, Serialize, Deserialize)]
332pub struct CodestralRequest {
333    pub model: String,
334    pub prompt: String,
335    #[serde(skip_serializing_if = "Option::is_none")]
336    pub suffix: Option<String>,
337    #[serde(skip_serializing_if = "Option::is_none")]
338    pub max_tokens: Option<u32>,
339    #[serde(skip_serializing_if = "Option::is_none")]
340    pub temperature: Option<f32>,
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub top_p: Option<f32>,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub stream: Option<bool>,
345    #[serde(skip_serializing_if = "Option::is_none")]
346    pub stop: Option<Vec<String>>,
347    #[serde(skip_serializing_if = "Option::is_none")]
348    pub random_seed: Option<u32>,
349    #[serde(skip_serializing_if = "Option::is_none")]
350    pub min_tokens: Option<u32>,
351}
352
353#[derive(Debug, Deserialize)]
354pub struct CodestralResponse {
355    pub id: String,
356    pub object: String,
357    pub model: String,
358    pub usage: Usage,
359    pub created: u64,
360    pub choices: Vec<Choice>,
361}
362
363#[derive(Debug, Deserialize)]
364pub struct Usage {
365    pub prompt_tokens: u32,
366    pub completion_tokens: u32,
367    pub total_tokens: u32,
368}
369
370#[derive(Debug, Deserialize)]
371pub struct Choice {
372    pub index: u32,
373    pub message: Message,
374    pub finish_reason: String,
375}
376
377#[derive(Debug, Deserialize)]
378pub struct Message {
379    pub content: String,
380    pub role: String,
381}