cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::Result;
  5use client::Client;
  6use gpui::{
  7    App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
  8};
  9use proto::{Plan, TypedEnvelope};
 10use schemars::JsonSchema;
 11use serde::{Deserialize, Serialize};
 12use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 13use strum::EnumIter;
 14use thiserror::Error;
 15
 16use crate::{LanguageModelAvailability, LanguageModelToolSchemaFormat};
 17
 18#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 19#[serde(tag = "provider", rename_all = "lowercase")]
 20pub enum CloudModel {
 21    Anthropic(anthropic::Model),
 22    OpenAi(open_ai::Model),
 23    Google(google_ai::Model),
 24}
 25
 26#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
 27pub enum ZedModel {
 28    #[serde(rename = "Qwen/Qwen2-7B-Instruct")]
 29    Qwen2_7bInstruct,
 30}
 31
 32impl Default for CloudModel {
 33    fn default() -> Self {
 34        Self::Anthropic(anthropic::Model::default())
 35    }
 36}
 37
 38impl CloudModel {
 39    pub fn id(&self) -> &str {
 40        match self {
 41            Self::Anthropic(model) => model.id(),
 42            Self::OpenAi(model) => model.id(),
 43            Self::Google(model) => model.id(),
 44        }
 45    }
 46
 47    pub fn display_name(&self) -> &str {
 48        match self {
 49            Self::Anthropic(model) => model.display_name(),
 50            Self::OpenAi(model) => model.display_name(),
 51            Self::Google(model) => model.display_name(),
 52        }
 53    }
 54
 55    pub fn max_token_count(&self) -> usize {
 56        match self {
 57            Self::Anthropic(model) => model.max_token_count(),
 58            Self::OpenAi(model) => model.max_token_count(),
 59            Self::Google(model) => model.max_token_count(),
 60        }
 61    }
 62
 63    /// Returns the availability of this model.
 64    pub fn availability(&self) -> LanguageModelAvailability {
 65        match self {
 66            Self::Anthropic(model) => match model {
 67                anthropic::Model::Claude3_5Sonnet
 68                | anthropic::Model::Claude3_7Sonnet
 69                | anthropic::Model::Claude3_7SonnetThinking => {
 70                    LanguageModelAvailability::RequiresPlan(Plan::Free)
 71                }
 72                anthropic::Model::Claude3Opus
 73                | anthropic::Model::Claude3Sonnet
 74                | anthropic::Model::Claude3Haiku
 75                | anthropic::Model::Claude3_5Haiku
 76                | anthropic::Model::Custom { .. } => {
 77                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
 78                }
 79            },
 80            Self::OpenAi(model) => match model {
 81                open_ai::Model::ThreePointFiveTurbo
 82                | open_ai::Model::Four
 83                | open_ai::Model::FourTurbo
 84                | open_ai::Model::FourOmni
 85                | open_ai::Model::FourOmniMini
 86                | open_ai::Model::FourPointOne
 87                | open_ai::Model::FourPointOneMini
 88                | open_ai::Model::FourPointOneNano
 89                | open_ai::Model::O1Mini
 90                | open_ai::Model::O1Preview
 91                | open_ai::Model::O1
 92                | open_ai::Model::O3Mini
 93                | open_ai::Model::O3
 94                | open_ai::Model::O4Mini
 95                | open_ai::Model::Custom { .. } => {
 96                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
 97                }
 98            },
 99            Self::Google(model) => match model {
100                google_ai::Model::Gemini15Pro
101                | google_ai::Model::Gemini15Flash
102                | google_ai::Model::Gemini20Pro
103                | google_ai::Model::Gemini20Flash
104                | google_ai::Model::Gemini20FlashThinking
105                | google_ai::Model::Gemini20FlashLite
106                | google_ai::Model::Gemini25ProExp0325
107                | google_ai::Model::Gemini25ProPreview0325
108                | google_ai::Model::Custom { .. } => {
109                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
110                }
111            },
112        }
113    }
114
115    pub fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
116        match self {
117            Self::Anthropic(_) | Self::OpenAi(_) => LanguageModelToolSchemaFormat::JsonSchema,
118            Self::Google(_) => LanguageModelToolSchemaFormat::JsonSchemaSubset,
119        }
120    }
121}
122
123#[derive(Error, Debug)]
124pub struct PaymentRequiredError;
125
126impl fmt::Display for PaymentRequiredError {
127    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
128        write!(
129            f,
130            "Payment required to use this language model. Please upgrade your account."
131        )
132    }
133}
134
135#[derive(Error, Debug)]
136pub struct MaxMonthlySpendReachedError;
137
138impl fmt::Display for MaxMonthlySpendReachedError {
139    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
140        write!(
141            f,
142            "Maximum spending limit reached for this month. For more usage, increase your spending limit."
143        )
144    }
145}
146
147#[derive(Error, Debug)]
148pub struct ModelRequestLimitReachedError {
149    pub plan: Plan,
150}
151
152impl fmt::Display for ModelRequestLimitReachedError {
153    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
154        let message = match self.plan {
155            Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
156            Plan::ZedPro => {
157                "Model request limit reached. Upgrade to usage-based billing for more requests."
158            }
159            Plan::ZedProTrial => {
160                "Model request limit reached. Upgrade to Zed Pro for more requests."
161            }
162        };
163
164        write!(f, "{message}")
165    }
166}
167
168#[derive(Clone, Default)]
169pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
170
171impl LlmApiToken {
172    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
173        let lock = self.0.upgradable_read().await;
174        if let Some(token) = lock.as_ref() {
175            Ok(token.to_string())
176        } else {
177            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
178        }
179    }
180
181    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
182        Self::fetch(self.0.write().await, client).await
183    }
184
185    async fn fetch(
186        mut lock: RwLockWriteGuard<'_, Option<String>>,
187        client: &Arc<Client>,
188    ) -> Result<String> {
189        let response = client.request(proto::GetLlmToken {}).await?;
190        *lock = Some(response.token.clone());
191        Ok(response.token.clone())
192    }
193}
194
195struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
196
197impl Global for GlobalRefreshLlmTokenListener {}
198
199pub struct RefreshLlmTokenEvent;
200
201pub struct RefreshLlmTokenListener {
202    _llm_token_subscription: client::Subscription,
203}
204
205impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
206
207impl RefreshLlmTokenListener {
208    pub fn register(client: Arc<Client>, cx: &mut App) {
209        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
210        cx.set_global(GlobalRefreshLlmTokenListener(listener));
211    }
212
213    pub fn global(cx: &App) -> Entity<Self> {
214        GlobalRefreshLlmTokenListener::global(cx).0.clone()
215    }
216
217    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
218        Self {
219            _llm_token_subscription: client
220                .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
221        }
222    }
223
224    async fn handle_refresh_llm_token(
225        this: Entity<Self>,
226        _: TypedEnvelope<proto::RefreshLlmToken>,
227        mut cx: AsyncApp,
228    ) -> Result<()> {
229        this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
230    }
231}