1use anyhow::{Context as _, Result};
  2use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
  3use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
  4use futures::AsyncReadExt;
  5use gpui::{App, Context, Entity, Task};
  6use http_client::HttpClient;
  7use language::{
  8    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint,
  9};
 10use language_models::MistralLanguageModelProvider;
 11use mistral::CODESTRAL_API_URL;
 12use serde::{Deserialize, Serialize};
 13use std::{
 14    ops::Range,
 15    sync::Arc,
 16    time::{Duration, Instant},
 17};
 18use text::ToOffset;
 19
 20pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
 21
 22const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 23    max_bytes: 1050,
 24    min_bytes: 525,
 25    target_before_cursor_over_total_bytes: 0.66,
 26};
 27
 28/// Represents a completion that has been received and processed from Codestral.
 29/// This struct maintains the state needed to interpolate the completion as the user types.
 30#[derive(Clone)]
 31struct CurrentCompletion {
 32    /// The buffer snapshot at the time the completion was generated.
 33    /// Used to detect changes and interpolate edits.
 34    snapshot: BufferSnapshot,
 35    /// The edits that should be applied to transform the original text into the predicted text.
 36    /// Each edit is a range in the buffer and the text to replace it with.
 37    edits: Arc<[(Range<Anchor>, String)]>,
 38    /// Preview of how the buffer will look after applying the edits.
 39    edit_preview: EditPreview,
 40}
 41
 42impl CurrentCompletion {
 43    /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
 44    /// Returns None if the user's edits conflict with the predicted edits.
 45    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 46        edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 47    }
 48}
 49
 50pub struct CodestralCompletionProvider {
 51    http_client: Arc<dyn HttpClient>,
 52    pending_request: Option<Task<Result<()>>>,
 53    current_completion: Option<CurrentCompletion>,
 54}
 55
 56impl CodestralCompletionProvider {
 57    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
 58        Self {
 59            http_client,
 60            pending_request: None,
 61            current_completion: None,
 62        }
 63    }
 64
 65    pub fn has_api_key(cx: &App) -> bool {
 66        Self::api_key(cx).is_some()
 67    }
 68
 69    fn api_key(cx: &App) -> Option<Arc<str>> {
 70        MistralLanguageModelProvider::try_global(cx)
 71            .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
 72    }
 73
 74    /// Uses Codestral's Fill-in-the-Middle API for code completion.
 75    async fn fetch_completion(
 76        http_client: Arc<dyn HttpClient>,
 77        api_key: &str,
 78        prompt: String,
 79        suffix: String,
 80        model: String,
 81        max_tokens: Option<u32>,
 82        api_url: String,
 83    ) -> Result<String> {
 84        let start_time = Instant::now();
 85
 86        log::debug!(
 87            "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
 88            model,
 89            max_tokens
 90        );
 91
 92        let request = CodestralRequest {
 93            model,
 94            prompt,
 95            suffix: if suffix.is_empty() {
 96                None
 97            } else {
 98                Some(suffix)
 99            },
100            max_tokens: max_tokens.or(Some(350)),
101            temperature: Some(0.2),
102            top_p: Some(1.0),
103            stream: Some(false),
104            stop: None,
105            random_seed: None,
106            min_tokens: None,
107        };
108
109        let request_body = serde_json::to_string(&request)?;
110
111        log::debug!("Codestral: Sending FIM request");
112
113        let http_request = http_client::Request::builder()
114            .method(http_client::Method::POST)
115            .uri(format!("{}/v1/fim/completions", api_url))
116            .header("Content-Type", "application/json")
117            .header("Authorization", format!("Bearer {}", api_key))
118            .body(http_client::AsyncBody::from(request_body))?;
119
120        let mut response = http_client.send(http_request).await?;
121        let status = response.status();
122
123        log::debug!("Codestral: Response status: {}", status);
124
125        if !status.is_success() {
126            let mut body = String::new();
127            response.body_mut().read_to_string(&mut body).await?;
128            return Err(anyhow::anyhow!(
129                "Codestral API error: {} - {}",
130                status,
131                body
132            ));
133        }
134
135        let mut body = String::new();
136        response.body_mut().read_to_string(&mut body).await?;
137
138        let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
139
140        let elapsed = start_time.elapsed();
141
142        if let Some(choice) = codestral_response.choices.first() {
143            let completion = &choice.message.content;
144
145            log::debug!(
146                "Codestral: Completion received ({} tokens, {:.2}s)",
147                codestral_response.usage.completion_tokens,
148                elapsed.as_secs_f64()
149            );
150
151            // Return just the completion text for insertion at cursor
152            Ok(completion.clone())
153        } else {
154            log::error!("Codestral: No completion returned in response");
155            Err(anyhow::anyhow!("No completion returned from Codestral"))
156        }
157    }
158}
159
160impl EditPredictionProvider for CodestralCompletionProvider {
161    fn name() -> &'static str {
162        "codestral"
163    }
164
165    fn display_name() -> &'static str {
166        "Codestral"
167    }
168
169    fn show_completions_in_menu() -> bool {
170        true
171    }
172
173    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
174        Self::api_key(cx).is_some()
175    }
176
177    fn is_refreshing(&self) -> bool {
178        self.pending_request.is_some()
179    }
180
181    fn refresh(
182        &mut self,
183        buffer: Entity<Buffer>,
184        cursor_position: language::Anchor,
185        debounce: bool,
186        cx: &mut Context<Self>,
187    ) {
188        log::debug!("Codestral: Refresh called (debounce: {})", debounce);
189
190        let Some(api_key) = Self::api_key(cx) else {
191            log::warn!("Codestral: No API key configured, skipping refresh");
192            return;
193        };
194
195        let snapshot = buffer.read(cx).snapshot();
196
197        // Check if current completion is still valid
198        if let Some(current_completion) = self.current_completion.as_ref() {
199            if current_completion.interpolate(&snapshot).is_some() {
200                return;
201            }
202        }
203
204        let http_client = self.http_client.clone();
205
206        // Get settings
207        let settings = all_language_settings(None, cx);
208        let model = settings
209            .edit_predictions
210            .codestral
211            .model
212            .clone()
213            .unwrap_or_else(|| "codestral-latest".to_string());
214        let max_tokens = settings.edit_predictions.codestral.max_tokens;
215        let api_url = settings
216            .edit_predictions
217            .codestral
218            .api_url
219            .clone()
220            .unwrap_or_else(|| CODESTRAL_API_URL.to_string());
221
222        self.pending_request = Some(cx.spawn(async move |this, cx| {
223            if debounce {
224                log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
225                smol::Timer::after(DEBOUNCE_TIMEOUT).await;
226            }
227
228            let cursor_offset = cursor_position.to_offset(&snapshot);
229            let cursor_point = cursor_offset.to_point(&snapshot);
230            let excerpt = EditPredictionExcerpt::select_from_buffer(
231                cursor_point,
232                &snapshot,
233                &EXCERPT_OPTIONS,
234                None,
235            )
236            .context("Line containing cursor doesn't fit in excerpt max bytes")?;
237
238            let excerpt_text = excerpt.text(&snapshot);
239            let cursor_within_excerpt = cursor_offset
240                .saturating_sub(excerpt.range.start)
241                .min(excerpt_text.body.len());
242            let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
243            let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
244
245            let completion_text = match Self::fetch_completion(
246                http_client,
247                &api_key,
248                prompt,
249                suffix,
250                model,
251                max_tokens,
252                api_url,
253            )
254            .await
255            {
256                Ok(completion) => completion,
257                Err(e) => {
258                    log::error!("Codestral: Failed to fetch completion: {}", e);
259                    this.update(cx, |this, cx| {
260                        this.pending_request = None;
261                        cx.notify();
262                    })?;
263                    return Err(e);
264                }
265            };
266
267            if completion_text.trim().is_empty() {
268                log::debug!("Codestral: Completion was empty after trimming; ignoring");
269                this.update(cx, |this, cx| {
270                    this.pending_request = None;
271                    cx.notify();
272                })?;
273                return Ok(());
274            }
275
276            let edits: Arc<[(Range<Anchor>, String)]> =
277                vec![(cursor_position..cursor_position, completion_text)].into();
278            let edit_preview = buffer
279                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
280                .await;
281
282            this.update(cx, |this, cx| {
283                this.current_completion = Some(CurrentCompletion {
284                    snapshot,
285                    edits,
286                    edit_preview,
287                });
288                this.pending_request = None;
289                cx.notify();
290            })?;
291
292            Ok(())
293        }));
294    }
295
296    fn cycle(
297        &mut self,
298        _buffer: Entity<Buffer>,
299        _cursor_position: Anchor,
300        _direction: Direction,
301        _cx: &mut Context<Self>,
302    ) {
303        // Codestral doesn't support multiple completions, so cycling does nothing
304    }
305
306    fn accept(&mut self, _cx: &mut Context<Self>) {
307        log::debug!("Codestral: Completion accepted");
308        self.pending_request = None;
309        self.current_completion = None;
310    }
311
312    fn discard(&mut self, _cx: &mut Context<Self>) {
313        log::debug!("Codestral: Completion discarded");
314        self.pending_request = None;
315        self.current_completion = None;
316    }
317
318    /// Returns the completion suggestion, adjusted or invalidated based on user edits
319    fn suggest(
320        &mut self,
321        buffer: &Entity<Buffer>,
322        _cursor_position: Anchor,
323        cx: &mut Context<Self>,
324    ) -> Option<EditPrediction> {
325        let current_completion = self.current_completion.as_ref()?;
326        let buffer = buffer.read(cx);
327        let edits = current_completion.interpolate(&buffer.snapshot())?;
328        if edits.is_empty() {
329            return None;
330        }
331        Some(EditPrediction::Local {
332            id: None,
333            edits,
334            edit_preview: Some(current_completion.edit_preview.clone()),
335        })
336    }
337}
338
339#[derive(Debug, Serialize, Deserialize)]
340pub struct CodestralRequest {
341    pub model: String,
342    pub prompt: String,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub suffix: Option<String>,
345    #[serde(skip_serializing_if = "Option::is_none")]
346    pub max_tokens: Option<u32>,
347    #[serde(skip_serializing_if = "Option::is_none")]
348    pub temperature: Option<f32>,
349    #[serde(skip_serializing_if = "Option::is_none")]
350    pub top_p: Option<f32>,
351    #[serde(skip_serializing_if = "Option::is_none")]
352    pub stream: Option<bool>,
353    #[serde(skip_serializing_if = "Option::is_none")]
354    pub stop: Option<Vec<String>>,
355    #[serde(skip_serializing_if = "Option::is_none")]
356    pub random_seed: Option<u32>,
357    #[serde(skip_serializing_if = "Option::is_none")]
358    pub min_tokens: Option<u32>,
359}
360
361#[derive(Debug, Deserialize)]
362pub struct CodestralResponse {
363    pub id: String,
364    pub object: String,
365    pub model: String,
366    pub usage: Usage,
367    pub created: u64,
368    pub choices: Vec<Choice>,
369}
370
371#[derive(Debug, Deserialize)]
372pub struct Usage {
373    pub prompt_tokens: u32,
374    pub completion_tokens: u32,
375    pub total_tokens: u32,
376}
377
378#[derive(Debug, Deserialize)]
379pub struct Choice {
380    pub index: u32,
381    pub message: Message,
382    pub finish_reason: String,
383}
384
385#[derive(Debug, Deserialize)]
386pub struct Message {
387    pub content: String,
388    pub role: String,
389}