cloud_model.rs

 1use std::fmt;
 2use std::sync::Arc;
 3
 4use cloud_api_client::ClientApiError;
 5use cloud_api_client::CloudApiClient;
 6use cloud_api_types::OrganizationId;
 7use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 8use thiserror::Error;
 9
10#[derive(Error, Debug)]
11pub struct PaymentRequiredError;
12
13impl fmt::Display for PaymentRequiredError {
14    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
15        write!(
16            f,
17            "Payment required to use this language model. Please upgrade your account."
18        )
19    }
20}
21
22#[derive(Clone, Default)]
23pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
24
25impl LlmApiToken {
26    pub async fn acquire(
27        &self,
28        client: &CloudApiClient,
29        system_id: Option<String>,
30        organization_id: Option<OrganizationId>,
31    ) -> Result<String, ClientApiError> {
32        let lock = self.0.upgradable_read().await;
33        if let Some(token) = lock.as_ref() {
34            Ok(token.to_string())
35        } else {
36            Self::fetch(
37                RwLockUpgradableReadGuard::upgrade(lock).await,
38                client,
39                system_id,
40                organization_id,
41            )
42            .await
43        }
44    }
45
46    pub async fn refresh(
47        &self,
48        client: &CloudApiClient,
49        system_id: Option<String>,
50        organization_id: Option<OrganizationId>,
51    ) -> Result<String, ClientApiError> {
52        Self::fetch(self.0.write().await, client, system_id, organization_id).await
53    }
54
55    /// Clears the existing token before attempting to fetch a new one.
56    ///
57    /// Used when switching organizations so that a failed refresh doesn't
58    /// leave a token for the wrong organization.
59    pub async fn clear_and_refresh(
60        &self,
61        client: &CloudApiClient,
62        system_id: Option<String>,
63        organization_id: Option<OrganizationId>,
64    ) -> Result<String, ClientApiError> {
65        let mut lock = self.0.write().await;
66        *lock = None;
67        Self::fetch(lock, client, system_id, organization_id).await
68    }
69
70    async fn fetch(
71        mut lock: RwLockWriteGuard<'_, Option<String>>,
72        client: &CloudApiClient,
73        system_id: Option<String>,
74        organization_id: Option<OrganizationId>,
75    ) -> Result<String, ClientApiError> {
76        let result = client.create_llm_token(system_id, organization_id).await;
77        match result {
78            Ok(response) => {
79                *lock = Some(response.token.0.clone());
80                Ok(response.token.0)
81            }
82            Err(err) => {
83                *lock = None;
84                Err(err)
85            }
86        }
87    }
88}