1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::Result;
5use client::Client;
6use gpui::{
7 App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
8};
9use proto::{Plan, TypedEnvelope};
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
13use strum::EnumIter;
14use thiserror::Error;
15
16use crate::{LanguageModelAvailability, LanguageModelToolSchemaFormat};
17
18#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
19#[serde(tag = "provider", rename_all = "lowercase")]
20pub enum CloudModel {
21 Anthropic(anthropic::Model),
22 OpenAi(open_ai::Model),
23 Google(google_ai::Model),
24}
25
26#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
27pub enum ZedModel {
28 #[serde(rename = "Qwen/Qwen2-7B-Instruct")]
29 Qwen2_7bInstruct,
30}
31
32impl Default for CloudModel {
33 fn default() -> Self {
34 Self::Anthropic(anthropic::Model::default())
35 }
36}
37
38impl CloudModel {
39 pub fn id(&self) -> &str {
40 match self {
41 Self::Anthropic(model) => model.id(),
42 Self::OpenAi(model) => model.id(),
43 Self::Google(model) => model.id(),
44 }
45 }
46
47 pub fn display_name(&self) -> &str {
48 match self {
49 Self::Anthropic(model) => model.display_name(),
50 Self::OpenAi(model) => model.display_name(),
51 Self::Google(model) => model.display_name(),
52 }
53 }
54
55 pub fn max_token_count(&self) -> usize {
56 match self {
57 Self::Anthropic(model) => model.max_token_count(),
58 Self::OpenAi(model) => model.max_token_count(),
59 Self::Google(model) => model.max_token_count(),
60 }
61 }
62
63 /// Returns the availability of this model.
64 pub fn availability(&self) -> LanguageModelAvailability {
65 match self {
66 Self::Anthropic(model) => match model {
67 anthropic::Model::Claude3_5Sonnet
68 | anthropic::Model::Claude3_7Sonnet
69 | anthropic::Model::Claude3_7SonnetThinking => {
70 LanguageModelAvailability::RequiresPlan(Plan::Free)
71 }
72 anthropic::Model::Claude3Opus
73 | anthropic::Model::Claude3Sonnet
74 | anthropic::Model::Claude3Haiku
75 | anthropic::Model::Claude3_5Haiku
76 | anthropic::Model::Custom { .. } => {
77 LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
78 }
79 },
80 Self::OpenAi(model) => match model {
81 open_ai::Model::ThreePointFiveTurbo
82 | open_ai::Model::Four
83 | open_ai::Model::FourTurbo
84 | open_ai::Model::FourOmni
85 | open_ai::Model::FourOmniMini
86 | open_ai::Model::FourPointOne
87 | open_ai::Model::FourPointOneMini
88 | open_ai::Model::FourPointOneNano
89 | open_ai::Model::O1Mini
90 | open_ai::Model::O1Preview
91 | open_ai::Model::O1
92 | open_ai::Model::O3Mini
93 | open_ai::Model::O3
94 | open_ai::Model::O4Mini
95 | open_ai::Model::Custom { .. } => {
96 LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
97 }
98 },
99 Self::Google(model) => match model {
100 google_ai::Model::Gemini15Pro
101 | google_ai::Model::Gemini15Flash
102 | google_ai::Model::Gemini20Pro
103 | google_ai::Model::Gemini20Flash
104 | google_ai::Model::Gemini20FlashThinking
105 | google_ai::Model::Gemini20FlashLite
106 | google_ai::Model::Gemini25ProExp0325
107 | google_ai::Model::Gemini25ProPreview0325
108 | google_ai::Model::Gemini25FlashPreview0417
109 | google_ai::Model::Custom { .. } => {
110 LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
111 }
112 },
113 }
114 }
115
116 pub fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
117 match self {
118 Self::Anthropic(_) | Self::OpenAi(_) => LanguageModelToolSchemaFormat::JsonSchema,
119 Self::Google(_) => LanguageModelToolSchemaFormat::JsonSchemaSubset,
120 }
121 }
122}
123
124#[derive(Error, Debug)]
125pub struct PaymentRequiredError;
126
127impl fmt::Display for PaymentRequiredError {
128 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
129 write!(
130 f,
131 "Payment required to use this language model. Please upgrade your account."
132 )
133 }
134}
135
136#[derive(Error, Debug)]
137pub struct MaxMonthlySpendReachedError;
138
139impl fmt::Display for MaxMonthlySpendReachedError {
140 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
141 write!(
142 f,
143 "Maximum spending limit reached for this month. For more usage, increase your spending limit."
144 )
145 }
146}
147
148#[derive(Error, Debug)]
149pub struct ModelRequestLimitReachedError {
150 pub plan: Plan,
151}
152
153impl fmt::Display for ModelRequestLimitReachedError {
154 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
155 let message = match self.plan {
156 Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
157 Plan::ZedPro => {
158 "Model request limit reached. Upgrade to usage-based billing for more requests."
159 }
160 Plan::ZedProTrial => {
161 "Model request limit reached. Upgrade to Zed Pro for more requests."
162 }
163 };
164
165 write!(f, "{message}")
166 }
167}
168
169#[derive(Clone, Default)]
170pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
171
172impl LlmApiToken {
173 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
174 let lock = self.0.upgradable_read().await;
175 if let Some(token) = lock.as_ref() {
176 Ok(token.to_string())
177 } else {
178 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
179 }
180 }
181
182 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
183 Self::fetch(self.0.write().await, client).await
184 }
185
186 async fn fetch(
187 mut lock: RwLockWriteGuard<'_, Option<String>>,
188 client: &Arc<Client>,
189 ) -> Result<String> {
190 let response = client.request(proto::GetLlmToken {}).await?;
191 *lock = Some(response.token.clone());
192 Ok(response.token.clone())
193 }
194}
195
196struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
197
198impl Global for GlobalRefreshLlmTokenListener {}
199
200pub struct RefreshLlmTokenEvent;
201
202pub struct RefreshLlmTokenListener {
203 _llm_token_subscription: client::Subscription,
204}
205
206impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
207
208impl RefreshLlmTokenListener {
209 pub fn register(client: Arc<Client>, cx: &mut App) {
210 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
211 cx.set_global(GlobalRefreshLlmTokenListener(listener));
212 }
213
214 pub fn global(cx: &App) -> Entity<Self> {
215 GlobalRefreshLlmTokenListener::global(cx).0.clone()
216 }
217
218 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
219 Self {
220 _llm_token_subscription: client
221 .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
222 }
223 }
224
225 async fn handle_refresh_llm_token(
226 this: Entity<Self>,
227 _: TypedEnvelope<proto::RefreshLlmToken>,
228 mut cx: AsyncApp,
229 ) -> Result<()> {
230 this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
231 }
232}