language_model: Refresh the LLM token upon receiving a `UserUpdated` message from Cloud (#35839)

Marshall Bowers created

This PR makes it so we refresh the LLM token upon receiving a
`UserUpdated` message from Cloud over the WebSocket connection.

Release Notes:

- N/A

Change summary

Cargo.lock                                     |  1 
crates/client/src/client.rs                    |  4 +-
crates/language_model/Cargo.toml               |  1 
crates/language_model/src/model/cloud_model.rs | 34 ++++++++++----------
4 files changed, 21 insertions(+), 19 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -9127,6 +9127,7 @@ dependencies = [
  "anyhow",
  "base64 0.22.1",
  "client",
+ "cloud_api_types",
  "cloud_llm_client",
  "collections",
  "futures 0.3.31",

crates/client/src/client.rs 🔗

@@ -193,7 +193,7 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
     });
 }
 
-pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &App) + Send + Sync + 'static>;
+pub type MessageToClientHandler = Box<dyn Fn(&MessageToClient, &mut App) + Send + Sync + 'static>;
 
 struct GlobalClient(Arc<Client>);
 
@@ -1684,7 +1684,7 @@ impl Client {
 
     pub fn add_message_to_client_handler(
         self: &Arc<Client>,
-        handler: impl Fn(&MessageToClient, &App) + Send + Sync + 'static,
+        handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static,
     ) {
         self.message_to_client_handlers
             .lock()

crates/language_model/Cargo.toml 🔗

@@ -20,6 +20,7 @@ anthropic = { workspace = true, features = ["schemars"] }
 anyhow.workspace = true
 base64.workspace = true
 client.workspace = true
+cloud_api_types.workspace = true
 cloud_llm_client.workspace = true
 collections.workspace = true
 futures.workspace = true

crates/language_model/src/model/cloud_model.rs 🔗

@@ -3,11 +3,9 @@ use std::sync::Arc;
 
 use anyhow::Result;
 use client::Client;
+use cloud_api_types::websocket_protocol::MessageToClient;
 use cloud_llm_client::Plan;
-use gpui::{
-    App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
-};
-use proto::TypedEnvelope;
+use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
 use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 use thiserror::Error;
 
@@ -82,9 +80,7 @@ impl Global for GlobalRefreshLlmTokenListener {}
 
 pub struct RefreshLlmTokenEvent;
 
-pub struct RefreshLlmTokenListener {
-    _llm_token_subscription: client::Subscription,
-}
+pub struct RefreshLlmTokenListener;
 
 impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
 
@@ -99,17 +95,21 @@ impl RefreshLlmTokenListener {
     }
 
     fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
-        Self {
-            _llm_token_subscription: client
-                .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
-        }
+        client.add_message_to_client_handler({
+            let this = cx.entity();
+            move |message, cx| {
+                Self::handle_refresh_llm_token(this.clone(), message, cx);
+            }
+        });
+
+        Self
     }
 
-    async fn handle_refresh_llm_token(
-        this: Entity<Self>,
-        _: TypedEnvelope<proto::RefreshLlmToken>,
-        mut cx: AsyncApp,
-    ) -> Result<()> {
-        this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
+    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
+        match message {
+            MessageToClient::UserUpdated => {
+                this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
+            }
+        }
     }
 }