1use std::fmt;
2use std::sync::Arc;
3
4use cloud_api_client::ClientApiError;
5use cloud_api_client::CloudApiClient;
6use cloud_api_types::OrganizationId;
7use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
8use thiserror::Error;
9
10#[derive(Error, Debug)]
11pub struct PaymentRequiredError;
12
13impl fmt::Display for PaymentRequiredError {
14 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
15 write!(
16 f,
17 "Payment required to use this language model. Please upgrade your account."
18 )
19 }
20}
21
22#[derive(Clone, Default)]
23pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
24
25impl LlmApiToken {
26 pub async fn acquire(
27 &self,
28 client: &CloudApiClient,
29 system_id: Option<String>,
30 organization_id: Option<OrganizationId>,
31 ) -> Result<String, ClientApiError> {
32 let lock = self.0.upgradable_read().await;
33 if let Some(token) = lock.as_ref() {
34 Ok(token.to_string())
35 } else {
36 Self::fetch(
37 RwLockUpgradableReadGuard::upgrade(lock).await,
38 client,
39 system_id,
40 organization_id,
41 )
42 .await
43 }
44 }
45
46 pub async fn refresh(
47 &self,
48 client: &CloudApiClient,
49 system_id: Option<String>,
50 organization_id: Option<OrganizationId>,
51 ) -> Result<String, ClientApiError> {
52 Self::fetch(self.0.write().await, client, system_id, organization_id).await
53 }
54
55 /// Clears the existing token before attempting to fetch a new one.
56 ///
57 /// Used when switching organizations so that a failed refresh doesn't
58 /// leave a token for the wrong organization.
59 pub async fn clear_and_refresh(
60 &self,
61 client: &CloudApiClient,
62 system_id: Option<String>,
63 organization_id: Option<OrganizationId>,
64 ) -> Result<String, ClientApiError> {
65 let mut lock = self.0.write().await;
66 *lock = None;
67 Self::fetch(lock, client, system_id, organization_id).await
68 }
69
70 async fn fetch(
71 mut lock: RwLockWriteGuard<'_, Option<String>>,
72 client: &CloudApiClient,
73 system_id: Option<String>,
74 organization_id: Option<OrganizationId>,
75 ) -> Result<String, ClientApiError> {
76 let result = client.create_llm_token(system_id, organization_id).await;
77 match result {
78 Ok(response) => {
79 *lock = Some(response.token.0.clone());
80 Ok(response.token.0)
81 }
82 Err(err) => {
83 *lock = None;
84 Err(err)
85 }
86 }
87 }
88}