cloud_model.rs

 1use std::fmt;
 2use std::sync::Arc;
 3
 4use anyhow::Result;
 5use client::Client;
 6use cloud_api_types::websocket_protocol::MessageToClient;
 7use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
 8use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct PaymentRequiredError;
13
14impl fmt::Display for PaymentRequiredError {
15    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
16        write!(
17            f,
18            "Payment required to use this language model. Please upgrade your account."
19        )
20    }
21}
22
23#[derive(Clone, Default)]
24pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
25
26impl LlmApiToken {
27    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
28        let lock = self.0.upgradable_read().await;
29        if let Some(token) = lock.as_ref() {
30            Ok(token.to_string())
31        } else {
32            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
33        }
34    }
35
36    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
37        Self::fetch(self.0.write().await, client).await
38    }
39
40    async fn fetch(
41        mut lock: RwLockWriteGuard<'_, Option<String>>,
42        client: &Arc<Client>,
43    ) -> Result<String> {
44        let system_id = client
45            .telemetry()
46            .system_id()
47            .map(|system_id| system_id.to_string());
48
49        let response = client.cloud_client().create_llm_token(system_id).await?;
50        *lock = Some(response.token.0.clone());
51        Ok(response.token.0)
52    }
53}
54
55struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
56
57impl Global for GlobalRefreshLlmTokenListener {}
58
59pub struct RefreshLlmTokenEvent;
60
61pub struct RefreshLlmTokenListener;
62
63impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
64
65impl RefreshLlmTokenListener {
66    pub fn register(client: Arc<Client>, cx: &mut App) {
67        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
68        cx.set_global(GlobalRefreshLlmTokenListener(listener));
69    }
70
71    pub fn global(cx: &App) -> Entity<Self> {
72        GlobalRefreshLlmTokenListener::global(cx).0.clone()
73    }
74
75    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
76        client.add_message_to_client_handler({
77            let this = cx.entity();
78            move |message, cx| {
79                Self::handle_refresh_llm_token(this.clone(), message, cx);
80            }
81        });
82
83        Self
84    }
85
86    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
87        match message {
88            MessageToClient::UserUpdated => {
89                this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
90            }
91        }
92    }
93}