1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::Result;
5use client::Client;
6use cloud_api_types::websocket_protocol::MessageToClient;
7use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
8use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
9use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
10use thiserror::Error;
11
12#[derive(Error, Debug)]
13pub struct PaymentRequiredError;
14
15impl fmt::Display for PaymentRequiredError {
16 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
17 write!(
18 f,
19 "Payment required to use this language model. Please upgrade your account."
20 )
21 }
22}
23
24#[derive(Clone, Default)]
25pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
26
27impl LlmApiToken {
28 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
29 let lock = self.0.upgradable_read().await;
30 if let Some(token) = lock.as_ref() {
31 Ok(token.to_string())
32 } else {
33 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
34 }
35 }
36
37 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
38 Self::fetch(self.0.write().await, client).await
39 }
40
41 async fn fetch(
42 mut lock: RwLockWriteGuard<'_, Option<String>>,
43 client: &Arc<Client>,
44 ) -> Result<String> {
45 let system_id = client
46 .telemetry()
47 .system_id()
48 .map(|system_id| system_id.to_string());
49
50 let response = client.cloud_client().create_llm_token(system_id).await?;
51 *lock = Some(response.token.0.clone());
52 Ok(response.token.0)
53 }
54}
55
56pub trait NeedsLlmTokenRefresh {
57 /// Returns whether the LLM token needs to be refreshed.
58 fn needs_llm_token_refresh(&self) -> bool;
59}
60
61impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
62 fn needs_llm_token_refresh(&self) -> bool {
63 self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
64 || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
65 }
66}
67
68struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
69
70impl Global for GlobalRefreshLlmTokenListener {}
71
72pub struct RefreshLlmTokenEvent;
73
74pub struct RefreshLlmTokenListener;
75
76impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
77
78impl RefreshLlmTokenListener {
79 pub fn register(client: Arc<Client>, cx: &mut App) {
80 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
81 cx.set_global(GlobalRefreshLlmTokenListener(listener));
82 }
83
84 pub fn global(cx: &App) -> Entity<Self> {
85 GlobalRefreshLlmTokenListener::global(cx).0.clone()
86 }
87
88 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
89 client.add_message_to_client_handler({
90 let this = cx.entity();
91 move |message, cx| {
92 Self::handle_refresh_llm_token(this.clone(), message, cx);
93 }
94 });
95
96 Self
97 }
98
99 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
100 match message {
101 MessageToClient::UserUpdated => {
102 this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
103 }
104 }
105 }
106}