cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::Result;
  5use client::Client;
  6use gpui::{
  7    App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
  8};
  9use proto::{Plan, TypedEnvelope};
 10use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 11use thiserror::Error;
 12
 13#[derive(Error, Debug)]
 14pub struct PaymentRequiredError;
 15
 16impl fmt::Display for PaymentRequiredError {
 17    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 18        write!(
 19            f,
 20            "Payment required to use this language model. Please upgrade your account."
 21        )
 22    }
 23}
 24
 25#[derive(Error, Debug)]
 26pub struct ModelRequestLimitReachedError {
 27    pub plan: Plan,
 28}
 29
 30impl fmt::Display for ModelRequestLimitReachedError {
 31    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 32        let message = match self.plan {
 33            Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
 34            Plan::ZedPro => {
 35                "Model request limit reached. Upgrade to usage-based billing for more requests."
 36            }
 37            Plan::ZedProTrial => {
 38                "Model request limit reached. Upgrade to Zed Pro for more requests."
 39            }
 40        };
 41
 42        write!(f, "{message}")
 43    }
 44}
 45
 46#[derive(Clone, Default)]
 47pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 48
 49impl LlmApiToken {
 50    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
 51        let lock = self.0.upgradable_read().await;
 52        if let Some(token) = lock.as_ref() {
 53            Ok(token.to_string())
 54        } else {
 55            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
 56        }
 57    }
 58
 59    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
 60        Self::fetch(self.0.write().await, client).await
 61    }
 62
 63    async fn fetch(
 64        mut lock: RwLockWriteGuard<'_, Option<String>>,
 65        client: &Arc<Client>,
 66    ) -> Result<String> {
 67        let system_id = client
 68            .telemetry()
 69            .system_id()
 70            .map(|system_id| system_id.to_string());
 71
 72        let response = client.cloud_client().create_llm_token(system_id).await?;
 73        *lock = Some(response.token.0.clone());
 74        Ok(response.token.0.clone())
 75    }
 76}
 77
 78struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 79
 80impl Global for GlobalRefreshLlmTokenListener {}
 81
 82pub struct RefreshLlmTokenEvent;
 83
 84pub struct RefreshLlmTokenListener {
 85    _llm_token_subscription: client::Subscription,
 86}
 87
 88impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
 89
 90impl RefreshLlmTokenListener {
 91    pub fn register(client: Arc<Client>, cx: &mut App) {
 92        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
 93        cx.set_global(GlobalRefreshLlmTokenListener(listener));
 94    }
 95
 96    pub fn global(cx: &App) -> Entity<Self> {
 97        GlobalRefreshLlmTokenListener::global(cx).0.clone()
 98    }
 99
100    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
101        Self {
102            _llm_token_subscription: client
103                .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
104        }
105    }
106
107    async fn handle_refresh_llm_token(
108        this: Entity<Self>,
109        _: TypedEnvelope<proto::RefreshLlmToken>,
110        mut cx: AsyncApp,
111    ) -> Result<()> {
112        this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
113    }
114}