@@ -63,6 +63,20 @@ impl LlmApiToken {
Self::fetch(self.0.write().await, client, organization_id).await
}
+ /// Clears the existing token before attempting to fetch a new one.
+ ///
+ /// Used when switching organizations so that a failed refresh doesn't
+ /// leave a token for the wrong organization.
+ pub async fn clear_and_refresh(
+ &self,
+ client: &Arc<Client>,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String> {
+ let mut lock = self.0.write().await;
+ *lock = None;
+ Self::fetch(lock, client, organization_id).await
+ }
+
async fn fetch(
mut lock: RwLockWriteGuard<'_, Option<String>>,
client: &Arc<Client>,
@@ -82,13 +96,16 @@ impl LlmApiToken {
*lock = Some(response.token.0.clone());
Ok(response.token.0)
}
- Err(err) => match err {
- ClientApiError::Unauthorized => {
- client.request_sign_out();
- Err(err).context("Failed to create LLM token")
+ Err(err) => {
+ *lock = None;
+ match err {
+ ClientApiError::Unauthorized => {
+ client.request_sign_out();
+ Err(err).context("Failed to create LLM token")
+ }
+ ClientApiError::Other(err) => Err(err),
}
- ClientApiError::Other(err) => Err(err),
- },
+ }
}
}
}
@@ -105,6 +122,11 @@ impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
}
}
+enum TokenRefreshMode {
+ Refresh,
+ ClearAndRefresh,
+}
+
struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
impl Global for GlobalRefreshLlmTokenListener {}
@@ -140,7 +162,7 @@ impl RefreshLlmTokenListener {
let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
if matches!(event, client::user::Event::OrganizationChanged) {
- this.refresh(cx);
+ this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
}
});
@@ -152,7 +174,7 @@ impl RefreshLlmTokenListener {
}
}
- fn refresh(&self, cx: &mut Context<Self>) {
+ fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let organization_id = self
@@ -161,7 +183,16 @@ impl RefreshLlmTokenListener {
.current_organization()
.map(|organization| organization.id.clone());
cx.spawn(async move |this, cx| {
- llm_api_token.refresh(&client, organization_id).await?;
+ match mode {
+ TokenRefreshMode::Refresh => {
+ llm_api_token.refresh(&client, organization_id).await?;
+ }
+ TokenRefreshMode::ClearAndRefresh => {
+ llm_api_token
+ .clear_and_refresh(&client, organization_id)
+ .await?;
+ }
+ }
this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
})
.detach_and_log_err(cx);
@@ -170,7 +201,7 @@ impl RefreshLlmTokenListener {
fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
match message {
MessageToClient::UserUpdated => {
- this.update(cx, |this, cx| this.refresh(cx));
+ this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
}
}
}