cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::Result;
  5use client::Client;
  6use cloud_api_types::websocket_protocol::MessageToClient;
  7use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
  8use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
  9use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 10use thiserror::Error;
 11
 12#[derive(Error, Debug)]
 13pub struct PaymentRequiredError;
 14
 15impl fmt::Display for PaymentRequiredError {
 16    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 17        write!(
 18            f,
 19            "Payment required to use this language model. Please upgrade your account."
 20        )
 21    }
 22}
 23
 24#[derive(Clone, Default)]
 25pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 26
 27impl LlmApiToken {
 28    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
 29        let lock = self.0.upgradable_read().await;
 30        if let Some(token) = lock.as_ref() {
 31            Ok(token.to_string())
 32        } else {
 33            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
 34        }
 35    }
 36
 37    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
 38        Self::fetch(self.0.write().await, client).await
 39    }
 40
 41    async fn fetch(
 42        mut lock: RwLockWriteGuard<'_, Option<String>>,
 43        client: &Arc<Client>,
 44    ) -> Result<String> {
 45        let system_id = client
 46            .telemetry()
 47            .system_id()
 48            .map(|system_id| system_id.to_string());
 49
 50        let response = client.cloud_client().create_llm_token(system_id).await?;
 51        *lock = Some(response.token.0.clone());
 52        Ok(response.token.0)
 53    }
 54}
 55
 56pub trait NeedsLlmTokenRefresh {
 57    /// Returns whether the LLM token needs to be refreshed.
 58    fn needs_llm_token_refresh(&self) -> bool;
 59}
 60
 61impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
 62    fn needs_llm_token_refresh(&self) -> bool {
 63        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
 64            || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
 65    }
 66}
 67
 68struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 69
 70impl Global for GlobalRefreshLlmTokenListener {}
 71
 72pub struct RefreshLlmTokenEvent;
 73
 74pub struct RefreshLlmTokenListener;
 75
 76impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
 77
 78impl RefreshLlmTokenListener {
 79    pub fn register(client: Arc<Client>, cx: &mut App) {
 80        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
 81        cx.set_global(GlobalRefreshLlmTokenListener(listener));
 82    }
 83
 84    pub fn global(cx: &App) -> Entity<Self> {
 85        GlobalRefreshLlmTokenListener::global(cx).0.clone()
 86    }
 87
 88    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
 89        client.add_message_to_client_handler({
 90            let this = cx.entity();
 91            move |message, cx| {
 92                Self::handle_refresh_llm_token(this.clone(), message, cx);
 93            }
 94        });
 95
 96        Self
 97    }
 98
 99    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
100        match message {
101            MessageToClient::UserUpdated => {
102                this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
103            }
104        }
105    }
106}