Detailed changes
@@ -12,7 +12,7 @@ use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
-use client::{ModelRequestUsage, RequestUsage};
+use client::{CloudUserStore, ModelRequestUsage, RequestUsage};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
use collections::HashMap;
use feature_flags::{self, FeatureFlagAppExt};
@@ -374,6 +374,7 @@ pub struct Thread {
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
project: Entity<Project>,
+ cloud_user_store: Entity<CloudUserStore>,
prompt_builder: Arc<PromptBuilder>,
tools: Entity<ToolWorkingSet>,
tool_use: ToolUseState,
@@ -444,6 +445,7 @@ pub struct ExceededWindowError {
impl Thread {
pub fn new(
project: Entity<Project>,
+ cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
system_prompt: SharedProjectContext,
@@ -470,6 +472,7 @@ impl Thread {
completion_count: 0,
pending_completions: Vec::new(),
project: project.clone(),
+ cloud_user_store,
prompt_builder,
tools: tools.clone(),
last_restore_checkpoint: None,
@@ -503,6 +506,7 @@ impl Thread {
id: ThreadId,
serialized: SerializedThread,
project: Entity<Project>,
+ cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext,
@@ -603,6 +607,7 @@ impl Thread {
last_restore_checkpoint: None,
pending_checkpoint: None,
project: project.clone(),
+ cloud_user_store,
prompt_builder,
tools: tools.clone(),
tool_use,
@@ -3255,16 +3260,14 @@ impl Thread {
}
fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
- self.project.update(cx, |project, cx| {
- project.user_store().update(cx, |user_store, cx| {
- user_store.update_model_request_usage(
- ModelRequestUsage(RequestUsage {
- amount: amount as i32,
- limit,
- }),
- cx,
- )
- })
+ self.cloud_user_store.update(cx, |cloud_user_store, cx| {
+ cloud_user_store.update_model_request_usage(
+ ModelRequestUsage(RequestUsage {
+ amount: amount as i32,
+ limit,
+ }),
+ cx,
+ )
});
}
@@ -3883,6 +3886,7 @@ fn main() {{
thread.id.clone(),
serialized,
thread.project.clone(),
+ thread.cloud_user_store.clone(),
thread.tools.clone(),
thread.prompt_builder.clone(),
thread.project_context.clone(),
@@ -5479,10 +5483,16 @@ fn main() {{
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+ let (client, user_store) =
+ project.read_with(cx, |project, _cx| (project.client(), project.user_store()));
+ let cloud_user_store =
+ cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx));
+
let thread_store = cx
.update(|_, cx| {
ThreadStore::load(
project.clone(),
+ cloud_user_store,
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),
@@ -8,6 +8,7 @@ use agent_settings::{AgentProfileId, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{Tool, ToolId, ToolWorkingSet};
use chrono::{DateTime, Utc};
+use client::CloudUserStore;
use collections::HashMap;
use context_server::ContextServerId;
use futures::{
@@ -104,6 +105,7 @@ pub type TextThreadStore = assistant_context::ContextStore;
pub struct ThreadStore {
project: Entity<Project>,
+ cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>,
@@ -124,6 +126,7 @@ impl EventEmitter<RulesLoadingError> for ThreadStore {}
impl ThreadStore {
pub fn load(
project: Entity<Project>,
+ cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>,
prompt_store: Option<Entity<PromptStore>>,
prompt_builder: Arc<PromptBuilder>,
@@ -133,8 +136,14 @@ impl ThreadStore {
let (thread_store, ready_rx) = cx.update(|cx| {
let mut option_ready_rx = None;
let thread_store = cx.new(|cx| {
- let (thread_store, ready_rx) =
- Self::new(project, tools, prompt_builder, prompt_store, cx);
+ let (thread_store, ready_rx) = Self::new(
+ project,
+ cloud_user_store,
+ tools,
+ prompt_builder,
+ prompt_store,
+ cx,
+ );
option_ready_rx = Some(ready_rx);
thread_store
});
@@ -147,6 +156,7 @@ impl ThreadStore {
fn new(
project: Entity<Project>,
+ cloud_user_store: Entity<CloudUserStore>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>,
@@ -190,6 +200,7 @@ impl ThreadStore {
let this = Self {
project,
+ cloud_user_store,
tools,
prompt_builder,
prompt_store,
@@ -407,6 +418,7 @@ impl ThreadStore {
cx.new(|cx| {
Thread::new(
self.project.clone(),
+ self.cloud_user_store.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
self.project_context.clone(),
@@ -425,6 +437,7 @@ impl ThreadStore {
ThreadId::new(),
serialized,
self.project.clone(),
+ self.cloud_user_store.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
self.project_context.clone(),
@@ -456,6 +469,7 @@ impl ThreadStore {
id.clone(),
thread,
this.project.clone(),
+ this.cloud_user_store.clone(),
this.tools.clone(),
this.prompt_builder.clone(),
this.project_context.clone(),
@@ -3820,6 +3820,7 @@ mod tests {
use super::*;
use agent::{MessageSegment, context::ContextLoadResult, thread_store};
use assistant_tool::{ToolRegistry, ToolWorkingSet};
+ use client::CloudUserStore;
use editor::EditorSettings;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext, VisualTestContext};
@@ -4116,10 +4117,16 @@ mod tests {
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+ let (client, user_store) =
+ project.read_with(cx, |project, _cx| (project.client(), project.user_store()));
+ let cloud_user_store =
+ cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx));
+
let thread_store = cx
.update(|_, cx| {
ThreadStore::load(
project.clone(),
+ cloud_user_store,
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),
@@ -1893,6 +1893,7 @@ mod tests {
use agent::thread_store::{self, ThreadStore};
use agent_settings::AgentSettings;
use assistant_tool::ToolWorkingSet;
+ use client::CloudUserStore;
use editor::EditorSettings;
use gpui::{TestAppContext, UpdateGlobal, VisualTestContext};
use project::{FakeFs, Project};
@@ -1932,11 +1933,17 @@ mod tests {
})
.unwrap();
+ let (client, user_store) =
+ project.read_with(cx, |project, _cx| (project.client(), project.user_store()));
+ let cloud_user_store =
+ cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx));
+
let prompt_store = None;
let thread_store = cx
.update(|cx| {
ThreadStore::load(
project.clone(),
+ cloud_user_store,
cx.new(|_| ToolWorkingSet::default()),
prompt_store,
Arc::new(PromptBuilder::new(None).unwrap()),
@@ -2098,11 +2105,17 @@ mod tests {
})
.unwrap();
+ let (client, user_store) =
+ project.read_with(cx, |project, _cx| (project.client(), project.user_store()));
+ let cloud_user_store =
+ cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store, cx));
+
let prompt_store = None;
let thread_store = cx
.update(|cx| {
ThreadStore::load(
project.clone(),
+ cloud_user_store,
cx.new(|_| ToolWorkingSet::default()),
prompt_store,
Arc::new(PromptBuilder::new(None).unwrap()),
@@ -43,7 +43,7 @@ use anyhow::{Result, anyhow};
use assistant_context::{AssistantContext, ContextEvent, ContextSummary};
use assistant_slash_command::SlashCommandWorkingSet;
use assistant_tool::ToolWorkingSet;
-use client::{DisableAiSettings, UserStore, zed_urls};
+use client::{CloudUserStore, DisableAiSettings, UserStore, zed_urls};
use cloud_llm_client::{CompletionIntent, UsageLimit};
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use feature_flags::{self, FeatureFlagAppExt};
@@ -427,6 +427,7 @@ impl ActiveView {
pub struct AgentPanel {
workspace: WeakEntity<Workspace>,
user_store: Entity<UserStore>,
+ cloud_user_store: Entity<CloudUserStore>,
project: Entity<Project>,
fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>,
@@ -486,6 +487,7 @@ impl AgentPanel {
let project = workspace.project().clone();
ThreadStore::load(
project,
+ workspace.app_state().cloud_user_store.clone(),
tools.clone(),
prompt_store.clone(),
prompt_builder.clone(),
@@ -553,6 +555,7 @@ impl AgentPanel {
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
let fs = workspace.app_state().fs.clone();
let user_store = workspace.app_state().user_store.clone();
+ let cloud_user_store = workspace.app_state().cloud_user_store.clone();
let project = workspace.project();
let language_registry = project.read(cx).languages().clone();
let client = workspace.client().clone();
@@ -579,7 +582,7 @@ impl AgentPanel {
MessageEditor::new(
fs.clone(),
workspace.clone(),
- user_store.clone(),
+ cloud_user_store.clone(),
message_editor_context_store.clone(),
prompt_store.clone(),
thread_store.downgrade(),
@@ -706,6 +709,7 @@ impl AgentPanel {
active_view,
workspace,
user_store,
+ cloud_user_store,
project: project.clone(),
fs: fs.clone(),
language_registry,
@@ -848,7 +852,7 @@ impl AgentPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
- self.user_store.clone(),
+ self.cloud_user_store.clone(),
context_store.clone(),
self.prompt_store.clone(),
self.thread_store.downgrade(),
@@ -1122,7 +1126,7 @@ impl AgentPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
- self.user_store.clone(),
+ self.cloud_user_store.clone(),
context_store,
self.prompt_store.clone(),
self.thread_store.downgrade(),
@@ -1821,8 +1825,8 @@ impl AgentPanel {
}
fn render_toolbar(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let user_store = self.user_store.read(cx);
- let usage = user_store.model_request_usage();
+ let cloud_user_store = self.cloud_user_store.read(cx);
+ let usage = cloud_user_store.model_request_usage();
let account_url = zed_urls::account_url(cx);
@@ -17,7 +17,7 @@ use agent::{
use agent_settings::{AgentSettings, CompletionMode};
use ai_onboarding::ApiKeysWithProviders;
use buffer_diff::BufferDiff;
-use client::UserStore;
+use client::CloudUserStore;
use cloud_llm_client::CompletionIntent;
use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste};
@@ -43,7 +43,6 @@ use language_model::{
use multi_buffer;
use project::Project;
use prompt_store::PromptStore;
-use proto::Plan;
use settings::Settings;
use std::time::Duration;
use theme::ThemeSettings;
@@ -79,7 +78,7 @@ pub struct MessageEditor {
editor: Entity<Editor>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
- user_store: Entity<UserStore>,
+ cloud_user_store: Entity<CloudUserStore>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
history_store: Option<WeakEntity<HistoryStore>>,
@@ -159,7 +158,7 @@ impl MessageEditor {
pub fn new(
fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>,
- user_store: Entity<UserStore>,
+ cloud_user_store: Entity<CloudUserStore>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: WeakEntity<ThreadStore>,
@@ -231,7 +230,7 @@ impl MessageEditor {
Self {
editor: editor.clone(),
project: thread.read(cx).project().clone(),
- user_store,
+ cloud_user_store,
thread,
incompatible_tools_state: incompatible_tools.clone(),
workspace,
@@ -1287,26 +1286,16 @@ impl MessageEditor {
return None;
}
- let user_store = self.user_store.read(cx);
-
- let ubb_enable = user_store
- .usage_based_billing_enabled()
- .map_or(false, |enabled| enabled);
-
- if ubb_enable {
+ let cloud_user_store = self.cloud_user_store.read(cx);
+ if cloud_user_store.is_usage_based_billing_enabled() {
return None;
}
- let plan = user_store
- .current_plan()
- .map(|plan| match plan {
- Plan::Free => cloud_llm_client::Plan::ZedFree,
- Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
- Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
- })
+ let plan = cloud_user_store
+ .plan()
.unwrap_or(cloud_llm_client::Plan::ZedFree);
- let usage = user_store.model_request_usage()?;
+ let usage = cloud_user_store.model_request_usage()?;
Some(
div()
@@ -1769,7 +1758,7 @@ impl AgentPreview for MessageEditor {
) -> Option<AnyElement> {
if let Some(workspace) = workspace.upgrade() {
let fs = workspace.read(cx).app_state().fs.clone();
- let user_store = workspace.read(cx).app_state().user_store.clone();
+ let cloud_user_store = workspace.read(cx).app_state().cloud_user_store.clone();
let project = workspace.read(cx).project().clone();
let weak_project = project.downgrade();
let context_store = cx.new(|_cx| ContextStore::new(weak_project, None));
@@ -1782,7 +1771,7 @@ impl AgentPreview for MessageEditor {
MessageEditor::new(
fs,
workspace.downgrade(),
- user_store,
+ cloud_user_store,
context_store,
None,
thread_store.downgrade(),
@@ -7,7 +7,7 @@ use crate::{
};
use Role::*;
use assistant_tool::ToolRegistry;
-use client::{Client, UserStore};
+use client::{Client, CloudUserStore, UserStore};
use collections::HashMap;
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
@@ -1470,12 +1470,14 @@ impl EditAgentTest {
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ let cloud_user_store =
+ cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
settings::init(cx);
Project::init_settings(cx);
language::init(cx);
language_model::init(client.clone(), cx);
- language_models::init(user_store.clone(), client.clone(), cx);
+ language_models::init(user_store.clone(), cloud_user_store, client.clone(), cx);
crate::init(client.http_client(), cx);
});
@@ -9,12 +9,13 @@ use gpui::{Context, Entity, Subscription, Task};
use util::{ResultExt as _, maybe};
use crate::user::Event as RpcUserStoreEvent;
-use crate::{EditPredictionUsage, RequestUsage, UserStore};
+use crate::{EditPredictionUsage, ModelRequestUsage, RequestUsage, UserStore};
pub struct CloudUserStore {
cloud_client: Arc<CloudApiClient>,
authenticated_user: Option<Arc<AuthenticatedUser>>,
plan_info: Option<Arc<PlanInfo>>,
+ model_request_usage: Option<ModelRequestUsage>,
edit_prediction_usage: Option<EditPredictionUsage>,
_maintain_authenticated_user_task: Task<()>,
_rpc_plan_updated_subscription: Subscription,
@@ -33,6 +34,7 @@ impl CloudUserStore {
cloud_client: cloud_client.clone(),
authenticated_user: None,
plan_info: None,
+ model_request_usage: None,
edit_prediction_usage: None,
_maintain_authenticated_user_task: cx.spawn(async move |this, cx| {
maybe!(async move {
@@ -104,6 +106,13 @@ impl CloudUserStore {
})
}
+ pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
+ self.plan_info
+ .as_ref()
+ .and_then(|plan| plan.trial_started_at)
+ .map(|trial_started_at| trial_started_at.0)
+ }
+
pub fn has_accepted_tos(&self) -> bool {
self.authenticated_user
.as_ref()
@@ -127,6 +136,22 @@ impl CloudUserStore {
.unwrap_or_default()
}
+ pub fn is_usage_based_billing_enabled(&self) -> bool {
+ self.plan_info
+ .as_ref()
+ .map(|plan| plan.is_usage_based_billing_enabled)
+ .unwrap_or_default()
+ }
+
+ pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
+ self.model_request_usage
+ }
+
+ pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
+ self.model_request_usage = Some(usage);
+ cx.notify();
+ }
+
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
self.edit_prediction_usage
}
@@ -142,6 +167,10 @@ impl CloudUserStore {
fn update_authenticated_user(&mut self, response: GetAuthenticatedUserResponse) {
self.authenticated_user = Some(Arc::new(response.user));
+ self.model_request_usage = Some(ModelRequestUsage(RequestUsage {
+ limit: response.plan.usage.model_requests.limit,
+ amount: response.plan.usage.model_requests.used as i32,
+ }));
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
limit: response.plan.usage.edit_predictions.limit,
amount: response.plan.usage.edit_predictions.used as i32,
@@ -113,7 +113,6 @@ pub struct UserStore {
current_plan: Option<proto::Plan>,
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
trial_started_at: Option<DateTime<Utc>>,
- model_request_usage: Option<ModelRequestUsage>,
is_usage_based_billing_enabled: Option<bool>,
account_too_young: Option<bool>,
has_overdue_invoices: Option<bool>,
@@ -191,7 +190,6 @@ impl UserStore {
current_plan: None,
subscription_period: None,
trial_started_at: None,
- model_request_usage: None,
is_usage_based_billing_enabled: None,
account_too_young: None,
has_overdue_invoices: None,
@@ -371,27 +369,12 @@ impl UserStore {
this.account_too_young = message.payload.account_too_young;
this.has_overdue_invoices = message.payload.has_overdue_invoices;
- if let Some(usage) = message.payload.usage {
- // limits are always present even though they are wrapped in Option
- this.model_request_usage = usage
- .model_requests_usage_limit
- .and_then(|limit| {
- RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
- })
- .map(ModelRequestUsage);
- }
-
cx.emit(Event::PlanUpdated);
cx.notify();
})?;
Ok(())
}
- pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
- self.model_request_usage = Some(usage);
- cx.notify();
- }
-
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
match message {
UpdateContacts::Wait(barrier) => {
@@ -776,10 +759,6 @@ impl UserStore {
self.is_usage_based_billing_enabled
}
- pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
- self.model_request_usage
- }
-
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
self.current_user.clone()
}
@@ -13,7 +13,7 @@ pub(crate) use tool_metrics::*;
use ::fs::RealFs;
use clap::Parser;
-use client::{Client, ProxySettings, UserStore};
+use client::{Client, CloudUserStore, ProxySettings, UserStore};
use collections::{HashMap, HashSet};
use extension::ExtensionHostProxy;
use futures::future;
@@ -329,6 +329,7 @@ pub struct AgentAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
+ pub cloud_user_store: Entity<CloudUserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
@@ -383,6 +384,8 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
let languages = Arc::new(languages);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ let cloud_user_store =
+ cx.new(|cx| CloudUserStore::new(client.cloud_client(), user_store.clone(), cx));
extension::init(cx);
@@ -422,7 +425,12 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
languages.clone(),
);
language_model::init(client.clone(), cx);
- language_models::init(user_store.clone(), client.clone(), cx);
+ language_models::init(
+ user_store.clone(),
+ cloud_user_store.clone(),
+ client.clone(),
+ cx,
+ );
languages::init(languages.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);
terminal_view::init(cx);
@@ -447,6 +455,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
languages,
client,
user_store,
+ cloud_user_store,
fs,
node_runtime,
prompt_builder,
@@ -221,6 +221,7 @@ impl ExampleInstance {
let prompt_store = None;
let thread_store = ThreadStore::load(
project.clone(),
+ app_state.cloud_user_store.clone(),
tools,
prompt_store,
app_state.prompt_builder.clone(),
@@ -1,7 +1,7 @@
use std::sync::Arc;
use ::settings::{Settings, SettingsStore};
-use client::{Client, UserStore};
+use client::{Client, CloudUserStore, UserStore};
use collections::HashSet;
use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
@@ -26,11 +26,22 @@ use crate::provider::vercel::VercelLanguageModelProvider;
use crate::provider::x_ai::XAiLanguageModelProvider;
pub use crate::settings::*;
-pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
+pub fn init(
+ user_store: Entity<UserStore>,
+ cloud_user_store: Entity<CloudUserStore>,
+ client: Arc<Client>,
+ cx: &mut App,
+) {
crate::settings::init_settings(cx);
let registry = LanguageModelRegistry::global(cx);
registry.update(cx, |registry, cx| {
- register_language_model_providers(registry, user_store, client.clone(), cx);
+ register_language_model_providers(
+ registry,
+ user_store,
+ cloud_user_store,
+ client.clone(),
+ cx,
+ );
});
let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
@@ -100,11 +111,17 @@ fn register_openai_compatible_providers(
fn register_language_model_providers(
registry: &mut LanguageModelRegistry,
user_store: Entity<UserStore>,
+ cloud_user_store: Entity<CloudUserStore>,
client: Arc<Client>,
cx: &mut Context<LanguageModelRegistry>,
) {
registry.register_provider(
- CloudLanguageModelProvider::new(user_store.clone(), client.clone(), cx),
+ CloudLanguageModelProvider::new(
+ user_store.clone(),
+ cloud_user_store.clone(),
+ client.clone(),
+ cx,
+ ),
cx,
);
@@ -2,11 +2,11 @@ use ai_onboarding::YoungAccountBanner;
use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
-use client::{Client, ModelRequestUsage, UserStore, zed_urls};
+use client::{Client, CloudUserStore, ModelRequestUsage, UserStore, zed_urls};
use cloud_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
- EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
+ EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan,
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
};
@@ -27,7 +27,6 @@ use language_model::{
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
};
-use proto::Plan;
use release_channel::AppVersion;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
@@ -118,6 +117,7 @@ pub struct State {
client: Arc<Client>,
llm_api_token: LlmApiToken,
user_store: Entity<UserStore>,
+ cloud_user_store: Entity<CloudUserStore>,
status: client::Status,
accept_terms_of_service_task: Option<Task<Result<()>>>,
models: Vec<Arc<cloud_llm_client::LanguageModel>>,
@@ -133,6 +133,7 @@ impl State {
fn new(
client: Arc<Client>,
user_store: Entity<UserStore>,
+ cloud_user_store: Entity<CloudUserStore>,
status: client::Status,
cx: &mut Context<Self>,
) -> Self {
@@ -142,6 +143,7 @@ impl State {
client: client.clone(),
llm_api_token: LlmApiToken::default(),
user_store,
+ cloud_user_store,
status,
accept_terms_of_service_task: None,
models: Vec::new(),
@@ -150,12 +152,19 @@ impl State {
recommended_models: Vec::new(),
_fetch_models_task: cx.spawn(async move |this, cx| {
maybe!(async move {
- let (client, llm_api_token) = this
- .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
+ let (client, cloud_user_store, llm_api_token) =
+ this.read_with(cx, |this, _cx| {
+ (
+ client.clone(),
+ this.cloud_user_store.clone(),
+ this.llm_api_token.clone(),
+ )
+ })?;
loop {
- let status = this.read_with(cx, |this, _cx| this.status)?;
- if matches!(status, client::Status::Connected { .. }) {
+ let is_authenticated =
+ cloud_user_store.read_with(cx, |this, _cx| this.is_authenticated())?;
+ if is_authenticated {
break;
}
@@ -194,8 +203,8 @@ impl State {
}
}
- fn is_signed_out(&self) -> bool {
- self.status.is_signed_out()
+ fn is_signed_out(&self, cx: &App) -> bool {
+ !self.cloud_user_store.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
@@ -210,10 +219,7 @@ impl State {
}
fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
- self.user_store
- .read(cx)
- .current_user_has_accepted_terms()
- .unwrap_or(false)
+ self.cloud_user_store.read(cx).has_accepted_tos()
}
fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
@@ -297,11 +303,24 @@ impl State {
}
impl CloudLanguageModelProvider {
- pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
+ pub fn new(
+ user_store: Entity<UserStore>,
+ cloud_user_store: Entity<CloudUserStore>,
+ client: Arc<Client>,
+ cx: &mut App,
+ ) -> Self {
let mut status_rx = client.status();
let status = *status_rx.borrow();
- let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
+ let state = cx.new(|cx| {
+ State::new(
+ client.clone(),
+ user_store.clone(),
+ cloud_user_store.clone(),
+ status,
+ cx,
+ )
+ });
let state_ref = state.downgrade();
let maintain_client_status = cx.spawn(async move |cx| {
@@ -398,7 +417,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
fn is_authenticated(&self, cx: &App) -> bool {
let state = self.state.read(cx);
- !state.is_signed_out() && state.has_accepted_terms_of_service(cx)
+ !state.is_signed_out(cx) && state.has_accepted_terms_of_service(cx)
}
fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
@@ -614,9 +633,9 @@ impl CloudLanguageModel {
.and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok())
{
let plan = match plan {
- cloud_llm_client::Plan::ZedFree => Plan::Free,
- cloud_llm_client::Plan::ZedPro => Plan::ZedPro,
- cloud_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
+ cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
+ cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
+ cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
};
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
}
@@ -1118,7 +1137,7 @@ fn response_lines<T: DeserializeOwned>(
#[derive(IntoElement, RegisterComponent)]
struct ZedAiConfiguration {
is_connected: bool,
- plan: Option<proto::Plan>,
+ plan: Option<Plan>,
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
eligible_for_trial: bool,
has_accepted_terms_of_service: bool,
@@ -1132,15 +1151,15 @@ impl RenderOnce for ZedAiConfiguration {
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
let young_account_banner = YoungAccountBanner;
- let is_pro = self.plan == Some(proto::Plan::ZedPro);
+ let is_pro = self.plan == Some(Plan::ZedPro);
let subscription_text = match (self.plan, self.subscription_period) {
- (Some(proto::Plan::ZedPro), Some(_)) => {
+ (Some(Plan::ZedPro), Some(_)) => {
"You have access to Zed's hosted models through your Pro subscription."
}
- (Some(proto::Plan::ZedProTrial), Some(_)) => {
+ (Some(Plan::ZedProTrial), Some(_)) => {
"You have access to Zed's hosted models through your Pro trial."
}
- (Some(proto::Plan::Free), Some(_)) => {
+ (Some(Plan::ZedFree), Some(_)) => {
"You have basic access to Zed's hosted models through the Free plan."
}
_ => {
@@ -1262,15 +1281,15 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let state = self.state.read(cx);
- let user_store = state.user_store.read(cx);
+ let cloud_user_store = state.cloud_user_store.read(cx);
ZedAiConfiguration {
- is_connected: !state.is_signed_out(),
- plan: user_store.current_plan(),
- subscription_period: user_store.subscription_period(),
- eligible_for_trial: user_store.trial_started_at().is_none(),
+ is_connected: !state.is_signed_out(cx),
+ plan: cloud_user_store.plan(),
+ subscription_period: cloud_user_store.subscription_period(),
+ eligible_for_trial: cloud_user_store.trial_started_at().is_none(),
has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
- account_too_young: user_store.account_too_young(),
+ account_too_young: cloud_user_store.account_too_young(),
accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
sign_in_callback: self.sign_in_callback.clone(),
@@ -1286,7 +1305,7 @@ impl Component for ZedAiConfiguration {
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
fn configuration(
is_connected: bool,
- plan: Option<proto::Plan>,
+ plan: Option<Plan>,
eligible_for_trial: bool,
account_too_young: bool,
has_accepted_terms_of_service: bool,
@@ -1330,15 +1349,15 @@ impl Component for ZedAiConfiguration {
),
single_example(
"Free Plan",
- configuration(true, Some(proto::Plan::Free), true, false, true),
+ configuration(true, Some(Plan::ZedFree), true, false, true),
),
single_example(
"Zed Pro Trial Plan",
- configuration(true, Some(proto::Plan::ZedProTrial), true, false, true),
+ configuration(true, Some(Plan::ZedProTrial), true, false, true),
),
single_example(
"Zed Pro Plan",
- configuration(true, Some(proto::Plan::ZedPro), true, false, true),
+ configuration(true, Some(Plan::ZedPro), true, false, true),
),
])
.into_any_element(),
@@ -556,7 +556,12 @@ pub fn main() {
);
supermaven::init(app_state.client.clone(), cx);
language_model::init(app_state.client.clone(), cx);
- language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+ language_models::init(
+ app_state.user_store.clone(),
+ app_state.cloud_user_store.clone(),
+ app_state.client.clone(),
+ cx,
+ );
agent_settings::init(cx);
agent_servers::init(cx);
web_search::init(cx);
@@ -4488,7 +4488,12 @@ mod tests {
);
image_viewer::init(cx);
language_model::init(app_state.client.clone(), cx);
- language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+ language_models::init(
+ app_state.user_store.clone(),
+ app_state.cloud_user_store.clone(),
+ app_state.client.clone(),
+ cx,
+ );
web_search::init(cx);
web_search_providers::init(app_state.client.clone(), cx);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
@@ -17,9 +17,10 @@ pub fn load_preview_thread_store(
cx: &mut AsyncApp,
) -> Task<Result<Entity<ThreadStore>>> {
workspace
- .update(cx, |_, cx| {
+ .update(cx, |workspace, cx| {
ThreadStore::load(
project.clone(),
+ workspace.app_state().cloud_user_store.clone(),
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),