cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::Result;
  5use client::Client;
  6use gpui::{
  7    App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
  8};
  9use proto::{Plan, TypedEnvelope};
 10use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 11use thiserror::Error;
 12
 13#[derive(Error, Debug)]
 14pub struct PaymentRequiredError;
 15
 16impl fmt::Display for PaymentRequiredError {
 17    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 18        write!(
 19            f,
 20            "Payment required to use this language model. Please upgrade your account."
 21        )
 22    }
 23}
 24
 25#[derive(Error, Debug)]
 26pub struct ModelRequestLimitReachedError {
 27    pub plan: Plan,
 28}
 29
 30impl fmt::Display for ModelRequestLimitReachedError {
 31    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 32        let message = match self.plan {
 33            Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
 34            Plan::ZedPro => {
 35                "Model request limit reached. Upgrade to usage-based billing for more requests."
 36            }
 37            Plan::ZedProTrial => {
 38                "Model request limit reached. Upgrade to Zed Pro for more requests."
 39            }
 40        };
 41
 42        write!(f, "{message}")
 43    }
 44}
 45
 46#[derive(Clone, Default)]
 47pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 48
 49impl LlmApiToken {
 50    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
 51        let lock = self.0.upgradable_read().await;
 52        if let Some(token) = lock.as_ref() {
 53            Ok(token.to_string())
 54        } else {
 55            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
 56        }
 57    }
 58
 59    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
 60        Self::fetch(self.0.write().await, client).await
 61    }
 62
 63    async fn fetch(
 64        mut lock: RwLockWriteGuard<'_, Option<String>>,
 65        client: &Arc<Client>,
 66    ) -> Result<String> {
 67        let response = client.request(proto::GetLlmToken {}).await?;
 68        *lock = Some(response.token.clone());
 69        Ok(response.token.clone())
 70    }
 71}
 72
 73struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 74
 75impl Global for GlobalRefreshLlmTokenListener {}
 76
 77pub struct RefreshLlmTokenEvent;
 78
 79pub struct RefreshLlmTokenListener {
 80    _llm_token_subscription: client::Subscription,
 81}
 82
 83impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
 84
 85impl RefreshLlmTokenListener {
 86    pub fn register(client: Arc<Client>, cx: &mut App) {
 87        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
 88        cx.set_global(GlobalRefreshLlmTokenListener(listener));
 89    }
 90
 91    pub fn global(cx: &App) -> Entity<Self> {
 92        GlobalRefreshLlmTokenListener::global(cx).0.clone()
 93    }
 94
 95    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
 96        Self {
 97            _llm_token_subscription: client
 98                .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
 99        }
100    }
101
102    async fn handle_refresh_llm_token(
103        this: Entity<Self>,
104        _: TypedEnvelope<proto::RefreshLlmToken>,
105        mut cx: AsyncApp,
106    ) -> Result<()> {
107        this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
108    }
109}