cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use client::Client;
  6use cloud_api_client::ClientApiError;
  7use cloud_api_types::websocket_protocol::MessageToClient;
  8use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
  9use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
 10use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 11use thiserror::Error;
 12
 13#[derive(Error, Debug)]
 14pub struct PaymentRequiredError;
 15
 16impl fmt::Display for PaymentRequiredError {
 17    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 18        write!(
 19            f,
 20            "Payment required to use this language model. Please upgrade your account."
 21        )
 22    }
 23}
 24
 25#[derive(Clone, Default)]
 26pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 27
 28impl LlmApiToken {
 29    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
 30        let lock = self.0.upgradable_read().await;
 31        if let Some(token) = lock.as_ref() {
 32            Ok(token.to_string())
 33        } else {
 34            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
 35        }
 36    }
 37
 38    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
 39        Self::fetch(self.0.write().await, client).await
 40    }
 41
 42    async fn fetch(
 43        mut lock: RwLockWriteGuard<'_, Option<String>>,
 44        client: &Arc<Client>,
 45    ) -> Result<String> {
 46        let system_id = client
 47            .telemetry()
 48            .system_id()
 49            .map(|system_id| system_id.to_string());
 50
 51        let result = client.cloud_client().create_llm_token(system_id).await;
 52        match result {
 53            Ok(response) => {
 54                *lock = Some(response.token.0.clone());
 55                Ok(response.token.0)
 56            }
 57            Err(err) => match err {
 58                ClientApiError::Unauthorized => {
 59                    client.request_sign_out();
 60                    Err(err).context("Failed to create LLM token")
 61                }
 62                ClientApiError::Other(err) => Err(err),
 63            },
 64        }
 65    }
 66}
 67
 68pub trait NeedsLlmTokenRefresh {
 69    /// Returns whether the LLM token needs to be refreshed.
 70    fn needs_llm_token_refresh(&self) -> bool;
 71}
 72
 73impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
 74    fn needs_llm_token_refresh(&self) -> bool {
 75        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
 76            || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
 77    }
 78}
 79
 80struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 81
 82impl Global for GlobalRefreshLlmTokenListener {}
 83
 84pub struct RefreshLlmTokenEvent;
 85
 86pub struct RefreshLlmTokenListener;
 87
 88impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
 89
 90impl RefreshLlmTokenListener {
 91    pub fn register(client: Arc<Client>, cx: &mut App) {
 92        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
 93        cx.set_global(GlobalRefreshLlmTokenListener(listener));
 94    }
 95
 96    pub fn global(cx: &App) -> Entity<Self> {
 97        GlobalRefreshLlmTokenListener::global(cx).0.clone()
 98    }
 99
100    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
101        client.add_message_to_client_handler({
102            let this = cx.entity();
103            move |message, cx| {
104                Self::handle_refresh_llm_token(this.clone(), message, cx);
105            }
106        });
107
108        Self
109    }
110
111    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
112        match message {
113            MessageToClient::UserUpdated => {
114                this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
115            }
116        }
117    }
118}