Detailed changes
@@ -23,14 +23,14 @@ use futures::{
use gpui::BackgroundExecutor;
use gpui::http_client::Url;
use gpui::{
- App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
+ App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
http_client::{self, AsyncBody, Method},
prelude::*,
};
use language::language_settings::all_language_settings;
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
+use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
use semver::Version;
@@ -133,7 +133,6 @@ pub struct EditPredictionStore {
client: Arc<Client>,
user_store: Entity<UserStore>,
llm_token: LlmApiToken,
- _llm_token_subscription: Subscription,
_fetch_experiments_task: Task<()>,
projects: HashMap<EntityId, ProjectState>,
update_required: bool,
@@ -674,10 +673,9 @@ impl EditPredictionStore {
}
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
- let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
let data_collection_choice = Self::load_data_collection_choice();
- let llm_token = LlmApiToken::default();
+ let llm_token = LlmApiToken::global(cx);
let (reject_tx, reject_rx) = mpsc::unbounded();
cx.background_spawn({
@@ -721,23 +719,6 @@ impl EditPredictionStore {
user_store,
llm_token,
_fetch_experiments_task: fetch_experiments_task,
- _llm_token_subscription: cx.subscribe(
- &refresh_llm_token_listener,
- |this, _listener, _event, cx| {
- let client = this.client.clone();
- let llm_token = this.llm_token.clone();
- let organization_id = this
- .user_store
- .read(cx)
- .current_organization()
- .map(|organization| organization.id.clone());
- cx.spawn(async move |_this, _cx| {
- llm_token.refresh(&client, organization_id).await?;
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- },
- ),
update_required: false,
edit_prediction_model: EditPredictionModel::Zeta,
zeta2_raw_config: Self::zeta2_raw_config_from_env(),
@@ -21,6 +21,7 @@ use language::{
Anchor, Buffer, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSeverity,
Operation, Point, Selection, SelectionGoal,
};
+use language_model::RefreshLlmTokenListener;
use lsp::LanguageServerId;
use parking_lot::Mutex;
use pretty_assertions::{assert_eq, assert_matches};
@@ -30,6 +30,13 @@ impl fmt::Display for PaymentRequiredError {
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LlmApiToken {
+ pub fn global(cx: &App) -> Self {
+ RefreshLlmTokenListener::global(cx)
+ .read(cx)
+ .llm_api_token
+ .clone()
+ }
+
pub async fn acquire(
&self,
client: &Arc<Client>,
@@ -102,13 +109,16 @@ struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
impl Global for GlobalRefreshLlmTokenListener {}
-pub struct RefreshLlmTokenEvent;
+pub struct LlmTokenRefreshedEvent;
pub struct RefreshLlmTokenListener {
+ client: Arc<Client>,
+ user_store: Entity<UserStore>,
+ llm_api_token: LlmApiToken,
_subscription: Subscription,
}
-impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
+impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
impl RefreshLlmTokenListener {
pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
@@ -128,21 +138,39 @@ impl RefreshLlmTokenListener {
}
});
- let subscription = cx.subscribe(&user_store, |_this, _user_store, event, cx| {
+ let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
if matches!(event, client::user::Event::OrganizationChanged) {
- cx.emit(RefreshLlmTokenEvent);
+ this.refresh(cx);
}
});
Self {
+ client,
+ user_store,
+ llm_api_token: LlmApiToken::default(),
_subscription: subscription,
}
}
+ fn refresh(&self, cx: &mut Context<Self>) {
+ let client = self.client.clone();
+ let llm_api_token = self.llm_api_token.clone();
+ let organization_id = self
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone());
+ cx.spawn(async move |this, cx| {
+ llm_api_token.refresh(&client, organization_id).await?;
+ this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
+ })
+ .detach_and_log_err(cx);
+ }
+
fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
match message {
MessageToClient::UserUpdated => {
- this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
+ this.update(cx, |this, cx| this.refresh(cx));
}
}
}
@@ -109,9 +109,10 @@ impl State {
cx: &mut Context<Self>,
) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
+ let llm_api_token = LlmApiToken::global(cx);
Self {
client: client.clone(),
- llm_api_token: LlmApiToken::default(),
+ llm_api_token,
user_store: user_store.clone(),
status,
models: Vec::new(),
@@ -158,9 +159,6 @@ impl State {
.current_organization()
.map(|o| o.id.clone());
cx.spawn(async move |this, cx| {
- llm_api_token
- .refresh(&client, organization_id.clone())
- .await?;
let response =
Self::fetch_models(client, llm_api_token, organization_id).await?;
this.update(cx, |this, cx| {
@@ -5,9 +5,9 @@ use client::{Client, UserStore};
use cloud_api_types::OrganizationId;
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
use futures::AsyncReadExt as _;
-use gpui::{App, AppContext, Context, Entity, Subscription, Task};
+use gpui::{App, AppContext, Context, Entity, Task};
use http_client::{HttpClient, Method};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
+use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
use web_search::{WebSearchProvider, WebSearchProviderId};
pub struct CloudWebSearchProvider {
@@ -26,34 +26,16 @@ pub struct State {
client: Arc<Client>,
user_store: Entity<UserStore>,
llm_api_token: LlmApiToken,
- _llm_token_subscription: Subscription,
}
impl State {
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
- let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
+ let llm_api_token = LlmApiToken::global(cx);
Self {
client,
user_store,
- llm_api_token: LlmApiToken::default(),
- _llm_token_subscription: cx.subscribe(
- &refresh_llm_token_listener,
- |this, _, _event, cx| {
- let client = this.client.clone();
- let llm_api_token = this.llm_api_token.clone();
- let organization_id = this
- .user_store
- .read(cx)
- .current_organization()
- .map(|o| o.id.clone());
- cx.spawn(async move |_this, _cx| {
- llm_api_token.refresh(&client, organization_id).await?;
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- },
- ),
+ llm_api_token,
}
}
}