codestral.rs

  1use anyhow::Result;
  2use edit_prediction::cursor_excerpt;
  3use edit_prediction_types::{
  4    EditPrediction, EditPredictionDelegate, EditPredictionDiscardReason, EditPredictionIconSet,
  5};
  6use futures::AsyncReadExt;
  7use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
  8use http_client::HttpClient;
  9use icons::IconName;
 10use language::{
 11    Anchor, Buffer, BufferSnapshot, EditPreview, language_settings::all_language_settings,
 12};
 13use language_model::{ApiKeyState, AuthenticateError, EnvVar, env_var};
 14use serde::{Deserialize, Serialize};
 15
 16use std::{
 17    ops::Range,
 18    sync::Arc,
 19    time::{Duration, Instant},
 20};
 21use text::ToOffset;
 22
 23pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai";
 24pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
 25
 26static CODESTRAL_API_KEY_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("CODESTRAL_API_KEY");
 27
 28struct GlobalCodestralApiKey(Entity<ApiKeyState>);
 29
 30impl Global for GlobalCodestralApiKey {}
 31
 32pub fn codestral_api_key_state(cx: &mut App) -> Entity<ApiKeyState> {
 33    if let Some(global) = cx.try_global::<GlobalCodestralApiKey>() {
 34        return global.0.clone();
 35    }
 36    let entity =
 37        cx.new(|cx| ApiKeyState::new(codestral_api_url(cx), CODESTRAL_API_KEY_ENV_VAR.clone()));
 38    cx.set_global(GlobalCodestralApiKey(entity.clone()));
 39    entity
 40}
 41
 42pub fn codestral_api_key(cx: &App) -> Option<Arc<str>> {
 43    let url = codestral_api_url(cx);
 44    cx.try_global::<GlobalCodestralApiKey>()?
 45        .0
 46        .read(cx)
 47        .key(&url)
 48}
 49
 50pub fn load_codestral_api_key(cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 51    let credentials_provider = zed_credentials_provider::global(cx);
 52    let api_url = codestral_api_url(cx);
 53    codestral_api_key_state(cx).update(cx, |key_state, cx| {
 54        key_state.load_if_needed(api_url, |s| s, credentials_provider, cx)
 55    })
 56}
 57
 58pub fn codestral_api_url(cx: &App) -> SharedString {
 59    all_language_settings(None, cx)
 60        .edit_predictions
 61        .codestral
 62        .api_url
 63        .clone()
 64        .unwrap_or_else(|| CODESTRAL_API_URL.to_string())
 65        .into()
 66}
 67
 68/// Represents a completion that has been received and processed from Codestral.
 69/// This struct maintains the state needed to interpolate the completion as the user types.
 70#[derive(Clone)]
 71struct CurrentCompletion {
 72    /// The buffer snapshot at the time the completion was generated.
 73    /// Used to detect changes and interpolate edits.
 74    snapshot: BufferSnapshot,
 75    /// The edits that should be applied to transform the original text into the predicted text.
 76    /// Each edit is a range in the buffer and the text to replace it with.
 77    edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 78    /// Preview of how the buffer will look after applying the edits.
 79    edit_preview: EditPreview,
 80}
 81
 82impl CurrentCompletion {
 83    /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
 84    /// Returns None if the user's edits conflict with the predicted edits.
 85    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 86        edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
 87    }
 88}
 89
 90pub struct CodestralEditPredictionDelegate {
 91    http_client: Arc<dyn HttpClient>,
 92    pending_request: Option<Task<Result<()>>>,
 93    current_completion: Option<CurrentCompletion>,
 94}
 95
 96impl CodestralEditPredictionDelegate {
 97    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
 98        Self {
 99            http_client,
100            pending_request: None,
101            current_completion: None,
102        }
103    }
104
105    pub fn ensure_api_key_loaded(cx: &mut App) {
106        load_codestral_api_key(cx).detach();
107    }
108
109    /// Uses Codestral's Fill-in-the-Middle API for code completion.
110    async fn fetch_completion(
111        http_client: Arc<dyn HttpClient>,
112        api_key: &str,
113        prompt: String,
114        suffix: String,
115        model: String,
116        max_tokens: Option<u32>,
117        api_url: String,
118    ) -> Result<String> {
119        let start_time = Instant::now();
120
121        log::debug!(
122            "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
123            model,
124            max_tokens
125        );
126
127        let request = CodestralRequest {
128            model,
129            prompt,
130            suffix: if suffix.is_empty() {
131                None
132            } else {
133                Some(suffix)
134            },
135            max_tokens: max_tokens.or(Some(350)),
136            temperature: Some(0.2),
137            top_p: Some(1.0),
138            stream: Some(false),
139            stop: None,
140            random_seed: None,
141            min_tokens: None,
142        };
143
144        let request_body = serde_json::to_string(&request)?;
145
146        log::debug!("Codestral: Sending FIM request");
147
148        let http_request = http_client::Request::builder()
149            .method(http_client::Method::POST)
150            .uri(format!("{}/v1/fim/completions", api_url))
151            .header("Content-Type", "application/json")
152            .header("Authorization", format!("Bearer {}", api_key))
153            .body(http_client::AsyncBody::from(request_body))?;
154
155        let mut response = http_client.send(http_request).await?;
156        let status = response.status();
157
158        log::debug!("Codestral: Response status: {}", status);
159
160        if !status.is_success() {
161            let mut body = String::new();
162            response.body_mut().read_to_string(&mut body).await?;
163            return Err(anyhow::anyhow!(
164                "Codestral API error: {} - {}",
165                status,
166                body
167            ));
168        }
169
170        let mut body = String::new();
171        response.body_mut().read_to_string(&mut body).await?;
172
173        let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
174
175        let elapsed = start_time.elapsed();
176
177        if let Some(choice) = codestral_response.choices.first() {
178            let completion = &choice.message.content;
179
180            log::debug!(
181                "Codestral: Completion received ({} tokens, {:.2}s)",
182                codestral_response.usage.completion_tokens,
183                elapsed.as_secs_f64()
184            );
185
186            // Return just the completion text for insertion at cursor
187            Ok(completion.clone())
188        } else {
189            log::error!("Codestral: No completion returned in response");
190            Err(anyhow::anyhow!("No completion returned from Codestral"))
191        }
192    }
193}
194
195impl EditPredictionDelegate for CodestralEditPredictionDelegate {
196    fn name() -> &'static str {
197        "codestral"
198    }
199
200    fn display_name() -> &'static str {
201        "Codestral"
202    }
203
204    fn show_predictions_in_menu() -> bool {
205        true
206    }
207
208    fn icons(&self, _cx: &App) -> EditPredictionIconSet {
209        EditPredictionIconSet::new(IconName::AiMistral)
210    }
211
212    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
213        codestral_api_key(cx).is_some()
214    }
215
216    fn is_refreshing(&self, _cx: &App) -> bool {
217        self.pending_request.is_some()
218    }
219
220    fn refresh(
221        &mut self,
222        buffer: Entity<Buffer>,
223        cursor_position: language::Anchor,
224        debounce: bool,
225        cx: &mut Context<Self>,
226    ) {
227        log::debug!("Codestral: Refresh called (debounce: {})", debounce);
228
229        let Some(api_key) = codestral_api_key(cx) else {
230            log::warn!("Codestral: No API key configured, skipping refresh");
231            return;
232        };
233
234        let snapshot = buffer.read(cx).snapshot();
235
236        // Check if current completion is still valid
237        if let Some(current_completion) = self.current_completion.as_ref() {
238            if current_completion.interpolate(&snapshot).is_some() {
239                return;
240            }
241        }
242
243        let http_client = self.http_client.clone();
244
245        // Get settings
246        let settings = all_language_settings(None, cx);
247        let model = settings
248            .edit_predictions
249            .codestral
250            .model
251            .clone()
252            .unwrap_or_else(|| "codestral-latest".to_string());
253        let max_tokens = settings.edit_predictions.codestral.max_tokens;
254        let api_url = codestral_api_url(cx).to_string();
255
256        self.pending_request = Some(cx.spawn(async move |this, cx| {
257            if debounce {
258                log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
259                cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
260            }
261
262            let cursor_offset = cursor_position.to_offset(&snapshot);
263
264            const MAX_EDITABLE_TOKENS: usize = 350;
265            const MAX_CONTEXT_TOKENS: usize = 150;
266
267            let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
268                cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
269            let syntax_ranges = cursor_excerpt::compute_syntax_ranges(
270                &snapshot,
271                cursor_offset,
272                &excerpt_offset_range,
273            );
274            let excerpt_text: String = snapshot.text_for_range(excerpt_point_range).collect();
275            let (_, context_range) = zeta_prompt::compute_editable_and_context_ranges(
276                &excerpt_text,
277                cursor_offset_in_excerpt,
278                &syntax_ranges,
279                MAX_EDITABLE_TOKENS,
280                MAX_CONTEXT_TOKENS,
281            );
282            let context_text = &excerpt_text[context_range.clone()];
283            let cursor_within_excerpt = cursor_offset_in_excerpt
284                .saturating_sub(context_range.start)
285                .min(context_text.len());
286            let prompt = context_text[..cursor_within_excerpt].to_string();
287            let suffix = context_text[cursor_within_excerpt..].to_string();
288
289            let completion_text = match Self::fetch_completion(
290                http_client,
291                &api_key,
292                prompt,
293                suffix,
294                model,
295                max_tokens,
296                api_url,
297            )
298            .await
299            {
300                Ok(completion) => completion,
301                Err(e) => {
302                    log::error!("Codestral: Failed to fetch completion: {}", e);
303                    this.update(cx, |this, cx| {
304                        this.pending_request = None;
305                        cx.notify();
306                    })?;
307                    return Err(e);
308                }
309            };
310
311            if completion_text.trim().is_empty() {
312                log::debug!("Codestral: Completion was empty after trimming; ignoring");
313                this.update(cx, |this, cx| {
314                    this.pending_request = None;
315                    cx.notify();
316                })?;
317                return Ok(());
318            }
319
320            let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
321                vec![(cursor_position..cursor_position, completion_text.into())].into();
322            let edit_preview = buffer
323                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
324                .await;
325
326            this.update(cx, |this, cx| {
327                this.current_completion = Some(CurrentCompletion {
328                    snapshot,
329                    edits,
330                    edit_preview,
331                });
332                this.pending_request = None;
333                cx.notify();
334            })?;
335
336            Ok(())
337        }));
338    }
339
340    fn accept(&mut self, _cx: &mut Context<Self>) {
341        log::debug!("Codestral: Completion accepted");
342        self.pending_request = None;
343        self.current_completion = None;
344    }
345
346    fn discard(&mut self, _reason: EditPredictionDiscardReason, _cx: &mut Context<Self>) {
347        log::debug!("Codestral: Completion discarded");
348        self.pending_request = None;
349        self.current_completion = None;
350    }
351
352    /// Returns the completion suggestion, adjusted or invalidated based on user edits
353    fn suggest(
354        &mut self,
355        buffer: &Entity<Buffer>,
356        _cursor_position: Anchor,
357        cx: &mut Context<Self>,
358    ) -> Option<EditPrediction> {
359        let current_completion = self.current_completion.as_ref()?;
360        let buffer = buffer.read(cx);
361        let edits = current_completion.interpolate(&buffer.snapshot())?;
362        if edits.is_empty() {
363            return None;
364        }
365        Some(EditPrediction::Local {
366            id: None,
367            edits,
368            cursor_position: None,
369            edit_preview: Some(current_completion.edit_preview.clone()),
370        })
371    }
372}
373
374#[derive(Debug, Serialize, Deserialize)]
375pub struct CodestralRequest {
376    pub model: String,
377    pub prompt: String,
378    #[serde(skip_serializing_if = "Option::is_none")]
379    pub suffix: Option<String>,
380    #[serde(skip_serializing_if = "Option::is_none")]
381    pub max_tokens: Option<u32>,
382    #[serde(skip_serializing_if = "Option::is_none")]
383    pub temperature: Option<f32>,
384    #[serde(skip_serializing_if = "Option::is_none")]
385    pub top_p: Option<f32>,
386    #[serde(skip_serializing_if = "Option::is_none")]
387    pub stream: Option<bool>,
388    #[serde(skip_serializing_if = "Option::is_none")]
389    pub stop: Option<Vec<String>>,
390    #[serde(skip_serializing_if = "Option::is_none")]
391    pub random_seed: Option<u32>,
392    #[serde(skip_serializing_if = "Option::is_none")]
393    pub min_tokens: Option<u32>,
394}
395
396#[derive(Debug, Deserialize)]
397pub struct CodestralResponse {
398    pub id: String,
399    pub object: String,
400    pub model: String,
401    pub usage: Usage,
402    pub created: u64,
403    pub choices: Vec<Choice>,
404}
405
406#[derive(Debug, Deserialize)]
407pub struct Usage {
408    pub prompt_tokens: u32,
409    pub completion_tokens: u32,
410    pub total_tokens: u32,
411}
412
413#[derive(Debug, Deserialize)]
414pub struct Choice {
415    pub index: u32,
416    pub message: Message,
417    pub finish_reason: String,
418}
419
420#[derive(Debug, Deserialize)]
421pub struct Message {
422    pub content: String,
423    pub role: String,
424}