1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::Result;
5use client::Client;
6use cloud_api_types::websocket_protocol::MessageToClient;
7use cloud_llm_client::{Plan, PlanV1};
8use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
9use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
10use thiserror::Error;
11
12#[derive(Error, Debug)]
13pub struct PaymentRequiredError;
14
15impl fmt::Display for PaymentRequiredError {
16 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
17 write!(
18 f,
19 "Payment required to use this language model. Please upgrade your account."
20 )
21 }
22}
23
24#[derive(Error, Debug)]
25pub struct ModelRequestLimitReachedError {
26 pub plan: Plan,
27}
28
29impl fmt::Display for ModelRequestLimitReachedError {
30 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
31 let message = match self.plan {
32 Plan::V1(PlanV1::ZedFree) => {
33 "Model request limit reached. Upgrade to Zed Pro for more requests."
34 }
35 Plan::V1(PlanV1::ZedPro) => {
36 "Model request limit reached. Upgrade to usage-based billing for more requests."
37 }
38 Plan::V1(PlanV1::ZedProTrial) => {
39 "Model request limit reached. Upgrade to Zed Pro for more requests."
40 }
41 Plan::V2(_) => "Model request limit reached.",
42 };
43
44 write!(f, "{message}")
45 }
46}
47
48#[derive(Error, Debug)]
49pub struct ToolUseLimitReachedError;
50
51impl fmt::Display for ToolUseLimitReachedError {
52 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
53 write!(
54 f,
55 "Consecutive tool use limit reached. Enable Burn Mode for unlimited tool use."
56 )
57 }
58}
59
60#[derive(Clone, Default)]
61pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
62
63impl LlmApiToken {
64 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
65 let lock = self.0.upgradable_read().await;
66 if let Some(token) = lock.as_ref() {
67 Ok(token.to_string())
68 } else {
69 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
70 }
71 }
72
73 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
74 Self::fetch(self.0.write().await, client).await
75 }
76
77 async fn fetch(
78 mut lock: RwLockWriteGuard<'_, Option<String>>,
79 client: &Arc<Client>,
80 ) -> Result<String> {
81 let system_id = client
82 .telemetry()
83 .system_id()
84 .map(|system_id| system_id.to_string());
85
86 let response = client.cloud_client().create_llm_token(system_id).await?;
87 *lock = Some(response.token.0.clone());
88 Ok(response.token.0)
89 }
90}
91
92struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
93
94impl Global for GlobalRefreshLlmTokenListener {}
95
96pub struct RefreshLlmTokenEvent;
97
98pub struct RefreshLlmTokenListener;
99
100impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
101
102impl RefreshLlmTokenListener {
103 pub fn register(client: Arc<Client>, cx: &mut App) {
104 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
105 cx.set_global(GlobalRefreshLlmTokenListener(listener));
106 }
107
108 pub fn global(cx: &App) -> Entity<Self> {
109 GlobalRefreshLlmTokenListener::global(cx).0.clone()
110 }
111
112 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
113 client.add_message_to_client_handler({
114 let this = cx.entity();
115 move |message, cx| {
116 Self::handle_refresh_llm_token(this.clone(), message, cx);
117 }
118 });
119
120 Self
121 }
122
123 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
124 match message {
125 MessageToClient::UserUpdated => {
126 this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
127 }
128 }
129 }
130}