Add Codestral edit predictions provider (#34371)

David and Michael Sloan created

Release Notes:

- Added Codestral edit predictions provider which can be enabled by adding an API key in the Mistral section of agent settings.

![2025-07-13 11 35
33](https://github.com/user-attachments/assets/8bf599d7-33c7-4556-b878-6c645d69661f)


## Config

Get API key from https://console.mistral.ai/codestral and add it in the Mistral section of the agent settings. 

```
  "features": {
    "edit_prediction_provider": "codestral"
  },
  "edit_predictions": {
    "codestral": {
      "model": "codestral-latest",
      "max_tokens": 150
    }
  },
```

---------

Co-authored-by: Michael Sloan <michael@zed.dev>

Change summary

Cargo.lock                                                        |  23 
Cargo.toml                                                        |   2 
assets/settings/default.json                                      |  13 
crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs |   4 
crates/codestral/Cargo.toml                                       |  28 
crates/codestral/LICENSE-GPL                                      |   1 
crates/codestral/src/codestral.rs                                 | 381 +
crates/edit_prediction/src/edit_prediction.rs                     |  50 
crates/edit_prediction_button/Cargo.toml                          |   1 
crates/edit_prediction_button/src/edit_prediction_button.rs       |  82 
crates/language/src/language_settings.rs                          |  17 
crates/language_model/src/registry.rs                             |  12 
crates/language_models/src/language_models.rs                     |  46 
crates/language_models/src/provider/mistral.rs                    | 286 
crates/mistral/src/mistral.rs                                     |   1 
crates/settings/src/settings_content/language.rs                  |  20 
crates/zed/Cargo.toml                                             |   1 
crates/zed/src/zed/edit_prediction_registry.rs                    |  11 
crates/zeta/src/zeta.rs                                           |  53 
19 files changed, 913 insertions(+), 119 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3316,6 +3316,27 @@ dependencies = [
  "unicode-width",
 ]
 
+[[package]]
+name = "codestral"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "edit_prediction",
+ "edit_prediction_context",
+ "futures 0.3.31",
+ "gpui",
+ "language",
+ "language_models",
+ "log",
+ "mistral",
+ "serde",
+ "serde_json",
+ "smol",
+ "text",
+ "workspace-hack",
+ "zed-http-client",
+]
+
 [[package]]
 name = "collab"
 version = "0.44.0"
@@ -5115,6 +5136,7 @@ dependencies = [
  "anyhow",
  "client",
  "cloud_llm_client",
+ "codestral",
  "copilot",
  "edit_prediction",
  "editor",
@@ -20005,6 +20027,7 @@ dependencies = [
  "clap",
  "cli",
  "client",
+ "codestral",
  "collab_ui",
  "command_palette",
  "component",

Cargo.toml 🔗

@@ -164,6 +164,7 @@ members = [
     "crates/sum_tree",
     "crates/supermaven",
     "crates/supermaven_api",
+    "crates/codestral",
     "crates/svg_preview",
     "crates/system_specs",
     "crates/tab_switcher",
@@ -398,6 +399,7 @@ streaming_diff = { path = "crates/streaming_diff" }
 sum_tree = { path = "crates/sum_tree", package = "zed-sum-tree", version = "0.1.0" }
 supermaven = { path = "crates/supermaven" }
 supermaven_api = { path = "crates/supermaven_api" }
+codestral = { path = "crates/codestral" }
 system_specs = { path = "crates/system_specs" }
 tab_switcher = { path = "crates/tab_switcher" }
 task = { path = "crates/task" }

assets/settings/default.json 🔗

@@ -1311,15 +1311,18 @@
     //   "proxy": "",
     //   "proxy_no_verify": false
     // },
-    // Whether edit predictions are enabled when editing text threads.
-    // This setting has no effect if globally disabled.
-    "enabled_in_text_threads": true,
-
     "copilot": {
       "enterprise_uri": null,
       "proxy": null,
       "proxy_no_verify": null
-    }
+    },
+    "codestral": {
+      "model": null,
+      "max_tokens": null
+    },
+    // Whether edit predictions are enabled when editing text threads.
+    // This setting has no effect if globally disabled.
+    "enabled_in_text_threads": true
   },
   // Settings specific to journaling
   "journal": {

crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs 🔗

@@ -619,10 +619,10 @@ mod tests {
         cx.update(|_window, cx| {
             LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
                 registry.register_provider(
-                    FakeLanguageModelProvider::new(
+                    Arc::new(FakeLanguageModelProvider::new(
                         LanguageModelProviderId::new("someprovider"),
                         LanguageModelProviderName::new("Some Provider"),
-                    ),
+                    )),
                     cx,
                 );
             });

crates/codestral/Cargo.toml 🔗

@@ -0,0 +1,28 @@
+[package]
+name = "codestral"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lib]
+path = "src/codestral.rs"
+
+[dependencies]
+anyhow.workspace = true
+edit_prediction.workspace = true
+edit_prediction_context.workspace = true
+futures.workspace = true
+gpui.workspace = true
+http_client.workspace = true
+language.workspace = true
+language_models.workspace = true
+log.workspace = true
+mistral.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+smol.workspace = true
+text.workspace = true
+workspace-hack.workspace = true
+
+[dev-dependencies]

crates/codestral/src/codestral.rs 🔗

@@ -0,0 +1,381 @@
+use anyhow::{Context as _, Result};
+use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
+use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
+use futures::AsyncReadExt;
+use gpui::{App, Context, Entity, Task};
+use http_client::HttpClient;
+use language::{
+    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint,
+};
+use language_models::MistralLanguageModelProvider;
+use mistral::CODESTRAL_API_URL;
+use serde::{Deserialize, Serialize};
+use std::{
+    ops::Range,
+    sync::Arc,
+    time::{Duration, Instant},
+};
+use text::ToOffset;
+
+pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
+
+const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
+    max_bytes: 1050,
+    min_bytes: 525,
+    target_before_cursor_over_total_bytes: 0.66,
+};
+
+/// Represents a completion that has been received and processed from Codestral.
+/// This struct maintains the state needed to interpolate the completion as the user types.
+#[derive(Clone)]
+struct CurrentCompletion {
+    /// The buffer snapshot at the time the completion was generated.
+    /// Used to detect changes and interpolate edits.
+    snapshot: BufferSnapshot,
+    /// The edits that should be applied to transform the original text into the predicted text.
+    /// Each edit is a range in the buffer and the text to replace it with.
+    edits: Arc<[(Range<Anchor>, String)]>,
+    /// Preview of how the buffer will look after applying the edits.
+    edit_preview: EditPreview,
+}
+
+impl CurrentCompletion {
+    /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
+    /// Returns None if the user's edits conflict with the predicted edits.
+    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
+        edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
+    }
+}
+
+pub struct CodestralCompletionProvider {
+    http_client: Arc<dyn HttpClient>,
+    pending_request: Option<Task<Result<()>>>,
+    current_completion: Option<CurrentCompletion>,
+}
+
+impl CodestralCompletionProvider {
+    pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
+        Self {
+            http_client,
+            pending_request: None,
+            current_completion: None,
+        }
+    }
+
+    pub fn has_api_key(cx: &App) -> bool {
+        Self::api_key(cx).is_some()
+    }
+
+    fn api_key(cx: &App) -> Option<Arc<str>> {
+        MistralLanguageModelProvider::try_global(cx)
+            .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
+    }
+
+    /// Uses Codestral's Fill-in-the-Middle API for code completion.
+    async fn fetch_completion(
+        http_client: Arc<dyn HttpClient>,
+        api_key: &str,
+        prompt: String,
+        suffix: String,
+        model: String,
+        max_tokens: Option<u32>,
+    ) -> Result<String> {
+        let start_time = Instant::now();
+
+        log::debug!(
+            "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
+            model,
+            max_tokens
+        );
+
+        let request = CodestralRequest {
+            model,
+            prompt,
+            suffix: if suffix.is_empty() {
+                None
+            } else {
+                Some(suffix)
+            },
+            max_tokens: max_tokens.or(Some(350)),
+            temperature: Some(0.2),
+            top_p: Some(1.0),
+            stream: Some(false),
+            stop: None,
+            random_seed: None,
+            min_tokens: None,
+        };
+
+        let request_body = serde_json::to_string(&request)?;
+
+        log::debug!("Codestral: Sending FIM request");
+
+        let http_request = http_client::Request::builder()
+            .method(http_client::Method::POST)
+            .uri(format!("{}/v1/fim/completions", CODESTRAL_API_URL))
+            .header("Content-Type", "application/json")
+            .header("Authorization", format!("Bearer {}", api_key))
+            .body(http_client::AsyncBody::from(request_body))?;
+
+        let mut response = http_client.send(http_request).await?;
+        let status = response.status();
+
+        log::debug!("Codestral: Response status: {}", status);
+
+        if !status.is_success() {
+            let mut body = String::new();
+            response.body_mut().read_to_string(&mut body).await?;
+            return Err(anyhow::anyhow!(
+                "Codestral API error: {} - {}",
+                status,
+                body
+            ));
+        }
+
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
+
+        let elapsed = start_time.elapsed();
+
+        if let Some(choice) = codestral_response.choices.first() {
+            let completion = &choice.message.content;
+
+            log::debug!(
+                "Codestral: Completion received ({} tokens, {:.2}s)",
+                codestral_response.usage.completion_tokens,
+                elapsed.as_secs_f64()
+            );
+
+            // Return just the completion text for insertion at cursor
+            Ok(completion.clone())
+        } else {
+            log::error!("Codestral: No completion returned in response");
+            Err(anyhow::anyhow!("No completion returned from Codestral"))
+        }
+    }
+}
+
+impl EditPredictionProvider for CodestralCompletionProvider {
+    fn name() -> &'static str {
+        "codestral"
+    }
+
+    fn display_name() -> &'static str {
+        "Codestral"
+    }
+
+    fn show_completions_in_menu() -> bool {
+        true
+    }
+
+    fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
+        Self::api_key(cx).is_some()
+    }
+
+    fn is_refreshing(&self) -> bool {
+        self.pending_request.is_some()
+    }
+
+    fn refresh(
+        &mut self,
+        buffer: Entity<Buffer>,
+        cursor_position: language::Anchor,
+        debounce: bool,
+        cx: &mut Context<Self>,
+    ) {
+        log::debug!("Codestral: Refresh called (debounce: {})", debounce);
+
+        let Some(api_key) = Self::api_key(cx) else {
+            log::warn!("Codestral: No API key configured, skipping refresh");
+            return;
+        };
+
+        let snapshot = buffer.read(cx).snapshot();
+
+        // Check if current completion is still valid
+        if let Some(current_completion) = self.current_completion.as_ref() {
+            if current_completion.interpolate(&snapshot).is_some() {
+                return;
+            }
+        }
+
+        let http_client = self.http_client.clone();
+
+        // Get settings
+        let settings = all_language_settings(None, cx);
+        let model = settings
+            .edit_predictions
+            .codestral
+            .model
+            .clone()
+            .unwrap_or_else(|| "codestral-latest".to_string());
+        let max_tokens = settings.edit_predictions.codestral.max_tokens;
+
+        self.pending_request = Some(cx.spawn(async move |this, cx| {
+            if debounce {
+                log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
+                smol::Timer::after(DEBOUNCE_TIMEOUT).await;
+            }
+
+            let cursor_offset = cursor_position.to_offset(&snapshot);
+            let cursor_point = cursor_offset.to_point(&snapshot);
+            let excerpt = EditPredictionExcerpt::select_from_buffer(
+                cursor_point,
+                &snapshot,
+                &EXCERPT_OPTIONS,
+                None,
+            )
+            .context("Line containing cursor doesn't fit in excerpt max bytes")?;
+
+            let excerpt_text = excerpt.text(&snapshot);
+            let cursor_within_excerpt = cursor_offset
+                .saturating_sub(excerpt.range.start)
+                .min(excerpt_text.body.len());
+            let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
+            let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
+
+            let completion_text = match Self::fetch_completion(
+                http_client,
+                &api_key,
+                prompt,
+                suffix,
+                model,
+                max_tokens,
+            )
+            .await
+            {
+                Ok(completion) => completion,
+                Err(e) => {
+                    log::error!("Codestral: Failed to fetch completion: {}", e);
+                    this.update(cx, |this, cx| {
+                        this.pending_request = None;
+                        cx.notify();
+                    })?;
+                    return Err(e);
+                }
+            };
+
+            if completion_text.trim().is_empty() {
+                log::debug!("Codestral: Completion was empty after trimming; ignoring");
+                this.update(cx, |this, cx| {
+                    this.pending_request = None;
+                    cx.notify();
+                })?;
+                return Ok(());
+            }
+
+            let edits: Arc<[(Range<Anchor>, String)]> =
+                vec![(cursor_position..cursor_position, completion_text)].into();
+            let edit_preview = buffer
+                .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
+                .await;
+
+            this.update(cx, |this, cx| {
+                this.current_completion = Some(CurrentCompletion {
+                    snapshot,
+                    edits,
+                    edit_preview,
+                });
+                this.pending_request = None;
+                cx.notify();
+            })?;
+
+            Ok(())
+        }));
+    }
+
+    fn cycle(
+        &mut self,
+        _buffer: Entity<Buffer>,
+        _cursor_position: Anchor,
+        _direction: Direction,
+        _cx: &mut Context<Self>,
+    ) {
+        // Codestral doesn't support multiple completions, so cycling does nothing
+    }
+
+    fn accept(&mut self, _cx: &mut Context<Self>) {
+        log::debug!("Codestral: Completion accepted");
+        self.pending_request = None;
+        self.current_completion = None;
+    }
+
+    fn discard(&mut self, _cx: &mut Context<Self>) {
+        log::debug!("Codestral: Completion discarded");
+        self.pending_request = None;
+        self.current_completion = None;
+    }
+
+    /// Returns the completion suggestion, adjusted or invalidated based on user edits
+    fn suggest(
+        &mut self,
+        buffer: &Entity<Buffer>,
+        _cursor_position: Anchor,
+        cx: &mut Context<Self>,
+    ) -> Option<EditPrediction> {
+        let current_completion = self.current_completion.as_ref()?;
+        let buffer = buffer.read(cx);
+        let edits = current_completion.interpolate(&buffer.snapshot())?;
+        if edits.is_empty() {
+            return None;
+        }
+        Some(EditPrediction::Local {
+            id: None,
+            edits,
+            edit_preview: Some(current_completion.edit_preview.clone()),
+        })
+    }
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CodestralRequest {
+    pub model: String,
+    pub prompt: String,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub suffix: Option<String>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub max_tokens: Option<u32>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub temperature: Option<f32>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub top_p: Option<f32>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub stream: Option<bool>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub stop: Option<Vec<String>>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub random_seed: Option<u32>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub min_tokens: Option<u32>,
+}
+
+#[derive(Debug, Deserialize)]
+pub struct CodestralResponse {
+    pub id: String,
+    pub object: String,
+    pub model: String,
+    pub usage: Usage,
+    pub created: u64,
+    pub choices: Vec<Choice>,
+}
+
+#[derive(Debug, Deserialize)]
+pub struct Usage {
+    pub prompt_tokens: u32,
+    pub completion_tokens: u32,
+    pub total_tokens: u32,
+}
+
+#[derive(Debug, Deserialize)]
+pub struct Choice {
+    pub index: u32,
+    pub message: Message,
+    pub finish_reason: String,
+}
+
+#[derive(Debug, Deserialize)]
+pub struct Message {
+    pub content: String,
+    pub role: String,
+}

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -2,7 +2,7 @@ use std::ops::Range;
 
 use client::EditPredictionUsage;
 use gpui::{App, Context, Entity, SharedString};
