1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::Result;
5use client::Client;
6use cloud_api_types::websocket_protocol::MessageToClient;
7use cloud_llm_client::Plan;
8use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
9use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
10use thiserror::Error;
11
12#[derive(Error, Debug)]
13pub struct PaymentRequiredError;
14
15impl fmt::Display for PaymentRequiredError {
16 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
17 write!(
18 f,
19 "Payment required to use this language model. Please upgrade your account."
20 )
21 }
22}
23
24#[derive(Error, Debug)]
25pub struct ModelRequestLimitReachedError {
26 pub plan: Plan,
27}
28
29impl fmt::Display for ModelRequestLimitReachedError {
30 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
31 let message = match self.plan {
32 Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.",
33 Plan::ZedPro => {
34 "Model request limit reached. Upgrade to usage-based billing for more requests."
35 }
36 Plan::ZedProTrial => {
37 "Model request limit reached. Upgrade to Zed Pro for more requests."
38 }
39 Plan::ZedProV2 | Plan::ZedProTrialV2 => "Model request limit reached.",
40 };
41
42 write!(f, "{message}")
43 }
44}
45
46#[derive(Error, Debug)]
47pub struct ToolUseLimitReachedError;
48
49impl fmt::Display for ToolUseLimitReachedError {
50 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
51 write!(
52 f,
53 "Consecutive tool use limit reached. Enable Burn Mode for unlimited tool use."
54 )
55 }
56}
57
58#[derive(Clone, Default)]
59pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
60
61impl LlmApiToken {
62 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
63 let lock = self.0.upgradable_read().await;
64 if let Some(token) = lock.as_ref() {
65 Ok(token.to_string())
66 } else {
67 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
68 }
69 }
70
71 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
72 Self::fetch(self.0.write().await, client).await
73 }
74
75 async fn fetch(
76 mut lock: RwLockWriteGuard<'_, Option<String>>,
77 client: &Arc<Client>,
78 ) -> Result<String> {
79 let system_id = client
80 .telemetry()
81 .system_id()
82 .map(|system_id| system_id.to_string());
83
84 let response = client.cloud_client().create_llm_token(system_id).await?;
85 *lock = Some(response.token.0.clone());
86 Ok(response.token.0)
87 }
88}
89
90struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
91
92impl Global for GlobalRefreshLlmTokenListener {}
93
94pub struct RefreshLlmTokenEvent;
95
96pub struct RefreshLlmTokenListener;
97
98impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
99
100impl RefreshLlmTokenListener {
101 pub fn register(client: Arc<Client>, cx: &mut App) {
102 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
103 cx.set_global(GlobalRefreshLlmTokenListener(listener));
104 }
105
106 pub fn global(cx: &App) -> Entity<Self> {
107 GlobalRefreshLlmTokenListener::global(cx).0.clone()
108 }
109
110 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
111 client.add_message_to_client_handler({
112 let this = cx.entity();
113 move |message, cx| {
114 Self::handle_refresh_llm_token(this.clone(), message, cx);
115 }
116 });
117
118 Self
119 }
120
121 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
122 match message {
123 MessageToClient::UserUpdated => {
124 this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
125 }
126 }
127 }
128}