Detailed changes
@@ -1423,7 +1423,7 @@ impl EditAgentTest {
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
settings::init(cx);
- language_model::init(client.clone(), cx);
+ language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store, client.clone(), cx);
});
@@ -3167,7 +3167,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let clock = Arc::new(clock::FakeSystemClock::new());
let client = Client::new(clock, http_client, cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(client.clone(), cx);
+ language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store, client.clone(), cx);
LanguageModelRegistry::test(cx);
});
@@ -3791,7 +3791,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
cx.set_http_client(Arc::new(http_client));
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(client.clone(), cx);
+ language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store, client.clone(), cx);
}
};
@@ -2,6 +2,7 @@ use crate::{AgentServer, AgentServerDelegate};
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
use agent_client_protocol as acp;
use futures::{FutureExt, StreamExt, channel::mpsc, select};
+use gpui::AppContext;
use gpui::{Entity, TestAppContext};
use indoc::indoc;
use project::{FakeFs, Project};
@@ -408,7 +409,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
let client = client::Client::production(cx);
- language_model::init(client, cx);
+ let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
+ language_model::init(user_store, client, cx);
#[cfg(test)]
project::agent_server_store::AllAgentServersSettings::override_global(
@@ -2120,7 +2120,7 @@ pub mod test {
client::init(&client, cx);
workspace::init(app_state.clone(), cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- language_model::init(client.clone(), cx);
+ language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store, client.clone(), cx);
cx.set_global(inline_assistant);
@@ -140,6 +140,7 @@ pub enum Event {
ParticipantIndicesChanged,
PrivateUserInfoUpdated,
PlanUpdated,
+ OrganizationChanged,
}
#[derive(Clone, Copy)]
@@ -694,8 +695,21 @@ impl UserStore {
self.current_organization.clone()
}
- pub fn set_current_organization(&mut self, organization: Arc<Organization>) {
- self.current_organization.replace(organization);
+ pub fn set_current_organization(
+ &mut self,
+ organization: Arc<Organization>,
+ cx: &mut Context<Self>,
+ ) {
+ let is_same_organization = self
+ .current_organization
+ .as_ref()
+ .is_some_and(|current| current.id == organization.id);
+
+ if !is_same_organization {
+ self.current_organization.replace(organization);
+ cx.emit(Event::OrganizationChanged);
+ cx.notify();
+ }
}
pub fn organizations(&self) -> &Vec<Arc<Organization>> {
@@ -533,8 +533,8 @@ mod tests {
zlog::init_test();
let http_client = FakeHttpClient::with_404_response();
let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
- language_model::init(client.clone(), cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ language_model::init(user_store.clone(), client.clone(), cx);
EditPredictionStore::global(&client, &user_store, cx);
})
}
@@ -1850,9 +1850,8 @@ fn init_test_with_fake_client(
let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
client.cloud_client().set_credentials(1, "test".into());
- language_model::init(client.clone(), cx);
-
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ language_model::init(user_store.clone(), client.clone(), cx);
let ep_store = EditPredictionStore::global(&client, &user_store, cx);
(
@@ -2218,8 +2217,9 @@ async fn make_test_ep_store(
});
let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
+ let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
cx.update(|cx| {
- RefreshLlmTokenListener::register(client.clone(), cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
});
let _server = FakeServer::for_client(42, &client, cx).await;
@@ -2301,8 +2301,9 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
let client =
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
+ let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
cx.update(|cx| {
- language_model::RefreshLlmTokenListener::register(client.clone(), cx);
+ language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
});
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
@@ -105,7 +105,7 @@ pub fn init(cx: &mut App) -> EpAppState {
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
- language_model::init(client.clone(), cx);
+ language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);
@@ -429,7 +429,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
let extension_host_proxy = ExtensionHostProxy::global(cx);
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
- language_model::init(client.clone(), cx);
+ language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);
@@ -104,7 +104,7 @@ pub fn init(cx: &mut App) -> Arc<AgentCliAppState> {
let extension_host_proxy = ExtensionHostProxy::global(cx);
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
- language_model::init(client.clone(), cx);
+ language_model::init(user_store.clone(), client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
prompt_store::init(cx);
@@ -13,10 +13,11 @@ pub mod fake_provider;
use anthropic::{AnthropicError, parse_prompt_too_long};
use anyhow::{Result, anyhow};
use client::Client;
+use client::UserStore;
use cloud_llm_client::CompletionRequestStatus;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
+use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
use http_client::{StatusCode, http};
use icons::IconName;
use open_router::OpenRouterError;
@@ -61,9 +62,9 @@ pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProvider
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Zed");
-pub fn init(client: Arc<Client>, cx: &mut App) {
+pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
init_settings(cx);
- RefreshLlmTokenListener::register(client, cx);
+ RefreshLlmTokenListener::register(client, user_store, cx);
}
pub fn init_settings(cx: &mut App) {
@@ -3,11 +3,14 @@ use std::sync::Arc;
use anyhow::{Context as _, Result};
use client::Client;
+use client::UserStore;
use cloud_api_client::ClientApiError;
use cloud_api_types::OrganizationId;
use cloud_api_types::websocket_protocol::MessageToClient;
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
-use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
+use gpui::{
+ App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
+};
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use thiserror::Error;
@@ -101,13 +104,15 @@ impl Global for GlobalRefreshLlmTokenListener {}
pub struct RefreshLlmTokenEvent;
-pub struct RefreshLlmTokenListener;
+pub struct RefreshLlmTokenListener {
+ _subscription: Subscription,
+}
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
impl RefreshLlmTokenListener {
- pub fn register(client: Arc<Client>, cx: &mut App) {
- let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
+ pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
+ let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
cx.set_global(GlobalRefreshLlmTokenListener(listener));
}
@@ -115,7 +120,7 @@ impl RefreshLlmTokenListener {
GlobalRefreshLlmTokenListener::global(cx).0.clone()
}
- fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
+ fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
client.add_message_to_client_handler({
let this = cx.entity();
move |message, cx| {
@@ -123,7 +128,15 @@ impl RefreshLlmTokenListener {
}
});
- Self
+ let subscription = cx.subscribe(&user_store, |_this, _user_store, event, cx| {
+ if matches!(event, client::user::Event::OrganizationChanged) {
+ cx.emit(RefreshLlmTokenEvent);
+ }
+ });
+
+ Self {
+ _subscription: subscription,
+ }
}
fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
@@ -1014,9 +1014,9 @@ impl TitleBar {
let user_store = user_store.clone();
let organization = organization.clone();
move |_window, cx| {
- user_store.update(cx, |user_store, _cx| {
+ user_store.update(cx, |user_store, cx| {
user_store
- .set_current_organization(organization.clone());
+ .set_current_organization(organization.clone(), cx);
});
}
},
@@ -657,7 +657,7 @@ fn main() {
);
copilot_ui::init(&app_state, cx);
- language_model::init(app_state.client.clone(), cx);
+ language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
acp_tools::init(cx);
zed::telemetry_log::init(cx);
@@ -200,7 +200,7 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()>
});
prompt_store::init(cx);
let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx);
- language_model::init(app_state.client.clone(), cx);
+ language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
git_ui::init(cx);
project::AgentRegistryStore::init_global(
@@ -5024,7 +5024,7 @@ mod tests {
cx,
);
image_viewer::init(cx);
- language_model::init(app_state.client.clone(), cx);
+ language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
web_search::init(cx);
git_graph::init(cx);
@@ -316,7 +316,7 @@ mod tests {
let app_state = cx.update(|cx| {
let app_state = AppState::test(cx);
client::init(&app_state.client, cx);
- language_model::init(app_state.client.clone(), cx);
+ language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
editor::init(cx);
app_state
});