1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::{Context as _, Result};
5use client::Client;
6use client::UserStore;
7use cloud_api_client::ClientApiError;
8use cloud_api_types::OrganizationId;
9use cloud_api_types::websocket_protocol::MessageToClient;
10use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
11use gpui::{
12 App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
13};
14use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
15use thiserror::Error;
16
17#[derive(Error, Debug)]
18pub struct PaymentRequiredError;
19
20impl fmt::Display for PaymentRequiredError {
21 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
22 write!(
23 f,
24 "Payment required to use this language model. Please upgrade your account."
25 )
26 }
27}
28
29#[derive(Clone, Default)]
30pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
31
32impl LlmApiToken {
33 pub fn global(cx: &App) -> Self {
34 RefreshLlmTokenListener::global(cx)
35 .read(cx)
36 .llm_api_token
37 .clone()
38 }
39
40 pub async fn acquire(
41 &self,
42 client: &Arc<Client>,
43 organization_id: Option<OrganizationId>,
44 ) -> Result<String> {
45 let lock = self.0.upgradable_read().await;
46 if let Some(token) = lock.as_ref() {
47 Ok(token.to_string())
48 } else {
49 Self::fetch(
50 RwLockUpgradableReadGuard::upgrade(lock).await,
51 client,
52 organization_id,
53 )
54 .await
55 }
56 }
57
58 pub async fn refresh(
59 &self,
60 client: &Arc<Client>,
61 organization_id: Option<OrganizationId>,
62 ) -> Result<String> {
63 Self::fetch(self.0.write().await, client, organization_id).await
64 }
65
66 async fn fetch(
67 mut lock: RwLockWriteGuard<'_, Option<String>>,
68 client: &Arc<Client>,
69 organization_id: Option<OrganizationId>,
70 ) -> Result<String> {
71 let system_id = client
72 .telemetry()
73 .system_id()
74 .map(|system_id| system_id.to_string());
75
76 let result = client
77 .cloud_client()
78 .create_llm_token(system_id, organization_id)
79 .await;
80 match result {
81 Ok(response) => {
82 *lock = Some(response.token.0.clone());
83 Ok(response.token.0)
84 }
85 Err(err) => match err {
86 ClientApiError::Unauthorized => {
87 client.request_sign_out();
88 Err(err).context("Failed to create LLM token")
89 }
90 ClientApiError::Other(err) => Err(err),
91 },
92 }
93 }
94}
95
96pub trait NeedsLlmTokenRefresh {
97 /// Returns whether the LLM token needs to be refreshed.
98 fn needs_llm_token_refresh(&self) -> bool;
99}
100
101impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
102 fn needs_llm_token_refresh(&self) -> bool {
103 self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
104 || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
105 }
106}
107
108struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
109
110impl Global for GlobalRefreshLlmTokenListener {}
111
112pub struct LlmTokenRefreshedEvent;
113
114pub struct RefreshLlmTokenListener {
115 client: Arc<Client>,
116 user_store: Entity<UserStore>,
117 llm_api_token: LlmApiToken,
118 _subscription: Subscription,
119}
120
121impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
122
123impl RefreshLlmTokenListener {
124 pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
125 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
126 cx.set_global(GlobalRefreshLlmTokenListener(listener));
127 }
128
129 pub fn global(cx: &App) -> Entity<Self> {
130 GlobalRefreshLlmTokenListener::global(cx).0.clone()
131 }
132
133 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
134 client.add_message_to_client_handler({
135 let this = cx.entity();
136 move |message, cx| {
137 Self::handle_refresh_llm_token(this.clone(), message, cx);
138 }
139 });
140
141 let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
142 if matches!(event, client::user::Event::OrganizationChanged) {
143 this.refresh(cx);
144 }
145 });
146
147 Self {
148 client,
149 user_store,
150 llm_api_token: LlmApiToken::default(),
151 _subscription: subscription,
152 }
153 }
154
155 fn refresh(&self, cx: &mut Context<Self>) {
156 let client = self.client.clone();
157 let llm_api_token = self.llm_api_token.clone();
158 let organization_id = self
159 .user_store
160 .read(cx)
161 .current_organization()
162 .map(|o| o.id.clone());
163 cx.spawn(async move |this, cx| {
164 llm_api_token.refresh(&client, organization_id).await?;
165 this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
166 })
167 .detach_and_log_err(cx);
168 }
169
170 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
171 match message {
172 MessageToClient::UserUpdated => {
173 this.update(cx, |this, cx| this.refresh(cx));
174 }
175 }
176 }
177}