cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use client::Client;
  6use cloud_api_client::ClientApiError;
  7use cloud_api_types::OrganizationId;
  8use cloud_api_types::websocket_protocol::MessageToClient;
  9use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
 10use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
 11use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 12use thiserror::Error;
 13
 14#[derive(Error, Debug)]
 15pub struct PaymentRequiredError;
 16
 17impl fmt::Display for PaymentRequiredError {
 18    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 19        write!(
 20            f,
 21            "Payment required to use this language model. Please upgrade your account."
 22        )
 23    }
 24}
 25
 26#[derive(Clone, Default)]
 27pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 28
 29impl LlmApiToken {
 30    pub async fn acquire(
 31        &self,
 32        client: &Arc<Client>,
 33        organization_id: Option<OrganizationId>,
 34    ) -> Result<String> {
 35        let lock = self.0.upgradable_read().await;
 36        if let Some(token) = lock.as_ref() {
 37            Ok(token.to_string())
 38        } else {
 39            Self::fetch(
 40                RwLockUpgradableReadGuard::upgrade(lock).await,
 41                client,
 42                organization_id,
 43            )
 44            .await
 45        }
 46    }
 47
 48    pub async fn refresh(
 49        &self,
 50        client: &Arc<Client>,
 51        organization_id: Option<OrganizationId>,
 52    ) -> Result<String> {
 53        Self::fetch(self.0.write().await, client, organization_id).await
 54    }
 55
 56    async fn fetch(
 57        mut lock: RwLockWriteGuard<'_, Option<String>>,
 58        client: &Arc<Client>,
 59        organization_id: Option<OrganizationId>,
 60    ) -> Result<String> {
 61        let system_id = client
 62            .telemetry()
 63            .system_id()
 64            .map(|system_id| system_id.to_string());
 65
 66        let result = client
 67            .cloud_client()
 68            .create_llm_token(system_id, organization_id)
 69            .await;
 70        match result {
 71            Ok(response) => {
 72                *lock = Some(response.token.0.clone());
 73                Ok(response.token.0)
 74            }
 75            Err(err) => match err {
 76                ClientApiError::Unauthorized => {
 77                    client.request_sign_out();
 78                    Err(err).context("Failed to create LLM token")
 79                }
 80                ClientApiError::Other(err) => Err(err),
 81            },
 82        }
 83    }
 84}
 85
 86pub trait NeedsLlmTokenRefresh {
 87    /// Returns whether the LLM token needs to be refreshed.
 88    fn needs_llm_token_refresh(&self) -> bool;
 89}
 90
 91impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
 92    fn needs_llm_token_refresh(&self) -> bool {
 93        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
 94            || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
 95    }
 96}
 97
 98struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 99
100impl Global for GlobalRefreshLlmTokenListener {}
101
102pub struct RefreshLlmTokenEvent;
103
104pub struct RefreshLlmTokenListener;
105
106impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
107
108impl RefreshLlmTokenListener {
109    pub fn register(client: Arc<Client>, cx: &mut App) {
110        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
111        cx.set_global(GlobalRefreshLlmTokenListener(listener));
112    }
113
114    pub fn global(cx: &App) -> Entity<Self> {
115        GlobalRefreshLlmTokenListener::global(cx).0.clone()
116    }
117
118    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
119        client.add_message_to_client_handler({
120            let this = cx.entity();
121            move |message, cx| {
122                Self::handle_refresh_llm_token(this.clone(), message, cx);
123            }
124        });
125
126        Self
127    }
128
129    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
130        match message {
131            MessageToClient::UserUpdated => {
132                this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
133            }
134        }
135    }
136}