cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::Result;
  5use client::Client;
  6use cloud_api_types::websocket_protocol::MessageToClient;
  7use cloud_llm_client::Plan;
  8use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
  9use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 10use thiserror::Error;
 11
 12#[derive(Error, Debug)]
 13pub struct PaymentRequiredError;
 14
 15impl fmt::Display for PaymentRequiredError {
 16    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 17        write!(
 18            f,
 19            "Payment required to use this language model. Please upgrade your account."
 20        )
 21    }
 22}
 23
 24#[derive(Error, Debug)]
 25pub struct ModelRequestLimitReachedError {
 26    pub plan: Plan,
 27}
 28
 29impl fmt::Display for ModelRequestLimitReachedError {
 30    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 31        let message = match self.plan {
 32            Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.",
 33            Plan::ZedPro => {
 34                "Model request limit reached. Upgrade to usage-based billing for more requests."
 35            }
 36            Plan::ZedProTrial => {
 37                "Model request limit reached. Upgrade to Zed Pro for more requests."
 38            }
 39            Plan::ZedProV2 | Plan::ZedProTrialV2 => "Model request limit reached.",
 40        };
 41
 42        write!(f, "{message}")
 43    }
 44}
 45
 46#[derive(Error, Debug)]
 47pub struct ToolUseLimitReachedError;
 48
 49impl fmt::Display for ToolUseLimitReachedError {
 50    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 51        write!(
 52            f,
 53            "Consecutive tool use limit reached. Enable Burn Mode for unlimited tool use."
 54        )
 55    }
 56}
 57
 58#[derive(Clone, Default)]
 59pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 60
 61impl LlmApiToken {
 62    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
 63        let lock = self.0.upgradable_read().await;
 64        if let Some(token) = lock.as_ref() {
 65            Ok(token.to_string())
 66        } else {
 67            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
 68        }
 69    }
 70
 71    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
 72        Self::fetch(self.0.write().await, client).await
 73    }
 74
 75    async fn fetch(
 76        mut lock: RwLockWriteGuard<'_, Option<String>>,
 77        client: &Arc<Client>,
 78    ) -> Result<String> {
 79        let system_id = client
 80            .telemetry()
 81            .system_id()
 82            .map(|system_id| system_id.to_string());
 83
 84        let response = client.cloud_client().create_llm_token(system_id).await?;
 85        *lock = Some(response.token.0.clone());
 86        Ok(response.token.0)
 87    }
 88}
 89
 90struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 91
 92impl Global for GlobalRefreshLlmTokenListener {}
 93
 94pub struct RefreshLlmTokenEvent;
 95
 96pub struct RefreshLlmTokenListener;
 97
 98impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
 99
100impl RefreshLlmTokenListener {
101    pub fn register(client: Arc<Client>, cx: &mut App) {
102        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
103        cx.set_global(GlobalRefreshLlmTokenListener(listener));
104    }
105
106    pub fn global(cx: &App) -> Entity<Self> {
107        GlobalRefreshLlmTokenListener::global(cx).0.clone()
108    }
109
110    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
111        client.add_message_to_client_handler({
112            let this = cx.entity();
113            move |message, cx| {
114                Self::handle_refresh_llm_token(this.clone(), message, cx);
115            }
116        });
117
118        Self
119    }
120
121    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
122        match message {
123            MessageToClient::UserUpdated => {
124                this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
125            }
126        }
127    }
128}