-use language::Buffer;
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt};
 
 // TODO: Find a better home for `Direction`.
 //
@@ -242,3 +242,51 @@ where
         self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx))
     }
 }
+
+/// Returns edits updated based on user edits since the old snapshot. None is returned if any user
+/// edit is not a prefix of a predicted insertion.
+pub fn interpolate_edits(
+    old_snapshot: &BufferSnapshot,
+    new_snapshot: &BufferSnapshot,
+    current_edits: &[(Range<Anchor>, String)],
+) -> Option<Vec<(Range<Anchor>, String)>> {
+    let mut edits = Vec::new();
+
+    let mut model_edits = current_edits.iter().peekable();
+    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
+        while let Some((model_old_range, _)) = model_edits.peek() {
+            let model_old_range = model_old_range.to_offset(old_snapshot);
+            if model_old_range.end < user_edit.old.start {
+                let (model_old_range, model_new_text) = model_edits.next().unwrap();
+                edits.push((model_old_range.clone(), model_new_text.clone()));
+            } else {
+                break;
+            }
+        }
+
+        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
+            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
+            if user_edit.old == model_old_offset_range {
+                let user_new_text = new_snapshot
+                    .text_for_range(user_edit.new.clone())
+                    .collect::<String>();
+
+                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
+                    if !model_suffix.is_empty() {
+                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
+                        edits.push((anchor..anchor, model_suffix.to_string()));
+                    }
+
+                    model_edits.next();
+                    continue;
+                }
+            }
+        }
+
+        return None;
+    }
+
+    edits.extend(model_edits.cloned());
+
+    if edits.is_empty() { None } else { Some(edits) }
+}

