cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use client::Client;
  6use client::UserStore;
  7use cloud_api_client::ClientApiError;
  8use cloud_api_types::OrganizationId;
  9use cloud_api_types::websocket_protocol::MessageToClient;
 10use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
 11use gpui::{
 12    App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
 13};
 14use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 15use thiserror::Error;
 16
 17#[derive(Error, Debug)]
 18pub struct PaymentRequiredError;
 19
 20impl fmt::Display for PaymentRequiredError {
 21    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 22        write!(
 23            f,
 24            "Payment required to use this language model. Please upgrade your account."
 25        )
 26    }
 27}
 28
 29#[derive(Clone, Default)]
 30pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 31
 32impl LlmApiToken {
 33    pub fn global(cx: &App) -> Self {
 34        RefreshLlmTokenListener::global(cx)
 35            .read(cx)
 36            .llm_api_token
 37            .clone()
 38    }
 39
 40    pub async fn acquire(
 41        &self,
 42        client: &Arc<Client>,
 43        organization_id: Option<OrganizationId>,
 44    ) -> Result<String> {
 45        let lock = self.0.upgradable_read().await;
 46        if let Some(token) = lock.as_ref() {
 47            Ok(token.to_string())
 48        } else {
 49            Self::fetch(
 50                RwLockUpgradableReadGuard::upgrade(lock).await,
 51                client,
 52                organization_id,
 53            )
 54            .await
 55        }
 56    }
 57
 58    pub async fn refresh(
 59        &self,
 60        client: &Arc<Client>,
 61        organization_id: Option<OrganizationId>,
 62    ) -> Result<String> {
 63        Self::fetch(self.0.write().await, client, organization_id).await
 64    }
 65
 66    async fn fetch(
 67        mut lock: RwLockWriteGuard<'_, Option<String>>,
 68        client: &Arc<Client>,
 69        organization_id: Option<OrganizationId>,
 70    ) -> Result<String> {
 71        let system_id = client
 72            .telemetry()
 73            .system_id()
 74            .map(|system_id| system_id.to_string());
 75
 76        let result = client
 77            .cloud_client()
 78            .create_llm_token(system_id, organization_id)
 79            .await;
 80        match result {
 81            Ok(response) => {
 82                *lock = Some(response.token.0.clone());
 83                Ok(response.token.0)
 84            }
 85            Err(err) => match err {
 86                ClientApiError::Unauthorized => {
 87                    client.request_sign_out();
 88                    Err(err).context("Failed to create LLM token")
 89                }
 90                ClientApiError::Other(err) => Err(err),
 91            },
 92        }
 93    }
 94}
 95
 96pub trait NeedsLlmTokenRefresh {
 97    /// Returns whether the LLM token needs to be refreshed.
 98    fn needs_llm_token_refresh(&self) -> bool;
 99}
100
101impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
102    fn needs_llm_token_refresh(&self) -> bool {
103        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
104            || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
105    }
106}
107
108struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
109
110impl Global for GlobalRefreshLlmTokenListener {}
111
112pub struct LlmTokenRefreshedEvent;
113
114pub struct RefreshLlmTokenListener {
115    client: Arc<Client>,
116    user_store: Entity<UserStore>,
117    llm_api_token: LlmApiToken,
118    _subscription: Subscription,
119}
120
121impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
122
123impl RefreshLlmTokenListener {
124    pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
125        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
126        cx.set_global(GlobalRefreshLlmTokenListener(listener));
127    }
128
129    pub fn global(cx: &App) -> Entity<Self> {
130        GlobalRefreshLlmTokenListener::global(cx).0.clone()
131    }
132
133    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
134        client.add_message_to_client_handler({
135            let this = cx.entity();
136            move |message, cx| {
137                Self::handle_refresh_llm_token(this.clone(), message, cx);
138            }
139        });
140
141        let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
142            if matches!(event, client::user::Event::OrganizationChanged) {
143                this.refresh(cx);
144            }
145        });
146
147        Self {
148            client,
149            user_store,
150            llm_api_token: LlmApiToken::default(),
151            _subscription: subscription,
152        }
153    }
154
155    fn refresh(&self, cx: &mut Context<Self>) {
156        let client = self.client.clone();
157        let llm_api_token = self.llm_api_token.clone();
158        let organization_id = self
159            .user_store
160            .read(cx)
161            .current_organization()
162            .map(|o| o.id.clone());
163        cx.spawn(async move |this, cx| {
164            llm_api_token.refresh(&client, organization_id).await?;
165            this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
166        })
167        .detach_and_log_err(cx);
168    }
169
170    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
171        match message {
172            MessageToClient::UserUpdated => {
173                this.update(cx, |this, cx| this.refresh(cx));
174            }
175        }
176    }
177}