1use super::{Client, UserStore};
2use cloud_api_client::LlmApiToken;
3use cloud_api_types::websocket_protocol::MessageToClient;
4use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
5use gpui::{
6 App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
7 TaskExt,
8};
9use std::sync::Arc;
10
11pub trait NeedsLlmTokenRefresh {
12 /// Returns whether the LLM token needs to be refreshed.
13 fn needs_llm_token_refresh(&self) -> bool;
14}
15
16impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
17 fn needs_llm_token_refresh(&self) -> bool {
18 self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
19 || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
20 }
21}
22
23enum TokenRefreshMode {
24 Refresh,
25 ClearAndRefresh,
26}
27
28pub fn global_llm_token(cx: &App) -> LlmApiToken {
29 RefreshLlmTokenListener::global(cx)
30 .read(cx)
31 .llm_api_token
32 .clone()
33}
34
35struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
36
37impl Global for GlobalRefreshLlmTokenListener {}
38
39pub struct LlmTokenRefreshedEvent;
40
41pub struct RefreshLlmTokenListener {
42 client: Arc<Client>,
43 user_store: Entity<UserStore>,
44 llm_api_token: LlmApiToken,
45 _subscription: Subscription,
46}
47
48impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
49
50impl RefreshLlmTokenListener {
51 pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
52 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
53 cx.set_global(GlobalRefreshLlmTokenListener(listener));
54 }
55
56 pub fn global(cx: &App) -> Entity<Self> {
57 GlobalRefreshLlmTokenListener::global(cx).0.clone()
58 }
59
60 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
61 client.add_message_to_client_handler({
62 let this = cx.weak_entity();
63 move |message, cx| {
64 if let Some(this) = this.upgrade() {
65 Self::handle_refresh_llm_token(this, message, cx);
66 }
67 }
68 });
69
70 let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
71 if matches!(event, super::user::Event::OrganizationChanged) {
72 this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
73 }
74 });
75
76 Self {
77 client,
78 user_store,
79 llm_api_token: LlmApiToken::default(),
80 _subscription: subscription,
81 }
82 }
83
84 fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
85 let client = self.client.clone();
86 let llm_api_token = self.llm_api_token.clone();
87 let organization_id = self
88 .user_store
89 .read(cx)
90 .current_organization()
91 .map(|organization| organization.id.clone());
92 cx.spawn(async move |this, cx| {
93 match mode {
94 TokenRefreshMode::Refresh => {
95 client
96 .refresh_llm_token(&llm_api_token, organization_id)
97 .await?;
98 }
99 TokenRefreshMode::ClearAndRefresh => {
100 client
101 .clear_and_refresh_llm_token(&llm_api_token, organization_id)
102 .await?;
103 }
104 }
105 this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
106 })
107 .detach_and_log_err(cx);
108 }
109
110 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
111 match message {
112 MessageToClient::UserUpdated => {
113 this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
114 }
115 }
116 }
117}