1use std::fmt;
2use std::sync::Arc;
3
4use anyhow::Result;
5use client::Client;
6use gpui::{
7 App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
8};
9use proto::{Plan, TypedEnvelope};
10use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
11use thiserror::Error;
12
13#[derive(Error, Debug)]
14pub struct PaymentRequiredError;
15
16impl fmt::Display for PaymentRequiredError {
17 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
18 write!(
19 f,
20 "Payment required to use this language model. Please upgrade your account."
21 )
22 }
23}
24
25#[derive(Error, Debug)]
26pub struct ModelRequestLimitReachedError {
27 pub plan: Plan,
28}
29
30impl fmt::Display for ModelRequestLimitReachedError {
31 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
32 let message = match self.plan {
33 Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
34 Plan::ZedPro => {
35 "Model request limit reached. Upgrade to usage-based billing for more requests."
36 }
37 Plan::ZedProTrial => {
38 "Model request limit reached. Upgrade to Zed Pro for more requests."
39 }
40 };
41
42 write!(f, "{message}")
43 }
44}
45
46#[derive(Clone, Default)]
47pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
48
49impl LlmApiToken {
50 pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
51 let lock = self.0.upgradable_read().await;
52 if let Some(token) = lock.as_ref() {
53 Ok(token.to_string())
54 } else {
55 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
56 }
57 }
58
59 pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
60 Self::fetch(self.0.write().await, client).await
61 }
62
63 async fn fetch(
64 mut lock: RwLockWriteGuard<'_, Option<String>>,
65 client: &Arc<Client>,
66 ) -> Result<String> {
67 let system_id = client
68 .telemetry()
69 .system_id()
70 .map(|system_id| system_id.to_string());
71
72 let response = client.cloud_client().create_llm_token(system_id).await?;
73 *lock = Some(response.token.0.clone());
74 Ok(response.token.0.clone())
75 }
76}
77
78struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
79
80impl Global for GlobalRefreshLlmTokenListener {}
81
82pub struct RefreshLlmTokenEvent;
83
84pub struct RefreshLlmTokenListener {
85 _llm_token_subscription: client::Subscription,
86}
87
88impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
89
90impl RefreshLlmTokenListener {
91 pub fn register(client: Arc<Client>, cx: &mut App) {
92 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
93 cx.set_global(GlobalRefreshLlmTokenListener(listener));
94 }
95
96 pub fn global(cx: &App) -> Entity<Self> {
97 GlobalRefreshLlmTokenListener::global(cx).0.clone()
98 }
99
100 fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
101 Self {
102 _llm_token_subscription: client
103 .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
104 }
105 }
106
107 async fn handle_refresh_llm_token(
108 this: Entity<Self>,
109 _: TypedEnvelope<proto::RefreshLlmToken>,
110 mut cx: AsyncApp,
111 ) -> Result<()> {
112 this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
113 }
114}