1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::{Context as _, Result};
5use client::Client;
6use client::UserStore;
7use cloud_api_client::ClientApiError;
8use cloud_api_types::OrganizationId;
9use cloud_api_types::websocket_protocol::MessageToClient;
10use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
11use gpui::{
12 App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
13};
14use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
15use thiserror::Error;
16
17#[derive(Error, Debug)]
18pub struct PaymentRequiredError;
19
20impl fmt::Display for PaymentRequiredError {
21 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
22 write!(
23 f,
24 "Payment required to use this language model. Please upgrade your account."
25 )
26 }
27}
28
29#[derive(Clone, Default)]
30pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
31
32impl LlmApiToken {
33 pub async fn acquire(
34 &self,
35 client: &Arc<Client>,
36 organization_id: Option<OrganizationId>,
37 ) -> Result<String> {
38 let lock = self.0.upgradable_read().await;
39 if let Some(token) = lock.as_ref() {
40 Ok(token.to_string())
41 } else {
42 Self::fetch(
43 RwLockUpgradableReadGuard::upgrade(lock).await,
44 client,
45 organization_id,
46 )
47 .await
48 }
49 }
50
51 pub async fn refresh(
52 &self,
53 client: &Arc<Client>,
54 organization_id: Option<OrganizationId>,
55 ) -> Result<String> {
56 Self::fetch(self.0.write().await, client, organization_id).await
57 }
58
59 async fn fetch(
60 mut lock: RwLockWriteGuard<'_, Option<String>>,
61 client: &Arc<Client>,
62 organization_id: Option<OrganizationId>,
63 ) -> Result<String> {
64 let system_id = client
65 .telemetry()
66 .system_id()
67 .map(|system_id| system_id.to_string());
68
69 let result = client
70 .cloud_client()
71 .create_llm_token(system_id, organization_id)
72 .await;
73 match result {
74 Ok(response) => {
75 *lock = Some(response.token.0.clone());
76 Ok(response.token.0)
77 }
78 Err(err) => match err {
79 ClientApiError::Unauthorized => {
80 client.request_sign_out();
81 Err(err).context("Failed to create LLM token")
82 }
83 ClientApiError::Other(err) => Err(err),
84 },
85 }
86 }
87}
88
89pub trait NeedsLlmTokenRefresh {
90 /// Returns whether the LLM token needs to be refreshed.
91 fn needs_llm_token_refresh(&self) -> bool;
92}
93
94impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
95 fn needs_llm_token_refresh(&self) -> bool {
96 self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
97 || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
98 }
99}
100
101struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
102
103impl Global for GlobalRefreshLlmTokenListener {}
104
105pub struct RefreshLlmTokenEvent;
106
107pub struct RefreshLlmTokenListener {
108 _subscription: Subscription,
109}
110
111impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
112
113impl RefreshLlmTokenListener {
114 pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
115 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
116 cx.set_global(GlobalRefreshLlmTokenListener(listener));
117 }
118
119 pub fn global(cx: &App) -> Entity<Self> {
120 GlobalRefreshLlmTokenListener::global(cx).0.clone()
121 }
122
123 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
124 client.add_message_to_client_handler({
125 let this = cx.entity();
126 move |message, cx| {
127 Self::handle_refresh_llm_token(this.clone(), message, cx);
128 }
129 });
130
131 let subscription = cx.subscribe(&user_store, |_this, _user_store, event, cx| {
132 if matches!(event, client::user::Event::OrganizationChanged) {
133 cx.emit(RefreshLlmTokenEvent);
134 }
135 });
136
137 Self {
138 _subscription: subscription,
139 }
140 }
141
142 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
143 match message {
144 MessageToClient::UserUpdated => {
145 this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
146 }
147 }
148 }
149}