cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result};
  5use client::Client;
  6use client::UserStore;
  7use cloud_api_client::ClientApiError;
  8use cloud_api_types::OrganizationId;
  9use cloud_api_types::websocket_protocol::MessageToClient;
 10use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
 11use gpui::{
 12    App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
 13};
 14use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 15use thiserror::Error;
 16
 17#[derive(Error, Debug)]
 18pub struct PaymentRequiredError;
 19
 20impl fmt::Display for PaymentRequiredError {
 21    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 22        write!(
 23            f,
 24            "Payment required to use this language model. Please upgrade your account."
 25        )
 26    }
 27}
 28
 29#[derive(Clone, Default)]
 30pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 31
 32impl LlmApiToken {
 33    pub fn global(cx: &App) -> Self {
 34        RefreshLlmTokenListener::global(cx)
 35            .read(cx)
 36            .llm_api_token
 37            .clone()
 38    }
 39
 40    pub async fn acquire(
 41        &self,
 42        client: &Arc<Client>,
 43        organization_id: Option<OrganizationId>,
 44    ) -> Result<String> {
 45        let lock = self.0.upgradable_read().await;
 46        if let Some(token) = lock.as_ref() {
 47            Ok(token.to_string())
 48        } else {
 49            Self::fetch(
 50                RwLockUpgradableReadGuard::upgrade(lock).await,
 51                client,
 52                organization_id,
 53            )
 54            .await
 55        }
 56    }
 57
 58    pub async fn refresh(
 59        &self,
 60        client: &Arc<Client>,
 61        organization_id: Option<OrganizationId>,
 62    ) -> Result<String> {
 63        Self::fetch(self.0.write().await, client, organization_id).await
 64    }
 65
 66    /// Clears the existing token before attempting to fetch a new one.
 67    ///
 68    /// Used when switching organizations so that a failed refresh doesn't
 69    /// leave a token for the wrong organization.
 70    pub async fn clear_and_refresh(
 71        &self,
 72        client: &Arc<Client>,
 73        organization_id: Option<OrganizationId>,
 74    ) -> Result<String> {
 75        let mut lock = self.0.write().await;
 76        *lock = None;
 77        Self::fetch(lock, client, organization_id).await
 78    }
 79
 80    async fn fetch(
 81        mut lock: RwLockWriteGuard<'_, Option<String>>,
 82        client: &Arc<Client>,
 83        organization_id: Option<OrganizationId>,
 84    ) -> Result<String> {
 85        let system_id = client
 86            .telemetry()
 87            .system_id()
 88            .map(|system_id| system_id.to_string());
 89
 90        let result = client
 91            .cloud_client()
 92            .create_llm_token(system_id, organization_id)
 93            .await;
 94        match result {
 95            Ok(response) => {
 96                *lock = Some(response.token.0.clone());
 97                Ok(response.token.0)
 98            }
 99            Err(err) => {
100                *lock = None;
101                match err {
102                    ClientApiError::Unauthorized => {
103                        client.request_sign_out();
104                        Err(err).context("Failed to create LLM token")
105                    }
106                    ClientApiError::Other(err) => Err(err),
107                }
108            }
109        }
110    }
111}
112
113pub trait NeedsLlmTokenRefresh {
114    /// Returns whether the LLM token needs to be refreshed.
115    fn needs_llm_token_refresh(&self) -> bool;
116}
117
118impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
119    fn needs_llm_token_refresh(&self) -> bool {
120        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
121            || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
122    }
123}
124
125enum TokenRefreshMode {
126    Refresh,
127    ClearAndRefresh,
128}
129
130struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
131
132impl Global for GlobalRefreshLlmTokenListener {}
133
134pub struct LlmTokenRefreshedEvent;
135
136pub struct RefreshLlmTokenListener {
137    client: Arc<Client>,
138    user_store: Entity<UserStore>,
139    llm_api_token: LlmApiToken,
140    _subscription: Subscription,
141}
142
143impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
144
145impl RefreshLlmTokenListener {
146    pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
147        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
148        cx.set_global(GlobalRefreshLlmTokenListener(listener));
149    }
150
151    pub fn global(cx: &App) -> Entity<Self> {
152        GlobalRefreshLlmTokenListener::global(cx).0.clone()
153    }
154
155    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
156        client.add_message_to_client_handler({
157            let this = cx.weak_entity();
158            move |message, cx| {
159                if let Some(this) = this.upgrade() {
160                    Self::handle_refresh_llm_token(this, message, cx);
161                }
162            }
163        });
164
165        let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
166            if matches!(event, client::user::Event::OrganizationChanged) {
167                this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
168            }
169        });
170
171        Self {
172            client,
173            user_store,
174            llm_api_token: LlmApiToken::default(),
175            _subscription: subscription,
176        }
177    }
178
179    fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
180        let client = self.client.clone();
181        let llm_api_token = self.llm_api_token.clone();
182        let organization_id = self
183            .user_store
184            .read(cx)
185            .current_organization()
186            .map(|organization| organization.id.clone());
187        cx.spawn(async move |this, cx| {
188            match mode {
189                TokenRefreshMode::Refresh => {
190                    llm_api_token.refresh(&client, organization_id).await?;
191                }
192                TokenRefreshMode::ClearAndRefresh => {
193                    llm_api_token
194                        .clear_and_refresh(&client, organization_id)
195                        .await?;
196                }
197            }
198            this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
199        })
200        .detach_and_log_err(cx);
201    }
202
203    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
204        match message {
205            MessageToClient::UserUpdated => {
206                this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
207            }
208        }
209    }
210}