From 5fd187769d52481ab75348db8839824ccd54491d Mon Sep 17 00:00:00 2001 From: David <688326+dvcrn@users.noreply.github.com> Date: Thu, 9 Oct 2025 01:02:21 +0700 Subject: [PATCH] Add Codestral edit predictions provider (#34371) Release Notes: - Added Codestral edit predictions provider which can be enabled by adding an API key in the Mistral section of agent settings. ![2025-07-13 11 35 33](https://github.com/user-attachments/assets/8bf599d7-33c7-4556-b878-6c645d69661f) ## Config Get API key from https://console.mistral.ai/codestral and add it in the Mistral section of the agent settings. ``` "features": { "edit_prediction_provider": "codestral" }, "edit_predictions": { "codestral": { "model": "codestral-latest", "max_tokens": 150 } }, ``` --------- Co-authored-by: Michael Sloan --- Cargo.lock | 23 ++ Cargo.toml | 2 + assets/settings/default.json | 13 +- .../add_llm_provider_modal.rs | 4 +- crates/codestral/Cargo.toml | 28 ++ crates/codestral/LICENSE-GPL | 1 + crates/codestral/src/codestral.rs | 381 ++++++++++++++++++ crates/edit_prediction/src/edit_prediction.rs | 50 ++- crates/edit_prediction_button/Cargo.toml | 1 + .../src/edit_prediction_button.rs | 82 ++++ crates/language/src/language_settings.rs | 17 + crates/language_model/src/registry.rs | 12 +- crates/language_models/src/language_models.rs | 46 ++- .../language_models/src/provider/mistral.rs | 286 +++++++++++-- crates/mistral/src/mistral.rs | 1 + .../settings/src/settings_content/language.rs | 20 +- crates/zed/Cargo.toml | 1 + .../zed/src/zed/edit_prediction_registry.rs | 11 + crates/zeta/src/zeta.rs | 53 +-- 19 files changed, 913 insertions(+), 119 deletions(-) create mode 100644 crates/codestral/Cargo.toml create mode 120000 crates/codestral/LICENSE-GPL create mode 100644 crates/codestral/src/codestral.rs diff --git a/Cargo.lock b/Cargo.lock index 78160c100c81b0f524d8194a31f7e62f7c73d61e..8aba19b5c0ee2777fb0809956712bbaf74997c5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3316,6 +3316,27 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "codestral" +version = "0.1.0" +dependencies = [ + "anyhow", + "edit_prediction", + "edit_prediction_context", + "futures 0.3.31", + "gpui", + "language", + "language_models", + "log", + "mistral", + "serde", + "serde_json", + "smol", + "text", + "workspace-hack", + "zed-http-client", +] + [[package]] name = "collab" version = "0.44.0" @@ -5115,6 +5136,7 @@ dependencies = [ "anyhow", "client", "cloud_llm_client", + "codestral", "copilot", "edit_prediction", "editor", @@ -20005,6 +20027,7 @@ dependencies = [ "clap", "cli", "client", + "codestral", "collab_ui", "command_palette", "component", diff --git a/Cargo.toml b/Cargo.toml index da7a892515d683ee3be675fd347e53f60c1a920d..87f912c6be8df1a5d93e6622b041c58d8f66e75f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -164,6 +164,7 @@ members = [ "crates/sum_tree", "crates/supermaven", "crates/supermaven_api", + "crates/codestral", "crates/svg_preview", "crates/system_specs", "crates/tab_switcher", @@ -398,6 +399,7 @@ streaming_diff = { path = "crates/streaming_diff" } sum_tree = { path = "crates/sum_tree", package = "zed-sum-tree", version = "0.1.0" } supermaven = { path = "crates/supermaven" } supermaven_api = { path = "crates/supermaven_api" } +codestral = { path = "crates/codestral" } system_specs = { path = "crates/system_specs" } tab_switcher = { path = "crates/tab_switcher" } task = { path = "crates/task" } diff --git a/assets/settings/default.json b/assets/settings/default.json index 02df278d669b7c150d8b9d99b0167debb26c08fc..a7d912748f70e5f386413b27eab134558c5730bf 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1311,15 +1311,18 @@ // "proxy": "", // "proxy_no_verify": false // }, - // Whether edit predictions are enabled when editing text threads. - // This setting has no effect if globally disabled. - "enabled_in_text_threads": true, - "copilot": { "enterprise_uri": null, "proxy": null, "proxy_no_verify": null - } + }, + "codestral": { + "model": null, + "max_tokens": null + }, + // Whether edit predictions are enabled when editing text threads. + // This setting has no effect if globally disabled. + "enabled_in_text_threads": true }, // Settings specific to journaling "journal": { diff --git a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs index 373756b2c45ceeb65afebaf1f2d82b1fc16c017d..5e1712e626da98c60834da28906afa3eb30b92e6 100644 --- a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs +++ b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs @@ -619,10 +619,10 @@ mod tests { cx.update(|_window, cx| { LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.register_provider( - FakeLanguageModelProvider::new( + Arc::new(FakeLanguageModelProvider::new( LanguageModelProviderId::new("someprovider"), LanguageModelProviderName::new("Some Provider"), - ), + )), cx, ); }); diff --git a/crates/codestral/Cargo.toml b/crates/codestral/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..932834827f3516f48fed06ccf6c430935c725fee --- /dev/null +++ b/crates/codestral/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "codestral" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lib] +path = "src/codestral.rs" + +[dependencies] +anyhow.workspace = true +edit_prediction.workspace = true +edit_prediction_context.workspace = true +futures.workspace = true +gpui.workspace = true +http_client.workspace = true +language.workspace = true +language_models.workspace = true +log.workspace = true +mistral.workspace = true +serde.workspace = true +serde_json.workspace = true +smol.workspace = true +text.workspace = true +workspace-hack.workspace = true + +[dev-dependencies] diff --git a/crates/codestral/LICENSE-GPL b/crates/codestral/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/codestral/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/codestral/src/codestral.rs b/crates/codestral/src/codestral.rs new file mode 100644 index 0000000000000000000000000000000000000000..a266212355795c2284fa30b054338608cb45fa9c --- /dev/null +++ b/crates/codestral/src/codestral.rs @@ -0,0 +1,381 @@ +use anyhow::{Context as _, Result}; +use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; +use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions}; +use futures::AsyncReadExt; +use gpui::{App, Context, Entity, Task}; +use http_client::HttpClient; +use language::{ + language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint, +}; +use language_models::MistralLanguageModelProvider; +use mistral::CODESTRAL_API_URL; +use serde::{Deserialize, Serialize}; +use std::{ + ops::Range, + sync::Arc, + time::{Duration, Instant}, +}; +use text::ToOffset; + +pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150); + +const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions { + max_bytes: 1050, + min_bytes: 525, + target_before_cursor_over_total_bytes: 0.66, +}; + +/// Represents a completion that has been received and processed from Codestral. +/// This struct maintains the state needed to interpolate the completion as the user types. +#[derive(Clone)] +struct CurrentCompletion { + /// The buffer snapshot at the time the completion was generated. + /// Used to detect changes and interpolate edits. + snapshot: BufferSnapshot, + /// The edits that should be applied to transform the original text into the predicted text. + /// Each edit is a range in the buffer and the text to replace it with. + edits: Arc<[(Range, String)]>, + /// Preview of how the buffer will look after applying the edits. + edit_preview: EditPreview, +} + +impl CurrentCompletion { + /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated. + /// Returns None if the user's edits conflict with the predicted edits. + fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, String)>> { + edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits) + } +} + +pub struct CodestralCompletionProvider { + http_client: Arc, + pending_request: Option>>, + current_completion: Option, +} + +impl CodestralCompletionProvider { + pub fn new(http_client: Arc) -> Self { + Self { + http_client, + pending_request: None, + current_completion: None, + } + } + + pub fn has_api_key(cx: &App) -> bool { + Self::api_key(cx).is_some() + } + + fn api_key(cx: &App) -> Option> { + MistralLanguageModelProvider::try_global(cx) + .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx)) + } + + /// Uses Codestral's Fill-in-the-Middle API for code completion. + async fn fetch_completion( + http_client: Arc, + api_key: &str, + prompt: String, + suffix: String, + model: String, + max_tokens: Option, + ) -> Result { + let start_time = Instant::now(); + + log::debug!( + "Codestral: Requesting completion (model: {}, max_tokens: {:?})", + model, + max_tokens + ); + + let request = CodestralRequest { + model, + prompt, + suffix: if suffix.is_empty() { + None + } else { + Some(suffix) + }, + max_tokens: max_tokens.or(Some(350)), + temperature: Some(0.2), + top_p: Some(1.0), + stream: Some(false), + stop: None, + random_seed: None, + min_tokens: None, + }; + + let request_body = serde_json::to_string(&request)?; + + log::debug!("Codestral: Sending FIM request"); + + let http_request = http_client::Request::builder() + .method(http_client::Method::POST) + .uri(format!("{}/v1/fim/completions", CODESTRAL_API_URL)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(http_client::AsyncBody::from(request_body))?; + + let mut response = http_client.send(http_request).await?; + let status = response.status(); + + log::debug!("Codestral: Response status: {}", status); + + if !status.is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow::anyhow!( + "Codestral API error: {} - {}", + status, + body + )); + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let codestral_response: CodestralResponse = serde_json::from_str(&body)?; + + let elapsed = start_time.elapsed(); + + if let Some(choice) = codestral_response.choices.first() { + let completion = &choice.message.content; + + log::debug!( + "Codestral: Completion received ({} tokens, {:.2}s)", + codestral_response.usage.completion_tokens, + elapsed.as_secs_f64() + ); + + // Return just the completion text for insertion at cursor + Ok(completion.clone()) + } else { + log::error!("Codestral: No completion returned in response"); + Err(anyhow::anyhow!("No completion returned from Codestral")) + } + } +} + +impl EditPredictionProvider for CodestralCompletionProvider { + fn name() -> &'static str { + "codestral" + } + + fn display_name() -> &'static str { + "Codestral" + } + + fn show_completions_in_menu() -> bool { + true + } + + fn is_enabled(&self, _buffer: &Entity, _cursor_position: Anchor, cx: &App) -> bool { + Self::api_key(cx).is_some() + } + + fn is_refreshing(&self) -> bool { + self.pending_request.is_some() + } + + fn refresh( + &mut self, + buffer: Entity, + cursor_position: language::Anchor, + debounce: bool, + cx: &mut Context, + ) { + log::debug!("Codestral: Refresh called (debounce: {})", debounce); + + let Some(api_key) = Self::api_key(cx) else { + log::warn!("Codestral: No API key configured, skipping refresh"); + return; + }; + + let snapshot = buffer.read(cx).snapshot(); + + // Check if current completion is still valid + if let Some(current_completion) = self.current_completion.as_ref() { + if current_completion.interpolate(&snapshot).is_some() { + return; + } + } + + let http_client = self.http_client.clone(); + + // Get settings + let settings = all_language_settings(None, cx); + let model = settings + .edit_predictions + .codestral + .model + .clone() + .unwrap_or_else(|| "codestral-latest".to_string()); + let max_tokens = settings.edit_predictions.codestral.max_tokens; + + self.pending_request = Some(cx.spawn(async move |this, cx| { + if debounce { + log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT); + smol::Timer::after(DEBOUNCE_TIMEOUT).await; + } + + let cursor_offset = cursor_position.to_offset(&snapshot); + let cursor_point = cursor_offset.to_point(&snapshot); + let excerpt = EditPredictionExcerpt::select_from_buffer( + cursor_point, + &snapshot, + &EXCERPT_OPTIONS, + None, + ) + .context("Line containing cursor doesn't fit in excerpt max bytes")?; + + let excerpt_text = excerpt.text(&snapshot); + let cursor_within_excerpt = cursor_offset + .saturating_sub(excerpt.range.start) + .min(excerpt_text.body.len()); + let prompt = excerpt_text.body[..cursor_within_excerpt].to_string(); + let suffix = excerpt_text.body[cursor_within_excerpt..].to_string(); + + let completion_text = match Self::fetch_completion( + http_client, + &api_key, + prompt, + suffix, + model, + max_tokens, + ) + .await + { + Ok(completion) => completion, + Err(e) => { + log::error!("Codestral: Failed to fetch completion: {}", e); + this.update(cx, |this, cx| { + this.pending_request = None; + cx.notify(); + })?; + return Err(e); + } + }; + + if completion_text.trim().is_empty() { + log::debug!("Codestral: Completion was empty after trimming; ignoring"); + this.update(cx, |this, cx| { + this.pending_request = None; + cx.notify(); + })?; + return Ok(()); + } + + let edits: Arc<[(Range, String)]> = + vec![(cursor_position..cursor_position, completion_text)].into(); + let edit_preview = buffer + .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))? + .await; + + this.update(cx, |this, cx| { + this.current_completion = Some(CurrentCompletion { + snapshot, + edits, + edit_preview, + }); + this.pending_request = None; + cx.notify(); + })?; + + Ok(()) + })); + } + + fn cycle( + &mut self, + _buffer: Entity, + _cursor_position: Anchor, + _direction: Direction, + _cx: &mut Context, + ) { + // Codestral doesn't support multiple completions, so cycling does nothing + } + + fn accept(&mut self, _cx: &mut Context) { + log::debug!("Codestral: Completion accepted"); + self.pending_request = None; + self.current_completion = None; + } + + fn discard(&mut self, _cx: &mut Context) { + log::debug!("Codestral: Completion discarded"); + self.pending_request = None; + self.current_completion = None; + } + + /// Returns the completion suggestion, adjusted or invalidated based on user edits + fn suggest( + &mut self, + buffer: &Entity, + _cursor_position: Anchor, + cx: &mut Context, + ) -> Option { + let current_completion = self.current_completion.as_ref()?; + let buffer = buffer.read(cx); + let edits = current_completion.interpolate(&buffer.snapshot())?; + if edits.is_empty() { + return None; + } + Some(EditPrediction::Local { + id: None, + edits, + edit_preview: Some(current_completion.edit_preview.clone()), + }) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CodestralRequest { + pub model: String, + pub prompt: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub random_seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, +} + +#[derive(Debug, Deserialize)] +pub struct CodestralResponse { + pub id: String, + pub object: String, + pub model: String, + pub usage: Usage, + pub created: u64, + pub choices: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Deserialize)] +pub struct Choice { + pub index: u32, + pub message: Message, + pub finish_reason: String, +} + +#[derive(Debug, Deserialize)] +pub struct Message { + pub content: String, + pub role: String, +} diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 90cad9f9227ae8071da6e256c6d9b494e61ac67c..22cb1047d1dda93b639990e549f9b76b3ff385f5 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -2,7 +2,7 @@ use std::ops::Range; use client::EditPredictionUsage; use gpui::{App, Context, Entity, SharedString}; -use language::Buffer; +use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt}; // TODO: Find a better home for `Direction`. // @@ -242,3 +242,51 @@ where self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx)) } } + +/// Returns edits updated based on user edits since the old snapshot. None is returned if any user +/// edit is not a prefix of a predicted insertion. +pub fn interpolate_edits( + old_snapshot: &BufferSnapshot, + new_snapshot: &BufferSnapshot, + current_edits: &[(Range, String)], +) -> Option, String)>> { + let mut edits = Vec::new(); + + let mut model_edits = current_edits.iter().peekable(); + for user_edit in new_snapshot.edits_since::(&old_snapshot.version) { + while let Some((model_old_range, _)) = model_edits.peek() { + let model_old_range = model_old_range.to_offset(old_snapshot); + if model_old_range.end < user_edit.old.start { + let (model_old_range, model_new_text) = model_edits.next().unwrap(); + edits.push((model_old_range.clone(), model_new_text.clone())); + } else { + break; + } + } + + if let Some((model_old_range, model_new_text)) = model_edits.peek() { + let model_old_offset_range = model_old_range.to_offset(old_snapshot); + if user_edit.old == model_old_offset_range { + let user_new_text = new_snapshot + .text_for_range(user_edit.new.clone()) + .collect::(); + + if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) { + if !model_suffix.is_empty() { + let anchor = old_snapshot.anchor_after(user_edit.old.end); + edits.push((anchor..anchor, model_suffix.to_string())); + } + + model_edits.next(); + continue; + } + } + } + + return None; + } + + edits.extend(model_edits.cloned()); + + if edits.is_empty() { None } else { Some(edits) } +} diff --git a/crates/edit_prediction_button/Cargo.toml b/crates/edit_prediction_button/Cargo.toml index 07447280fa0d3b8041f1d35eba9c368288322c25..597a83da33cf49cd8170630a53675bdd6da92af4 100644 --- a/crates/edit_prediction_button/Cargo.toml +++ b/crates/edit_prediction_button/Cargo.toml @@ -16,6 +16,7 @@ doctest = false anyhow.workspace = true client.workspace = true cloud_llm_client.workspace = true +codestral.workspace = true copilot.workspace = true editor.workspace = true feature_flags.workspace = true diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index b2186c6aae592b5a4f73f1caaeb9e6c267d82afc..6f050fc86c708e2c97f9b34f2fa786516ba0aca9 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -1,6 +1,7 @@ use anyhow::Result; use client::{UserStore, zed_urls}; use cloud_llm_client::UsageLimit; +use codestral::CodestralCompletionProvider; use copilot::{Copilot, Status}; use editor::{Editor, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll}; use feature_flags::{FeatureFlagAppExt, PredictEditsRateCompletionsFeatureFlag}; @@ -234,6 +235,67 @@ impl Render for EditPredictionButton { ) } + EditPredictionProvider::Codestral => { + let enabled = self.editor_enabled.unwrap_or(true); + let has_api_key = CodestralCompletionProvider::has_api_key(cx); + let fs = self.fs.clone(); + let this = cx.entity(); + + div().child( + PopoverMenu::new("codestral") + .menu(move |window, cx| { + if has_api_key { + Some(this.update(cx, |this, cx| { + this.build_codestral_context_menu(window, cx) + })) + } else { + Some(ContextMenu::build(window, cx, |menu, _, _| { + let fs = fs.clone(); + menu.entry("Use Zed AI instead", None, move |_, cx| { + set_completion_provider( + fs.clone(), + cx, + EditPredictionProvider::Zed, + ) + }) + .separator() + .entry( + "Configure Codestral API Key", + None, + move |window, cx| { + window.dispatch_action( + zed_actions::agent::OpenSettings.boxed_clone(), + cx, + ); + }, + ) + })) + } + }) + .anchor(Corner::BottomRight) + .trigger_with_tooltip( + IconButton::new("codestral-icon", IconName::AiMistral) + .shape(IconButtonShape::Square) + .when(!has_api_key, |this| { + this.indicator(Indicator::dot().color(Color::Error)) + .indicator_border_color(Some( + cx.theme().colors().status_bar_background, + )) + }) + .when(has_api_key && !enabled, |this| { + this.indicator(Indicator::dot().color(Color::Ignored)) + .indicator_border_color(Some( + cx.theme().colors().status_bar_background, + )) + }), + move |window, cx| { + Tooltip::for_action("Codestral", &ToggleMenu, window, cx) + }, + ) + .with_handle(self.popover_menu_handle.clone()), + ) + } + EditPredictionProvider::Zed => { let enabled = self.editor_enabled.unwrap_or(true); @@ -493,6 +555,7 @@ impl EditPredictionButton { EditPredictionProvider::Zed | EditPredictionProvider::Copilot | EditPredictionProvider::Supermaven + | EditPredictionProvider::Codestral ) { menu = menu .separator() @@ -719,6 +782,25 @@ impl EditPredictionButton { }) } + fn build_codestral_context_menu( + &self, + window: &mut Window, + cx: &mut Context, + ) -> Entity { + let fs = self.fs.clone(); + ContextMenu::build(window, cx, |menu, window, cx| { + self.build_language_settings_menu(menu, window, cx) + .separator() + .entry("Use Zed AI instead", None, move |_, cx| { + set_completion_provider(fs.clone(), cx, EditPredictionProvider::Zed) + }) + .separator() + .entry("Configure Codestral API Key", None, move |window, cx| { + window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx); + }) + }) + } + fn build_zeta_context_menu( &self, window: &mut Window, diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index f815ecc8517d1e3e83f8614c21786e7b1a6cbfd0..f74fc8749ce88963078ee243073e889415080c6f 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -377,6 +377,8 @@ pub struct EditPredictionSettings { pub mode: settings::EditPredictionsMode, /// Settings specific to GitHub Copilot. pub copilot: CopilotSettings, + /// Settings specific to Codestral. + pub codestral: CodestralSettings, /// Whether edit predictions are enabled in the assistant panel. /// This setting has no effect if globally disabled. pub enabled_in_text_threads: bool, @@ -412,6 +414,14 @@ pub struct CopilotSettings { pub enterprise_uri: Option, } +#[derive(Clone, Debug, Default)] +pub struct CodestralSettings { + /// Model to use for completions. + pub model: Option, + /// Maximum tokens to generate. + pub max_tokens: Option, +} + impl AllLanguageSettings { /// Returns the [`LanguageSettings`] for the language with the specified name. pub fn language<'a>( @@ -622,6 +632,12 @@ impl settings::Settings for AllLanguageSettings { enterprise_uri: copilot.enterprise_uri, }; + let codestral = edit_predictions.codestral.unwrap(); + let codestral_settings = CodestralSettings { + model: codestral.model, + max_tokens: codestral.max_tokens, + }; + let enabled_in_text_threads = edit_predictions.enabled_in_text_threads.unwrap(); let mut file_types: FxHashMap, GlobSet> = FxHashMap::default(); @@ -655,6 +671,7 @@ impl settings::Settings for AllLanguageSettings { .collect(), mode: edit_predictions_mode, copilot: copilot_settings, + codestral: codestral_settings, enabled_in_text_threads, }, defaults: default_language_settings, diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index bab258bca1728ac45f5ef5c0397149b93f0d6031..6ed8bf07c4e976c88fecebd929843335333b1fa6 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -118,14 +118,14 @@ impl LanguageModelRegistry { } #[cfg(any(test, feature = "test-support"))] - pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider { - let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default(); + pub fn test(cx: &mut App) -> Arc { + let fake_provider = Arc::new(crate::fake_provider::FakeLanguageModelProvider::default()); let registry = cx.new(|cx| { let mut registry = Self::default(); registry.register_provider(fake_provider.clone(), cx); let model = fake_provider.provided_models(cx)[0].clone(); let configured_model = ConfiguredModel { - provider: Arc::new(fake_provider.clone()), + provider: fake_provider.clone(), model, }; registry.set_default_model(Some(configured_model), cx); @@ -137,7 +137,7 @@ impl LanguageModelRegistry { pub fn register_provider( &mut self, - provider: T, + provider: Arc, cx: &mut Context, ) { let id = provider.id(); @@ -152,7 +152,7 @@ impl LanguageModelRegistry { subscription.detach(); } - self.providers.insert(id.clone(), Arc::new(provider)); + self.providers.insert(id.clone(), provider); cx.emit(Event::AddedProvider(id)); } @@ -395,7 +395,7 @@ mod tests { fn test_register_providers(cx: &mut App) { let registry = cx.new(|_| LanguageModelRegistry::default()); - let provider = FakeLanguageModelProvider::default(); + let provider = Arc::new(FakeLanguageModelProvider::default()); registry.update(cx, |registry, cx| { registry.register_provider(provider.clone(), cx); }); diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 61e1a794695310421397469515a43a4d5bf5deb8..1b7243780ad30d737118046c8fc71fe9e4186fa6 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -18,7 +18,7 @@ use crate::provider::cloud::CloudLanguageModelProvider; use crate::provider::copilot_chat::CopilotChatLanguageModelProvider; use crate::provider::google::GoogleLanguageModelProvider; use crate::provider::lmstudio::LmStudioLanguageModelProvider; -use crate::provider::mistral::MistralLanguageModelProvider; +pub use crate::provider::mistral::MistralLanguageModelProvider; use crate::provider::ollama::OllamaLanguageModelProvider; use crate::provider::open_ai::OpenAiLanguageModelProvider; use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider; @@ -87,11 +87,11 @@ fn register_openai_compatible_providers( for provider_id in new { if !old.contains(provider_id) { registry.register_provider( - OpenAiCompatibleLanguageModelProvider::new( + Arc::new(OpenAiCompatibleLanguageModelProvider::new( provider_id.clone(), client.http_client(), cx, - ), + )), cx, ); } @@ -105,50 +105,62 @@ fn register_language_model_providers( cx: &mut Context, ) { registry.register_provider( - CloudLanguageModelProvider::new(user_store, client.clone(), cx), + Arc::new(CloudLanguageModelProvider::new( + user_store, + client.clone(), + cx, + )), + cx, + ); + registry.register_provider( + Arc::new(AnthropicLanguageModelProvider::new( + client.http_client(), + cx, + )), cx, ); - registry.register_provider( - AnthropicLanguageModelProvider::new(client.http_client(), cx), + Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)), cx, ); registry.register_provider( - OpenAiLanguageModelProvider::new(client.http_client(), cx), + Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)), cx, ); registry.register_provider( - OllamaLanguageModelProvider::new(client.http_client(), cx), + Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)), cx, ); registry.register_provider( - LmStudioLanguageModelProvider::new(client.http_client(), cx), + Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)), cx, ); registry.register_provider( - DeepSeekLanguageModelProvider::new(client.http_client(), cx), + Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)), cx, ); registry.register_provider( - GoogleLanguageModelProvider::new(client.http_client(), cx), + MistralLanguageModelProvider::global(client.http_client(), cx), cx, ); registry.register_provider( - MistralLanguageModelProvider::new(client.http_client(), cx), + Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)), cx, ); registry.register_provider( - BedrockLanguageModelProvider::new(client.http_client(), cx), + Arc::new(OpenRouterLanguageModelProvider::new( + client.http_client(), + cx, + )), cx, ); registry.register_provider( - OpenRouterLanguageModelProvider::new(client.http_client(), cx), + Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)), cx, ); registry.register_provider( - VercelLanguageModelProvider::new(client.http_client(), cx), + Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)), cx, ); - registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx); - registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx); + registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx); } diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 7623f123c60e75a9c1fc6716e56075e4ea5b882b..ad7bf600d56354ee12e72c9ebc2bfe09f0094da7 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -1,7 +1,8 @@ use anyhow::{Result, anyhow}; use collections::BTreeMap; +use fs::Fs; use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -10,9 +11,9 @@ use language_model::{ LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; -use mistral::{MISTRAL_API_URL, StreamResponse}; +use mistral::{CODESTRAL_API_URL, MISTRAL_API_URL, StreamResponse}; pub use settings::MistralAvailableModel as AvailableModel; -use settings::{Settings, SettingsStore}; +use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file}; use std::collections::HashMap; use std::pin::Pin; use std::str::FromStr; @@ -31,6 +32,9 @@ const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new( const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY"; static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); +const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY"; +static CODESTRAL_API_KEY_ENV_VAR: LazyLock = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct MistralSettings { pub api_url: String, @@ -44,6 +48,7 @@ pub struct MistralLanguageModelProvider { pub struct State { api_key_state: ApiKeyState, + codestral_api_key_state: ApiKeyState, } impl State { @@ -57,6 +62,19 @@ impl State { .store(api_url, api_key, |this| &mut this.api_key_state, cx) } + fn set_codestral_api_key( + &mut self, + api_key: Option, + cx: &mut Context, + ) -> Task> { + self.codestral_api_key_state.store( + CODESTRAL_API_URL.into(), + api_key, + |this| &mut this.codestral_api_key_state, + cx, + ) + } + fn authenticate(&mut self, cx: &mut Context) -> Task> { let api_url = MistralLanguageModelProvider::api_url(cx); self.api_key_state.load_if_needed( @@ -66,10 +84,34 @@ impl State { cx, ) } + + fn authenticate_codestral( + &mut self, + cx: &mut Context, + ) -> Task> { + self.codestral_api_key_state.load_if_needed( + CODESTRAL_API_URL.into(), + &CODESTRAL_API_KEY_ENV_VAR, + |this| &mut this.codestral_api_key_state, + cx, + ) + } } +struct GlobalMistralLanguageModelProvider(Arc); + +impl Global for GlobalMistralLanguageModelProvider {} + impl MistralLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { + pub fn try_global(cx: &App) -> Option<&Arc> { + cx.try_global::() + .map(|this| &this.0) + } + + pub fn global(http_client: Arc, cx: &mut App) -> Arc { + if let Some(this) = cx.try_global::() { + return this.0.clone(); + } let state = cx.new(|cx| { cx.observe_global::(|this: &mut State, cx| { let api_url = Self::api_url(cx); @@ -84,10 +126,22 @@ impl MistralLanguageModelProvider { .detach(); State { api_key_state: ApiKeyState::new(Self::api_url(cx)), + codestral_api_key_state: ApiKeyState::new(CODESTRAL_API_URL.into()), } }); - Self { http_client, state } + let this = Arc::new(Self { http_client, state }); + cx.set_global(GlobalMistralLanguageModelProvider(this)); + cx.global::().0.clone() + } + + pub fn load_codestral_api_key(&self, cx: &mut App) -> Task> { + self.state + .update(cx, |state, cx| state.authenticate_codestral(cx)) + } + + pub fn codestral_api_key(&self, url: &str, cx: &App) -> Option> { + self.state.read(cx).codestral_api_key_state.key(url) } fn create_language_model(&self, model: mistral::Model) -> Arc { @@ -691,6 +745,7 @@ struct RawToolCall { struct ConfigurationView { api_key_editor: Entity, + codestral_api_key_editor: Entity, state: Entity, load_credentials_task: Option>, } @@ -699,6 +754,8 @@ impl ConfigurationView { fn new(state: Entity, window: &mut Window, cx: &mut Context) -> Self { let api_key_editor = cx.new(|cx| SingleLineInput::new(window, cx, "0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2")); + let codestral_api_key_editor = + cx.new(|cx| SingleLineInput::new(window, cx, "0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2")); cx.observe(&state, |_, _, cx| { cx.notify(); @@ -715,6 +772,12 @@ impl ConfigurationView { // We don't log an error, because "not signed in" is also an error. let _ = task.await; } + if let Some(task) = state + .update(cx, |state, cx| state.authenticate_codestral(cx)) + .log_err() + { + let _ = task.await; + } this.update(cx, |this, cx| { this.load_credentials_task = None; @@ -726,6 +789,7 @@ impl ConfigurationView { Self { api_key_editor, + codestral_api_key_editor, state, load_credentials_task, } @@ -763,47 +827,92 @@ impl ConfigurationView { .detach_and_log_err(cx); } - fn should_render_editor(&self, cx: &mut Context) -> bool { - !self.state.read(cx).is_authenticated() + fn save_codestral_api_key( + &mut self, + _: &menu::Confirm, + window: &mut Window, + cx: &mut Context, + ) { + let api_key = self + .codestral_api_key_editor + .read(cx) + .text(cx) + .trim() + .to_string(); + if api_key.is_empty() { + return; + } + + // url changes can cause the editor to be displayed again + self.codestral_api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| { + state.set_codestral_api_key(Some(api_key), cx) + })? + .await?; + cx.update(|_window, cx| { + set_edit_prediction_provider(EditPredictionProvider::Codestral, cx) + }) + }) + .detach_and_log_err(cx); } -} -impl Render for ConfigurationView { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); + fn reset_codestral_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.codestral_api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); - if self.load_credentials_task.is_some() { - div().child(Label::new("Loading credentials...")).into_any() - } else if self.should_render_editor(cx) { + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_codestral_api_key(None, cx))? + .await?; + cx.update(|_window, cx| set_edit_prediction_provider(EditPredictionProvider::Zed, cx)) + }) + .detach_and_log_err(cx); + } + + fn should_render_api_key_editor(&self, cx: &mut Context) -> bool { + !self.state.read(cx).is_authenticated() + } + + fn render_codestral_api_key_editor(&mut self, cx: &mut Context) -> AnyElement { + let key_state = &self.state.read(cx).codestral_api_key_state; + let should_show_editor = !key_state.has_key(); + let env_var_set = key_state.is_from_env_var(); + if should_show_editor { v_flex() + .id("codestral") .size_full() - .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:")) + .mt_2() + .on_action(cx.listener(Self::save_codestral_api_key)) + .child(Label::new( + "To use Codestral as an edit prediction provider, \ + you need to add a Codestral-specific API key. Follow these steps:", + )) .child( List::new() .child(InstructionListItem::new( "Create one by visiting", - Some("Mistral's console"), - Some("https://console.mistral.ai/api-keys"), + Some("the Codestral section of Mistral's console"), + Some("https://console.mistral.ai/codestral"), )) - .child(InstructionListItem::text_only( - "Ensure your Mistral account has credits", - )) - .child(InstructionListItem::text_only( - "Paste your API key below and hit enter to start using the assistant", - )), + .child(InstructionListItem::text_only("Paste your API key below and hit enter")), ) - .child(self.api_key_editor.clone()) + .child(self.codestral_api_key_editor.clone()) .child( Label::new( - format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), + format!("You can also assign the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), - ) - .into_any() + ).into_any() } else { h_flex() - .mt_1() + .id("codestral") + .mt_2() .p_1() .justify_between() .rounded_md() @@ -815,14 +924,9 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") + format!("API key set in {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable") } else { - let api_url = MistralLanguageModelProvider::api_url(cx); - if api_url == MISTRAL_API_URL { - "API key configured".to_string() - } else { - format!("API key configured for {}", truncate_and_trailoff(&api_url, 32)) - } + "Codestral API key configured".to_string() })), ) .child( @@ -833,15 +937,121 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) + this.tooltip(Tooltip::text(format!( + "To reset your API key, \ + unset the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable." + ))) }) - .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), + .on_click( + cx.listener(|this, _, window, cx| this.reset_codestral_api_key(window, cx)), + ), + ).into_any() + } + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); + + if self.load_credentials_task.is_some() { + div().child(Label::new("Loading credentials...")).into_any() + } else if self.should_render_api_key_editor(cx) { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:")) + .child( + List::new() + .child(InstructionListItem::new( + "Create one by visiting", + Some("Mistral's console"), + Some("https://console.mistral.ai/api-keys"), + )) + .child(InstructionListItem::text_only( + "Ensure your Mistral account has credits", + )) + .child(InstructionListItem::text_only( + "Paste your API key below and hit enter to start using the assistant", + )), ) + .child(self.api_key_editor.clone()) + .child( + Label::new( + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), + ) + .size(LabelSize::Small).color(Color::Muted), + ) + .child(self.render_codestral_api_key_editor(cx)) + .into_any() + } else { + v_flex() + .size_full() + .child( + h_flex() + .mt_1() + .p_1() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().background) + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new(if env_var_set { + format!( + "API key set in {API_KEY_ENV_VAR_NAME} environment variable" + ) + } else { + let api_url = MistralLanguageModelProvider::api_url(cx); + if api_url == MISTRAL_API_URL { + "API key configured".to_string() + } else { + format!( + "API key configured for {}", + truncate_and_trailoff(&api_url, 32) + ) + } + })), + ) + .child( + Button::new("reset-key", "Reset Key") + .label_size(LabelSize::Small) + .icon(Some(IconName::Trash)) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .disabled(env_var_set) + .when(env_var_set, |this| { + this.tooltip(Tooltip::text(format!( + "To reset your API key, \ + unset the {API_KEY_ENV_VAR_NAME} environment variable." + ))) + }) + .on_click(cx.listener(|this, _, window, cx| { + this.reset_api_key(window, cx) + })), + ), + ) + .child(self.render_codestral_api_key_editor(cx)) .into_any() } } } +fn set_edit_prediction_provider(provider: EditPredictionProvider, cx: &mut App) { + let fs = ::global(cx); + update_settings_file(fs, cx, move |settings, _| { + settings + .project + .all_languages + .features + .get_or_insert_default() + .edit_prediction_provider = Some(provider); + }); +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/mistral/src/mistral.rs b/crates/mistral/src/mistral.rs index 2e79a8d59f67389c97ffff50fa30c4ca92318209..eca4743d0442b9ca169ac966f78af0112565fcbc 100644 --- a/crates/mistral/src/mistral.rs +++ b/crates/mistral/src/mistral.rs @@ -7,6 +7,7 @@ use std::convert::TryFrom; use strum::EnumIter; pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1"; +pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai"; #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] diff --git a/crates/settings/src/settings_content/language.rs b/crates/settings/src/settings_content/language.rs index b56e64465336dbc7726c20ed187fe9e71068cb65..2abc7db574edea27b2e1d8cc809955a3b9d3cfe8 100644 --- a/crates/settings/src/settings_content/language.rs +++ b/crates/settings/src/settings_content/language.rs @@ -82,6 +82,7 @@ pub enum EditPredictionProvider { Copilot, Supermaven, Zed, + Codestral, } impl EditPredictionProvider { @@ -90,7 +91,8 @@ impl EditPredictionProvider { EditPredictionProvider::Zed => true, EditPredictionProvider::None | EditPredictionProvider::Copilot - | EditPredictionProvider::Supermaven => false, + | EditPredictionProvider::Supermaven + | EditPredictionProvider::Codestral => false, } } } @@ -108,6 +110,8 @@ pub struct EditPredictionSettingsContent { pub mode: Option, /// Settings specific to GitHub Copilot. pub copilot: Option, + /// Settings specific to Codestral. + pub codestral: Option, /// Whether edit predictions are enabled in the assistant prompt editor. /// This has no effect if globally disabled. pub enabled_in_text_threads: Option, @@ -130,6 +134,20 @@ pub struct CopilotSettingsContent { pub enterprise_uri: Option, } +#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, MergeFrom, PartialEq)] +pub struct CodestralSettingsContent { + /// Model to use for completions. + /// + /// Default: "codestral-latest" + #[serde(default)] + pub model: Option, + /// Maximum tokens to generate. + /// + /// Default: 150 + #[serde(default)] + pub max_tokens: Option, +} + /// The mode in which edit predictions should be displayed. #[derive( Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema, MergeFrom, diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 1c19f9d889a3b8a2dfbab0c4e539de7dd6c018af..abaeb40fa6dc1b78c93f24af21a186f1ef0bb0c3 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -39,6 +39,7 @@ channel.workspace = true clap.workspace = true cli.workspace = true client.workspace = true +codestral.workspace = true collab_ui.workspace = true collections.workspace = true command_palette.workspace = true diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index a1ae52fc0650b7eb4eacd37b3670a0d93eed532e..a9bd0395347dadcb9caa706fcbcc81f58d6af944 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -1,9 +1,11 @@ use client::{Client, UserStore}; +use codestral::CodestralCompletionProvider; use collections::HashMap; use copilot::{Copilot, CopilotCompletionProvider}; use editor::Editor; use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; use language::language_settings::{EditPredictionProvider, all_language_settings}; +use language_models::MistralLanguageModelProvider; use settings::SettingsStore; use std::{cell::RefCell, rc::Rc, sync::Arc}; use supermaven::{Supermaven, SupermavenCompletionProvider}; @@ -109,6 +111,10 @@ fn assign_edit_prediction_providers( user_store: Entity, cx: &mut App, ) { + if provider == EditPredictionProvider::Codestral { + let mistral = MistralLanguageModelProvider::global(client.http_client(), cx); + mistral.load_codestral_api_key(cx).detach(); + } for (editor, window) in editors.borrow().iter() { _ = window.update(cx, |_window, window, cx| { _ = editor.update(cx, |editor, cx| { @@ -189,6 +195,11 @@ fn assign_edit_prediction_provider( editor.set_edit_prediction_provider(Some(provider), window, cx); } } + EditPredictionProvider::Codestral => { + let http_client = client.http_client(); + let provider = cx.new(|_| CodestralCompletionProvider::new(http_client)); + editor.set_edit_prediction_provider(Some(provider), window, cx); + } EditPredictionProvider::Zed => { if user_store.read(cx).current_user().is_some() { let mut worktree = None; diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 3a156f351d8f34e858ce199aa1244729fe07a227..1d48571d7b06f35d82934122919e75bbbd087ffa 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -151,56 +151,10 @@ impl EditPrediction { } fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, String)>> { - interpolate(&self.snapshot, new_snapshot, self.edits.clone()) + edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits) } } -fn interpolate( - old_snapshot: &BufferSnapshot, - new_snapshot: &BufferSnapshot, - current_edits: Arc<[(Range, String)]>, -) -> Option, String)>> { - let mut edits = Vec::new(); - - let mut model_edits = current_edits.iter().peekable(); - for user_edit in new_snapshot.edits_since::(&old_snapshot.version) { - while let Some((model_old_range, _)) = model_edits.peek() { - let model_old_range = model_old_range.to_offset(old_snapshot); - if model_old_range.end < user_edit.old.start { - let (model_old_range, model_new_text) = model_edits.next().unwrap(); - edits.push((model_old_range.clone(), model_new_text.clone())); - } else { - break; - } - } - - if let Some((model_old_range, model_new_text)) = model_edits.peek() { - let model_old_offset_range = model_old_range.to_offset(old_snapshot); - if user_edit.old == model_old_offset_range { - let user_new_text = new_snapshot - .text_for_range(user_edit.new.clone()) - .collect::(); - - if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) { - if !model_suffix.is_empty() { - let anchor = old_snapshot.anchor_after(user_edit.old.end); - edits.push((anchor..anchor, model_suffix.to_string())); - } - - model_edits.next(); - continue; - } - } - } - - return None; - } - - edits.extend(model_edits.cloned()); - - if edits.is_empty() { None } else { Some(edits) } -} - impl std::fmt::Debug for EditPrediction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EditPrediction") @@ -769,10 +723,11 @@ impl Zeta { let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, { let edits = edits.clone(); - |buffer, cx| { + move |buffer, cx| { let new_snapshot = buffer.snapshot(); let edits: Arc<[(Range, String)]> = - interpolate(&snapshot, &new_snapshot, edits)?.into(); + edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)? + .into(); Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx))) } })?