1use std::sync::Arc;
2
3use cloud_api_types::OrganizationId;
4use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
5
6use crate::{ClientApiError, CloudApiClient};
7
8#[derive(Clone, Default)]
9pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
10
11impl LlmApiToken {
12 pub async fn acquire(
13 &self,
14 client: &CloudApiClient,
15 system_id: Option<String>,
16 organization_id: Option<OrganizationId>,
17 ) -> Result<String, ClientApiError> {
18 let lock = self.0.upgradable_read().await;
19 if let Some(token) = lock.as_ref() {
20 Ok(token.to_string())
21 } else {
22 Self::fetch(
23 RwLockUpgradableReadGuard::upgrade(lock).await,
24 client,
25 system_id,
26 organization_id,
27 )
28 .await
29 }
30 }
31
32 pub async fn refresh(
33 &self,
34 client: &CloudApiClient,
35 system_id: Option<String>,
36 organization_id: Option<OrganizationId>,
37 ) -> Result<String, ClientApiError> {
38 Self::fetch(self.0.write().await, client, system_id, organization_id).await
39 }
40
41 /// Clears the existing token before attempting to fetch a new one.
42 ///
43 /// Used when switching organizations so that a failed refresh doesn't
44 /// leave a token for the wrong organization.
45 pub async fn clear_and_refresh(
46 &self,
47 client: &CloudApiClient,
48 system_id: Option<String>,
49 organization_id: Option<OrganizationId>,
50 ) -> Result<String, ClientApiError> {
51 let mut lock = self.0.write().await;
52 *lock = None;
53 Self::fetch(lock, client, system_id, organization_id).await
54 }
55
56 async fn fetch(
57 mut lock: RwLockWriteGuard<'_, Option<String>>,
58 client: &CloudApiClient,
59 system_id: Option<String>,
60 organization_id: Option<OrganizationId>,
61 ) -> Result<String, ClientApiError> {
62 let result = client.create_llm_token(system_id, organization_id).await;
63 match result {
64 Ok(response) => {
65 *lock = Some(response.token.0.clone());
66 Ok(response.token.0)
67 }
68 Err(err) => {
69 *lock = None;
70 Err(err)
71 }
72 }
73 }
74}