crates/edit_prediction_button/Cargo.toml 🔗

@@ -16,6 +16,7 @@ doctest = false
 anyhow.workspace = true
 client.workspace = true
 cloud_llm_client.workspace = true
+codestral.workspace = true
 copilot.workspace = true
 editor.workspace = true
 feature_flags.workspace = true

crates/edit_prediction_button/src/edit_prediction_button.rs 🔗

@@ -1,6 +1,7 @@
 use anyhow::Result;
 use client::{UserStore, zed_urls};
 use cloud_llm_client::UsageLimit;
+use codestral::CodestralCompletionProvider;
 use copilot::{Copilot, Status};
 use editor::{Editor, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll};
 use feature_flags::{FeatureFlagAppExt, PredictEditsRateCompletionsFeatureFlag};
@@ -234,6 +235,67 @@ impl Render for EditPredictionButton {
                 )
             }
 
+            EditPredictionProvider::Codestral => {
+                let enabled = self.editor_enabled.unwrap_or(true);
+                let has_api_key = CodestralCompletionProvider::has_api_key(cx);
+                let fs = self.fs.clone();
+                let this = cx.entity();
+
+                div().child(
+                    PopoverMenu::new("codestral")
+                        .menu(move |window, cx| {
+                            if has_api_key {
+                                Some(this.update(cx, |this, cx| {
+                                    this.build_codestral_context_menu(window, cx)
+                                }))
+                            } else {
+                                Some(ContextMenu::build(window, cx, |menu, _, _| {
+                                    let fs = fs.clone();
+                                    menu.entry("Use Zed AI instead", None, move |_, cx| {
+                                        set_completion_provider(
+                                            fs.clone(),
+                                            cx,
+                                            EditPredictionProvider::Zed,
+                                        )
+                                    })
+                                    .separator()
+                                    .entry(
+                                        "Configure Codestral API Key",
+                                        None,
+                                        move |window, cx| {
+                                            window.dispatch_action(
+                                                zed_actions::agent::OpenSettings.boxed_clone(),
+                                                cx,
+                                            );
+                                        },
+                                    )
+                                }))
+                            }
+                        })
+                        .anchor(Corner::BottomRight)
+                        .trigger_with_tooltip(
+                            IconButton::new("codestral-icon", IconName::AiMistral)
+                                .shape(IconButtonShape::Square)
+                                .when(!has_api_key, |this| {
+                                    this.indicator(Indicator::dot().color(Color::Error))
+                                        .indicator_border_color(Some(
+                                            cx.theme().colors().status_bar_background,
+                                        ))
+                                })
+                                .when(has_api_key && !enabled, |this| {
+                                    this.indicator(Indicator::dot().color(Color::Ignored))
+                                        .indicator_border_color(Some(
+                                            cx.theme().colors().status_bar_background,
+                                        ))
+                                }),
+                            move |window, cx| {
+                                Tooltip::for_action("Codestral", &ToggleMenu, window, cx)
+                            },
+                        )
+                        .with_handle(self.popover_menu_handle.clone()),
+                )
+            }
+
             EditPredictionProvider::Zed => {
                 let enabled = self.editor_enabled.unwrap_or(true);
 
@@ -493,6 +555,7 @@ impl EditPredictionButton {
             EditPredictionProvider::Zed
                 | EditPredictionProvider::Copilot
                 | EditPredictionProvider::Supermaven
+                | EditPredictionProvider::Codestral
         ) {
             menu = menu
                 .separator()
@@ -719,6 +782,25 @@ impl EditPredictionButton {
         })
     }
 
