1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::Result;
5use client::Client;
6use gpui::{
7 App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
8};
9use icons::IconName;
10use proto::{Plan, TypedEnvelope};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
14use strum::EnumIter;
15use thiserror::Error;
16
17use crate::{LanguageModelAvailability, LanguageModelToolSchemaFormat};
18
19#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
20#[serde(tag = "provider", rename_all = "lowercase")]
21pub enum CloudModel {
22 Anthropic(anthropic::Model),
23 OpenAi(open_ai::Model),
24 Google(google_ai::Model),
25}
26
27#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
28pub enum ZedModel {
29 #[serde(rename = "Qwen/Qwen2-7B-Instruct")]
30 Qwen2_7bInstruct,
31}
32
33impl Default for CloudModel {
34 fn default() -> Self {
35 Self::Anthropic(anthropic::Model::default())
36 }
37}
38
39impl CloudModel {
40 pub fn id(&self) -> &str {
41 match self {
42 Self::Anthropic(model) => model.id(),
43 Self::OpenAi(model) => model.id(),
44 Self::Google(model) => model.id(),
45 }
46 }
47
48 pub fn display_name(&self) -> &str {
49 match self {
50 Self::Anthropic(model) => model.display_name(),
51 Self::OpenAi(model) => model.display_name(),
52 Self::Google(model) => model.display_name(),
53 }
54 }
55
56 pub fn icon(&self) -> Option<IconName> {
57 match self {
58 Self::Anthropic(_) => Some(IconName::AiAnthropicHosted),
59 _ => None,
60 }
61 }
62
63 pub fn max_token_count(&self) -> usize {
64 match self {
65 Self::Anthropic(model) => model.max_token_count(),
66 Self::OpenAi(model) => model.max_token_count(),
67 Self::Google(model) => model.max_token_count(),
68 }
69 }
70
71 /// Returns the availability of this model.
72 pub fn availability(&self) -> LanguageModelAvailability {
73 match self {
74 Self::Anthropic(model) => match model {
75 anthropic::Model::Claude3_5Sonnet
76 | anthropic::Model::Claude3_7Sonnet
77 | anthropic::Model::Claude3_7SonnetThinking => {
78 LanguageModelAvailability::RequiresPlan(Plan::Free)
79 }
80 anthropic::Model::Claude3Opus
81 | anthropic::Model::Claude3Sonnet
82 | anthropic::Model::Claude3Haiku
83 | anthropic::Model::Claude3_5Haiku
84 | anthropic::Model::Custom { .. } => {
85 LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
86 }
87 },
88 Self::OpenAi(model) => match model {
89 open_ai::Model::ThreePointFiveTurbo
90 | open_ai::Model::Four
91 | open_ai::Model::FourTurbo
92 | open_ai::Model::FourOmni
93 | open_ai::Model::FourOmniMini
94 | open_ai::Model::FourPointOne
95 | open_ai::Model::FourPointOneMini
96 | open_ai::Model::FourPointOneNano
97 | open_ai::Model::O1Mini
98 | open_ai::Model::O1Preview
99 | open_ai::Model::O1
100 | open_ai::Model::O3Mini
101 | open_ai::Model::Custom { .. } => {
102 LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
103 }
104 },
105 Self::Google(model) => match model {
106 google_ai::Model::Gemini15Pro
107 | google_ai::Model::Gemini15Flash
108 | google_ai::Model::Gemini20Pro
109 | google_ai::Model::Gemini20Flash
110 | google_ai::Model::Gemini20FlashThinking
111 | google_ai::Model::Gemini20FlashLite
112 | google_ai::Model::Gemini25ProExp0325
113 | google_ai::Model::Gemini25ProPreview0325
114 | google_ai::Model::Custom { .. } => {
115 LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
116 }
117 },
118 }
119 }
120
121 pub fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
122 match self {
123 Self::Anthropic(_) | Self::OpenAi(_) => LanguageModelToolSchemaFormat::JsonSchema,
124 Self::Google(_) => LanguageModelToolSchemaFormat::JsonSchemaSubset,
125 }
126 }
127}
128
129#[derive(Error, Debug)]
130pub struct PaymentRequiredError;
131
132impl fmt::Display for PaymentRequiredError {
133 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
134 write!(
135 f,
136 "Payment required to use this language model. Please upgrade your account."
137 )
138 }
139}
140
141#[derive(Error, Debug)]
142pub struct MaxMonthlySpendReachedError;
143
144impl fmt::Display for MaxMonthlySpendReachedError {
145 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
146 write!(
147 f,
148 "Maximum spending limit reached for this month. For more usage, increase your spending limit."
149 )
150 }
151}
152
153#[derive(Clone, Default)]
154pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
155
156impl LlmApiToken {
157 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
158 let lock = self.0.upgradable_read().await;
159 if let Some(token) = lock.as_ref() {
160 Ok(token.to_string())
161 } else {
162 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
163 }
164 }
165
166 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
167 Self::fetch(self.0.write().await, client).await
168 }
169
170 async fn fetch(
171 mut lock: RwLockWriteGuard<'_, Option<String>>,
172 client: &Arc<Client>,
173 ) -> Result<String> {
174 let response = client.request(proto::GetLlmToken {}).await?;
175 *lock = Some(response.token.clone());
176 Ok(response.token.clone())
177 }
178}
179
180struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
181
182impl Global for GlobalRefreshLlmTokenListener {}
183
184pub struct RefreshLlmTokenEvent;
185
186pub struct RefreshLlmTokenListener {
187 _llm_token_subscription: client::Subscription,
188}
189
190impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
191
192impl RefreshLlmTokenListener {
193 pub fn register(client: Arc<Client>, cx: &mut App) {
194 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
195 cx.set_global(GlobalRefreshLlmTokenListener(listener));
196 }
197
198 pub fn global(cx: &App) -> Entity<Self> {
199 GlobalRefreshLlmTokenListener::global(cx).0.clone()
200 }
201
202 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
203 Self {
204 _llm_token_subscription: client
205 .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
206 }
207 }
208
209 async fn handle_refresh_llm_token(
210 this: Entity<Self>,
211 _: TypedEnvelope<proto::RefreshLlmToken>,
212 mut cx: AsyncApp,
213 ) -> Result<()> {
214 this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
215 }
216}