cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::Result;
  5use client::Client;
  6use gpui::{
  7    App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
  8};
  9use proto::{Plan, TypedEnvelope};
 10use schemars::JsonSchema;
 11use serde::{Deserialize, Serialize};
 12use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 13use strum::EnumIter;
 14use thiserror::Error;
 15
 16use crate::LanguageModelToolSchemaFormat;
 17
 18#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 19#[serde(tag = "provider", rename_all = "lowercase")]
 20pub enum CloudModel {
 21    Anthropic(anthropic::Model),
 22    OpenAi(open_ai::Model),
 23    Google(google_ai::Model),
 24}
 25
 26#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
 27pub enum ZedModel {
 28    #[serde(rename = "Qwen/Qwen2-7B-Instruct")]
 29    Qwen2_7bInstruct,
 30}
 31
 32impl Default for CloudModel {
 33    fn default() -> Self {
 34        Self::Anthropic(anthropic::Model::default())
 35    }
 36}
 37
 38impl CloudModel {
 39    pub fn id(&self) -> &str {
 40        match self {
 41            Self::Anthropic(model) => model.id(),
 42            Self::OpenAi(model) => model.id(),
 43            Self::Google(model) => model.id(),
 44        }
 45    }
 46
 47    pub fn display_name(&self) -> &str {
 48        match self {
 49            Self::Anthropic(model) => model.display_name(),
 50            Self::OpenAi(model) => model.display_name(),
 51            Self::Google(model) => model.display_name(),
 52        }
 53    }
 54
 55    pub fn max_token_count(&self) -> usize {
 56        match self {
 57            Self::Anthropic(model) => model.max_token_count(),
 58            Self::OpenAi(model) => model.max_token_count(),
 59            Self::Google(model) => model.max_token_count(),
 60        }
 61    }
 62
 63    pub fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 64        match self {
 65            Self::Anthropic(_) | Self::OpenAi(_) => LanguageModelToolSchemaFormat::JsonSchema,
 66            Self::Google(_) => LanguageModelToolSchemaFormat::JsonSchemaSubset,
 67        }
 68    }
 69}
 70
 71#[derive(Error, Debug)]
 72pub struct PaymentRequiredError;
 73
 74impl fmt::Display for PaymentRequiredError {
 75    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 76        write!(
 77            f,
 78            "Payment required to use this language model. Please upgrade your account."
 79        )
 80    }
 81}
 82
 83#[derive(Error, Debug)]
 84pub struct ModelRequestLimitReachedError {
 85    pub plan: Plan,
 86}
 87
 88impl fmt::Display for ModelRequestLimitReachedError {
 89    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 90        let message = match self.plan {
 91            Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
 92            Plan::ZedPro => {
 93                "Model request limit reached. Upgrade to usage-based billing for more requests."
 94            }
 95            Plan::ZedProTrial => {
 96                "Model request limit reached. Upgrade to Zed Pro for more requests."
 97            }
 98        };
 99
100        write!(f, "{message}")
101    }
102}
103
104#[derive(Clone, Default)]
105pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
106
107impl LlmApiToken {
108    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
109        let lock = self.0.upgradable_read().await;
110        if let Some(token) = lock.as_ref() {
111            Ok(token.to_string())
112        } else {
113            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
114        }
115    }
116
117    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
118        Self::fetch(self.0.write().await, client).await
119    }
120
121    async fn fetch(
122        mut lock: RwLockWriteGuard<'_, Option<String>>,
123        client: &Arc<Client>,
124    ) -> Result<String> {
125        let response = client.request(proto::GetLlmToken {}).await?;
126        *lock = Some(response.token.clone());
127        Ok(response.token.clone())
128    }
129}
130
131struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
132
133impl Global for GlobalRefreshLlmTokenListener {}
134
135pub struct RefreshLlmTokenEvent;
136
137pub struct RefreshLlmTokenListener {
138    _llm_token_subscription: client::Subscription,
139}
140
141impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
142
143impl RefreshLlmTokenListener {
144    pub fn register(client: Arc<Client>, cx: &mut App) {
145        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
146        cx.set_global(GlobalRefreshLlmTokenListener(listener));
147    }
148
149    pub fn global(cx: &App) -> Entity<Self> {
150        GlobalRefreshLlmTokenListener::global(cx).0.clone()
151    }
152
153    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
154        Self {
155            _llm_token_subscription: client
156                .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
157        }
158    }
159
160    async fn handle_refresh_llm_token(
161        this: Entity<Self>,
162        _: TypedEnvelope<proto::RefreshLlmToken>,
163        mut cx: AsyncApp,
164    ) -> Result<()> {
165        this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
166    }
167}