@@ -1,4 +1,4 @@
-use client::{Client, DisableAiSettings, UserStore};
+use client::{Client, CloudUserStore, DisableAiSettings};
use collections::HashMap;
use copilot::{Copilot, CopilotCompletionProvider};
use editor::Editor;
@@ -13,12 +13,12 @@ use util::ResultExt;
use workspace::Workspace;
use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
-pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
+pub fn init(client: Arc<Client>, cloud_user_store: Entity<CloudUserStore>, cx: &mut App) {
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
cx.observe_new({
let editors = editors.clone();
let client = client.clone();
- let user_store = user_store.clone();
+ let cloud_user_store = cloud_user_store.clone();
move |editor: &mut Editor, window, cx: &mut Context<Editor>| {
if !editor.mode().is_full() {
return;
@@ -48,7 +48,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
editor,
provider,
&client,
- user_store.clone(),
+ cloud_user_store.clone(),
window,
cx,
);
@@ -60,7 +60,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let mut provider = all_language_settings(None, cx).edit_predictions.provider;
cx.spawn({
- let user_store = user_store.clone();
+ let cloud_user_store = cloud_user_store.clone();
let editors = editors.clone();
let client = client.clone();
@@ -72,7 +72,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
&editors,
provider,
&client,
- user_store.clone(),
+ cloud_user_store.clone(),
cx,
);
})
@@ -85,15 +85,12 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
cx.observe_global::<SettingsStore>({
let editors = editors.clone();
let client = client.clone();
- let user_store = user_store.clone();
+ let cloud_user_store = cloud_user_store.clone();
move |cx| {
let new_provider = all_language_settings(None, cx).edit_predictions.provider;
if new_provider != provider {
- let tos_accepted = user_store
- .read(cx)
- .current_user_has_accepted_terms()
- .unwrap_or(false);
+ let tos_accepted = cloud_user_store.read(cx).has_accepted_tos();
telemetry::event!(
"Edit Prediction Provider Changed",
@@ -107,7 +104,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
&editors,
provider,
&client,
- user_store.clone(),
+ cloud_user_store.clone(),
cx,
);
@@ -148,7 +145,7 @@ fn assign_edit_prediction_providers(
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
provider: EditPredictionProvider,
client: &Arc<Client>,
- user_store: Entity<UserStore>,
+ cloud_user_store: Entity<CloudUserStore>,
cx: &mut App,
) {
for (editor, window) in editors.borrow().iter() {
@@ -158,7 +155,7 @@ fn assign_edit_prediction_providers(
editor,
provider,
&client,
- user_store.clone(),
+ cloud_user_store.clone(),
window,
cx,
);
@@ -213,7 +210,7 @@ fn assign_edit_prediction_provider(
editor: &mut Editor,
provider: EditPredictionProvider,
client: &Arc<Client>,
- user_store: Entity<UserStore>,
+ cloud_user_store: Entity<CloudUserStore>,
window: &mut Window,
cx: &mut Context<Editor>,
) {
@@ -244,7 +241,7 @@ fn assign_edit_prediction_provider(
}
}
EditPredictionProvider::Zed => {
- if client.status().borrow().is_connected() {
+ if cloud_user_store.read(cx).is_authenticated() {
let mut worktree = None;
if let Some(buffer) = &singleton_buffer {
@@ -266,7 +263,7 @@ fn assign_edit_prediction_provider(
.map(|workspace| workspace.downgrade());
let zeta =
- zeta::Zeta::register(workspace, worktree, client.clone(), user_store, cx);
+ zeta::Zeta::register(workspace, worktree, client.clone(), cloud_user_store, cx);
if let Some(buffer) = &singleton_buffer {
if buffer.read(cx).file().is_some() {
@@ -16,7 +16,7 @@ pub use rate_completion_modal::*;
use anyhow::{Context as _, Result, anyhow};
use arrayvec::ArrayVec;
-use client::{Client, EditPredictionUsage, UserStore};
+use client::{Client, CloudUserStore, EditPredictionUsage, UserStore};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
@@ -226,12 +226,9 @@ pub struct Zeta {
data_collection_choice: Entity<DataCollectionChoice>,
llm_token: LlmApiToken,
_llm_token_subscription: Subscription,
- /// Whether the terms of service have been accepted.
- tos_accepted: bool,
/// Whether an update to a newer version of Zed is required to continue using Zeta.
update_required: bool,
- user_store: Entity<UserStore>,
- _user_store_subscription: Subscription,
+ cloud_user_store: Entity<CloudUserStore>,
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
}
@@ -244,11 +241,11 @@ impl Zeta {
workspace: Option<WeakEntity<Workspace>>,
worktree: Option<Entity<Worktree>>,
client: Arc<Client>,
- user_store: Entity<UserStore>,
+ cloud_user_store: Entity<CloudUserStore>,
cx: &mut App,
) -> Entity<Self> {
let this = Self::global(cx).unwrap_or_else(|| {
- let entity = cx.new(|cx| Self::new(workspace, client, user_store, cx));
+ let entity = cx.new(|cx| Self::new(workspace, client, cloud_user_store, cx));
cx.set_global(ZetaGlobal(entity.clone()));
entity
});
@@ -271,13 +268,13 @@ impl Zeta {
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
- self.user_store.read(cx).edit_prediction_usage()
+ self.cloud_user_store.read(cx).edit_prediction_usage()
}
fn new(
workspace: Option<WeakEntity<Workspace>>,
client: Arc<Client>,
- user_store: Entity<UserStore>,
+ cloud_user_store: Entity<CloudUserStore>,
cx: &mut Context<Self>,
) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
@@ -306,24 +303,9 @@ impl Zeta {
.detach_and_log_err(cx);
},
),
- tos_accepted: user_store
- .read(cx)
- .current_user_has_accepted_terms()
- .unwrap_or(false),
update_required: false,
- _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
- match event {
- client::user::Event::PrivateUserInfoUpdated => {
- this.tos_accepted = user_store
- .read(cx)
- .current_user_has_accepted_terms()
- .unwrap_or(false);
- }
- _ => {}
- }
- }),
license_detection_watchers: HashMap::default(),
- user_store,
+ cloud_user_store,
}
}
@@ -552,8 +534,8 @@ impl Zeta {
if let Some(usage) = usage {
this.update(cx, |this, cx| {
- this.user_store.update(cx, |user_store, cx| {
- user_store.update_edit_prediction_usage(usage, cx);
+ this.cloud_user_store.update(cx, |cloud_user_store, cx| {
+ cloud_user_store.update_edit_prediction_usage(usage, cx);
});
})
.ok();
@@ -894,8 +876,8 @@ and then another
if response.status().is_success() {
if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
this.update(cx, |this, cx| {
- this.user_store.update(cx, |user_store, cx| {
- user_store.update_edit_prediction_usage(usage, cx);
+ this.cloud_user_store.update(cx, |cloud_user_store, cx| {
+ cloud_user_store.update_edit_prediction_usage(usage, cx);
});
})?;
}
@@ -1573,7 +1555,12 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
}
fn needs_terms_acceptance(&self, cx: &App) -> bool {
- !self.zeta.read(cx).tos_accepted
+ !self
+ .zeta
+ .read(cx)
+ .cloud_user_store
+ .read(cx)
+ .has_accepted_tos()
}
fn is_refreshing(&self) -> bool {
@@ -1588,7 +1575,7 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
_debounce: bool,
cx: &mut Context<Self>,
) {
- if !self.zeta.read(cx).tos_accepted {
+ if self.needs_terms_acceptance(cx) {
return;
}
@@ -1599,9 +1586,9 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
if self
.zeta
.read(cx)
- .user_store
- .read_with(cx, |user_store, _| {
- user_store.account_too_young() || user_store.has_overdue_invoices()
+ .cloud_user_store
+ .read_with(cx, |cloud_user_store, _cx| {
+ cloud_user_store.account_too_young() || cloud_user_store.has_overdue_invoices()
})
{
return;
@@ -1819,15 +1806,51 @@ fn tokens_for_bytes(bytes: usize) -> usize {
mod tests {
use client::test::FakeServer;
use clock::FakeSystemClock;
+ use cloud_api_types::{
+ AuthenticatedUser, CreateLlmTokenResponse, GetAuthenticatedUserResponse, LlmToken, PlanInfo,
+ };
+ use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
use gpui::TestAppContext;
use http_client::FakeHttpClient;
use indoc::indoc;
use language::Point;
- use rpc::proto;
use settings::SettingsStore;
use super::*;
+ fn make_get_authenticated_user_response() -> GetAuthenticatedUserResponse {
+ GetAuthenticatedUserResponse {
+ user: AuthenticatedUser {
+ id: 1,
+ metrics_id: "metrics-id-1".to_string(),
+ avatar_url: "".to_string(),
+ github_login: "".to_string(),
+ name: None,
+ is_staff: false,
+ accepted_tos_at: None,
+ },
+ feature_flags: vec![],
+ plan: PlanInfo {
+ plan: Plan::ZedPro,
+ subscription_period: None,
+ usage: CurrentUsage {
+ model_requests: UsageData {
+ used: 0,
+ limit: UsageLimit::Limited(500),
+ },
+ edit_predictions: UsageData {
+ used: 250,
+ limit: UsageLimit::Unlimited,
+ },
+ },
+ trial_started_at: None,
+ is_usage_based_billing_enabled: false,
+ is_account_too_young: false,
+ has_overdue_invoices: false,
+ },
+ }
+ }
+
#[gpui::test]
async fn test_inline_completion_basic_interpolation(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
@@ -2027,28 +2050,55 @@ mod tests {
<|editable_region_end|>
```"};
- let http_client = FakeHttpClient::create(move |_| async move {
- Ok(http_client::Response::builder()
- .status(200)
- .body(
- serde_json::to_string(&PredictEditsResponse {
- request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
- .unwrap(),
- output_excerpt: completion_response.to_string(),
- })
- .unwrap()
- .into(),
- )
- .unwrap())
+ let http_client = FakeHttpClient::create(move |req| async move {
+ match (req.method(), req.uri().path()) {
+ (&Method::GET, "/client/users/me") => Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&make_get_authenticated_user_response())
+ .unwrap()
+ .into(),
+ )
+ .unwrap()),
+ (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&CreateLlmTokenResponse {
+ token: LlmToken("the-llm-token".to_string()),
+ })
+ .unwrap()
+ .into(),
+ )
+ .unwrap()),
+ (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&PredictEditsResponse {
+ request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45")
+ .unwrap(),
+ output_excerpt: completion_response.to_string(),
+ })
+ .unwrap()
+ .into(),
+ )
+ .unwrap()),
+ _ => Ok(http_client::Response::builder()
+ .status(404)
+ .body("Not Found".into())
+ .unwrap()),
+ }
});
let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
cx.update(|cx| {
RefreshLlmTokenListener::register(client.clone(), cx);
});
- let server = FakeServer::for_client(42, &client, cx).await;
+ // Construct the fake server to authenticate.
+ let _server = FakeServer::for_client(42, &client, cx).await;
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
+ let cloud_user_store =
+ cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
+ let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx));
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
@@ -2056,13 +2106,6 @@ mod tests {
zeta.request_completion(None, &buffer, cursor, false, cx)
});
- server.receive::<proto::GetUsers>().await.unwrap();
- let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
- server.respond(
- token_request.receipt(),
- proto::GetLlmTokenResponse { token: "".into() },
- );
-
let completion = completion_task.await.unwrap().unwrap();
buffer.update(cx, |buffer, cx| {
buffer.edit(completion.edits.iter().cloned(), None, cx)
@@ -2079,20 +2122,44 @@ mod tests {
cx: &mut TestAppContext,
) -> Vec<(Range<Point>, String)> {
let completion_response = completion_response.to_string();
- let http_client = FakeHttpClient::create(move |_| {
+ let http_client = FakeHttpClient::create(move |req| {
let completion = completion_response.clone();
async move {
- Ok(http_client::Response::builder()
- .status(200)
- .body(
- serde_json::to_string(&PredictEditsResponse {
- request_id: Uuid::new_v4(),
- output_excerpt: completion,
- })
- .unwrap()
- .into(),
- )
- .unwrap())
+ match (req.method(), req.uri().path()) {
+ (&Method::GET, "/client/users/me") => Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&make_get_authenticated_user_response())
+ .unwrap()
+ .into(),
+ )
+ .unwrap()),
+ (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&CreateLlmTokenResponse {
+ token: LlmToken("the-llm-token".to_string()),
+ })
+ .unwrap()
+ .into(),
+ )
+ .unwrap()),
+ (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&PredictEditsResponse {
+ request_id: Uuid::new_v4(),
+ output_excerpt: completion,
+ })
+ .unwrap()
+ .into(),
+ )
+ .unwrap()),
+ _ => Ok(http_client::Response::builder()
+ .status(404)
+ .body("Not Found".into())
+ .unwrap()),
+ }
}
});
@@ -2100,9 +2167,12 @@ mod tests {
cx.update(|cx| {
RefreshLlmTokenListener::register(client.clone(), cx);
});
- let server = FakeServer::for_client(42, &client, cx).await;
+ // Construct the fake server to authenticate.
+ let _server = FakeServer::for_client(42, &client, cx).await;
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- let zeta = cx.new(|cx| Zeta::new(None, client, user_store, cx));
+ let cloud_user_store =
+ cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
+ let zeta = cx.new(|cx| Zeta::new(None, client, cloud_user_store, cx));
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
@@ -2111,13 +2181,6 @@ mod tests {
zeta.request_completion(None, &buffer, cursor, false, cx)
});
- server.receive::<proto::GetUsers>().await.unwrap();
- let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
- server.respond(
- token_request.receipt(),
- proto::GetLlmTokenResponse { token: "".into() },
- );
-
let completion = completion_task.await.unwrap().unwrap();
completion
.edits