@@ -9,6 +9,7 @@ mod rate_completion_modal;
pub(crate) use completion_diff_element::*;
use db::kvp::KEY_VALUE_STORE;
+use editor::Editor;
pub use init::*;
use inline_completion::DataCollectionState;
pub use license_detection::is_license_eligible_for_data_collection;
@@ -20,10 +21,10 @@ use anyhow::{anyhow, Context as _, Result};
use arrayvec::ArrayVec;
use client::{Client, UserStore};
use collections::{HashMap, HashSet, VecDeque};
-use feature_flags::FeatureFlagAppExt as _;
use futures::AsyncReadExt;
use gpui::{
- actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
+ actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
+ Subscription, Task,
};
use http_client::{HttpClient, Method};
use input_excerpt::excerpt_for_cursor_position;
@@ -34,7 +35,9 @@ use language::{
use language_models::LlmApiToken;
use postage::watch;
use project::Project;
+use release_channel::AppVersion;
use settings::WorktreeId;
+use std::str::FromStr;
use std::{
borrow::Cow,
cmp,
@@ -48,10 +51,16 @@ use std::{
time::{Duration, Instant},
};
use telemetry_events::InlineCompletionRating;
+use thiserror::Error;
use util::ResultExt;
use uuid::Uuid;
+use workspace::notifications::{ErrorMessagePrompt, NotificationId};
+use workspace::Workspace;
use worktree::Worktree;
-use zed_llm_client::{PredictEditsBody, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
+use zed_llm_client::{
+ PredictEditsBody, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
+ MINIMUM_REQUIRED_VERSION_HEADER_NAME,
+};
const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
@@ -178,6 +187,7 @@ impl std::fmt::Debug for InlineCompletion {
}
pub struct Zeta {
+ editor: Option<Entity<Editor>>,
client: Arc<Client>,
events: VecDeque<Event>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
@@ -188,6 +198,8 @@ pub struct Zeta {
_llm_token_subscription: Subscription,
/// Whether the terms of service have been accepted.
tos_accepted: bool,
+ /// Whether an update to a newer version of Zed is required to continue using Zeta.
+ update_required: bool,
_user_store_subscription: Subscription,
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
}
@@ -198,13 +210,14 @@ impl Zeta {
}
pub fn register(
+ editor: Option<Entity<Editor>>,
worktree: Option<Entity<Worktree>>,
client: Arc<Client>,
user_store: Entity<UserStore>,
cx: &mut App,
) -> Entity<Self> {
let this = Self::global(cx).unwrap_or_else(|| {
- let entity = cx.new(|cx| Self::new(client, user_store, cx));
+ let entity = cx.new(|cx| Self::new(editor, client, user_store, cx));
cx.set_global(ZetaGlobal(entity.clone()));
entity
});
@@ -226,13 +239,19 @@ impl Zeta {
self.events.clear();
}
- fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
+ fn new(
+ editor: Option<Entity<Editor>>,
+ client: Arc<Client>,
+ user_store: Entity<UserStore>,
+ cx: &mut Context<Self>,
+ ) -> Self {
let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
let data_collection_choice = Self::load_data_collection_choices();
let data_collection_choice = cx.new(|_| data_collection_choice);
Self {
+ editor,
client,
events: VecDeque::new(),
shown_completions: VecDeque::new(),
@@ -256,6 +275,7 @@ impl Zeta {
.read(cx)
.current_user_has_accepted_terms()
.unwrap_or(false),
+ update_required: false,
_user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
match event {
client::user::Event::PrivateUserInfoUpdated => {
@@ -335,8 +355,10 @@ impl Zeta {
}
}
- pub fn request_completion_impl<F, R>(
+ #[allow(clippy::too_many_arguments)]
+ fn request_completion_impl<F, R>(
&mut self,
+ workspace: Option<Entity<Workspace>>,
project: Option<&Entity<Project>>,
buffer: &Entity<Buffer>,
cursor: language::Anchor,
@@ -345,7 +367,7 @@ impl Zeta {
perform_predict_edits: F,
) -> Task<Result<Option<InlineCompletion>>>
where
- F: FnOnce(Arc<Client>, LlmApiToken, bool, PredictEditsBody) -> R + 'static,
+ F: FnOnce(PerformPredictEditsParams) -> R + 'static,
R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
{
let snapshot = self.report_changes_for_buffer(&buffer, cx);
@@ -358,9 +380,10 @@ impl Zeta {
.map(|f| Arc::from(f.full_path(cx).as_path()))
.unwrap_or_else(|| Arc::from(Path::new("untitled")));
+ let zeta = cx.entity();
let client = self.client.clone();
let llm_token = self.llm_token.clone();
- let is_staff = cx.is_staff();
+ let app_version = AppVersion::global(cx);
let buffer = buffer.clone();
@@ -447,7 +470,46 @@ impl Zeta {
}),
};
- let response = perform_predict_edits(client, llm_token, is_staff, body).await?;
+ let response = perform_predict_edits(PerformPredictEditsParams {
+ client,
+ llm_token,
+ app_version,
+ body,
+ })
+ .await;
+ let response = match response {
+ Ok(response) => response,
+ Err(err) => {
+ if err.is::<ZedUpdateRequiredError>() {
+ cx.update(|cx| {
+ zeta.update(cx, |zeta, _cx| {
+ zeta.update_required = true;
+ });
+
+ if let Some(workspace) = workspace {
+ workspace.update(cx, |workspace, cx| {
+ workspace.show_notification(
+ NotificationId::unique::<ZedUpdateRequiredError>(),
+ cx,
+ |cx| {
+ cx.new(|_| {
+ ErrorMessagePrompt::new(err.to_string())
+ .with_link_button(
+ "Update Zed",
+ "https://zed.dev/releases",
+ )
+ })
+ },
+ );
+ });
+ }
+ })
+ .ok();
+ }
+
+ return Err(err);
+ }
+ };
log::debug!("completion response: {}", &response.output_excerpt);
@@ -632,7 +694,7 @@ and then another
) -> Task<Result<Option<InlineCompletion>>> {
use std::future::ready;
- self.request_completion_impl(project, buffer, position, false, cx, |_, _, _, _| {
+ self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
ready(Ok(response))
})
}
@@ -645,7 +707,12 @@ and then another
can_collect_data: bool,
cx: &mut Context<Self>,
) -> Task<Result<Option<InlineCompletion>>> {
+ let workspace = self
+ .editor
+ .as_ref()
+ .and_then(|editor| editor.read(cx).workspace());
self.request_completion_impl(
+ workspace,
project,
buffer,
position,
@@ -656,12 +723,17 @@ and then another
}
fn perform_predict_edits(
- client: Arc<Client>,
- llm_token: LlmApiToken,
- _is_staff: bool,
- body: PredictEditsBody,
+ params: PerformPredictEditsParams,
) -> impl Future<Output = Result<PredictEditsResponse>> {
async move {
+ let PerformPredictEditsParams {
+ client,
+ llm_token,
+ app_version,
+ body,
+ ..
+ } = params;
+
let http_client = client.http_client();
let mut token = llm_token.acquire(&client).await?;
let mut did_retry = false;
@@ -685,6 +757,18 @@ and then another
let mut response = http_client.send(request).await?;
+ if let Some(minimum_required_version) = response
+ .headers()
+ .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
+ .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
+ {
+ if app_version < minimum_required_version {
+ return Err(anyhow!(ZedUpdateRequiredError {
+ minimum_version: minimum_required_version
+ }));
+ }
+ }
+
if response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
@@ -1011,6 +1095,21 @@ and then another
}
}
+struct PerformPredictEditsParams {
+ pub client: Arc<Client>,
+ pub llm_token: LlmApiToken,
+ pub app_version: SemanticVersion,
+ pub body: PredictEditsBody,
+}
+
+#[derive(Error, Debug)]
+#[error(
+ "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
+)]
+pub struct ZedUpdateRequiredError {
+ minimum_version: SemanticVersion,
+}
+
struct LicenseDetectionWatcher {
is_open_source_rx: watch::Receiver<bool>,
_is_open_source_task: Task<()>,
@@ -1406,6 +1505,10 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
return;
}
+ if self.zeta.read(cx).update_required {
+ return;
+ }
+
if let Some(current_completion) = self.current_completion.as_ref() {
let snapshot = buffer.read(cx).snapshot();
if current_completion
@@ -1837,7 +1940,7 @@ mod tests {
});
let server = FakeServer::for_client(42, &client, cx).await;
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
+ let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
@@ -1890,7 +1993,7 @@ mod tests {
});
let server = FakeServer::for_client(42, &client, cx).await;
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
+ let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());