1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::Result;
5use client::Client;
6use cloud_api_types::websocket_protocol::MessageToClient;
7use cloud_llm_client::Plan;
8use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
9use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
10use thiserror::Error;
11
12#[derive(Error, Debug)]
13pub struct PaymentRequiredError;
14
15impl fmt::Display for PaymentRequiredError {
16 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
17 write!(
18 f,
19 "Payment required to use this language model. Please upgrade your account."
20 )
21 }
22}
23
24#[derive(Error, Debug)]
25pub struct ModelRequestLimitReachedError {
26 pub plan: Plan,
27}
28
29impl fmt::Display for ModelRequestLimitReachedError {
30 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
31 let message = match self.plan {
32 Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.",
33 Plan::ZedPro => {
34 "Model request limit reached. Upgrade to usage-based billing for more requests."
35 }
36 Plan::ZedProTrial => {
37 "Model request limit reached. Upgrade to Zed Pro for more requests."
38 }
39 };
40
41 write!(f, "{message}")
42 }
43}
44
45#[derive(Error, Debug)]
46pub struct ToolUseLimitReachedError;
47
48impl fmt::Display for ToolUseLimitReachedError {
49 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
50 write!(
51 f,
52 "Consecutive tool use limit reached. Enable Burn Mode for unlimited tool use."
53 )
54 }
55}
56
57#[derive(Clone, Default)]
58pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
59
60impl LlmApiToken {
61 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
62 let lock = self.0.upgradable_read().await;
63 if let Some(token) = lock.as_ref() {
64 Ok(token.to_string())
65 } else {
66 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
67 }
68 }
69
70 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
71 Self::fetch(self.0.write().await, client).await
72 }
73
74 async fn fetch(
75 mut lock: RwLockWriteGuard<'_, Option<String>>,
76 client: &Arc<Client>,
77 ) -> Result<String> {
78 let system_id = client
79 .telemetry()
80 .system_id()
81 .map(|system_id| system_id.to_string());
82
83 let response = client.cloud_client().create_llm_token(system_id).await?;
84 *lock = Some(response.token.0.clone());
85 Ok(response.token.0)
86 }
87}
88
89struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
90
91impl Global for GlobalRefreshLlmTokenListener {}
92
93pub struct RefreshLlmTokenEvent;
94
95pub struct RefreshLlmTokenListener;
96
97impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
98
99impl RefreshLlmTokenListener {
100 pub fn register(client: Arc<Client>, cx: &mut App) {
101 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
102 cx.set_global(GlobalRefreshLlmTokenListener(listener));
103 }
104
105 pub fn global(cx: &App) -> Entity<Self> {
106 GlobalRefreshLlmTokenListener::global(cx).0.clone()
107 }
108
109 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
110 client.add_message_to_client_handler({
111 let this = cx.entity();
112 move |message, cx| {
113 Self::handle_refresh_llm_token(this.clone(), message, cx);
114 }
115 });
116
117 Self
118 }
119
120 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
121 match message {
122 MessageToClient::UserUpdated => {
123 this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
124 }
125 }
126 }
127}