cloud_model.rs

  1use std::fmt;
  2use std::sync::Arc;
  3
  4use anyhow::Result;
  5use client::Client;
  6use gpui::{
  7    App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
  8};
  9use icons::IconName;
 10use proto::{Plan, TypedEnvelope};
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 14use strum::EnumIter;
 15use thiserror::Error;
 16
 17use crate::{LanguageModelAvailability, LanguageModelToolSchemaFormat};
 18
 19#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 20#[serde(tag = "provider", rename_all = "lowercase")]
 21pub enum CloudModel {
 22    Anthropic(anthropic::Model),
 23    OpenAi(open_ai::Model),
 24    Google(google_ai::Model),
 25}
 26
 27#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
 28pub enum ZedModel {
 29    #[serde(rename = "Qwen/Qwen2-7B-Instruct")]
 30    Qwen2_7bInstruct,
 31}
 32
 33impl Default for CloudModel {
 34    fn default() -> Self {
 35        Self::Anthropic(anthropic::Model::default())
 36    }
 37}
 38
 39impl CloudModel {
 40    pub fn id(&self) -> &str {
 41        match self {
 42            Self::Anthropic(model) => model.id(),
 43            Self::OpenAi(model) => model.id(),
 44            Self::Google(model) => model.id(),
 45        }
 46    }
 47
 48    pub fn display_name(&self) -> &str {
 49        match self {
 50            Self::Anthropic(model) => model.display_name(),
 51            Self::OpenAi(model) => model.display_name(),
 52            Self::Google(model) => model.display_name(),
 53        }
 54    }
 55
 56    pub fn icon(&self) -> Option<IconName> {
 57        match self {
 58            Self::Anthropic(_) => Some(IconName::AiAnthropicHosted),
 59            _ => None,
 60        }
 61    }
 62
 63    pub fn max_token_count(&self) -> usize {
 64        match self {
 65            Self::Anthropic(model) => model.max_token_count(),
 66            Self::OpenAi(model) => model.max_token_count(),
 67            Self::Google(model) => model.max_token_count(),
 68        }
 69    }
 70
 71    /// Returns the availability of this model.
 72    pub fn availability(&self) -> LanguageModelAvailability {
 73        match self {
 74            Self::Anthropic(model) => match model {
 75                anthropic::Model::Claude3_5Sonnet
 76                | anthropic::Model::Claude3_7Sonnet
 77                | anthropic::Model::Claude3_7SonnetThinking => {
 78                    LanguageModelAvailability::RequiresPlan(Plan::Free)
 79                }
 80                anthropic::Model::Claude3Opus
 81                | anthropic::Model::Claude3Sonnet
 82                | anthropic::Model::Claude3Haiku
 83                | anthropic::Model::Claude3_5Haiku
 84                | anthropic::Model::Custom { .. } => {
 85                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
 86                }
 87            },
 88            Self::OpenAi(model) => match model {
 89                open_ai::Model::ThreePointFiveTurbo
 90                | open_ai::Model::Four
 91                | open_ai::Model::FourTurbo
 92                | open_ai::Model::FourOmni
 93                | open_ai::Model::FourOmniMini
 94                | open_ai::Model::O1Mini
 95                | open_ai::Model::O1Preview
 96                | open_ai::Model::O1
 97                | open_ai::Model::O3Mini
 98                | open_ai::Model::Custom { .. } => {
 99                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
100                }
101            },
102            Self::Google(model) => match model {
103                google_ai::Model::Gemini15Pro
104                | google_ai::Model::Gemini15Flash
105                | google_ai::Model::Gemini20Pro
106                | google_ai::Model::Gemini20Flash
107                | google_ai::Model::Gemini20FlashThinking
108                | google_ai::Model::Gemini20FlashLite
109                | google_ai::Model::Gemini25ProExp0325
110                | google_ai::Model::Gemini25ProPreview0325
111                | google_ai::Model::Custom { .. } => {
112                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
113                }
114            },
115        }
116    }
117
118    pub fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
119        match self {
120            Self::Anthropic(_) | Self::OpenAi(_) => LanguageModelToolSchemaFormat::JsonSchema,
121            Self::Google(_) => LanguageModelToolSchemaFormat::JsonSchemaSubset,
122        }
123    }
124}
125
126#[derive(Error, Debug)]
127pub struct PaymentRequiredError;
128
129impl fmt::Display for PaymentRequiredError {
130    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
131        write!(
132            f,
133            "Payment required to use this language model. Please upgrade your account."
134        )
135    }
136}
137
138#[derive(Error, Debug)]
139pub struct MaxMonthlySpendReachedError;
140
141impl fmt::Display for MaxMonthlySpendReachedError {
142    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
143        write!(
144            f,
145            "Maximum spending limit reached for this month. For more usage, increase your spending limit."
146        )
147    }
148}
149
150#[derive(Clone, Default)]
151pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
152
153impl LlmApiToken {
154    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
155        let lock = self.0.upgradable_read().await;
156        if let Some(token) = lock.as_ref() {
157            Ok(token.to_string())
158        } else {
159            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
160        }
161    }
162
163    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
164        Self::fetch(self.0.write().await, client).await
165    }
166
167    async fn fetch(
168        mut lock: RwLockWriteGuard<'_, Option<String>>,
169        client: &Arc<Client>,
170    ) -> Result<String> {
171        let response = client.request(proto::GetLlmToken {}).await?;
172        *lock = Some(response.token.clone());
173        Ok(response.token.clone())
174    }
175}
176
177struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
178
179impl Global for GlobalRefreshLlmTokenListener {}
180
181pub struct RefreshLlmTokenEvent;
182
183pub struct RefreshLlmTokenListener {
184    _llm_token_subscription: client::Subscription,
185}
186
187impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
188
189impl RefreshLlmTokenListener {
190    pub fn register(client: Arc<Client>, cx: &mut App) {
191        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
192        cx.set_global(GlobalRefreshLlmTokenListener(listener));
193    }
194
195    pub fn global(cx: &App) -> Entity<Self> {
196        GlobalRefreshLlmTokenListener::global(cx).0.clone()
197    }
198
199    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
200        Self {
201            _llm_token_subscription: client
202                .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
203        }
204    }
205
206    async fn handle_refresh_llm_token(
207        this: Entity<Self>,
208        _: TypedEnvelope<proto::RefreshLlmToken>,
209        mut cx: AsyncApp,
210    ) -> Result<()> {
211        this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
212    }
213}