cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::Result;
  5use client::Client;
  6use cloud_api_types::websocket_protocol::MessageToClient;
  7use cloud_llm_client::{Plan, PlanV1};
  8use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
  9use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 10use thiserror::Error;
 11
 12#[derive(Error, Debug)]
 13pub struct PaymentRequiredError;
 14
 15impl fmt::Display for PaymentRequiredError {
 16    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 17        write!(
 18            f,
 19            "Payment required to use this language model. Please upgrade your account."
 20        )
 21    }
 22}
 23
 24#[derive(Error, Debug)]
 25pub struct ModelRequestLimitReachedError {
 26    pub plan: Plan,
 27}
 28
 29impl fmt::Display for ModelRequestLimitReachedError {
 30    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 31        let message = match self.plan {
 32            Plan::V1(PlanV1::ZedFree) => {
 33                "Model request limit reached. Upgrade to Zed Pro for more requests."
 34            }
 35            Plan::V1(PlanV1::ZedPro) => {
 36                "Model request limit reached. Upgrade to usage-based billing for more requests."
 37            }
 38            Plan::V1(PlanV1::ZedProTrial) => {
 39                "Model request limit reached. Upgrade to Zed Pro for more requests."
 40            }
 41            Plan::V2(_) => "Model request limit reached.",
 42        };
 43
 44        write!(f, "{message}")
 45    }
 46}
 47
 48#[derive(Error, Debug)]
 49pub struct ToolUseLimitReachedError;
 50
 51impl fmt::Display for ToolUseLimitReachedError {
 52    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 53        write!(
 54            f,
 55            "Consecutive tool use limit reached. Enable Burn Mode for unlimited tool use."
 56        )
 57    }
 58}
 59
 60#[derive(Clone, Default)]
 61pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 62
 63impl LlmApiToken {
 64    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
 65        let lock = self.0.upgradable_read().await;
 66        if let Some(token) = lock.as_ref() {
 67            Ok(token.to_string())
 68        } else {
 69            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
 70        }
 71    }
 72
 73    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
 74        Self::fetch(self.0.write().await, client).await
 75    }
 76
 77    async fn fetch(
 78        mut lock: RwLockWriteGuard<'_, Option<String>>,
 79        client: &Arc<Client>,
 80    ) -> Result<String> {
 81        let system_id = client
 82            .telemetry()
 83            .system_id()
 84            .map(|system_id| system_id.to_string());
 85
 86        let response = client.cloud_client().create_llm_token(system_id).await?;
 87        *lock = Some(response.token.0.clone());
 88        Ok(response.token.0)
 89    }
 90}
 91
 92struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 93
 94impl Global for GlobalRefreshLlmTokenListener {}
 95
 96pub struct RefreshLlmTokenEvent;
 97
 98pub struct RefreshLlmTokenListener;
 99
100impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
101
102impl RefreshLlmTokenListener {
103    pub fn register(client: Arc<Client>, cx: &mut App) {
104        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
105        cx.set_global(GlobalRefreshLlmTokenListener(listener));
106    }
107
108    pub fn global(cx: &App) -> Entity<Self> {
109        GlobalRefreshLlmTokenListener::global(cx).0.clone()
110    }
111
112    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
113        client.add_message_to_client_handler({
114            let this = cx.entity();
115            move |message, cx| {
116                Self::handle_refresh_llm_token(this.clone(), message, cx);
117            }
118        });
119
120        Self
121    }
122
123    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
124        match message {
125            MessageToClient::UserUpdated => {
126                this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
127            }
128        }
129    }
130}