1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::{Context as _, Result};
5use client::Client;
6use client::UserStore;
7use cloud_api_client::ClientApiError;
8use cloud_api_types::OrganizationId;
9use cloud_api_types::websocket_protocol::MessageToClient;
10use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
11use gpui::{
12 App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
13};
14use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
15use thiserror::Error;
16
17#[derive(Error, Debug)]
18pub struct PaymentRequiredError;
19
20impl fmt::Display for PaymentRequiredError {
21 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
22 write!(
23 f,
24 "Payment required to use this language model. Please upgrade your account."
25 )
26 }
27}
28
29#[derive(Clone, Default)]
30pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
31
32impl LlmApiToken {
33 pub fn global(cx: &App) -> Self {
34 RefreshLlmTokenListener::global(cx)
35 .read(cx)
36 .llm_api_token
37 .clone()
38 }
39
40 pub async fn acquire(
41 &self,
42 client: &Arc<Client>,
43 organization_id: Option<OrganizationId>,
44 ) -> Result<String> {
45 let lock = self.0.upgradable_read().await;
46 if let Some(token) = lock.as_ref() {
47 Ok(token.to_string())
48 } else {
49 Self::fetch(
50 RwLockUpgradableReadGuard::upgrade(lock).await,
51 client,
52 organization_id,
53 )
54 .await
55 }
56 }
57
58 pub async fn refresh(
59 &self,
60 client: &Arc<Client>,
61 organization_id: Option<OrganizationId>,
62 ) -> Result<String> {
63 Self::fetch(self.0.write().await, client, organization_id).await
64 }
65
66 /// Clears the existing token before attempting to fetch a new one.
67 ///
68 /// Used when switching organizations so that a failed refresh doesn't
69 /// leave a token for the wrong organization.
70 pub async fn clear_and_refresh(
71 &self,
72 client: &Arc<Client>,
73 organization_id: Option<OrganizationId>,
74 ) -> Result<String> {
75 let mut lock = self.0.write().await;
76 *lock = None;
77 Self::fetch(lock, client, organization_id).await
78 }
79
80 async fn fetch(
81 mut lock: RwLockWriteGuard<'_, Option<String>>,
82 client: &Arc<Client>,
83 organization_id: Option<OrganizationId>,
84 ) -> Result<String> {
85 let system_id = client
86 .telemetry()
87 .system_id()
88 .map(|system_id| system_id.to_string());
89
90 let result = client
91 .cloud_client()
92 .create_llm_token(system_id, organization_id)
93 .await;
94 match result {
95 Ok(response) => {
96 *lock = Some(response.token.0.clone());
97 Ok(response.token.0)
98 }
99 Err(err) => {
100 *lock = None;
101 match err {
102 ClientApiError::Unauthorized => {
103 client.request_sign_out();
104 Err(err).context("Failed to create LLM token")
105 }
106 ClientApiError::Other(err) => Err(err),
107 }
108 }
109 }
110 }
111}
112
113pub trait NeedsLlmTokenRefresh {
114 /// Returns whether the LLM token needs to be refreshed.
115 fn needs_llm_token_refresh(&self) -> bool;
116}
117
118impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
119 fn needs_llm_token_refresh(&self) -> bool {
120 self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
121 || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
122 }
123}
124
125enum TokenRefreshMode {
126 Refresh,
127 ClearAndRefresh,
128}
129
130struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
131
132impl Global for GlobalRefreshLlmTokenListener {}
133
134pub struct LlmTokenRefreshedEvent;
135
136pub struct RefreshLlmTokenListener {
137 client: Arc<Client>,
138 user_store: Entity<UserStore>,
139 llm_api_token: LlmApiToken,
140 _subscription: Subscription,
141}
142
143impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
144
145impl RefreshLlmTokenListener {
146 pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
147 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
148 cx.set_global(GlobalRefreshLlmTokenListener(listener));
149 }
150
151 pub fn global(cx: &App) -> Entity<Self> {
152 GlobalRefreshLlmTokenListener::global(cx).0.clone()
153 }
154
155 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
156 client.add_message_to_client_handler({
157 let this = cx.weak_entity();
158 move |message, cx| {
159 if let Some(this) = this.upgrade() {
160 Self::handle_refresh_llm_token(this, message, cx);
161 }
162 }
163 });
164
165 let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
166 if matches!(event, client::user::Event::OrganizationChanged) {
167 this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
168 }
169 });
170
171 Self {
172 client,
173 user_store,
174 llm_api_token: LlmApiToken::default(),
175 _subscription: subscription,
176 }
177 }
178
179 fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
180 let client = self.client.clone();
181 let llm_api_token = self.llm_api_token.clone();
182 let organization_id = self
183 .user_store
184 .read(cx)
185 .current_organization()
186 .map(|organization| organization.id.clone());
187 cx.spawn(async move |this, cx| {
188 match mode {
189 TokenRefreshMode::Refresh => {
190 llm_api_token.refresh(&client, organization_id).await?;
191 }
192 TokenRefreshMode::ClearAndRefresh => {
193 llm_api_token
194 .clear_and_refresh(&client, organization_id)
195 .await?;
196 }
197 }
198 this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
199 })
200 .detach_and_log_err(cx);
201 }
202
203 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
204 match message {
205 MessageToClient::UserUpdated => {
206 this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
207 }
208 }
209 }
210}