1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::Result;
5use client::Client;
6use gpui::{
7 App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
8};
9use proto::{Plan, TypedEnvelope};
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
13use strum::EnumIter;
14use thiserror::Error;
15use ui::IconName;
16
17use crate::LanguageModelAvailability;
18
19#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
20#[serde(tag = "provider", rename_all = "lowercase")]
21pub enum CloudModel {
22 Anthropic(anthropic::Model),
23 OpenAi(open_ai::Model),
24 Google(google_ai::Model),
25}
26
27#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
28pub enum ZedModel {
29 #[serde(rename = "Qwen/Qwen2-7B-Instruct")]
30 Qwen2_7bInstruct,
31}
32
33impl Default for CloudModel {
34 fn default() -> Self {
35 Self::Anthropic(anthropic::Model::default())
36 }
37}
38
39impl CloudModel {
40 pub fn id(&self) -> &str {
41 match self {
42 Self::Anthropic(model) => model.id(),
43 Self::OpenAi(model) => model.id(),
44 Self::Google(model) => model.id(),
45 }
46 }
47
48 pub fn display_name(&self) -> &str {
49 match self {
50 Self::Anthropic(model) => model.display_name(),
51 Self::OpenAi(model) => model.display_name(),
52 Self::Google(model) => model.display_name(),
53 }
54 }
55
56 pub fn icon(&self) -> Option<IconName> {
57 match self {
58 Self::Anthropic(_) => Some(IconName::AiAnthropicHosted),
59 _ => None,
60 }
61 }
62
63 pub fn max_token_count(&self) -> usize {
64 match self {
65 Self::Anthropic(model) => model.max_token_count(),
66 Self::OpenAi(model) => model.max_token_count(),
67 Self::Google(model) => model.max_token_count(),
68 }
69 }
70
71 /// Returns the availability of this model.
72 pub fn availability(&self) -> LanguageModelAvailability {
73 match self {
74 Self::Anthropic(model) => match model {
75 anthropic::Model::Claude3_5Sonnet | anthropic::Model::Claude3_7Sonnet => {
76 LanguageModelAvailability::RequiresPlan(Plan::Free)
77 }
78 anthropic::Model::Claude3Opus
79 | anthropic::Model::Claude3Sonnet
80 | anthropic::Model::Claude3Haiku
81 | anthropic::Model::Claude3_5Haiku
82 | anthropic::Model::Custom { .. } => {
83 LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
84 }
85 },
86 Self::OpenAi(model) => match model {
87 open_ai::Model::ThreePointFiveTurbo
88 | open_ai::Model::Four
89 | open_ai::Model::FourTurbo
90 | open_ai::Model::FourOmni
91 | open_ai::Model::FourOmniMini
92 | open_ai::Model::O1Mini
93 | open_ai::Model::O1Preview
94 | open_ai::Model::O1
95 | open_ai::Model::O3Mini
96 | open_ai::Model::Custom { .. } => {
97 LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
98 }
99 },
100 Self::Google(model) => match model {
101 google_ai::Model::Gemini15Pro
102 | google_ai::Model::Gemini15Flash
103 | google_ai::Model::Gemini20Pro
104 | google_ai::Model::Gemini20Flash
105 | google_ai::Model::Gemini20FlashThinking
106 | google_ai::Model::Gemini20FlashLite
107 | google_ai::Model::Custom { .. } => {
108 LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
109 }
110 },
111 }
112 }
113}
114
115#[derive(Error, Debug)]
116pub struct PaymentRequiredError;
117
118impl fmt::Display for PaymentRequiredError {
119 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120 write!(
121 f,
122 "Payment required to use this language model. Please upgrade your account."
123 )
124 }
125}
126
127#[derive(Error, Debug)]
128pub struct MaxMonthlySpendReachedError;
129
130impl fmt::Display for MaxMonthlySpendReachedError {
131 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
132 write!(
133 f,
134 "Maximum spending limit reached for this month. For more usage, increase your spending limit."
135 )
136 }
137}
138
139#[derive(Clone, Default)]
140pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
141
142impl LlmApiToken {
143 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
144 let lock = self.0.upgradable_read().await;
145 if let Some(token) = lock.as_ref() {
146 Ok(token.to_string())
147 } else {
148 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
149 }
150 }
151
152 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
153 Self::fetch(self.0.write().await, client).await
154 }
155
156 async fn fetch(
157 mut lock: RwLockWriteGuard<'_, Option<String>>,
158 client: &Arc<Client>,
159 ) -> Result<String> {
160 let response = client.request(proto::GetLlmToken {}).await?;
161 *lock = Some(response.token.clone());
162 Ok(response.token.clone())
163 }
164}
165
166struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
167
168impl Global for GlobalRefreshLlmTokenListener {}
169
170pub struct RefreshLlmTokenEvent;
171
172pub struct RefreshLlmTokenListener {
173 _llm_token_subscription: client::Subscription,
174}
175
176impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
177
178impl RefreshLlmTokenListener {
179 pub fn register(client: Arc<Client>, cx: &mut App) {
180 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
181 cx.set_global(GlobalRefreshLlmTokenListener(listener));
182 }
183
184 pub fn global(cx: &App) -> Entity<Self> {
185 GlobalRefreshLlmTokenListener::global(cx).0.clone()
186 }
187
188 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
189 Self {
190 _llm_token_subscription: client
191 .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
192 }
193 }
194
195 async fn handle_refresh_llm_token(
196 this: Entity<Self>,
197 _: TypedEnvelope<proto::RefreshLlmToken>,
198 mut cx: AsyncApp,
199 ) -> Result<()> {
200 this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
201 }
202}