diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 527d24ec18c0f9ef08576a71fe92562dd94d4afd..f6ad907483e5946652752895d0a48ec129660b0b 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -63,6 +63,20 @@ impl LlmApiToken { Self::fetch(self.0.write().await, client, organization_id).await } + /// Clears the existing token before attempting to fetch a new one. + /// + /// Used when switching organizations so that a failed refresh doesn't + /// leave a token for the wrong organization. + pub async fn clear_and_refresh( + &self, + client: &Arc, + organization_id: Option, + ) -> Result { + let mut lock = self.0.write().await; + *lock = None; + Self::fetch(lock, client, organization_id).await + } + async fn fetch( mut lock: RwLockWriteGuard<'_, Option>, client: &Arc, @@ -82,13 +96,16 @@ impl LlmApiToken { *lock = Some(response.token.0.clone()); Ok(response.token.0) } - Err(err) => match err { - ClientApiError::Unauthorized => { - client.request_sign_out(); - Err(err).context("Failed to create LLM token") + Err(err) => { + *lock = None; + match err { + ClientApiError::Unauthorized => { + client.request_sign_out(); + Err(err).context("Failed to create LLM token") + } + ClientApiError::Other(err) => Err(err), } - ClientApiError::Other(err) => Err(err), - }, + } } } } @@ -105,6 +122,11 @@ impl NeedsLlmTokenRefresh for http_client::Response { } } +enum TokenRefreshMode { + Refresh, + ClearAndRefresh, +} + struct GlobalRefreshLlmTokenListener(Entity); impl Global for GlobalRefreshLlmTokenListener {} @@ -140,7 +162,7 @@ impl RefreshLlmTokenListener { let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| { if matches!(event, client::user::Event::OrganizationChanged) { - this.refresh(cx); + this.refresh(TokenRefreshMode::ClearAndRefresh, cx); } }); @@ -152,7 +174,7 @@ impl RefreshLlmTokenListener { } } - fn refresh(&self, cx: &mut Context) { + fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context) { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); let organization_id = self @@ -161,7 +183,16 @@ impl RefreshLlmTokenListener { .current_organization() .map(|organization| organization.id.clone()); cx.spawn(async move |this, cx| { - llm_api_token.refresh(&client, organization_id).await?; + match mode { + TokenRefreshMode::Refresh => { + llm_api_token.refresh(&client, organization_id).await?; + } + TokenRefreshMode::ClearAndRefresh => { + llm_api_token + .clear_and_refresh(&client, organization_id) + .await?; + } + } this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent)) }) .detach_and_log_err(cx); @@ -170,7 +201,7 @@ impl RefreshLlmTokenListener { fn handle_refresh_llm_token(this: Entity, message: &MessageToClient, cx: &mut App) { match message { MessageToClient::UserUpdated => { - this.update(cx, |this, cx| this.refresh(cx)); + this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx)); } } }