1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::Result;
5use client::Client;
6use cloud_api_types::websocket_protocol::MessageToClient;
7use cloud_llm_client::Plan;
8use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
9use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
10use thiserror::Error;
11
12#[derive(Error, Debug)]
13pub struct PaymentRequiredError;
14
15impl fmt::Display for PaymentRequiredError {
16 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
17 write!(
18 f,
19 "Payment required to use this language model. Please upgrade your account."
20 )
21 }
22}
23
24#[derive(Error, Debug)]
25pub struct ModelRequestLimitReachedError {
26 pub plan: Plan,
27}
28
29impl fmt::Display for ModelRequestLimitReachedError {
30 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
31 let message = match self.plan {
32 Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.",
33 Plan::ZedPro => {
34 "Model request limit reached. Upgrade to usage-based billing for more requests."
35 }
36 Plan::ZedProTrial => {
37 "Model request limit reached. Upgrade to Zed Pro for more requests."
38 }
39 };
40
41 write!(f, "{message}")
42 }
43}
44
45#[derive(Clone, Default)]
46pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
47
48impl LlmApiToken {
49 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
50 let lock = self.0.upgradable_read().await;
51 if let Some(token) = lock.as_ref() {
52 Ok(token.to_string())
53 } else {
54 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
55 }
56 }
57
58 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
59 Self::fetch(self.0.write().await, client).await
60 }
61
62 async fn fetch(
63 mut lock: RwLockWriteGuard<'_, Option<String>>,
64 client: &Arc<Client>,
65 ) -> Result<String> {
66 let system_id = client
67 .telemetry()
68 .system_id()
69 .map(|system_id| system_id.to_string());
70
71 let response = client.cloud_client().create_llm_token(system_id).await?;
72 *lock = Some(response.token.0.clone());
73 Ok(response.token.0.clone())
74 }
75}
76
77struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
78
79impl Global for GlobalRefreshLlmTokenListener {}
80
81pub struct RefreshLlmTokenEvent;
82
83pub struct RefreshLlmTokenListener;
84
85impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
86
87impl RefreshLlmTokenListener {
88 pub fn register(client: Arc<Client>, cx: &mut App) {
89 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
90 cx.set_global(GlobalRefreshLlmTokenListener(listener));
91 }
92
93 pub fn global(cx: &App) -> Entity<Self> {
94 GlobalRefreshLlmTokenListener::global(cx).0.clone()
95 }
96
97 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
98 client.add_message_to_client_handler({
99 let this = cx.entity();
100 move |message, cx| {
101 Self::handle_refresh_llm_token(this.clone(), message, cx);
102 }
103 });
104
105 Self
106 }
107
108 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
109 match message {
110 MessageToClient::UserUpdated => {
111 this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
112 }
113 }
114 }
115}