cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::Result;
  5use client::Client;
  6use cloud_llm_client::Plan;
  7use gpui::{
  8    App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
  9};
 10use proto::TypedEnvelope;
 11use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 12use thiserror::Error;
 13
 14#[derive(Error, Debug)]
 15pub struct PaymentRequiredError;
 16
 17impl fmt::Display for PaymentRequiredError {
 18    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 19        write!(
 20            f,
 21            "Payment required to use this language model. Please upgrade your account."
 22        )
 23    }
 24}
 25
 26#[derive(Error, Debug)]
 27pub struct ModelRequestLimitReachedError {
 28    pub plan: Plan,
 29}
 30
 31impl fmt::Display for ModelRequestLimitReachedError {
 32    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 33        let message = match self.plan {
 34            Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.",
 35            Plan::ZedPro => {
 36                "Model request limit reached. Upgrade to usage-based billing for more requests."
 37            }
 38            Plan::ZedProTrial => {
 39                "Model request limit reached. Upgrade to Zed Pro for more requests."
 40            }
 41        };
 42
 43        write!(f, "{message}")
 44    }
 45}
 46
 47#[derive(Clone, Default)]
 48pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 49
 50impl LlmApiToken {
 51    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
 52        let lock = self.0.upgradable_read().await;
 53        if let Some(token) = lock.as_ref() {
 54            Ok(token.to_string())
 55        } else {
 56            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
 57        }
 58    }
 59
 60    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
 61        Self::fetch(self.0.write().await, client).await
 62    }
 63
 64    async fn fetch(
 65        mut lock: RwLockWriteGuard<'_, Option<String>>,
 66        client: &Arc<Client>,
 67    ) -> Result<String> {
 68        let system_id = client
 69            .telemetry()
 70            .system_id()
 71            .map(|system_id| system_id.to_string());
 72
 73        let response = client.cloud_client().create_llm_token(system_id).await?;
 74        *lock = Some(response.token.0.clone());
 75        Ok(response.token.0.clone())
 76    }
 77}
 78
 79struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 80
 81impl Global for GlobalRefreshLlmTokenListener {}
 82
 83pub struct RefreshLlmTokenEvent;
 84
 85pub struct RefreshLlmTokenListener {
 86    _llm_token_subscription: client::Subscription,
 87}
 88
 89impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
 90
 91impl RefreshLlmTokenListener {
 92    pub fn register(client: Arc<Client>, cx: &mut App) {
 93        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
 94        cx.set_global(GlobalRefreshLlmTokenListener(listener));
 95    }
 96
 97    pub fn global(cx: &App) -> Entity<Self> {
 98        GlobalRefreshLlmTokenListener::global(cx).0.clone()
 99    }
100
101    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
102        Self {
103            _llm_token_subscription: client
104                .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
105        }
106    }
107
108    async fn handle_refresh_llm_token(
109        this: Entity<Self>,
110        _: TypedEnvelope<proto::RefreshLlmToken>,
111        mut cx: AsyncApp,
112    ) -> Result<()> {
113        this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
114    }
115}