1use anyhow::{Context as _, Result};
  2use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
  3use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
  4use futures::AsyncReadExt;
  5use gpui::{App, Context, Entity, Task};
  6use http_client::HttpClient;
  7use language::{
  8    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint,
  9};
 10use language_models::MistralLanguageModelProvider;
 11use mistral::CODESTRAL_API_URL;
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    ops::Range,
 15    sync::Arc,
 16    time::{Duration, Instant},
 17};
 18use text::ToOffset;
 19
 20pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
 21
 22const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 23    max_bytes: 1050,
 24    min_bytes: 525,
 25    target_before_cursor_over_total_bytes: 0.66,
 26};
 27
 28/// Represents a completion that has been received and processed from Codestral.
 29/// This struct maintains the state needed to interpolate the completion as the user types.
 30#[derive(Clone)]
 31struct CurrentCompletion {
 32    /// The buffer snapshot at the time the completion was generated.
 33    /// Used to detect changes and interpolate edits.
 34    snapshot: BufferSnapshot,
 35    /// The edits that should be applied to transform the original text into the predicted text.
 36    /// Each edit is a range in the buffer and the text to replace it with.
 37    edits: Arc<[(Range<Anchor>, String)]>,
 38    /// Preview of how the buffer will look after applying the edits.
 39    edit_preview: EditPreview,
 40}
 41
 42impl CurrentCompletion {
 43    /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
 44    /// Returns None if the user's edits conflict with the predicted edits.
 45    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 46        edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 47    }
 48}
 49
 50pub struct CodestralCompletionProvider {
 51    http_client: Arc<dyn HttpClient>,
 52    pending_request: Option<Task<Result<()>>>,
 53    current_completion: Option<CurrentCompletion>,
 54}
 55
 56impl CodestralCompletionProvider {
 57    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
 58        Self {
 59            http_client,
 60            pending_request: None,
 61            current_completion: None,
 62        }
 63    }
 64
 65    pub fn has_api_key(cx: &App) -> bool {
 66        Self::api_key(cx).is_some()
 67    }
 68
 69    /// This is so we can immediately show Codestral as a provider users can
 70    /// switch to in the edit prediction menu, if the API has been added
 71    pub fn ensure_api_key_loaded(http_client: Arc<dyn HttpClient>, cx: &mut App) {
 72        MistralLanguageModelProvider::global(http_client, cx)
 73            .load_codestral_api_key(cx)
 74            .detach();
 75    }
 76
 77    fn api_key(cx: &App) -> Option<Arc<str>> {
 78        MistralLanguageModelProvider::try_global(cx)
 79            .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
 80    }
 81
 82    /// Uses Codestral's Fill-in-the-Middle API for code completion.
 83    async fn fetch_completion(
 84        http_client: Arc<dyn HttpClient>,
 85        api_key: &str,
 86        prompt: String,
 87        suffix: String,
 88        model: String,
 89        max_tokens: Option<u32>,
 90        api_url: String,
 91    ) -> Result<String> {
 92        let start_time = Instant::now();
 93
 94        log::debug!(
 95            "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
 96            model,
 97            max_tokens
 98        );
 99
100        let request = CodestralRequest {
101            model,
102            prompt,
103            suffix: if suffix.is_empty() {
104                None
105            } else {
106                Some(suffix)
107            },
108            max_tokens: max_tokens.or(Some(350)),
109            temperature: Some(0.2),
110            top_p: Some(1.0),
111            stream: Some(false),
112            stop: None,
113            random_seed: None,
114            min_tokens: None,
115        };
116
117        let request_body = serde_json::to_string(&request)?;
118
119        log::debug!("Codestral: Sending FIM request");
120
121        let http_request = http_client::Request::builder()
122            .method(http_client::Method::POST)
123            .uri(format!("{}/v1/fim/completions", api_url))
124            .header("Content-Type", "application/json")
125            .header("Authorization", format!("Bearer {}", api_key))
126            .body(http_client::AsyncBody::from(request_body))?;
127
128        let mut response = http_client.send(http_request).await?;
129        let status = response.status();
130
131        log::debug!("Codestral: Response status: {}", status);
132
133        if !status.is_success() {
134            let mut body = String::new();
135            response.body_mut().read_to_string(&mut body).await?;
136            return Err(anyhow::anyhow!(
137                "Codestral API error: {} - {}",
138                status,
139                body
140            ));
141        }
142
143        let mut body = String::new();
144        response.body_mut().read_to_string(&mut body).await?;
145
146        let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
147
148        let elapsed = start_time.elapsed();
149
150        if let Some(choice) = codestral_response.choices.first() {
151            let completion = &choice.message.content;
152
153            log::debug!(
154                "Codestral: Completion received ({} tokens, {:.2}s)",
155                codestral_response.usage.completion_tokens,
156                elapsed.as_secs_f64()
157            );
158
159            // Return just the completion text for insertion at cursor
160            Ok(completion.clone())
161        } else {
162            log::error!("Codestral: No completion returned in response");
163            Err(anyhow::anyhow!("No completion returned from Codestral"))
164        }
165    }
166}
167
168impl EditPredictionProvider for CodestralCompletionProvider {
169    fn name() -> &'static str {
170        "codestral"
171    }
172
173    fn display_name() -> &'static str {
174        "Codestral"
175    }
176
177    fn show_completions_in_menu() -> bool {
178        true
179    }
180
181    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
182        Self::api_key(cx).is_some()
183    }
184
185    fn is_refreshing(&self) -> bool {
186        self.pending_request.is_some()
187    }
188
189    fn refresh(
190        &mut self,
191        buffer: Entity<Buffer>,
192        cursor_position: language::Anchor,
193        debounce: bool,
194        cx: &mut Context<Self>,
195    ) {
196        log::debug!("Codestral: Refresh called (debounce: {})", debounce);
197
198        let Some(api_key) = Self::api_key(cx) else {
199            log::warn!("Codestral: No API key configured, skipping refresh");
200            return;
201        };
202
203        let snapshot = buffer.read(cx).snapshot();
204
205        // Check if current completion is still valid
206        if let Some(current_completion) = self.current_completion.as_ref() {
207            if current_completion.interpolate(&snapshot).is_some() {
208                return;
209            }
210        }
211
212        let http_client = self.http_client.clone();
213
214        // Get settings
215        let settings = all_language_settings(None, cx);
216        let model = settings
217            .edit_predictions
218            .codestral
219            .model
220            .clone()
221            .unwrap_or_else(|| "codestral-latest".to_string());
222        let max_tokens = settings.edit_predictions.codestral.max_tokens;
223        let api_url = settings
224            .edit_predictions
225            .codestral
226            .api_url
227            .clone()
228            .unwrap_or_else(|| CODESTRAL_API_URL.to_string());
229
230        self.pending_request = Some(cx.spawn(async move |this, cx| {
231            if debounce {
232                log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
233                smol::Timer::after(DEBOUNCE_TIMEOUT).await;
234            }
235
236            let cursor_offset = cursor_position.to_offset(&snapshot);
237            let cursor_point = cursor_offset.to_point(&snapshot);
238            let excerpt = EditPredictionExcerpt::select_from_buffer(
239                cursor_point,
240                &snapshot,
241                &EXCERPT_OPTIONS,
242                None,
243            )
244            .context("Line containing cursor doesn't fit in excerpt max bytes")?;
245
246            let excerpt_text = excerpt.text(&snapshot);
247            let cursor_within_excerpt = cursor_offset
248                .saturating_sub(excerpt.range.start)
249                .min(excerpt_text.body.len());
250            let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
251            let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
252
253            let completion_text = match Self::fetch_completion(
254                http_client,
255                &api_key,
256                prompt,
257                suffix,
258                model,
259                max_tokens,
260                api_url,
261            )
262            .await
263            {
264                Ok(completion) => completion,
265                Err(e) => {
266                    log::error!("Codestral: Failed to fetch completion: {}", e);
267                    this.update(cx, |this, cx| {
268                        this.pending_request = None;
269                        cx.notify();
270                    })?;
271                    return Err(e);
272                }
273            };
274
275            if completion_text.trim().is_empty() {
276                log::debug!("Codestral: Completion was empty after trimming; ignoring");
277                this.update(cx, |this, cx| {
278                    this.pending_request = None;
279                    cx.notify();
280                })?;
281                return Ok(());
282            }
283
284            let edits: Arc<[(Range<Anchor>, String)]> =
285                vec![(cursor_position..cursor_position, completion_text)].into();
286            let edit_preview = buffer
287                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
288                .await;
289
290            this.update(cx, |this, cx| {
291                this.current_completion = Some(CurrentCompletion {
292                    snapshot,
293                    edits,
294                    edit_preview,
295                });
296                this.pending_request = None;
297                cx.notify();
298            })?;
299
300            Ok(())
301        }));
302    }
303
304    fn cycle(
305        &mut self,
306        _buffer: Entity<Buffer>,
307        _cursor_position: Anchor,
308        _direction: Direction,
309        _cx: &mut Context<Self>,
310    ) {
311        // Codestral doesn't support multiple completions, so cycling does nothing
312    }
313
314    fn accept(&mut self, _cx: &mut Context<Self>) {
315        log::debug!("Codestral: Completion accepted");
316        self.pending_request = None;
317        self.current_completion = None;
318    }
319
320    fn discard(&mut self, _cx: &mut Context<Self>) {
321        log::debug!("Codestral: Completion discarded");
322        self.pending_request = None;
323        self.current_completion = None;
324    }
325
326    /// Returns the completion suggestion, adjusted or invalidated based on user edits
327    fn suggest(
328        &mut self,
329        buffer: &Entity<Buffer>,
330        _cursor_position: Anchor,
331        cx: &mut Context<Self>,
332    ) -> Option<EditPrediction> {
333        let current_completion = self.current_completion.as_ref()?;
334        let buffer = buffer.read(cx);
335        let edits = current_completion.interpolate(&buffer.snapshot())?;
336        if edits.is_empty() {
337            return None;
338        }
339        Some(EditPrediction::Local {
340            id: None,
341            edits,
342            edit_preview: Some(current_completion.edit_preview.clone()),
343        })
344    }
345}
346
347#[derive(Debug, Serialize, Deserialize)]
348pub struct CodestralRequest {
349    pub model: String,
350    pub prompt: String,
351    #[serde(skip_serializing_if = "Option::is_none")]
352    pub suffix: Option<String>,
353    #[serde(skip_serializing_if = "Option::is_none")]
354    pub max_tokens: Option<u32>,
355    #[serde(skip_serializing_if = "Option::is_none")]
356    pub temperature: Option<f32>,
357    #[serde(skip_serializing_if = "Option::is_none")]
358    pub top_p: Option<f32>,
359    #[serde(skip_serializing_if = "Option::is_none")]
360    pub stream: Option<bool>,
361    #[serde(skip_serializing_if = "Option::is_none")]
362    pub stop: Option<Vec<String>>,
363    #[serde(skip_serializing_if = "Option::is_none")]
364    pub random_seed: Option<u32>,
365    #[serde(skip_serializing_if = "Option::is_none")]
366    pub min_tokens: Option<u32>,
367}
368
369#[derive(Debug, Deserialize)]
370pub struct CodestralResponse {
371    pub id: String,
372    pub object: String,
373    pub model: String,
374    pub usage: Usage,
375    pub created: u64,
376    pub choices: Vec<Choice>,
377}
378
379#[derive(Debug, Deserialize)]
380pub struct Usage {
381    pub prompt_tokens: u32,
382    pub completion_tokens: u32,
383    pub total_tokens: u32,
384}
385
386#[derive(Debug, Deserialize)]
387pub struct Choice {
388    pub index: u32,
389    pub message: Message,
390    pub finish_reason: String,
391}
392
393#[derive(Debug, Deserialize)]
394pub struct Message {
395    pub content: String,
396    pub role: String,
397}