+    fn build_codestral_context_menu(
+        &self,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) -> Entity<ContextMenu> {
+        let fs = self.fs.clone();
+        ContextMenu::build(window, cx, |menu, window, cx| {
+            self.build_language_settings_menu(menu, window, cx)
+                .separator()
+                .entry("Use Zed AI instead", None, move |_, cx| {
+                    set_completion_provider(fs.clone(), cx, EditPredictionProvider::Zed)
+                })
+                .separator()
+                .entry("Configure Codestral API Key", None, move |window, cx| {
+                    window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx);
+                })
+        })
+    }
+
     fn build_zeta_context_menu(
         &self,
         window: &mut Window,

crates/language/src/language_settings.rs 🔗

@@ -377,6 +377,8 @@ pub struct EditPredictionSettings {
     pub mode: settings::EditPredictionsMode,
     /// Settings specific to GitHub Copilot.
     pub copilot: CopilotSettings,
+    /// Settings specific to Codestral.
+    pub codestral: CodestralSettings,
     /// Whether edit predictions are enabled in the assistant panel.
     /// This setting has no effect if globally disabled.
     pub enabled_in_text_threads: bool,
@@ -412,6 +414,14 @@ pub struct CopilotSettings {
     pub enterprise_uri: Option<String>,
 }
 
+#[derive(Clone, Debug, Default)]
+pub struct CodestralSettings {
+    /// Model to use for completions.
+    pub model: Option<String>,
+    /// Maximum tokens to generate.
+    pub max_tokens: Option<u32>,
+}
+
 impl AllLanguageSettings {
     /// Returns the [`LanguageSettings`] for the language with the specified name.
     pub fn language<'a>(
@@ -622,6 +632,12 @@ impl settings::Settings for AllLanguageSettings {
             enterprise_uri: copilot.enterprise_uri,
         };
 
+        let codestral = edit_predictions.codestral.unwrap();
+        let codestral_settings = CodestralSettings {
+            model: codestral.model,
+            max_tokens: codestral.max_tokens,
+        };
+
         let enabled_in_text_threads = edit_predictions.enabled_in_text_threads.unwrap();
 
         let mut file_types: FxHashMap<Arc<str>, GlobSet> = FxHashMap::default();
@@ -655,6 +671,7 @@ impl settings::Settings for AllLanguageSettings {
                     .collect(),
                 mode: edit_predictions_mode,
                 copilot: copilot_settings,
+                codestral: codestral_settings,
                 enabled_in_text_threads,
             },
             defaults: default_language_settings,

crates/language_model/src/registry.rs 🔗

@@ -118,14 +118,14 @@ impl LanguageModelRegistry {
     }
 
     #[cfg(any(test, feature = "test-support"))]
-    pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
-        let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
+    pub fn test(cx: &mut App) -> Arc<crate::fake_provider::FakeLanguageModelProvider> {
+        let fake_provider = Arc::new(crate::fake_provider::FakeLanguageModelProvider::default());
         let registry = cx.new(|cx| {
             let mut registry = Self::default();
             registry.register_provider(fake_provider.clone(), cx);
             let model = fake_provider.provided_models(cx)[0].clone();
             let configured_model = ConfiguredModel {
-                provider: Arc::new(fake_provider.clone()),
+                provider: fake_provider.clone(),
                 model,
             };
             registry.set_default_model(Some(configured_model), cx);
@@ -137,7 +137,7 @@ impl LanguageModelRegistry {
 
     pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
         &mut self,
-        provider: T,
+        provider: Arc<T>,
         cx: &mut Context<Self>,
     ) {
         let id = provider.id();
@@ -152,7 +152,7 @@ impl LanguageModelRegistry {
             subscription.detach();
         }
 
-        self.providers.insert(id.clone(), Arc::new(provider));
+        self.providers.insert(id.clone(), provider);
         cx.emit(Event::AddedProvider(id));
     }
 
@@ -395,7 +395,7 @@ mod tests {
     fn test_register_providers(cx: &mut App) {
         let registry = cx.new(|_| LanguageModelRegistry::default());
 
-        let provider = FakeLanguageModelProvider::default();
+        let provider = Arc::new(FakeLanguageModelProvider::default());
         registry.update(cx, |registry, cx| {
             registry.register_provider(provider.clone(), cx);
         });

crates/language_models/src/language_models.rs 🔗

@@ -18,7 +18,7 @@ use crate::provider::cloud::CloudLanguageModelProvider;
 use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
 use crate::provider::google::GoogleLanguageModelProvider;
 use crate::provider::lmstudio::LmStudioLanguageModelProvider;
-use crate::provider::mistral::MistralLanguageModelProvider;
+pub use crate::provider::mistral::MistralLanguageModelProvider;
 use crate::provider::ollama::OllamaLanguageModelProvider;
 use crate::provider::open_ai::OpenAiLanguageModelProvider;
 use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
@@ -87,11 +87,11 @@ fn register_openai_compatible_providers(
     for provider_id in new {
         if !old.contains(provider_id) {
             registry.register_provider(
-                OpenAiCompatibleLanguageModelProvider::new(
+                Arc::new(OpenAiCompatibleLanguageModelProvider::new(
                     provider_id.clone(),
                     client.http_client(),
                     cx,
-                ),
+                )),
                 cx,
             );
         }
@@ -105,50 +105,62 @@ fn register_language_model_providers(
     cx: &mut Context<LanguageModelRegistry>,
 ) {
     registry.register_provider(
-        CloudLanguageModelProvider::new(user_store, client.clone(), cx),
+        Arc::new(CloudLanguageModelProvider::new(
+            user_store,
+            client.clone(),
+            cx,
+        )),
+        cx,
+    );
+    registry.register_provider(
+        Arc::new(AnthropicLanguageModelProvider::new(
+            client.http_client(),
+            cx,
+        )),
         cx,
     );
-
     registry.register_provider(
-        AnthropicLanguageModelProvider::new(client.http_client(), cx),
+        Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
         cx,
     );
     registry.register_provider(
-        OpenAiLanguageModelProvider::new(client.http_client(), cx),
+        Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
         cx,
     );
     registry.register_provider(
-        OllamaLanguageModelProvider::new(client.http_client(), cx),
+        Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
         cx,
     );
     registry.register_provider(
-        LmStudioLanguageModelProvider::new(client.http_client(), cx),
+        Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
         cx,
     );
     registry.register_provider(
-        DeepSeekLanguageModelProvider::new(client.http_client(), cx),
+        Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
         cx,
     );
     registry.register_provider(
-        GoogleLanguageModelProvider::new(client.http_client(), cx),
+        MistralLanguageModelProvider::global(client.http_client(), cx),
         cx,
     );
     registry.register_provider(
-        MistralLanguageModelProvider::new(client.http_client(), cx),
+        Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
         cx,
     );
     registry.register_provider(
-        BedrockLanguageModelProvider::new(client.http_client(), cx),
+        Arc::new(OpenRouterLanguageModelProvider::new(
+            client.http_client(),
+            cx,
+        )),
         cx,
     );
     registry.register_provider(
-        OpenRouterLanguageModelProvider::new(client.http_client(), cx),
+        Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
         cx,
     );
     registry.register_provider(
-        VercelLanguageModelProvider::new(client.http_client(), cx),
+        Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
         cx,
     );
-    registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
-    registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
+    registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
 }

crates/language_models/src/provider/mistral.rs 🔗

@@ -1,7 +1,8 @@
 use anyhow::{Result, anyhow};
 use collections::BTreeMap;
+use fs::Fs;
 use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
 use http_client::HttpClient;
 use language_model::{
     AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
@@ -10,9 +11,9 @@ use language_model::{
     LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
     RateLimiter, Role, StopReason, TokenUsage,
 };
-use mistral::{MISTRAL_API_URL, StreamResponse};
+use mistral::{CODESTRAL_API_URL, MISTRAL_API_URL, StreamResponse};
 pub use settings::MistralAvailableModel as AvailableModel;
-use settings::{Settings, SettingsStore};
+use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file};
 use std::collections::HashMap;
 use std::pin::Pin;
 use std::str::FromStr;
@@ -31,6 +32,9 @@ const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new(
 const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
 static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 
+const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY";
+static CODESTRAL_API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME);
+
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct MistralSettings {
     pub api_url: String,
@@ -44,6 +48,7 @@ pub struct MistralLanguageModelProvider {
 
 pub struct State {
     api_key_state: ApiKeyState,
+    codestral_api_key_state: ApiKeyState,
 }
 
 impl State {
@@ -57,6 +62,19 @@ impl State {
             .store(api_url, api_key, |this| &mut this.api_key_state, cx)
     }
 
+    fn set_codestral_api_key(
+        &mut self,
+        api_key: Option<String>,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<()>> {
+        self.codestral_api_key_state.store(
+            CODESTRAL_API_URL.into(),
+            api_key,
+            |this| &mut this.codestral_api_key_state,
+            cx,
+        )
+    }
+
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
         let api_url = MistralLanguageModelProvider::api_url(cx);
         self.api_key_state.load_if_needed(
@@ -66,10 +84,34 @@ impl State {
             cx,
         )
     }
+
+    fn authenticate_codestral(
+        &mut self,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<(), AuthenticateError>> {
+        self.codestral_api_key_state.load_if_needed(
+            CODESTRAL_API_URL.into(),
+            &CODESTRAL_API_KEY_ENV_VAR,
+            |this| &mut this.codestral_api_key_state,
+            cx,
+        )
+    }
 }
 
+struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
+
+impl Global for GlobalMistralLanguageModelProvider {}
+
 impl MistralLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+    pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
+        cx.try_global::<GlobalMistralLanguageModelProvider>()
+            .map(|this| &this.0)
+    }
+
+    pub fn global(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Arc<Self> {
+        if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
+            return this.0.clone();
+        }
         let state = cx.new(|cx| {
             cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
                 let api_url = Self::api_url(cx);
@@ -84,10 +126,22 @@ impl MistralLanguageModelProvider {
             .detach();
             State {
                 api_key_state: ApiKeyState::new(Self::api_url(cx)),
+                codestral_api_key_state: ApiKeyState::new(CODESTRAL_API_URL.into()),
             }
         });
 
-        Self { http_client, state }
+        let this = Arc::new(Self { http_client, state });
+        cx.set_global(GlobalMistralLanguageModelProvider(this));
+        cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
+    }
+
+    pub fn load_codestral_api_key(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
+        self.state
+            .update(cx, |state, cx| state.authenticate_codestral(cx))
+    }
+
+    pub fn codestral_api_key(&self, url: &str, cx: &App) -> Option<Arc<str>> {
+        self.state.read(cx).codestral_api_key_state.key(url)
     }
 
     fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
@@ -691,6 +745,7 @@ struct RawToolCall {
 
 struct ConfigurationView {
     api_key_editor: Entity<SingleLineInput>,
+    codestral_api_key_editor: Entity<SingleLineInput>,
     state: Entity<State>,
     load_credentials_task: Option<Task<()>>,
 }
@@ -699,6 +754,8 @@ impl ConfigurationView {
     fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
         let api_key_editor =
             cx.new(|cx| SingleLineInput::new(window, cx, "0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2"));
+        let codestral_api_key_editor =
+            cx.new(|cx| SingleLineInput::new(window, cx, "0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2"));
 
         cx.observe(&state, |_, _, cx| {
             cx.notify();
@@ -715,6 +772,12 @@ impl ConfigurationView {
                     // We don't log an error, because "not signed in" is also an error.
                     let _ = task.await;
                 }
+                if let Some(task) = state
+                    .update(cx, |state, cx| state.authenticate_codestral(cx))
+                    .log_err()
+                {
+                    let _ = task.await;
+                }
 
                 this.update(cx, |this, cx| {
                     this.load_credentials_task = None;
@@ -726,6 +789,7 @@ impl ConfigurationView {
 
         Self {
             api_key_editor,
+            codestral_api_key_editor,
             state,
             load_credentials_task,
         }
@@ -763,47 +827,92 @@ impl ConfigurationView {
         .detach_and_log_err(cx);
     }
 
-    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
-        !self.state.read(cx).is_authenticated()
+    fn save_codestral_api_key(
+        &mut self,
+        _: &menu::Confirm,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) {
+        let api_key = self
+            .codestral_api_key_editor
+            .read(cx)
+            .text(cx)
+            .trim()
+            .to_string();
+        if api_key.is_empty() {
+            return;
+        }
+
+        // url changes can cause the editor to be displayed again
+        self.codestral_api_key_editor
+            .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+        let state = self.state.clone();
+        cx.spawn_in(window, async move |_, cx| {
+            state
+                .update(cx, |state, cx| {
+                    state.set_codestral_api_key(Some(api_key), cx)
+                })?
+                .await?;
+            cx.update(|_window, cx| {
+                set_edit_prediction_provider(EditPredictionProvider::Codestral, cx)
+            })
+        })
+        .detach_and_log_err(cx);
     }
-}
 
-impl Render for ConfigurationView {
-    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
-        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
+    fn reset_codestral_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+        self.codestral_api_key_editor
+            .update(cx, |editor, cx| editor.set_text("", window, cx));
 
-        if self.load_credentials_task.is_some() {
-            div().child(Label::new("Loading credentials...")).into_any()
-        } else if self.should_render_editor(cx) {
+        let state = self.state.clone();
+        cx.spawn_in(window, async move |_, cx| {
+            state
+                .update(cx, |state, cx| state.set_codestral_api_key(None, cx))?
+                .await?;
+            cx.update(|_window, cx| set_edit_prediction_provider(EditPredictionProvider::Zed, cx))
+        })
+        .detach_and_log_err(cx);
+    }
+
+    fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
+        !self.state.read(cx).is_authenticated()
+    }
+
+    fn render_codestral_api_key_editor(&mut self, cx: &mut Context<Self>) -> AnyElement {
+        let key_state = &self.state.read(cx).codestral_api_key_state;
+        let should_show_editor = !key_state.has_key();
+        let env_var_set = key_state.is_from_env_var();
+        if should_show_editor {
             v_flex()
+                .id("codestral")
                 .size_full()
-                .on_action(cx.listener(Self::save_api_key))
-                .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
+                .mt_2()
+                .on_action(cx.listener(Self::save_codestral_api_key))
+                .child(Label::new(
+                    "To use Codestral as an edit prediction provider, \
+                    you need to add a Codestral-specific API key. Follow these steps:",
+                ))
                 .child(
                     List::new()
                         .child(InstructionListItem::new(
                             "Create one by visiting",
-                            Some("Mistral's console"),
-                            Some("https://console.mistral.ai/api-keys"),
+                            Some("the Codestral section of Mistral's console"),
+                            Some("https://console.mistral.ai/codestral"),
                         ))
-                        .child(InstructionListItem::text_only(
-                            "Ensure your Mistral account has credits",
-                        ))
-                        .child(InstructionListItem::text_only(
-                            "Paste your API key below and hit enter to start using the assistant",
-                        )),
+                        .child(InstructionListItem::text_only("Paste your API key below and hit enter")),
                 )
-                .child(self.api_key_editor.clone())
+                .child(self.codestral_api_key_editor.clone())
                 .child(
                     Label::new(
-                        format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
+                        format!("You can also assign the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
                     )
                     .size(LabelSize::Small).color(Color::Muted),
-                )
-                .into_any()
+                ).into_any()
         } else {
             h_flex()
-                .mt_1()
+                .id("codestral")
+                .mt_2()
                 .p_1()
                 .justify_between()
                 .rounded_md()
@@ -815,14 +924,9 @@ impl Render for ConfigurationView {
                         .gap_1()
                         .child(Icon::new(IconName::Check).color(Color::Success))
                         .child(Label::new(if env_var_set {
-                            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
+                            format!("API key set in {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable")
                         } else {
-                            let api_url = MistralLanguageModelProvider::api_url(cx);
-                            if api_url == MISTRAL_API_URL {
-                                "API key configured".to_string()
-                            } else {
-                                format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
-                            }
+                            "Codestral API key configured".to_string()
                         })),
                 )
                 .child(
@@ -833,15 +937,121 @@ impl Render for ConfigurationView {
                         .icon_position(IconPosition::Start)
                         .disabled(env_var_set)
                         .when(env_var_set, |this| {
-                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
+                            this.tooltip(Tooltip::text(format!(
+                                "To reset your API key, \
+                                unset the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable."
+                            )))
                         })
-                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
+                        .on_click(
+                            cx.listener(|this, _, window, cx| this.reset_codestral_api_key(window, cx)),
+                        ),
+                ).into_any()
+        }
+    }
+}
+
+impl Render for ConfigurationView {
+    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
+
+        if self.load_credentials_task.is_some() {
+            div().child(Label::new("Loading credentials...")).into_any()
+        } else if self.should_render_api_key_editor(cx) {
+            v_flex()
+                .size_full()
+                .on_action(cx.listener(Self::save_api_key))
+                .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
+                .child(
+                    List::new()
+                        .child(InstructionListItem::new(
+                            "Create one by visiting",
+                            Some("Mistral's console"),
+                            Some("https://console.mistral.ai/api-keys"),
+                        ))
+                        .child(InstructionListItem::text_only(
+                            "Ensure your Mistral account has credits",
+                        ))
+                        .child(InstructionListItem::text_only(
+                            "Paste your API key below and hit enter to start using the assistant",
+                        )),
                 )
+                .child(self.api_key_editor.clone())
+                .child(
+                    Label::new(
+                        format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
+                    )
+                    .size(LabelSize::Small).color(Color::Muted),
+                )
+                .child(self.render_codestral_api_key_editor(cx))
+                .into_any()
+        } else {
+            v_flex()
+                .size_full()
+                .child(
+                    h_flex()
+                        .mt_1()
+                        .p_1()
+                        .justify_between()
+                        .rounded_md()
+                        .border_1()
+                        .border_color(cx.theme().colors().border)
+                        .bg(cx.theme().colors().background)
+                        .child(
+                            h_flex()
+                                .gap_1()
+                                .child(Icon::new(IconName::Check).color(Color::Success))
+                                .child(Label::new(if env_var_set {
+                                    format!(
+                                        "API key set in {API_KEY_ENV_VAR_NAME} environment variable"
+                                    )
+                                } else {
+                                    let api_url = MistralLanguageModelProvider::api_url(cx);
+                                    if api_url == MISTRAL_API_URL {
+                                        "API key configured".to_string()
+                                    } else {
+                                        format!(
+                                            "API key configured for {}",
+                                            truncate_and_trailoff(&api_url, 32)
+                                        )
+                                    }
+                                })),
+                        )
+                        .child(
+                            Button::new("reset-key", "Reset Key")
+                                .label_size(LabelSize::Small)
+                                .icon(Some(IconName::Trash))
+                                .icon_size(IconSize::Small)
+                                .icon_position(IconPosition::Start)
+                                .disabled(env_var_set)
+                                .when(env_var_set, |this| {
+                                    this.tooltip(Tooltip::text(format!(
+                                        "To reset your API key, \
+                                        unset the {API_KEY_ENV_VAR_NAME} environment variable."
+                                    )))
+                                })
+                                .on_click(cx.listener(|this, _, window, cx| {
+                                    this.reset_api_key(window, cx)
+                                })),
+                        ),
+                )
+                .child(self.render_codestral_api_key_editor(cx))
                 .into_any()
         }
     }
 }
 
+fn set_edit_prediction_provider(provider: EditPredictionProvider, cx: &mut App) {
+    let fs = <dyn Fs>::global(cx);
+    update_settings_file(fs, cx, move |settings, _| {
+        settings
+            .project
+            .all_languages
+            .features
+            .get_or_insert_default()
+            .edit_prediction_provider = Some(provider);
+    });
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;

crates/mistral/src/mistral.rs 🔗

@@ -7,6 +7,7 @@ use std::convert::TryFrom;
 use strum::EnumIter;
 
 pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
+pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai";
 
 #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 #[serde(rename_all = "lowercase")]

crates/settings/src/settings_content/language.rs 🔗

@@ -82,6 +82,7 @@ pub enum EditPredictionProvider {
     Copilot,
     Supermaven,
     Zed,
+    Codestral,
 }
 
 impl EditPredictionProvider {
@@ -90,7 +91,8 @@ impl EditPredictionProvider {
             EditPredictionProvider::Zed => true,
             EditPredictionProvider::None
             | EditPredictionProvider::Copilot
-            | EditPredictionProvider::Supermaven => false,
+            | EditPredictionProvider::Supermaven
+            | EditPredictionProvider::Codestral => false,
         }
     }
 }
@@ -108,6 +110,8 @@ pub struct EditPredictionSettingsContent {
     pub mode: Option<EditPredictionsMode>,
     /// Settings specific to GitHub Copilot.
     pub copilot: Option<CopilotSettingsContent>,
+    /// Settings specific to Codestral.
+    pub codestral: Option<CodestralSettingsContent>,
     /// Whether edit predictions are enabled in the assistant prompt editor.
     /// This has no effect if globally disabled.
     pub enabled_in_text_threads: Option<bool>,
@@ -130,6 +134,20 @@ pub struct CopilotSettingsContent {
     pub enterprise_uri: Option<String>,
 }
 
+#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, MergeFrom, PartialEq)]
+pub struct CodestralSettingsContent {
+    /// Model to use for completions.
+    ///
+    /// Default: "codestral-latest"
+    #[serde(default)]
+    pub model: Option<String>,
+    /// Maximum tokens to generate.
+    ///
+    /// Default: 150
+    #[serde(default)]
+    pub max_tokens: Option<u32>,
+}
+
 /// The mode in which edit predictions should be displayed.
 #[derive(
     Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema, MergeFrom,

crates/zed/Cargo.toml 🔗

@@ -39,6 +39,7 @@ channel.workspace = true
 clap.workspace = true
 cli.workspace = true
 client.workspace = true
+codestral.workspace = true
 collab_ui.workspace = true
 collections.workspace = true
 command_palette.workspace = true

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -1,9 +1,11 @@
 use client::{Client, UserStore};
+use codestral::CodestralCompletionProvider;
 use collections::HashMap;
 use copilot::{Copilot, CopilotCompletionProvider};
 use editor::Editor;
 use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
 use language::language_settings::{EditPredictionProvider, all_language_settings};
+use language_models::MistralLanguageModelProvider;
 use settings::SettingsStore;
 use std::{cell::RefCell, rc::Rc, sync::Arc};
 use supermaven::{Supermaven, SupermavenCompletionProvider};
@@ -109,6 +111,10 @@ fn assign_edit_prediction_providers(
     user_store: Entity<UserStore>,
     cx: &mut App,
 ) {
+    if provider == EditPredictionProvider::Codestral {
+        let mistral = MistralLanguageModelProvider::global(client.http_client(), cx);
+        mistral.load_codestral_api_key(cx).detach();
+    }
     for (editor, window) in editors.borrow().iter() {
         _ = window.update(cx, |_window, window, cx| {
             _ = editor.update(cx, |editor, cx| {
@@ -189,6 +195,11 @@ fn assign_edit_prediction_provider(
                 editor.set_edit_prediction_provider(Some(provider), window, cx);
             }
         }
+        EditPredictionProvider::Codestral => {
+            let http_client = client.http_client();
+            let provider = cx.new(|_| CodestralCompletionProvider::new(http_client));
+            editor.set_edit_prediction_provider(Some(provider), window, cx);
+        }
         EditPredictionProvider::Zed => {
             if user_store.read(cx).current_user().is_some() {
                 let mut worktree = None;

crates/zeta/src/zeta.rs 🔗

@@ -151,56 +151,10 @@ impl EditPrediction {
     }
 
     fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
-        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
+        edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
     }
 }
 
-fn interpolate(
-    old_snapshot: &BufferSnapshot,
-    new_snapshot: &BufferSnapshot,
-    current_edits: Arc<[(Range<Anchor>, String)]>,
-) -> Option<Vec<(Range<Anchor>, String)>> {
-    let mut edits = Vec::new();
-
-    let mut model_edits = current_edits.iter().peekable();
-    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
-        while let Some((model_old_range, _)) = model_edits.peek() {
-            let model_old_range = model_old_range.to_offset(old_snapshot);
-            if model_old_range.end < user_edit.old.start {
-                let (model_old_range, model_new_text) = model_edits.next().unwrap();
-                edits.push((model_old_range.clone(), model_new_text.clone()));
-            } else {
-                break;
-            }
-        }
-
-        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
-            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
-            if user_edit.old == model_old_offset_range {
-                let user_new_text = new_snapshot
-                    .text_for_range(user_edit.new.clone())
-                    .collect::<String>();
-
-                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
-                    if !model_suffix.is_empty() {
-                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
-                        edits.push((anchor..anchor, model_suffix.to_string()));
-                    }
-
-                    model_edits.next();
-                    continue;
-                }
-            }
-        }
-
-        return None;
-    }
-
-    edits.extend(model_edits.cloned());
-
-    if edits.is_empty() { None } else { Some(edits) }
-}
-
 impl std::fmt::Debug for EditPrediction {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         f.debug_struct("EditPrediction")
@@ -769,10 +723,11 @@ impl Zeta {
 
             let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
                 let edits = edits.clone();
-                |buffer, cx| {
+                move |buffer, cx| {
                     let new_snapshot = buffer.snapshot();
                     let edits: Arc<[(Range<Anchor>, String)]> =
-                        interpolate(&snapshot, &new_snapshot, edits)?.into();
+                        edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)?
+                            .into();
                     Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
                 }
             })?