From 277fb546321d6f4740ba7e668a5443ecbf177075 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 12 Feb 2025 18:58:38 -0500 Subject: [PATCH] zeta: Respect `x-zed-minimum-required-version` header (#24771) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR makes it so Zeta respects the `x-zed-minimum-required-version` header sent back from the server. If the current Zed version is strictly less than the indicated minimum required version, we show an error indicating that an update is required in order to continue using Zeta: Screenshot 2025-02-12 at 6 15 44 PM Release Notes: - N/A --- Cargo.lock | 6 +- .../zed/src/zed/inline_completion_registry.rs | 8 +- crates/zeta/Cargo.toml | 2 + crates/zeta/src/zeta.rs | 135 +++++++++++++++--- 4 files changed, 132 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eb44454c909377eb1f555a5a9ff9747045fce9be..9653529d7ecf80679b66b7b97f32026b9430559d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -16852,9 +16852,9 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614669bead4741b2fc352ae1967318be16949cf46f59013e548c6dbfdfc01252" +checksum = "1bf21350eced858d129840589158a8f6895c4fa4327ae56dd8c7d6a98495bed4" dependencies = [ "serde", "serde_json", @@ -17076,6 +17076,7 @@ dependencies = [ "postage", "project", "regex", + "release_channel", "reqwest_client", "rpc", "serde", @@ -17085,6 +17086,7 @@ dependencies = [ "telemetry", "telemetry_events", "theme", + "thiserror 1.0.69", "tree-sitter-go", "tree-sitter-rust", "ui", diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs index 8639ad51f9948ffbc8d77b9cbf774c3e29f6123a..21351265447c5dd58b9a4aded99f077a6c52eeee 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -264,7 +264,13 @@ fn assign_edit_prediction_provider( } } - let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx); + let zeta = zeta::Zeta::register( + Some(cx.entity()), + worktree, + client.clone(), + user_store, + cx, + ); if let Some(buffer) = &singleton_buffer { if buffer.read(cx).file().is_some() { diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 1904a4d2bac484394e07ce3f708358c78a79e81d..7e1f46c5fefa9ea6e1c21250a7971029be7b833b 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -39,6 +39,7 @@ menu.workspace = true postage.workspace = true project.workspace = true regex.workspace = true +release_channel.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true @@ -46,6 +47,7 @@ similar.workspace = true telemetry.workspace = true telemetry_events.workspace = true theme.workspace = true +thiserror.workspace = true ui.workspace = true util.workspace = true uuid.workspace = true diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index cc60ad46ec78f092c28e5944843cb909f5f5e106..7627b1e832e6f9beffd3675318dbd06462e5cad6 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -9,6 +9,7 @@ mod rate_completion_modal; pub(crate) use completion_diff_element::*; use db::kvp::KEY_VALUE_STORE; +use editor::Editor; pub use init::*; use inline_completion::DataCollectionState; pub use license_detection::is_license_eligible_for_data_collection; @@ -20,10 +21,10 @@ use anyhow::{anyhow, Context as _, Result}; use arrayvec::ArrayVec; use client::{Client, UserStore}; use collections::{HashMap, HashSet, VecDeque}; -use feature_flags::FeatureFlagAppExt as _; use futures::AsyncReadExt; use gpui::{ - actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task, + actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion, + Subscription, Task, }; use http_client::{HttpClient, Method}; use input_excerpt::excerpt_for_cursor_position; @@ -34,7 +35,9 @@ use language::{ use language_models::LlmApiToken; use postage::watch; use project::Project; +use release_channel::AppVersion; use settings::WorktreeId; +use std::str::FromStr; use std::{ borrow::Cow, cmp, @@ -48,10 +51,16 @@ use std::{ time::{Duration, Instant}, }; use telemetry_events::InlineCompletionRating; +use thiserror::Error; use util::ResultExt; use uuid::Uuid; +use workspace::notifications::{ErrorMessagePrompt, NotificationId}; +use workspace::Workspace; use worktree::Worktree; -use zed_llm_client::{PredictEditsBody, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME}; +use zed_llm_client::{ + PredictEditsBody, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, + MINIMUM_REQUIRED_VERSION_HEADER_NAME, +}; const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>"; const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>"; @@ -178,6 +187,7 @@ impl std::fmt::Debug for InlineCompletion { } pub struct Zeta { + editor: Option>, client: Arc, events: VecDeque, registered_buffers: HashMap, @@ -188,6 +198,8 @@ pub struct Zeta { _llm_token_subscription: Subscription, /// Whether the terms of service have been accepted. tos_accepted: bool, + /// Whether an update to a newer version of Zed is required to continue using Zeta. + update_required: bool, _user_store_subscription: Subscription, license_detection_watchers: HashMap>, } @@ -198,13 +210,14 @@ impl Zeta { } pub fn register( + editor: Option>, worktree: Option>, client: Arc, user_store: Entity, cx: &mut App, ) -> Entity { let this = Self::global(cx).unwrap_or_else(|| { - let entity = cx.new(|cx| Self::new(client, user_store, cx)); + let entity = cx.new(|cx| Self::new(editor, client, user_store, cx)); cx.set_global(ZetaGlobal(entity.clone())); entity }); @@ -226,13 +239,19 @@ impl Zeta { self.events.clear(); } - fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { + fn new( + editor: Option>, + client: Arc, + user_store: Entity, + cx: &mut Context, + ) -> Self { let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx); let data_collection_choice = Self::load_data_collection_choices(); let data_collection_choice = cx.new(|_| data_collection_choice); Self { + editor, client, events: VecDeque::new(), shown_completions: VecDeque::new(), @@ -256,6 +275,7 @@ impl Zeta { .read(cx) .current_user_has_accepted_terms() .unwrap_or(false), + update_required: false, _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| { match event { client::user::Event::PrivateUserInfoUpdated => { @@ -335,8 +355,10 @@ impl Zeta { } } - pub fn request_completion_impl( + #[allow(clippy::too_many_arguments)] + fn request_completion_impl( &mut self, + workspace: Option>, project: Option<&Entity>, buffer: &Entity, cursor: language::Anchor, @@ -345,7 +367,7 @@ impl Zeta { perform_predict_edits: F, ) -> Task>> where - F: FnOnce(Arc, LlmApiToken, bool, PredictEditsBody) -> R + 'static, + F: FnOnce(PerformPredictEditsParams) -> R + 'static, R: Future> + Send + 'static, { let snapshot = self.report_changes_for_buffer(&buffer, cx); @@ -358,9 +380,10 @@ impl Zeta { .map(|f| Arc::from(f.full_path(cx).as_path())) .unwrap_or_else(|| Arc::from(Path::new("untitled"))); + let zeta = cx.entity(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); - let is_staff = cx.is_staff(); + let app_version = AppVersion::global(cx); let buffer = buffer.clone(); @@ -447,7 +470,46 @@ impl Zeta { }), }; - let response = perform_predict_edits(client, llm_token, is_staff, body).await?; + let response = perform_predict_edits(PerformPredictEditsParams { + client, + llm_token, + app_version, + body, + }) + .await; + let response = match response { + Ok(response) => response, + Err(err) => { + if err.is::() { + cx.update(|cx| { + zeta.update(cx, |zeta, _cx| { + zeta.update_required = true; + }); + + if let Some(workspace) = workspace { + workspace.update(cx, |workspace, cx| { + workspace.show_notification( + NotificationId::unique::(), + cx, + |cx| { + cx.new(|_| { + ErrorMessagePrompt::new(err.to_string()) + .with_link_button( + "Update Zed", + "https://zed.dev/releases", + ) + }) + }, + ); + }); + } + }) + .ok(); + } + + return Err(err); + } + }; log::debug!("completion response: {}", &response.output_excerpt); @@ -632,7 +694,7 @@ and then another ) -> Task>> { use std::future::ready; - self.request_completion_impl(project, buffer, position, false, cx, |_, _, _, _| { + self.request_completion_impl(None, project, buffer, position, false, cx, |_params| { ready(Ok(response)) }) } @@ -645,7 +707,12 @@ and then another can_collect_data: bool, cx: &mut Context, ) -> Task>> { + let workspace = self + .editor + .as_ref() + .and_then(|editor| editor.read(cx).workspace()); self.request_completion_impl( + workspace, project, buffer, position, @@ -656,12 +723,17 @@ and then another } fn perform_predict_edits( - client: Arc, - llm_token: LlmApiToken, - _is_staff: bool, - body: PredictEditsBody, + params: PerformPredictEditsParams, ) -> impl Future> { async move { + let PerformPredictEditsParams { + client, + llm_token, + app_version, + body, + .. + } = params; + let http_client = client.http_client(); let mut token = llm_token.acquire(&client).await?; let mut did_retry = false; @@ -685,6 +757,18 @@ and then another let mut response = http_client.send(request).await?; + if let Some(minimum_required_version) = response + .headers() + .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) + .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok()) + { + if app_version < minimum_required_version { + return Err(anyhow!(ZedUpdateRequiredError { + minimum_version: minimum_required_version + })); + } + } + if response.status().is_success() { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; @@ -1011,6 +1095,21 @@ and then another } } +struct PerformPredictEditsParams { + pub client: Arc, + pub llm_token: LlmApiToken, + pub app_version: SemanticVersion, + pub body: PredictEditsBody, +} + +#[derive(Error, Debug)] +#[error( + "You must update to Zed version {minimum_version} or higher to continue using edit predictions." +)] +pub struct ZedUpdateRequiredError { + minimum_version: SemanticVersion, +} + struct LicenseDetectionWatcher { is_open_source_rx: watch::Receiver, _is_open_source_task: Task<()>, @@ -1406,6 +1505,10 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider return; } + if self.zeta.read(cx).update_required { + return; + } + if let Some(current_completion) = self.current_completion.as_ref() { let snapshot = buffer.read(cx).snapshot(); if current_completion @@ -1837,7 +1940,7 @@ mod tests { }); let server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(client, user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); @@ -1890,7 +1993,7 @@ mod tests { }); let server = FakeServer::for_client(42, &client, cx).await; let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(client, user_store, cx)); + let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());