llm_token.rs

  1use super::{Client, UserStore};
  2use cloud_api_client::LlmApiToken;
  3use cloud_api_types::websocket_protocol::MessageToClient;
  4use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
  5use gpui::{
  6    App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
  7    TaskExt,
  8};
  9use std::sync::Arc;
 10
 11pub trait NeedsLlmTokenRefresh {
 12    /// Returns whether the LLM token needs to be refreshed.
 13    fn needs_llm_token_refresh(&self) -> bool;
 14}
 15
 16impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
 17    fn needs_llm_token_refresh(&self) -> bool {
 18        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
 19            || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
 20    }
 21}
 22
 23enum TokenRefreshMode {
 24    Refresh,
 25    ClearAndRefresh,
 26}
 27
 28pub fn global_llm_token(cx: &App) -> LlmApiToken {
 29    RefreshLlmTokenListener::global(cx)
 30        .read(cx)
 31        .llm_api_token
 32        .clone()
 33}
 34
 35struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 36
 37impl Global for GlobalRefreshLlmTokenListener {}
 38
 39pub struct LlmTokenRefreshedEvent;
 40
 41pub struct RefreshLlmTokenListener {
 42    client: Arc<Client>,
 43    user_store: Entity<UserStore>,
 44    llm_api_token: LlmApiToken,
 45    _subscription: Subscription,
 46}
 47
 48impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
 49
 50impl RefreshLlmTokenListener {
 51    pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
 52        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
 53        cx.set_global(GlobalRefreshLlmTokenListener(listener));
 54    }
 55
 56    pub fn global(cx: &App) -> Entity<Self> {
 57        GlobalRefreshLlmTokenListener::global(cx).0.clone()
 58    }
 59
 60    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 61        client.add_message_to_client_handler({
 62            let this = cx.weak_entity();
 63            move |message, cx| {
 64                if let Some(this) = this.upgrade() {
 65                    Self::handle_refresh_llm_token(this, message, cx);
 66                }
 67            }
 68        });
 69
 70        let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
 71            if matches!(event, super::user::Event::OrganizationChanged) {
 72                this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
 73            }
 74        });
 75
 76        Self {
 77            client,
 78            user_store,
 79            llm_api_token: LlmApiToken::default(),
 80            _subscription: subscription,
 81        }
 82    }
 83
 84    fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
 85        let client = self.client.clone();
 86        let llm_api_token = self.llm_api_token.clone();
 87        let organization_id = self
 88            .user_store
 89            .read(cx)
 90            .current_organization()
 91            .map(|organization| organization.id.clone());
 92        cx.spawn(async move |this, cx| {
 93            match mode {
 94                TokenRefreshMode::Refresh => {
 95                    client
 96                        .refresh_llm_token(&llm_api_token, organization_id)
 97                        .await?;
 98                }
 99                TokenRefreshMode::ClearAndRefresh => {
100                    client
101                        .clear_and_refresh_llm_token(&llm_api_token, organization_id)
102                        .await?;
103                }
104            }
105            this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
106        })
107        .detach_and_log_err(cx);
108    }
109
110    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
111        match message {
112            MessageToClient::UserUpdated => {
113                this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
114            }
115        }
116    }
117}