1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::{Context as _, Result};
5use client::Client;
6use cloud_api_client::ClientApiError;
7use cloud_api_types::websocket_protocol::MessageToClient;
8use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
9use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
10use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
11use thiserror::Error;
12
13#[derive(Error, Debug)]
14pub struct PaymentRequiredError;
15
16impl fmt::Display for PaymentRequiredError {
17 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
18 write!(
19 f,
20 "Payment required to use this language model. Please upgrade your account."
21 )
22 }
23}
24
25#[derive(Clone, Default)]
26pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
27
28impl LlmApiToken {
29 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
30 let lock = self.0.upgradable_read().await;
31 if let Some(token) = lock.as_ref() {
32 Ok(token.to_string())
33 } else {
34 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
35 }
36 }
37
38 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
39 Self::fetch(self.0.write().await, client).await
40 }
41
42 async fn fetch(
43 mut lock: RwLockWriteGuard<'_, Option<String>>,
44 client: &Arc<Client>,
45 ) -> Result<String> {
46 let system_id = client
47 .telemetry()
48 .system_id()
49 .map(|system_id| system_id.to_string());
50
51 let result = client.cloud_client().create_llm_token(system_id).await;
52 match result {
53 Ok(response) => {
54 *lock = Some(response.token.0.clone());
55 Ok(response.token.0)
56 }
57 Err(err) => match err {
58 ClientApiError::Unauthorized => {
59 client.request_sign_out();
60 Err(err).context("Failed to create LLM token")
61 }
62 ClientApiError::Other(err) => Err(err),
63 },
64 }
65 }
66}
67
68pub trait NeedsLlmTokenRefresh {
69 /// Returns whether the LLM token needs to be refreshed.
70 fn needs_llm_token_refresh(&self) -> bool;
71}
72
73impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
74 fn needs_llm_token_refresh(&self) -> bool {
75 self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
76 || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
77 }
78}
79
80struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
81
82impl Global for GlobalRefreshLlmTokenListener {}
83
84pub struct RefreshLlmTokenEvent;
85
86pub struct RefreshLlmTokenListener;
87
88impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
89
90impl RefreshLlmTokenListener {
91 pub fn register(client: Arc<Client>, cx: &mut App) {
92 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
93 cx.set_global(GlobalRefreshLlmTokenListener(listener));
94 }
95
96 pub fn global(cx: &App) -> Entity<Self> {
97 GlobalRefreshLlmTokenListener::global(cx).0.clone()
98 }
99
100 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
101 client.add_message_to_client_handler({
102 let this = cx.entity();
103 move |message, cx| {
104 Self::handle_refresh_llm_token(this.clone(), message, cx);
105 }
106 });
107
108 Self
109 }
110
111 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
112 match message {
113 MessageToClient::UserUpdated => {
114 this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
115 }
116 }
117 }
118}