language_model: Clear the LlmApiToken first on org switch (#51826)

Tom Houlé and Neel created

When we switch organizations, we try and refresh the token. If the token
refresh fails, we are left with the old LlmApiToken, which is for the
wrong organization. In this commit, we make sure to clear the old token
before trying a refresh on organization switch.

Release Notes:

- N/A

---------

Co-authored-by: Neel <neel@zed.dev>

Change summary

crates/language_model/src/model/cloud_model.rs | 51 ++++++++++++++++---
1 file changed, 41 insertions(+), 10 deletions(-)

Detailed changes

crates/language_model/src/model/cloud_model.rs 🔗

@@ -63,6 +63,20 @@ impl LlmApiToken {
         Self::fetch(self.0.write().await, client, organization_id).await
     }
 
+    /// Clears the existing token before attempting to fetch a new one.
+    ///
+    /// Used when switching organizations so that a failed refresh doesn't
+    /// leave a token for the wrong organization.
+    pub async fn clear_and_refresh(
+        &self,
+        client: &Arc<Client>,
+        organization_id: Option<OrganizationId>,
+    ) -> Result<String> {
+        let mut lock = self.0.write().await;
+        *lock = None;
+        Self::fetch(lock, client, organization_id).await
+    }
+
     async fn fetch(
         mut lock: RwLockWriteGuard<'_, Option<String>>,
         client: &Arc<Client>,
@@ -82,13 +96,16 @@ impl LlmApiToken {
                 *lock = Some(response.token.0.clone());
                 Ok(response.token.0)
             }
-            Err(err) => match err {
-                ClientApiError::Unauthorized => {
-                    client.request_sign_out();
-                    Err(err).context("Failed to create LLM token")
+            Err(err) => {
+                *lock = None;
+                match err {
+                    ClientApiError::Unauthorized => {
+                        client.request_sign_out();
+                        Err(err).context("Failed to create LLM token")
+                    }
+                    ClientApiError::Other(err) => Err(err),
                 }
-                ClientApiError::Other(err) => Err(err),
-            },
+            }
         }
     }
 }
@@ -105,6 +122,11 @@ impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
     }
 }
 
+enum TokenRefreshMode {
+    Refresh,
+    ClearAndRefresh,
+}
+
 struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 
 impl Global for GlobalRefreshLlmTokenListener {}
@@ -140,7 +162,7 @@ impl RefreshLlmTokenListener {
 
         let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
             if matches!(event, client::user::Event::OrganizationChanged) {
-                this.refresh(cx);
+                this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
             }
         });
 
@@ -152,7 +174,7 @@ impl RefreshLlmTokenListener {
         }
     }
 
-    fn refresh(&self, cx: &mut Context<Self>) {
+    fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
         let client = self.client.clone();
         let llm_api_token = self.llm_api_token.clone();
         let organization_id = self
@@ -161,7 +183,16 @@ impl RefreshLlmTokenListener {
             .current_organization()
             .map(|organization| organization.id.clone());
         cx.spawn(async move |this, cx| {
-            llm_api_token.refresh(&client, organization_id).await?;
+            match mode {
+                TokenRefreshMode::Refresh => {
+                    llm_api_token.refresh(&client, organization_id).await?;
+                }
+                TokenRefreshMode::ClearAndRefresh => {
+                    llm_api_token
+                        .clear_and_refresh(&client, organization_id)
+                        .await?;
+                }
+            }
             this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
         })
         .detach_and_log_err(cx);
@@ -170,7 +201,7 @@ impl RefreshLlmTokenListener {
     fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
         match message {
             MessageToClient::UserUpdated => {
-                this.update(cx, |this, cx| this.refresh(cx));
+                this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
             }
         }
     }