1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::Result;
5use client::Client;
6use cloud_llm_client::Plan;
7use gpui::{
8 App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
9};
10use proto::TypedEnvelope;
11use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub struct PaymentRequiredError;
16
17impl fmt::Display for PaymentRequiredError {
18 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
19 write!(
20 f,
21 "Payment required to use this language model. Please upgrade your account."
22 )
23 }
24}
25
26#[derive(Error, Debug)]
27pub struct ModelRequestLimitReachedError {
28 pub plan: Plan,
29}
30
31impl fmt::Display for ModelRequestLimitReachedError {
32 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
33 let message = match self.plan {
34 Plan::ZedFree => "Model request limit reached. Upgrade to Zed Pro for more requests.",
35 Plan::ZedPro => {
36 "Model request limit reached. Upgrade to usage-based billing for more requests."
37 }
38 Plan::ZedProTrial => {
39 "Model request limit reached. Upgrade to Zed Pro for more requests."
40 }
41 };
42
43 write!(f, "{message}")
44 }
45}
46
47#[derive(Clone, Default)]
48pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
49
50impl LlmApiToken {
51 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
52 let lock = self.0.upgradable_read().await;
53 if let Some(token) = lock.as_ref() {
54 Ok(token.to_string())
55 } else {
56 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
57 }
58 }
59
60 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
61 Self::fetch(self.0.write().await, client).await
62 }
63
64 async fn fetch(
65 mut lock: RwLockWriteGuard<'_, Option<String>>,
66 client: &Arc<Client>,
67 ) -> Result<String> {
68 let system_id = client
69 .telemetry()
70 .system_id()
71 .map(|system_id| system_id.to_string());
72
73 let response = client.cloud_client().create_llm_token(system_id).await?;
74 *lock = Some(response.token.0.clone());
75 Ok(response.token.0.clone())
76 }
77}
78
79struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
80
81impl Global for GlobalRefreshLlmTokenListener {}
82
83pub struct RefreshLlmTokenEvent;
84
85pub struct RefreshLlmTokenListener {
86 _llm_token_subscription: client::Subscription,
87}
88
89impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
90
91impl RefreshLlmTokenListener {
92 pub fn register(client: Arc<Client>, cx: &mut App) {
93 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
94 cx.set_global(GlobalRefreshLlmTokenListener(listener));
95 }
96
97 pub fn global(cx: &App) -> Entity<Self> {
98 GlobalRefreshLlmTokenListener::global(cx).0.clone()
99 }
100
101 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
102 Self {
103 _llm_token_subscription: client
104 .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
105 }
106 }
107
108 async fn handle_refresh_llm_token(
109 this: Entity<Self>,
110 _: TypedEnvelope<proto::RefreshLlmToken>,
111 mut cx: AsyncApp,
112 ) -> Result<()> {
113 this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
114 }
115}