Detailed changes
@@ -3316,6 +3316,27 @@ dependencies = [
"unicode-width",
]
+[[package]]
+name = "codestral"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "edit_prediction",
+ "edit_prediction_context",
+ "futures 0.3.31",
+ "gpui",
+ "language",
+ "language_models",
+ "log",
+ "mistral",
+ "serde",
+ "serde_json",
+ "smol",
+ "text",
+ "workspace-hack",
+ "zed-http-client",
+]
+
[[package]]
name = "collab"
version = "0.44.0"
@@ -5115,6 +5136,7 @@ dependencies = [
"anyhow",
"client",
"cloud_llm_client",
+ "codestral",
"copilot",
"edit_prediction",
"editor",
@@ -20005,6 +20027,7 @@ dependencies = [
"clap",
"cli",
"client",
+ "codestral",
"collab_ui",
"command_palette",
"component",
@@ -164,6 +164,7 @@ members = [
"crates/sum_tree",
"crates/supermaven",
"crates/supermaven_api",
+ "crates/codestral",
"crates/svg_preview",
"crates/system_specs",
"crates/tab_switcher",
@@ -398,6 +399,7 @@ streaming_diff = { path = "crates/streaming_diff" }
sum_tree = { path = "crates/sum_tree", package = "zed-sum-tree", version = "0.1.0" }
supermaven = { path = "crates/supermaven" }
supermaven_api = { path = "crates/supermaven_api" }
+codestral = { path = "crates/codestral" }
system_specs = { path = "crates/system_specs" }
tab_switcher = { path = "crates/tab_switcher" }
task = { path = "crates/task" }
@@ -1311,15 +1311,18 @@
// "proxy": "",
// "proxy_no_verify": false
// },
- // Whether edit predictions are enabled when editing text threads.
- // This setting has no effect if globally disabled.
- "enabled_in_text_threads": true,
-
"copilot": {
"enterprise_uri": null,
"proxy": null,
"proxy_no_verify": null
- }
+ },
+ "codestral": {
+ "model": null,
+ "max_tokens": null
+ },
+ // Whether edit predictions are enabled when editing text threads.
+ // This setting has no effect if globally disabled.
+ "enabled_in_text_threads": true
},
// Settings specific to journaling
"journal": {
@@ -619,10 +619,10 @@ mod tests {
cx.update(|_window, cx| {
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.register_provider(
- FakeLanguageModelProvider::new(
+ Arc::new(FakeLanguageModelProvider::new(
LanguageModelProviderId::new("someprovider"),
LanguageModelProviderName::new("Some Provider"),
- ),
+ )),
cx,
);
});
@@ -0,0 +1,28 @@
+[package]
+name = "codestral"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lib]
+path = "src/codestral.rs"
+
+[dependencies]
+anyhow.workspace = true
+edit_prediction.workspace = true
+edit_prediction_context.workspace = true
+futures.workspace = true
+gpui.workspace = true
+http_client.workspace = true
+language.workspace = true
+language_models.workspace = true
+log.workspace = true
+mistral.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+smol.workspace = true
+text.workspace = true
+workspace-hack.workspace = true
+
+[dev-dependencies]
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,381 @@
+use anyhow::{Context as _, Result};
+use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
+use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
+use futures::AsyncReadExt;
+use gpui::{App, Context, Entity, Task};
+use http_client::HttpClient;
+use language::{
+ language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, ToPoint,
+};
+use language_models::MistralLanguageModelProvider;
+use mistral::CODESTRAL_API_URL;
+use serde::{Deserialize, Serialize};
+use std::{
+ ops::Range,
+ sync::Arc,
+ time::{Duration, Instant},
+};
+use text::ToOffset;
+
+pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(150);
+
+const EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
+ max_bytes: 1050,
+ min_bytes: 525,
+ target_before_cursor_over_total_bytes: 0.66,
+};
+
+/// Represents a completion that has been received and processed from Codestral.
+/// This struct maintains the state needed to interpolate the completion as the user types.
+#[derive(Clone)]
+struct CurrentCompletion {
+ /// The buffer snapshot at the time the completion was generated.
+ /// Used to detect changes and interpolate edits.
+ snapshot: BufferSnapshot,
+ /// The edits that should be applied to transform the original text into the predicted text.
+ /// Each edit is a range in the buffer and the text to replace it with.
+ edits: Arc<[(Range<Anchor>, String)]>,
+ /// Preview of how the buffer will look after applying the edits.
+ edit_preview: EditPreview,
+}
+
+impl CurrentCompletion {
+ /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
+ /// Returns None if the user's edits conflict with the predicted edits.
+ fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
+ edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
+ }
+}
+
+pub struct CodestralCompletionProvider {
+ http_client: Arc<dyn HttpClient>,
+ pending_request: Option<Task<Result<()>>>,
+ current_completion: Option<CurrentCompletion>,
+}
+
+impl CodestralCompletionProvider {
+ pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
+ Self {
+ http_client,
+ pending_request: None,
+ current_completion: None,
+ }
+ }
+
+ pub fn has_api_key(cx: &App) -> bool {
+ Self::api_key(cx).is_some()
+ }
+
+ fn api_key(cx: &App) -> Option<Arc<str>> {
+ MistralLanguageModelProvider::try_global(cx)
+ .and_then(|provider| provider.codestral_api_key(CODESTRAL_API_URL, cx))
+ }
+
+ /// Uses Codestral's Fill-in-the-Middle API for code completion.
+ async fn fetch_completion(
+ http_client: Arc<dyn HttpClient>,
+ api_key: &str,
+ prompt: String,
+ suffix: String,
+ model: String,
+ max_tokens: Option<u32>,
+ ) -> Result<String> {
+ let start_time = Instant::now();
+
+ log::debug!(
+ "Codestral: Requesting completion (model: {}, max_tokens: {:?})",
+ model,
+ max_tokens
+ );
+
+ let request = CodestralRequest {
+ model,
+ prompt,
+ suffix: if suffix.is_empty() {
+ None
+ } else {
+ Some(suffix)
+ },
+ max_tokens: max_tokens.or(Some(350)),
+ temperature: Some(0.2),
+ top_p: Some(1.0),
+ stream: Some(false),
+ stop: None,
+ random_seed: None,
+ min_tokens: None,
+ };
+
+ let request_body = serde_json::to_string(&request)?;
+
+ log::debug!("Codestral: Sending FIM request");
+
+ let http_request = http_client::Request::builder()
+ .method(http_client::Method::POST)
+ .uri(format!("{}/v1/fim/completions", CODESTRAL_API_URL))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(http_client::AsyncBody::from(request_body))?;
+
+ let mut response = http_client.send(http_request).await?;
+ let status = response.status();
+
+ log::debug!("Codestral: Response status: {}", status);
+
+ if !status.is_success() {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow::anyhow!(
+ "Codestral API error: {} - {}",
+ status,
+ body
+ ));
+ }
+
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ let codestral_response: CodestralResponse = serde_json::from_str(&body)?;
+
+ let elapsed = start_time.elapsed();
+
+ if let Some(choice) = codestral_response.choices.first() {
+ let completion = &choice.message.content;
+
+ log::debug!(
+ "Codestral: Completion received ({} tokens, {:.2}s)",
+ codestral_response.usage.completion_tokens,
+ elapsed.as_secs_f64()
+ );
+
+ // Return just the completion text for insertion at cursor
+ Ok(completion.clone())
+ } else {
+ log::error!("Codestral: No completion returned in response");
+ Err(anyhow::anyhow!("No completion returned from Codestral"))
+ }
+ }
+}
+
+impl EditPredictionProvider for CodestralCompletionProvider {
+ fn name() -> &'static str {
+ "codestral"
+ }
+
+ fn display_name() -> &'static str {
+ "Codestral"
+ }
+
+ fn show_completions_in_menu() -> bool {
+ true
+ }
+
+ fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, cx: &App) -> bool {
+ Self::api_key(cx).is_some()
+ }
+
+ fn is_refreshing(&self) -> bool {
+ self.pending_request.is_some()
+ }
+
+ fn refresh(
+ &mut self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ debounce: bool,
+ cx: &mut Context<Self>,
+ ) {
+ log::debug!("Codestral: Refresh called (debounce: {})", debounce);
+
+ let Some(api_key) = Self::api_key(cx) else {
+ log::warn!("Codestral: No API key configured, skipping refresh");
+ return;
+ };
+
+ let snapshot = buffer.read(cx).snapshot();
+
+ // Check if current completion is still valid
+ if let Some(current_completion) = self.current_completion.as_ref() {
+ if current_completion.interpolate(&snapshot).is_some() {
+ return;
+ }
+ }
+
+ let http_client = self.http_client.clone();
+
+ // Get settings
+ let settings = all_language_settings(None, cx);
+ let model = settings
+ .edit_predictions
+ .codestral
+ .model
+ .clone()
+ .unwrap_or_else(|| "codestral-latest".to_string());
+ let max_tokens = settings.edit_predictions.codestral.max_tokens;
+
+ self.pending_request = Some(cx.spawn(async move |this, cx| {
+ if debounce {
+ log::debug!("Codestral: Debouncing for {:?}", DEBOUNCE_TIMEOUT);
+ smol::Timer::after(DEBOUNCE_TIMEOUT).await;
+ }
+
+ let cursor_offset = cursor_position.to_offset(&snapshot);
+ let cursor_point = cursor_offset.to_point(&snapshot);
+ let excerpt = EditPredictionExcerpt::select_from_buffer(
+ cursor_point,
+ &snapshot,
+ &EXCERPT_OPTIONS,
+ None,
+ )
+ .context("Line containing cursor doesn't fit in excerpt max bytes")?;
+
+ let excerpt_text = excerpt.text(&snapshot);
+ let cursor_within_excerpt = cursor_offset
+ .saturating_sub(excerpt.range.start)
+ .min(excerpt_text.body.len());
+ let prompt = excerpt_text.body[..cursor_within_excerpt].to_string();
+ let suffix = excerpt_text.body[cursor_within_excerpt..].to_string();
+
+ let completion_text = match Self::fetch_completion(
+ http_client,
+ &api_key,
+ prompt,
+ suffix,
+ model,
+ max_tokens,
+ )
+ .await
+ {
+ Ok(completion) => completion,
+ Err(e) => {
+ log::error!("Codestral: Failed to fetch completion: {}", e);
+ this.update(cx, |this, cx| {
+ this.pending_request = None;
+ cx.notify();
+ })?;
+ return Err(e);
+ }
+ };
+
+ if completion_text.trim().is_empty() {
+ log::debug!("Codestral: Completion was empty after trimming; ignoring");
+ this.update(cx, |this, cx| {
+ this.pending_request = None;
+ cx.notify();
+ })?;
+ return Ok(());
+ }
+
+ let edits: Arc<[(Range<Anchor>, String)]> =
+ vec![(cursor_position..cursor_position, completion_text)].into();
+ let edit_preview = buffer
+ .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
+ .await;
+
+ this.update(cx, |this, cx| {
+ this.current_completion = Some(CurrentCompletion {
+ snapshot,
+ edits,
+ edit_preview,
+ });
+ this.pending_request = None;
+ cx.notify();
+ })?;
+
+ Ok(())
+ }));
+ }
+
+ fn cycle(
+ &mut self,
+ _buffer: Entity<Buffer>,
+ _cursor_position: Anchor,
+ _direction: Direction,
+ _cx: &mut Context<Self>,
+ ) {
+ // Codestral doesn't support multiple completions, so cycling does nothing
+ }
+
+ fn accept(&mut self, _cx: &mut Context<Self>) {
+ log::debug!("Codestral: Completion accepted");
+ self.pending_request = None;
+ self.current_completion = None;
+ }
+
+ fn discard(&mut self, _cx: &mut Context<Self>) {
+ log::debug!("Codestral: Completion discarded");
+ self.pending_request = None;
+ self.current_completion = None;
+ }
+
+ /// Returns the completion suggestion, adjusted or invalidated based on user edits
+ fn suggest(
+ &mut self,
+ buffer: &Entity<Buffer>,
+ _cursor_position: Anchor,
+ cx: &mut Context<Self>,
+ ) -> Option<EditPrediction> {
+ let current_completion = self.current_completion.as_ref()?;
+ let buffer = buffer.read(cx);
+ let edits = current_completion.interpolate(&buffer.snapshot())?;
+ if edits.is_empty() {
+ return None;
+ }
+ Some(EditPrediction::Local {
+ id: None,
+ edits,
+ edit_preview: Some(current_completion.edit_preview.clone()),
+ })
+ }
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CodestralRequest {
+ pub model: String,
+ pub prompt: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub suffix: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub max_tokens: Option<u32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub temperature: Option<f32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub top_p: Option<f32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub stream: Option<bool>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub stop: Option<Vec<String>>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub random_seed: Option<u32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub min_tokens: Option<u32>,
+}
+
+#[derive(Debug, Deserialize)]
+pub struct CodestralResponse {
+ pub id: String,
+ pub object: String,
+ pub model: String,
+ pub usage: Usage,
+ pub created: u64,
+ pub choices: Vec<Choice>,
+}
+
+#[derive(Debug, Deserialize)]
+pub struct Usage {
+ pub prompt_tokens: u32,
+ pub completion_tokens: u32,
+ pub total_tokens: u32,
+}
+
+#[derive(Debug, Deserialize)]
+pub struct Choice {
+ pub index: u32,
+ pub message: Message,
+ pub finish_reason: String,
+}
+
+#[derive(Debug, Deserialize)]
+pub struct Message {
+ pub content: String,
+ pub role: String,
+}
@@ -2,7 +2,7 @@ use std::ops::Range;
use client::EditPredictionUsage;
use gpui::{App, Context, Entity, SharedString};
-use language::Buffer;
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt};
// TODO: Find a better home for `Direction`.
//
@@ -242,3 +242,51 @@ where
self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx))
}
}
+
+/// Returns edits updated based on user edits since the old snapshot. None is returned if any user
+/// edit is not a prefix of a predicted insertion.
+pub fn interpolate_edits(
+ old_snapshot: &BufferSnapshot,
+ new_snapshot: &BufferSnapshot,
+ current_edits: &[(Range<Anchor>, String)],
+) -> Option<Vec<(Range<Anchor>, String)>> {
+ let mut edits = Vec::new();
+
+ let mut model_edits = current_edits.iter().peekable();
+ for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
+ while let Some((model_old_range, _)) = model_edits.peek() {
+ let model_old_range = model_old_range.to_offset(old_snapshot);
+ if model_old_range.end < user_edit.old.start {
+ let (model_old_range, model_new_text) = model_edits.next().unwrap();
+ edits.push((model_old_range.clone(), model_new_text.clone()));
+ } else {
+ break;
+ }
+ }
+
+ if let Some((model_old_range, model_new_text)) = model_edits.peek() {
+ let model_old_offset_range = model_old_range.to_offset(old_snapshot);
+ if user_edit.old == model_old_offset_range {
+ let user_new_text = new_snapshot
+ .text_for_range(user_edit.new.clone())
+ .collect::<String>();
+
+ if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
+ if !model_suffix.is_empty() {
+ let anchor = old_snapshot.anchor_after(user_edit.old.end);
+ edits.push((anchor..anchor, model_suffix.to_string()));
+ }
+
+ model_edits.next();
+ continue;
+ }
+ }
+ }
+
+ return None;
+ }
+
+ edits.extend(model_edits.cloned());
+
+ if edits.is_empty() { None } else { Some(edits) }
+}
@@ -16,6 +16,7 @@ doctest = false
anyhow.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
+codestral.workspace = true
copilot.workspace = true
editor.workspace = true
feature_flags.workspace = true
@@ -1,6 +1,7 @@
use anyhow::Result;
use client::{UserStore, zed_urls};
use cloud_llm_client::UsageLimit;
+use codestral::CodestralCompletionProvider;
use copilot::{Copilot, Status};
use editor::{Editor, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll};
use feature_flags::{FeatureFlagAppExt, PredictEditsRateCompletionsFeatureFlag};
@@ -234,6 +235,67 @@ impl Render for EditPredictionButton {
)
}
+ EditPredictionProvider::Codestral => {
+ let enabled = self.editor_enabled.unwrap_or(true);
+ let has_api_key = CodestralCompletionProvider::has_api_key(cx);
+ let fs = self.fs.clone();
+ let this = cx.entity();
+
+ div().child(
+ PopoverMenu::new("codestral")
+ .menu(move |window, cx| {
+ if has_api_key {
+ Some(this.update(cx, |this, cx| {
+ this.build_codestral_context_menu(window, cx)
+ }))
+ } else {
+ Some(ContextMenu::build(window, cx, |menu, _, _| {
+ let fs = fs.clone();
+ menu.entry("Use Zed AI instead", None, move |_, cx| {
+ set_completion_provider(
+ fs.clone(),
+ cx,
+ EditPredictionProvider::Zed,
+ )
+ })
+ .separator()
+ .entry(
+ "Configure Codestral API Key",
+ None,
+ move |window, cx| {
+ window.dispatch_action(
+ zed_actions::agent::OpenSettings.boxed_clone(),
+ cx,
+ );
+ },
+ )
+ }))
+ }
+ })
+ .anchor(Corner::BottomRight)
+ .trigger_with_tooltip(
+ IconButton::new("codestral-icon", IconName::AiMistral)
+ .shape(IconButtonShape::Square)
+ .when(!has_api_key, |this| {
+ this.indicator(Indicator::dot().color(Color::Error))
+ .indicator_border_color(Some(
+ cx.theme().colors().status_bar_background,
+ ))
+ })
+ .when(has_api_key && !enabled, |this| {
+ this.indicator(Indicator::dot().color(Color::Ignored))
+ .indicator_border_color(Some(
+ cx.theme().colors().status_bar_background,
+ ))
+ }),
+ move |window, cx| {
+ Tooltip::for_action("Codestral", &ToggleMenu, window, cx)
+ },
+ )
+ .with_handle(self.popover_menu_handle.clone()),
+ )
+ }
+
EditPredictionProvider::Zed => {
let enabled = self.editor_enabled.unwrap_or(true);
@@ -493,6 +555,7 @@ impl EditPredictionButton {
EditPredictionProvider::Zed
| EditPredictionProvider::Copilot
| EditPredictionProvider::Supermaven
+ | EditPredictionProvider::Codestral
) {
menu = menu
.separator()
@@ -719,6 +782,25 @@ impl EditPredictionButton {
})
}
+ fn build_codestral_context_menu(
+ &self,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Entity<ContextMenu> {
+ let fs = self.fs.clone();
+ ContextMenu::build(window, cx, |menu, window, cx| {
+ self.build_language_settings_menu(menu, window, cx)
+ .separator()
+ .entry("Use Zed AI instead", None, move |_, cx| {
+ set_completion_provider(fs.clone(), cx, EditPredictionProvider::Zed)
+ })
+ .separator()
+ .entry("Configure Codestral API Key", None, move |window, cx| {
+ window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx);
+ })
+ })
+ }
+
fn build_zeta_context_menu(
&self,
window: &mut Window,
@@ -377,6 +377,8 @@ pub struct EditPredictionSettings {
pub mode: settings::EditPredictionsMode,
/// Settings specific to GitHub Copilot.
pub copilot: CopilotSettings,
+ /// Settings specific to Codestral.
+ pub codestral: CodestralSettings,
/// Whether edit predictions are enabled in the assistant panel.
/// This setting has no effect if globally disabled.
pub enabled_in_text_threads: bool,
@@ -412,6 +414,14 @@ pub struct CopilotSettings {
pub enterprise_uri: Option<String>,
}
+#[derive(Clone, Debug, Default)]
+pub struct CodestralSettings {
+ /// Model to use for completions.
+ pub model: Option<String>,
+ /// Maximum tokens to generate.
+ pub max_tokens: Option<u32>,
+}
+
impl AllLanguageSettings {
/// Returns the [`LanguageSettings`] for the language with the specified name.
pub fn language<'a>(
@@ -622,6 +632,12 @@ impl settings::Settings for AllLanguageSettings {
enterprise_uri: copilot.enterprise_uri,
};
+ let codestral = edit_predictions.codestral.unwrap();
+ let codestral_settings = CodestralSettings {
+ model: codestral.model,
+ max_tokens: codestral.max_tokens,
+ };
+
let enabled_in_text_threads = edit_predictions.enabled_in_text_threads.unwrap();
let mut file_types: FxHashMap<Arc<str>, GlobSet> = FxHashMap::default();
@@ -655,6 +671,7 @@ impl settings::Settings for AllLanguageSettings {
.collect(),
mode: edit_predictions_mode,
copilot: copilot_settings,
+ codestral: codestral_settings,
enabled_in_text_threads,
},
defaults: default_language_settings,
@@ -118,14 +118,14 @@ impl LanguageModelRegistry {
}
#[cfg(any(test, feature = "test-support"))]
- pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
- let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default();
+ pub fn test(cx: &mut App) -> Arc<crate::fake_provider::FakeLanguageModelProvider> {
+ let fake_provider = Arc::new(crate::fake_provider::FakeLanguageModelProvider::default());
let registry = cx.new(|cx| {
let mut registry = Self::default();
registry.register_provider(fake_provider.clone(), cx);
let model = fake_provider.provided_models(cx)[0].clone();
let configured_model = ConfiguredModel {
- provider: Arc::new(fake_provider.clone()),
+ provider: fake_provider.clone(),
model,
};
registry.set_default_model(Some(configured_model), cx);
@@ -137,7 +137,7 @@ impl LanguageModelRegistry {
pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
&mut self,
- provider: T,
+ provider: Arc<T>,
cx: &mut Context<Self>,
) {
let id = provider.id();
@@ -152,7 +152,7 @@ impl LanguageModelRegistry {
subscription.detach();
}
- self.providers.insert(id.clone(), Arc::new(provider));
+ self.providers.insert(id.clone(), provider);
cx.emit(Event::AddedProvider(id));
}
@@ -395,7 +395,7 @@ mod tests {
fn test_register_providers(cx: &mut App) {
let registry = cx.new(|_| LanguageModelRegistry::default());
- let provider = FakeLanguageModelProvider::default();
+ let provider = Arc::new(FakeLanguageModelProvider::default());
registry.update(cx, |registry, cx| {
registry.register_provider(provider.clone(), cx);
});
@@ -18,7 +18,7 @@ use crate::provider::cloud::CloudLanguageModelProvider;
use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
use crate::provider::google::GoogleLanguageModelProvider;
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
-use crate::provider::mistral::MistralLanguageModelProvider;
+pub use crate::provider::mistral::MistralLanguageModelProvider;
use crate::provider::ollama::OllamaLanguageModelProvider;
use crate::provider::open_ai::OpenAiLanguageModelProvider;
use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
@@ -87,11 +87,11 @@ fn register_openai_compatible_providers(
for provider_id in new {
if !old.contains(provider_id) {
registry.register_provider(
- OpenAiCompatibleLanguageModelProvider::new(
+ Arc::new(OpenAiCompatibleLanguageModelProvider::new(
provider_id.clone(),
client.http_client(),
cx,
- ),
+ )),
cx,
);
}
@@ -105,50 +105,62 @@ fn register_language_model_providers(
cx: &mut Context<LanguageModelRegistry>,
) {
registry.register_provider(
- CloudLanguageModelProvider::new(user_store, client.clone(), cx),
+ Arc::new(CloudLanguageModelProvider::new(
+ user_store,
+ client.clone(),
+ cx,
+ )),
+ cx,
+ );
+ registry.register_provider(
+ Arc::new(AnthropicLanguageModelProvider::new(
+ client.http_client(),
+ cx,
+ )),
cx,
);
-
registry.register_provider(
- AnthropicLanguageModelProvider::new(client.http_client(), cx),
+ Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
registry.register_provider(
- OpenAiLanguageModelProvider::new(client.http_client(), cx),
+ Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
registry.register_provider(
- OllamaLanguageModelProvider::new(client.http_client(), cx),
+ Arc::new(LmStudioLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
registry.register_provider(
- LmStudioLanguageModelProvider::new(client.http_client(), cx),
+ Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
registry.register_provider(
- DeepSeekLanguageModelProvider::new(client.http_client(), cx),
+ Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
registry.register_provider(
- GoogleLanguageModelProvider::new(client.http_client(), cx),
+ MistralLanguageModelProvider::global(client.http_client(), cx),
cx,
);
registry.register_provider(
- MistralLanguageModelProvider::new(client.http_client(), cx),
+ Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
registry.register_provider(
- BedrockLanguageModelProvider::new(client.http_client(), cx),
+ Arc::new(OpenRouterLanguageModelProvider::new(
+ client.http_client(),
+ cx,
+ )),
cx,
);
registry.register_provider(
- OpenRouterLanguageModelProvider::new(client.http_client(), cx),
+ Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
registry.register_provider(
- VercelLanguageModelProvider::new(client.http_client(), cx),
+ Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
- registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
- registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
+ registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
}
@@ -1,7 +1,8 @@
use anyhow::{Result, anyhow};
use collections::BTreeMap;
+use fs::Fs;
use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
@@ -10,9 +11,9 @@ use language_model::{
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, TokenUsage,
};
-use mistral::{MISTRAL_API_URL, StreamResponse};
+use mistral::{CODESTRAL_API_URL, MISTRAL_API_URL, StreamResponse};
pub use settings::MistralAvailableModel as AvailableModel;
-use settings::{Settings, SettingsStore};
+use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file};
use std::collections::HashMap;
use std::pin::Pin;
use std::str::FromStr;
@@ -31,6 +32,9 @@ const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new(
const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
+const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY";
+static CODESTRAL_API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME);
+
#[derive(Default, Clone, Debug, PartialEq)]
pub struct MistralSettings {
pub api_url: String,
@@ -44,6 +48,7 @@ pub struct MistralLanguageModelProvider {
pub struct State {
api_key_state: ApiKeyState,
+ codestral_api_key_state: ApiKeyState,
}
impl State {
@@ -57,6 +62,19 @@ impl State {
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
+ fn set_codestral_api_key(
+ &mut self,
+ api_key: Option<String>,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<()>> {
+ self.codestral_api_key_state.store(
+ CODESTRAL_API_URL.into(),
+ api_key,
+ |this| &mut this.codestral_api_key_state,
+ cx,
+ )
+ }
+
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = MistralLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
@@ -66,10 +84,34 @@ impl State {
cx,
)
}
+
+ fn authenticate_codestral(
+ &mut self,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<(), AuthenticateError>> {
+ self.codestral_api_key_state.load_if_needed(
+ CODESTRAL_API_URL.into(),
+ &CODESTRAL_API_KEY_ENV_VAR,
+ |this| &mut this.codestral_api_key_state,
+ cx,
+ )
+ }
}
+struct GlobalMistralLanguageModelProvider(Arc<MistralLanguageModelProvider>);
+
+impl Global for GlobalMistralLanguageModelProvider {}
+
impl MistralLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ pub fn try_global(cx: &App) -> Option<&Arc<MistralLanguageModelProvider>> {
+ cx.try_global::<GlobalMistralLanguageModelProvider>()
+ .map(|this| &this.0)
+ }
+
+ pub fn global(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Arc<Self> {
+ if let Some(this) = cx.try_global::<GlobalMistralLanguageModelProvider>() {
+ return this.0.clone();
+ }
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
@@ -84,10 +126,22 @@ impl MistralLanguageModelProvider {
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
+ codestral_api_key_state: ApiKeyState::new(CODESTRAL_API_URL.into()),
}
});
- Self { http_client, state }
+ let this = Arc::new(Self { http_client, state });
+ cx.set_global(GlobalMistralLanguageModelProvider(this));
+ cx.global::<GlobalMistralLanguageModelProvider>().0.clone()
+ }
+
+ pub fn load_codestral_api_key(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
+ self.state
+ .update(cx, |state, cx| state.authenticate_codestral(cx))
+ }
+
+ pub fn codestral_api_key(&self, url: &str, cx: &App) -> Option<Arc<str>> {
+ self.state.read(cx).codestral_api_key_state.key(url)
}
fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
@@ -691,6 +745,7 @@ struct RawToolCall {
struct ConfigurationView {
api_key_editor: Entity<SingleLineInput>,
+ codestral_api_key_editor: Entity<SingleLineInput>,
state: Entity<State>,
load_credentials_task: Option<Task<()>>,
}
@@ -699,6 +754,8 @@ impl ConfigurationView {
fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
let api_key_editor =
cx.new(|cx| SingleLineInput::new(window, cx, "0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2"));
+ let codestral_api_key_editor =
+ cx.new(|cx| SingleLineInput::new(window, cx, "0aBCDEFGhIjKLmNOpqrSTUVwxyzabCDE1f2"));
cx.observe(&state, |_, _, cx| {
cx.notify();
@@ -715,6 +772,12 @@ impl ConfigurationView {
// We don't log an error, because "not signed in" is also an error.
let _ = task.await;
}
+ if let Some(task) = state
+ .update(cx, |state, cx| state.authenticate_codestral(cx))
+ .log_err()
+ {
+ let _ = task.await;
+ }
this.update(cx, |this, cx| {
this.load_credentials_task = None;
@@ -726,6 +789,7 @@ impl ConfigurationView {
Self {
api_key_editor,
+ codestral_api_key_editor,
state,
load_credentials_task,
}
@@ -763,47 +827,92 @@ impl ConfigurationView {
.detach_and_log_err(cx);
}
- fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
- !self.state.read(cx).is_authenticated()
+ fn save_codestral_api_key(
+ &mut self,
+ _: &menu::Confirm,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let api_key = self
+ .codestral_api_key_editor
+ .read(cx)
+ .text(cx)
+ .trim()
+ .to_string();
+ if api_key.is_empty() {
+ return;
+ }
+
+ // url changes can cause the editor to be displayed again
+ self.codestral_api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| {
+ state.set_codestral_api_key(Some(api_key), cx)
+ })?
+ .await?;
+ cx.update(|_window, cx| {
+ set_edit_prediction_provider(EditPredictionProvider::Codestral, cx)
+ })
+ })
+ .detach_and_log_err(cx);
}
-}
-impl Render for ConfigurationView {
- fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
+ fn reset_codestral_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ self.codestral_api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
- if self.load_credentials_task.is_some() {
- div().child(Label::new("Loading credentials...")).into_any()
- } else if self.should_render_editor(cx) {
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_codestral_api_key(None, cx))?
+ .await?;
+ cx.update(|_window, cx| set_edit_prediction_provider(EditPredictionProvider::Zed, cx))
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
+ !self.state.read(cx).is_authenticated()
+ }
+
+ fn render_codestral_api_key_editor(&mut self, cx: &mut Context<Self>) -> AnyElement {
+ let key_state = &self.state.read(cx).codestral_api_key_state;
+ let should_show_editor = !key_state.has_key();
+ let env_var_set = key_state.is_from_env_var();
+ if should_show_editor {
v_flex()
+ .id("codestral")
.size_full()
- .on_action(cx.listener(Self::save_api_key))
- .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
+ .mt_2()
+ .on_action(cx.listener(Self::save_codestral_api_key))
+ .child(Label::new(
+ "To use Codestral as an edit prediction provider, \
+ you need to add a Codestral-specific API key. Follow these steps:",
+ ))
.child(
List::new()
.child(InstructionListItem::new(
"Create one by visiting",
- Some("Mistral's console"),
- Some("https://console.mistral.ai/api-keys"),
+ Some("the Codestral section of Mistral's console"),
+ Some("https://console.mistral.ai/codestral"),
))
- .child(InstructionListItem::text_only(
- "Ensure your Mistral account has credits",
- ))
- .child(InstructionListItem::text_only(
- "Paste your API key below and hit enter to start using the assistant",
- )),
+ .child(InstructionListItem::text_only("Paste your API key below and hit enter")),
)
- .child(self.api_key_editor.clone())
+ .child(self.codestral_api_key_editor.clone())
.child(
Label::new(
- format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
+ format!("You can also assign the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
)
.size(LabelSize::Small).color(Color::Muted),
- )
- .into_any()
+ ).into_any()
} else {
h_flex()
- .mt_1()
+ .id("codestral")
+ .mt_2()
.p_1()
.justify_between()
.rounded_md()
@@ -815,14 +924,9 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
- format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
+ format!("API key set in {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable")
} else {
- let api_url = MistralLanguageModelProvider::api_url(cx);
- if api_url == MISTRAL_API_URL {
- "API key configured".to_string()
- } else {
- format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
- }
+ "Codestral API key configured".to_string()
})),
)
.child(
@@ -833,15 +937,121 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.when(env_var_set, |this| {
- this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
+ this.tooltip(Tooltip::text(format!(
+ "To reset your API key, \
+ unset the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable."
+ )))
})
- .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
+ .on_click(
+ cx.listener(|this, _, window, cx| this.reset_codestral_api_key(window, cx)),
+ ),
+ ).into_any()
+ }
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
+
+ if self.load_credentials_task.is_some() {
+ div().child(Label::new("Loading credentials...")).into_any()
+ } else if self.should_render_api_key_editor(cx) {
+ v_flex()
+ .size_full()
+ .on_action(cx.listener(Self::save_api_key))
+ .child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
+ .child(
+ List::new()
+ .child(InstructionListItem::new(
+ "Create one by visiting",
+ Some("Mistral's console"),
+ Some("https://console.mistral.ai/api-keys"),
+ ))
+ .child(InstructionListItem::text_only(
+ "Ensure your Mistral account has credits",
+ ))
+ .child(InstructionListItem::text_only(
+ "Paste your API key below and hit enter to start using the assistant",
+ )),
)
+ .child(self.api_key_editor.clone())
+ .child(
+ Label::new(
+ format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
+ )
+ .size(LabelSize::Small).color(Color::Muted),
+ )
+ .child(self.render_codestral_api_key_editor(cx))
+ .into_any()
+ } else {
+ v_flex()
+ .size_full()
+ .child(
+ h_flex()
+ .mt_1()
+ .p_1()
+ .justify_between()
+ .rounded_md()
+ .border_1()
+ .border_color(cx.theme().colors().border)
+ .bg(cx.theme().colors().background)
+ .child(
+ h_flex()
+ .gap_1()
+ .child(Icon::new(IconName::Check).color(Color::Success))
+ .child(Label::new(if env_var_set {
+ format!(
+ "API key set in {API_KEY_ENV_VAR_NAME} environment variable"
+ )
+ } else {
+ let api_url = MistralLanguageModelProvider::api_url(cx);
+ if api_url == MISTRAL_API_URL {
+ "API key configured".to_string()
+ } else {
+ format!(
+ "API key configured for {}",
+ truncate_and_trailoff(&api_url, 32)
+ )
+ }
+ })),
+ )
+ .child(
+ Button::new("reset-key", "Reset Key")
+ .label_size(LabelSize::Small)
+ .icon(Some(IconName::Trash))
+ .icon_size(IconSize::Small)
+ .icon_position(IconPosition::Start)
+ .disabled(env_var_set)
+ .when(env_var_set, |this| {
+ this.tooltip(Tooltip::text(format!(
+ "To reset your API key, \
+ unset the {API_KEY_ENV_VAR_NAME} environment variable."
+ )))
+ })
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.reset_api_key(window, cx)
+ })),
+ ),
+ )
+ .child(self.render_codestral_api_key_editor(cx))
.into_any()
}
}
}
+fn set_edit_prediction_provider(provider: EditPredictionProvider, cx: &mut App) {
+ let fs = <dyn Fs>::global(cx);
+ update_settings_file(fs, cx, move |settings, _| {
+ settings
+ .project
+ .all_languages
+ .features
+ .get_or_insert_default()
+ .edit_prediction_provider = Some(provider);
+ });
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -7,6 +7,7 @@ use std::convert::TryFrom;
use strum::EnumIter;
pub const MISTRAL_API_URL: &str = "https://api.mistral.ai/v1";
+pub const CODESTRAL_API_URL: &str = "https://codestral.mistral.ai";
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
@@ -82,6 +82,7 @@ pub enum EditPredictionProvider {
Copilot,
Supermaven,
Zed,
+ Codestral,
}
impl EditPredictionProvider {
@@ -90,7 +91,8 @@ impl EditPredictionProvider {
EditPredictionProvider::Zed => true,
EditPredictionProvider::None
| EditPredictionProvider::Copilot
- | EditPredictionProvider::Supermaven => false,
+ | EditPredictionProvider::Supermaven
+ | EditPredictionProvider::Codestral => false,
}
}
}
@@ -108,6 +110,8 @@ pub struct EditPredictionSettingsContent {
pub mode: Option<EditPredictionsMode>,
/// Settings specific to GitHub Copilot.
pub copilot: Option<CopilotSettingsContent>,
+ /// Settings specific to Codestral.
+ pub codestral: Option<CodestralSettingsContent>,
/// Whether edit predictions are enabled in the assistant prompt editor.
/// This has no effect if globally disabled.
pub enabled_in_text_threads: Option<bool>,
@@ -130,6 +134,20 @@ pub struct CopilotSettingsContent {
pub enterprise_uri: Option<String>,
}
+#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, MergeFrom, PartialEq)]
+pub struct CodestralSettingsContent {
+ /// Model to use for completions.
+ ///
+ /// Default: "codestral-latest"
+ #[serde(default)]
+ pub model: Option<String>,
+ /// Maximum tokens to generate.
+ ///
+ /// Default: 150
+ #[serde(default)]
+ pub max_tokens: Option<u32>,
+}
+
/// The mode in which edit predictions should be displayed.
#[derive(
Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema, MergeFrom,
@@ -39,6 +39,7 @@ channel.workspace = true
clap.workspace = true
cli.workspace = true
client.workspace = true
+codestral.workspace = true
collab_ui.workspace = true
collections.workspace = true
command_palette.workspace = true
@@ -1,9 +1,11 @@
use client::{Client, UserStore};
+use codestral::CodestralCompletionProvider;
use collections::HashMap;
use copilot::{Copilot, CopilotCompletionProvider};
use editor::Editor;
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings};
+use language_models::MistralLanguageModelProvider;
use settings::SettingsStore;
use std::{cell::RefCell, rc::Rc, sync::Arc};
use supermaven::{Supermaven, SupermavenCompletionProvider};
@@ -109,6 +111,10 @@ fn assign_edit_prediction_providers(
user_store: Entity<UserStore>,
cx: &mut App,
) {
+ if provider == EditPredictionProvider::Codestral {
+ let mistral = MistralLanguageModelProvider::global(client.http_client(), cx);
+ mistral.load_codestral_api_key(cx).detach();
+ }
for (editor, window) in editors.borrow().iter() {
_ = window.update(cx, |_window, window, cx| {
_ = editor.update(cx, |editor, cx| {
@@ -189,6 +195,11 @@ fn assign_edit_prediction_provider(
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
+ EditPredictionProvider::Codestral => {
+ let http_client = client.http_client();
+ let provider = cx.new(|_| CodestralCompletionProvider::new(http_client));
+ editor.set_edit_prediction_provider(Some(provider), window, cx);
+ }
EditPredictionProvider::Zed => {
if user_store.read(cx).current_user().is_some() {
let mut worktree = None;
@@ -151,56 +151,10 @@ impl EditPrediction {
}
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
- interpolate(&self.snapshot, new_snapshot, self.edits.clone())
+ edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
}
}
-fn interpolate(
- old_snapshot: &BufferSnapshot,
- new_snapshot: &BufferSnapshot,
- current_edits: Arc<[(Range<Anchor>, String)]>,
-) -> Option<Vec<(Range<Anchor>, String)>> {
- let mut edits = Vec::new();
-
- let mut model_edits = current_edits.iter().peekable();
- for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
- while let Some((model_old_range, _)) = model_edits.peek() {
- let model_old_range = model_old_range.to_offset(old_snapshot);
- if model_old_range.end < user_edit.old.start {
- let (model_old_range, model_new_text) = model_edits.next().unwrap();
- edits.push((model_old_range.clone(), model_new_text.clone()));
- } else {
- break;
- }
- }
-
- if let Some((model_old_range, model_new_text)) = model_edits.peek() {
- let model_old_offset_range = model_old_range.to_offset(old_snapshot);
- if user_edit.old == model_old_offset_range {
- let user_new_text = new_snapshot
- .text_for_range(user_edit.new.clone())
- .collect::<String>();
-
- if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
- if !model_suffix.is_empty() {
- let anchor = old_snapshot.anchor_after(user_edit.old.end);
- edits.push((anchor..anchor, model_suffix.to_string()));
- }
-
- model_edits.next();
- continue;
- }
- }
- }
-
- return None;
- }
-
- edits.extend(model_edits.cloned());
-
- if edits.is_empty() { None } else { Some(edits) }
-}
-
impl std::fmt::Debug for EditPrediction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EditPrediction")
@@ -769,10 +723,11 @@ impl Zeta {
let Some((edits, snapshot, edit_preview)) = buffer.read_with(cx, {
let edits = edits.clone();
- |buffer, cx| {
+ move |buffer, cx| {
let new_snapshot = buffer.snapshot();
let edits: Arc<[(Range<Anchor>, String)]> =
- interpolate(&snapshot, &new_snapshot, edits)?.into();
+ edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)?
+ .into();
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
}
})?