1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::{Context as _, Result};
5use client::Client;
6use cloud_api_client::ClientApiError;
7use cloud_api_types::OrganizationId;
8use cloud_api_types::websocket_protocol::MessageToClient;
9use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
10use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
11use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub struct PaymentRequiredError;
16
17impl fmt::Display for PaymentRequiredError {
18 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
19 write!(
20 f,
21 "Payment required to use this language model. Please upgrade your account."
22 )
23 }
24}
25
26#[derive(Clone, Default)]
27pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
28
29impl LlmApiToken {
30 pub async fn acquire(
31 &self,
32 client: &Arc<Client>,
33 organization_id: Option<OrganizationId>,
34 ) -> Result<String> {
35 let lock = self.0.upgradable_read().await;
36 if let Some(token) = lock.as_ref() {
37 Ok(token.to_string())
38 } else {
39 Self::fetch(
40 RwLockUpgradableReadGuard::upgrade(lock).await,
41 client,
42 organization_id,
43 )
44 .await
45 }
46 }
47
48 pub async fn refresh(
49 &self,
50 client: &Arc<Client>,
51 organization_id: Option<OrganizationId>,
52 ) -> Result<String> {
53 Self::fetch(self.0.write().await, client, organization_id).await
54 }
55
56 async fn fetch(
57 mut lock: RwLockWriteGuard<'_, Option<String>>,
58 client: &Arc<Client>,
59 organization_id: Option<OrganizationId>,
60 ) -> Result<String> {
61 let system_id = client
62 .telemetry()
63 .system_id()
64 .map(|system_id| system_id.to_string());
65
66 let result = client
67 .cloud_client()
68 .create_llm_token(system_id, organization_id)
69 .await;
70 match result {
71 Ok(response) => {
72 *lock = Some(response.token.0.clone());
73 Ok(response.token.0)
74 }
75 Err(err) => match err {
76 ClientApiError::Unauthorized => {
77 client.request_sign_out();
78 Err(err).context("Failed to create LLM token")
79 }
80 ClientApiError::Other(err) => Err(err),
81 },
82 }
83 }
84}
85
86pub trait NeedsLlmTokenRefresh {
87 /// Returns whether the LLM token needs to be refreshed.
88 fn needs_llm_token_refresh(&self) -> bool;
89}
90
91impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
92 fn needs_llm_token_refresh(&self) -> bool {
93 self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
94 || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
95 }
96}
97
98struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
99
100impl Global for GlobalRefreshLlmTokenListener {}
101
102pub struct RefreshLlmTokenEvent;
103
104pub struct RefreshLlmTokenListener;
105
106impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
107
108impl RefreshLlmTokenListener {
109 pub fn register(client: Arc<Client>, cx: &mut App) {
110 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
111 cx.set_global(GlobalRefreshLlmTokenListener(listener));
112 }
113
114 pub fn global(cx: &App) -> Entity<Self> {
115 GlobalRefreshLlmTokenListener::global(cx).0.clone()
116 }
117
118 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
119 client.add_message_to_client_handler({
120 let this = cx.entity();
121 move |message, cx| {
122 Self::handle_refresh_llm_token(this.clone(), message, cx);
123 }
124 });
125
126 Self
127 }
128
129 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
130 match message {
131 MessageToClient::UserUpdated => {
132 this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
133 }
134 }
135 }
136}