1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::Result;
5use client::Client;
6use gpui::{
7 App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
8};
9use proto::{Plan, TypedEnvelope};
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
13use strum::EnumIter;
14use thiserror::Error;
15
16use crate::LanguageModelToolSchemaFormat;
17
18#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
19#[serde(tag = "provider", rename_all = "lowercase")]
20pub enum CloudModel {
21 Anthropic(anthropic::Model),
22 OpenAi(open_ai::Model),
23 Google(google_ai::Model),
24}
25
26#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
27pub enum ZedModel {
28 #[serde(rename = "Qwen/Qwen2-7B-Instruct")]
29 Qwen2_7bInstruct,
30}
31
32impl Default for CloudModel {
33 fn default() -> Self {
34 Self::Anthropic(anthropic::Model::default())
35 }
36}
37
38impl CloudModel {
39 pub fn id(&self) -> &str {
40 match self {
41 Self::Anthropic(model) => model.id(),
42 Self::OpenAi(model) => model.id(),
43 Self::Google(model) => model.id(),
44 }
45 }
46
47 pub fn display_name(&self) -> &str {
48 match self {
49 Self::Anthropic(model) => model.display_name(),
50 Self::OpenAi(model) => model.display_name(),
51 Self::Google(model) => model.display_name(),
52 }
53 }
54
55 pub fn max_token_count(&self) -> usize {
56 match self {
57 Self::Anthropic(model) => model.max_token_count(),
58 Self::OpenAi(model) => model.max_token_count(),
59 Self::Google(model) => model.max_token_count(),
60 }
61 }
62
63 pub fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
64 match self {
65 Self::Anthropic(_) | Self::OpenAi(_) => LanguageModelToolSchemaFormat::JsonSchema,
66 Self::Google(_) => LanguageModelToolSchemaFormat::JsonSchemaSubset,
67 }
68 }
69}
70
71#[derive(Error, Debug)]
72pub struct PaymentRequiredError;
73
74impl fmt::Display for PaymentRequiredError {
75 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
76 write!(
77 f,
78 "Payment required to use this language model. Please upgrade your account."
79 )
80 }
81}
82
83#[derive(Error, Debug)]
84pub struct ModelRequestLimitReachedError {
85 pub plan: Plan,
86}
87
88impl fmt::Display for ModelRequestLimitReachedError {
89 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
90 let message = match self.plan {
91 Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
92 Plan::ZedPro => {
93 "Model request limit reached. Upgrade to usage-based billing for more requests."
94 }
95 Plan::ZedProTrial => {
96 "Model request limit reached. Upgrade to Zed Pro for more requests."
97 }
98 };
99
100 write!(f, "{message}")
101 }
102}
103
104#[derive(Clone, Default)]
105pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
106
107impl LlmApiToken {
108 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
109 let lock = self.0.upgradable_read().await;
110 if let Some(token) = lock.as_ref() {
111 Ok(token.to_string())
112 } else {
113 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
114 }
115 }
116
117 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
118 Self::fetch(self.0.write().await, client).await
119 }
120
121 async fn fetch(
122 mut lock: RwLockWriteGuard<'_, Option<String>>,
123 client: &Arc<Client>,
124 ) -> Result<String> {
125 let response = client.request(proto::GetLlmToken {}).await?;
126 *lock = Some(response.token.clone());
127 Ok(response.token.clone())
128 }
129}
130
131struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
132
133impl Global for GlobalRefreshLlmTokenListener {}
134
135pub struct RefreshLlmTokenEvent;
136
137pub struct RefreshLlmTokenListener {
138 _llm_token_subscription: client::Subscription,
139}
140
141impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
142
143impl RefreshLlmTokenListener {
144 pub fn register(client: Arc<Client>, cx: &mut App) {
145 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
146 cx.set_global(GlobalRefreshLlmTokenListener(listener));
147 }
148
149 pub fn global(cx: &App) -> Entity<Self> {
150 GlobalRefreshLlmTokenListener::global(cx).0.clone()
151 }
152
153 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
154 Self {
155 _llm_token_subscription: client
156 .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
157 }
158 }
159
160 async fn handle_refresh_llm_token(
161 this: Entity<Self>,
162 _: TypedEnvelope<proto::RefreshLlmToken>,
163 mut cx: AsyncApp,
164 ) -> Result<()> {
165 this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
166 }
167}