cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use client::Client;
  6use client::UserStore;
  7use cloud_api_client::ClientApiError;
  8use cloud_api_types::OrganizationId;
  9use cloud_api_types::websocket_protocol::MessageToClient;
 10use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
 11use gpui::{
 12    App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
 13};
 14use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 15use thiserror::Error;
 16
 17#[derive(Error, Debug)]
 18pub struct PaymentRequiredError;
 19
 20impl fmt::Display for PaymentRequiredError {
 21    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 22        write!(
 23            f,
 24            "Payment required to use this language model. Please upgrade your account."
 25        )
 26    }
 27}
 28
 29#[derive(Clone, Default)]
 30pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 31
 32impl LlmApiToken {
 33    pub async fn acquire(
 34        &self,
 35        client: &Arc<Client>,
 36        organization_id: Option<OrganizationId>,
 37    ) -> Result<String> {
 38        let lock = self.0.upgradable_read().await;
 39        if let Some(token) = lock.as_ref() {
 40            Ok(token.to_string())
 41        } else {
 42            Self::fetch(
 43                RwLockUpgradableReadGuard::upgrade(lock).await,
 44                client,
 45                organization_id,
 46            )
 47            .await
 48        }
 49    }
 50
 51    pub async fn refresh(
 52        &self,
 53        client: &Arc<Client>,
 54        organization_id: Option<OrganizationId>,
 55    ) -> Result<String> {
 56        Self::fetch(self.0.write().await, client, organization_id).await
 57    }
 58
 59    async fn fetch(
 60        mut lock: RwLockWriteGuard<'_, Option<String>>,
 61        client: &Arc<Client>,
 62        organization_id: Option<OrganizationId>,
 63    ) -> Result<String> {
 64        let system_id = client
 65            .telemetry()
 66            .system_id()
 67            .map(|system_id| system_id.to_string());
 68
 69        let result = client
 70            .cloud_client()
 71            .create_llm_token(system_id, organization_id)
 72            .await;
 73        match result {
 74            Ok(response) => {
 75                *lock = Some(response.token.0.clone());
 76                Ok(response.token.0)
 77            }
 78            Err(err) => match err {
 79                ClientApiError::Unauthorized => {
 80                    client.request_sign_out();
 81                    Err(err).context("Failed to create LLM token")
 82                }
 83                ClientApiError::Other(err) => Err(err),
 84            },
 85        }
 86    }
 87}
 88
 89pub trait NeedsLlmTokenRefresh {
 90    /// Returns whether the LLM token needs to be refreshed.
 91    fn needs_llm_token_refresh(&self) -> bool;
 92}
 93
 94impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
 95    fn needs_llm_token_refresh(&self) -> bool {
 96        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
 97            || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
 98    }
 99}
100
101struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
102
103impl Global for GlobalRefreshLlmTokenListener {}
104
105pub struct RefreshLlmTokenEvent;
106
107pub struct RefreshLlmTokenListener {
108    _subscription: Subscription,
109}
110
111impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
112
113impl RefreshLlmTokenListener {
114    pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
115        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
116        cx.set_global(GlobalRefreshLlmTokenListener(listener));
117    }
118
119    pub fn global(cx: &App) -> Entity<Self> {
120        GlobalRefreshLlmTokenListener::global(cx).0.clone()
121    }
122
123    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
124        client.add_message_to_client_handler({
125            let this = cx.entity();
126            move |message, cx| {
127                Self::handle_refresh_llm_token(this.clone(), message, cx);
128            }
129        });
130
131        let subscription = cx.subscribe(&user_store, |_this, _user_store, event, cx| {
132            if matches!(event, client::user::Event::OrganizationChanged) {
133                cx.emit(RefreshLlmTokenEvent);
134            }
135        });
136
137        Self {
138            _subscription: subscription,
139        }
140    }
141
142    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
143        match message {
144            MessageToClient::UserUpdated => {
145                this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
146            }
147        }
148    }
149}