1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::Result;
5use client::Client;
6use cloud_api_types::websocket_protocol::MessageToClient;
7use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
8use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct PaymentRequiredError;
13
14impl fmt::Display for PaymentRequiredError {
15 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
16 write!(
17 f,
18 "Payment required to use this language model. Please upgrade your account."
19 )
20 }
21}
22
23#[derive(Clone, Default)]
24pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
25
26impl LlmApiToken {
27 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
28 let lock = self.0.upgradable_read().await;
29 if let Some(token) = lock.as_ref() {
30 Ok(token.to_string())
31 } else {
32 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
33 }
34 }
35
36 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
37 Self::fetch(self.0.write().await, client).await
38 }
39
40 async fn fetch(
41 mut lock: RwLockWriteGuard<'_, Option<String>>,
42 client: &Arc<Client>,
43 ) -> Result<String> {
44 let system_id = client
45 .telemetry()
46 .system_id()
47 .map(|system_id| system_id.to_string());
48
49 let response = client.cloud_client().create_llm_token(system_id).await?;
50 *lock = Some(response.token.0.clone());
51 Ok(response.token.0)
52 }
53}
54
55struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
56
57impl Global for GlobalRefreshLlmTokenListener {}
58
59pub struct RefreshLlmTokenEvent;
60
61pub struct RefreshLlmTokenListener;
62
63impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
64
65impl RefreshLlmTokenListener {
66 pub fn register(client: Arc<Client>, cx: &mut App) {
67 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
68 cx.set_global(GlobalRefreshLlmTokenListener(listener));
69 }
70
71 pub fn global(cx: &App) -> Entity<Self> {
72 GlobalRefreshLlmTokenListener::global(cx).0.clone()
73 }
74
75 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
76 client.add_message_to_client_handler({
77 let this = cx.entity();
78 move |message, cx| {
79 Self::handle_refresh_llm_token(this.clone(), message, cx);
80 }
81 });
82
83 Self
84 }
85
86 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
87 match message {
88 MessageToClient::UserUpdated => {
89 this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
90 }
91 }
92 }
93}