llm_token.rs

  1use super::{Client, UserStore};
  2use cloud_api_client::LlmApiToken;
  3use cloud_api_types::websocket_protocol::MessageToClient;
  4use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
  5use gpui::{
  6    App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
  7};
  8use std::sync::Arc;
  9
 10pub trait NeedsLlmTokenRefresh {
 11    /// Returns whether the LLM token needs to be refreshed.
 12    fn needs_llm_token_refresh(&self) -> bool;
 13}
 14
 15impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
 16    fn needs_llm_token_refresh(&self) -> bool {
 17        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
 18            || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
 19    }
 20}
 21
 22enum TokenRefreshMode {
 23    Refresh,
 24    ClearAndRefresh,
 25}
 26
 27pub fn global_llm_token(cx: &App) -> LlmApiToken {
 28    RefreshLlmTokenListener::global(cx)
 29        .read(cx)
 30        .llm_api_token
 31        .clone()
 32}
 33
 34struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 35
 36impl Global for GlobalRefreshLlmTokenListener {}
 37
 38pub struct LlmTokenRefreshedEvent;
 39
 40pub struct RefreshLlmTokenListener {
 41    client: Arc<Client>,
 42    user_store: Entity<UserStore>,
 43    llm_api_token: LlmApiToken,
 44    _subscription: Subscription,
 45}
 46
 47impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
 48
 49impl RefreshLlmTokenListener {
 50    pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
 51        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
 52        cx.set_global(GlobalRefreshLlmTokenListener(listener));
 53    }
 54
 55    pub fn global(cx: &App) -> Entity<Self> {
 56        GlobalRefreshLlmTokenListener::global(cx).0.clone()
 57    }
 58
 59    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 60        client.add_message_to_client_handler({
 61            let this = cx.weak_entity();
 62            move |message, cx| {
 63                if let Some(this) = this.upgrade() {
 64                    Self::handle_refresh_llm_token(this, message, cx);
 65                }
 66            }
 67        });
 68
 69        let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
 70            if matches!(event, super::user::Event::OrganizationChanged) {
 71                this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
 72            }
 73        });
 74
 75        Self {
 76            client,
 77            user_store,
 78            llm_api_token: LlmApiToken::default(),
 79            _subscription: subscription,
 80        }
 81    }
 82
 83    fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
 84        let client = self.client.clone();
 85        let llm_api_token = self.llm_api_token.clone();
 86        let organization_id = self
 87            .user_store
 88            .read(cx)
 89            .current_organization()
 90            .map(|organization| organization.id.clone());
 91        cx.spawn(async move |this, cx| {
 92            match mode {
 93                TokenRefreshMode::Refresh => {
 94                    client
 95                        .refresh_llm_token(&llm_api_token, organization_id)
 96                        .await?;
 97                }
 98                TokenRefreshMode::ClearAndRefresh => {
 99                    client
100                        .clear_and_refresh_llm_token(&llm_api_token, organization_id)
101                        .await?;
102                }
103            }
104            this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
105        })
106        .detach_and_log_err(cx);
107    }
108
109    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
110        match message {
111            MessageToClient::UserUpdated => {
112                this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
113            }
114        }
115    }
116}