diff --git a/Cargo.lock b/Cargo.lock index eb44454c909377eb1f555a5a9ff9747045fce9be..9653529d7ecf80679b66b7b97f32026b9430559d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -16852,9 +16852,9 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614669bead4741b2fc352ae1967318be16949cf46f59013e548c6dbfdfc01252" +checksum = "1bf21350eced858d129840589158a8f6895c4fa4327ae56dd8c7d6a98495bed4" dependencies = [ "serde", "serde_json", @@ -17076,6 +17076,7 @@ dependencies = [ "postage", "project", "regex", + "release_channel", "reqwest_client", "rpc", "serde", @@ -17085,6 +17086,7 @@ dependencies = [ "telemetry", "telemetry_events", "theme", + "thiserror 1.0.69", "tree-sitter-go", "tree-sitter-rust", "ui", diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs index 8639ad51f9948ffbc8d77b9cbf774c3e29f6123a..21351265447c5dd58b9a4aded99f077a6c52eeee 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -264,7 +264,13 @@ fn assign_edit_prediction_provider( } } - let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx); + let zeta = zeta::Zeta::register( + Some(cx.entity()), + worktree, + client.clone(), + user_store, + cx, + ); if let Some(buffer) = &singleton_buffer { if buffer.read(cx).file().is_some() { diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 1904a4d2bac484394e07ce3f708358c78a79e81d..7e1f46c5fefa9ea6e1c21250a7971029be7b833b 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -39,6 +39,7 @@ menu.workspace = true postage.workspace = true project.workspace = true regex.workspace = true +release_channel.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true @@ -46,6 +47,7 @@ similar.workspace = true telemetry.workspace = true telemetry_events.workspace = true theme.workspace = true +thiserror.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index cc60ad46ec78f092c28e5944843cb909f5f5e106..7627b1e832e6f9beffd3675318dbd06462e5cad6 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -9,6 +9,7 @@ mod rate_completion_modal; pub(crate) use completion_diff_element::*; use db::kvp::KEY_VALUE_STORE; +use editor::Editor; pub use init::*; use inline_completion::DataCollectionState; pub use license_detection::is_license_eligible_for_data_collection; @@ -20,10 +21,10 @@ use anyhow::{anyhow, Context as _, Result}; use arrayvec::ArrayVec; use client::{Client, UserStore}; use collections::{HashMap, HashSet, VecDeque}; -use feature_flags::FeatureFlagAppExt as _; use futures::AsyncReadExt; use gpui::{ - actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task, + actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion, + Subscription, Task, }; use http_client::{HttpClient, Method}; use input_excerpt::excerpt_for_cursor_position; @@ -34,7 +35,9 @@ use language::{ use language_models::LlmApiToken; use postage::watch; use project::Project; +use release_channel::AppVersion; use settings::WorktreeId; +use std::str::FromStr; use std::{ borrow::Cow, cmp, @@ -48,10 +51,16 @@ use std::{ time::{Duration, Instant}, }; use telemetry_events::InlineCompletionRating; +use thiserror::Error; use util::ResultExt; use uuid::Uuid; +use workspace::notifications::{ErrorMessagePrompt, NotificationId}; +use workspace::Workspace; use worktree::Worktree; -use zed_llm_client::{PredictEditsBody, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME}; +use zed_llm_client::{ + PredictEditsBody, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, + MINIMUM_REQUIRED_VERSION_HEADER_NAME, +}; const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>"; const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>"; @@ -178,6 +187,7 @@ impl std::fmt::Debug for InlineCompletion { } pub struct Zeta { + editor: Option>, client: Arc, events: VecDeque, registered_buffers: HashMap, @@ -188,6 +198,8 @@ pub struct Zeta { _llm_token_subscription: Subscription, /// Whether the terms of service have been accepted. tos_accepted: bool, + /// Whether an update to a newer version of Zed is required to continue using Zeta. + update_required: bool, _user_store_subscription: Subscription, license_detection_watchers: HashMap>, } @@ -198,13 +210,14 @@ impl Zeta { } pub fn register( + editor: Option>, worktree: Option>, client: Arc, user_store: Entity, cx: &mut App, ) -> Entity { let this = Self::global(cx).unwrap_or_else(|| { - let entity = cx.new(|cx| Self::new(client, user_store, cx)); + let entity = cx.new(|cx| Self::new(editor, client, user_store, cx)); cx.set_global(ZetaGlobal(entity.clone())); entity }); @@ -226,13 +239,19 @@ impl Zeta { self.events.clear(); } - fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { + fn new( + editor: Option>, + client: Arc, + user_store: Entity, + cx: &mut Context, + ) -> Self { let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx); let data_collection_choice = Self::load_data_collection_choices(); let data_collection_choice = cx.new(|_| data_collection_choice); Self { + editor, client, events: VecDeque::new(), shown_completions: VecDeque::new(), @@ -256,6 +275,7 @@ impl Zeta { .read(cx) .current_user_has_accepted_terms() .unwrap_or(false), + update_required: false, _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| { match event { client::user::Event::PrivateUserInfoUpdated => { @@ -335,8 +355,10 @@ impl Zeta { } } - pub fn request_completion_impl( + #[allow(clippy::too_many_arguments)] + fn request_completion_impl( &mut self, + workspace: Option>, project: Option<&Entity>, buffer: &Entity, cursor: language::Anchor, @@ -345,7 +367,7 @@ impl Zeta { perform_predict_edits: F, ) -> Task>> where - F: FnOnce(Arc, LlmApiToken, bool, PredictEditsBody) -> R + 'static, + F: FnOnce(PerformPredictEditsParams) -> R + 'static, R: Future> + Send + 'static, { let snapshot = self.report_changes_for_buffer(&buffer, cx); @@ -358,9 +380,10 @@ impl Zeta { .map(|f| Arc::from(f.full_path(cx).as_path())) .unwrap_or_else(|| Arc::from(Path::new("untitled"))); + let zeta = cx.entity(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); - let is_staff = cx.is_staff(); + let app_version = AppVersion::global(cx); let buffer = buffer.clone(); @@ -447,7 +470,46 @@ impl Zeta { }), }; - let response = perform_predict_edits(client, llm_token, is_staff, body).await?; + let response = perform_predict_edits(PerformPredictEditsParams { + client, + llm_token, + app_version, + body, + }) + .await; + let response = match response { + Ok(response) => response, + Err(err) => { + if err.is::() { + cx.update(|cx| { + zeta.update(cx, |zeta, _cx| { + zeta.update_required = true; + }); + + if let Some(workspace) = workspace { + workspace.update(cx, |workspace, cx| { + workspace.show_notification( + NotificationId::unique::(), + cx, + |cx| { + cx.new(|_| { + ErrorMessagePrompt::new(err.to_string()) + .with_link_button( + "Update Zed", + "https://zed.dev/releases", + ) + }) + }, + ); + }); + } + }) + .ok(); + } + + return Err(err); + } + }; log::debug!("completion response: {}", &response.output_excerpt); @@ -632,7 +694,7 @@ and then another ) -> Task>> { use std::future::ready; - self.request_completion_impl(project, buffer, position, false, cx, |_, _, _, _| { + self.request_completion_impl(None, project, buffer, position, false, cx, |_params| { ready(Ok(response)) }) } @@ -645,7 +707,12 @@ and then another can_collect_data: bool, cx: &mut Context, ) -> Task>> { + let workspace = self + .editor + .as_ref() + .and_then(|editor| editor.read(cx).workspace()); self.request_completion_impl( + workspace, project, buffer, position, @@ -656,12 +723,17 @@ and then another } fn perform_predict_edits( - client: Arc, - llm_token: LlmApiToken, - _is_staff: bool, - body: PredictEditsBody, + params: PerformPredictEditsParams, ) -> impl Future> { async move { + let PerformPredictEditsParams { + client, + llm_token, + app_version, + body, + .. + } = params; + let http_client = client.http_client(); let mut token = llm_token.acquire(&client).await?; let mut did_retry = false; @@ -685,6 +757,18 @@ and then another let mut response = http_client.send(request).await?; + if let Some(minimum_required_version) = response + .headers() + .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) + .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok()) + { + if app_version < minimum_required_version { + return Err(anyhow!(ZedUpdateRequiredError { + minimum_version: minimum_required_version + })); + } + } + if response.status().is_success() { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; @@ -1011,6 +1095,21 @@ and then another } } +struct PerformPredictEditsParams { + pub client: Arc, + pub llm_token: LlmApiToken, + pub app_version: SemanticVersion, + pub body: PredictEditsBody, +} + +#[derive(Error, Debug)] +#[error( + "You must update to Zed version {minimum_version} or higher to continue using edit predictions." +)] +pub struct ZedUpdateRequiredError { + minimum_version: SemanticVersion, +} + struct LicenseDetectionWatcher { is_open_source_rx: watch::Receiver, _is_open_source_task: Task<()>, @@ -1406,6 +1505,10 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider return; } + if self.zeta.read(cx).update_required { + return; + } + if let Some(current_completion) = self.current_completion.as_ref() { let snapshot = buffer.read(cx).snapshot(); if current_completion @@ -1837,7 +1940,7 @@ mod tests { }); let server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(client, user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); @@ -1890,7 +1993,7 @@ mod tests { }); let server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(client, user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());