language_model: Centralize LlmApiToken to a singleton (#51225)

Tom Houlé and Marshall Bowers created

The edit prediction, web search and completions endpoints in Cloud all
use tokens called LlmApiToken. These were independently created, cached,
and refreshed in three places: the cloud language model provider, the
edit prediction store, and the cloud web search provider. Each held its
own LlmApiToken instance, meaning three separate requests to get these
tokens at startup / login and three redundant refreshes whenever the
server signaled a token update was needed.

We already had a global singleton reacting to the refresh signals:
RefreshLlmTokenListener. It now holds a single LlmApiToken that all
three services use, performs the refresh itself, and emits
RefreshLlmTokenEvent only after the token is fresh. That event is used
by the language model provider to re-fetch models after a refresh. The
singleton is accessed only through `LlmApiToken::global()`.

I have tested this manually, and it token acquisition and usage appear
to be working fine.

Edit: I've tested it with a long running session, and refresh seems to
be working fine too.

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>

Change summary

crates/edit_prediction/src/edit_prediction.rs       | 25 +--------
crates/edit_prediction/src/edit_prediction_tests.rs |  1 
crates/language_model/src/model.rs                  |  0 
crates/language_model/src/model/cloud_model.rs      | 38 +++++++++++++-
crates/language_models/src/provider/cloud.rs        |  6 -
crates/web_search_providers/src/cloud.rs            | 26 +--------
6 files changed, 43 insertions(+), 53 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -23,14 +23,14 @@ use futures::{
 use gpui::BackgroundExecutor;
 use gpui::http_client::Url;
 use gpui::{
-    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
+    App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
     http_client::{self, AsyncBody, Method},
     prelude::*,
 };
 use language::language_settings::all_language_settings;
 use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
 use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
+use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
 use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
 use release_channel::AppVersion;
 use semver::Version;
@@ -133,7 +133,6 @@ pub struct EditPredictionStore {
     client: Arc<Client>,
     user_store: Entity<UserStore>,
     llm_token: LlmApiToken,
-    _llm_token_subscription: Subscription,
     _fetch_experiments_task: Task<()>,
     projects: HashMap<EntityId, ProjectState>,
     update_required: bool,
@@ -674,10 +673,9 @@ impl EditPredictionStore {
     }
 
     pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
-        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
         let data_collection_choice = Self::load_data_collection_choice();
 
-        let llm_token = LlmApiToken::default();
+        let llm_token = LlmApiToken::global(cx);
 
         let (reject_tx, reject_rx) = mpsc::unbounded();
         cx.background_spawn({
@@ -721,23 +719,6 @@ impl EditPredictionStore {
             user_store,
             llm_token,
             _fetch_experiments_task: fetch_experiments_task,
-            _llm_token_subscription: cx.subscribe(
-                &refresh_llm_token_listener,
-                |this, _listener, _event, cx| {
-                    let client = this.client.clone();
-                    let llm_token = this.llm_token.clone();
-                    let organization_id = this
-                        .user_store
-                        .read(cx)
-                        .current_organization()
-                        .map(|organization| organization.id.clone());
-                    cx.spawn(async move |_this, _cx| {
-                        llm_token.refresh(&client, organization_id).await?;
-                        anyhow::Ok(())
-                    })
-                    .detach_and_log_err(cx);
-                },
-            ),
             update_required: false,
             edit_prediction_model: EditPredictionModel::Zeta,
             zeta2_raw_config: Self::zeta2_raw_config_from_env(),

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -21,6 +21,7 @@ use language::{
     Anchor, Buffer, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSeverity,
     Operation, Point, Selection, SelectionGoal,
 };
+use language_model::RefreshLlmTokenListener;
 use lsp::LanguageServerId;
 use parking_lot::Mutex;
 use pretty_assertions::{assert_eq, assert_matches};

crates/language_model/src/model/cloud_model.rs 🔗

@@ -30,6 +30,13 @@ impl fmt::Display for PaymentRequiredError {
 pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 
 impl LlmApiToken {
+    pub fn global(cx: &App) -> Self {
+        RefreshLlmTokenListener::global(cx)
+            .read(cx)
+            .llm_api_token
+            .clone()
+    }
+
     pub async fn acquire(
         &self,
         client: &Arc<Client>,
@@ -102,13 +109,16 @@ struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 
 impl Global for GlobalRefreshLlmTokenListener {}
 
-pub struct RefreshLlmTokenEvent;
+pub struct LlmTokenRefreshedEvent;
 
 pub struct RefreshLlmTokenListener {
+    client: Arc<Client>,
+    user_store: Entity<UserStore>,
+    llm_api_token: LlmApiToken,
     _subscription: Subscription,
 }
 
-impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
+impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
 
 impl RefreshLlmTokenListener {
     pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
@@ -128,21 +138,39 @@ impl RefreshLlmTokenListener {
             }
         });
 
-        let subscription = cx.subscribe(&user_store, |_this, _user_store, event, cx| {
+        let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
             if matches!(event, client::user::Event::OrganizationChanged) {
-                cx.emit(RefreshLlmTokenEvent);
+                this.refresh(cx);
             }
         });
 
         Self {
+            client,
+            user_store,
+            llm_api_token: LlmApiToken::default(),
             _subscription: subscription,
         }
     }
 
+    fn refresh(&self, cx: &mut Context<Self>) {
+        let client = self.client.clone();
+        let llm_api_token = self.llm_api_token.clone();
+        let organization_id = self
+            .user_store
+            .read(cx)
+            .current_organization()
+            .map(|o| o.id.clone());
+        cx.spawn(async move |this, cx| {
+            llm_api_token.refresh(&client, organization_id).await?;
+            this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
+        })
+        .detach_and_log_err(cx);
+    }
+
     fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
         match message {
             MessageToClient::UserUpdated => {
-                this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
+                this.update(cx, |this, cx| this.refresh(cx));
             }
         }
     }

crates/language_models/src/provider/cloud.rs 🔗

@@ -109,9 +109,10 @@ impl State {
         cx: &mut Context<Self>,
     ) -> Self {
         let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
+        let llm_api_token = LlmApiToken::global(cx);
         Self {
             client: client.clone(),
-            llm_api_token: LlmApiToken::default(),
+            llm_api_token,
             user_store: user_store.clone(),
             status,
             models: Vec::new(),
@@ -158,9 +159,6 @@ impl State {
                         .current_organization()
                         .map(|o| o.id.clone());
                     cx.spawn(async move |this, cx| {
-                        llm_api_token
-                            .refresh(&client, organization_id.clone())
-                            .await?;
                         let response =
                             Self::fetch_models(client, llm_api_token, organization_id).await?;
                         this.update(cx, |this, cx| {

crates/web_search_providers/src/cloud.rs 🔗

@@ -5,9 +5,9 @@ use client::{Client, UserStore};
 use cloud_api_types::OrganizationId;
 use cloud_llm_client::{WebSearchBody, WebSearchResponse};
 use futures::AsyncReadExt as _;
-use gpui::{App, AppContext, Context, Entity, Subscription, Task};
+use gpui::{App, AppContext, Context, Entity, Task};
 use http_client::{HttpClient, Method};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
+use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
 use web_search::{WebSearchProvider, WebSearchProviderId};
 
 pub struct CloudWebSearchProvider {
@@ -26,34 +26,16 @@ pub struct State {
     client: Arc<Client>,
     user_store: Entity<UserStore>,
     llm_api_token: LlmApiToken,
-    _llm_token_subscription: Subscription,
 }
 
 impl State {
     pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
-        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
+        let llm_api_token = LlmApiToken::global(cx);
 
         Self {
             client,
             user_store,
-            llm_api_token: LlmApiToken::default(),
-            _llm_token_subscription: cx.subscribe(
-                &refresh_llm_token_listener,
-                |this, _, _event, cx| {
-                    let client = this.client.clone();
-                    let llm_api_token = this.llm_api_token.clone();
-                    let organization_id = this
-                        .user_store
-                        .read(cx)
-                        .current_organization()
-                        .map(|o| o.id.clone());
-                    cx.spawn(async move |_this, _cx| {
-                        llm_api_token.refresh(&client, organization_id).await?;
-                        anyhow::Ok(())
-                    })
-                    .detach_and_log_err(cx);
-                },
-            ),
+            llm_api_token,
         }
     }
 }