1use super::{Client, UserStore};
2use cloud_api_client::LlmApiToken;
3use cloud_api_types::websocket_protocol::MessageToClient;
4use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
5use gpui::{
6 App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
7};
8use std::sync::Arc;
9
10pub trait NeedsLlmTokenRefresh {
11 /// Returns whether the LLM token needs to be refreshed.
12 fn needs_llm_token_refresh(&self) -> bool;
13}
14
15impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
16 fn needs_llm_token_refresh(&self) -> bool {
17 self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
18 || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
19 }
20}
21
22enum TokenRefreshMode {
23 Refresh,
24 ClearAndRefresh,
25}
26
27pub fn global_llm_token(cx: &App) -> LlmApiToken {
28 RefreshLlmTokenListener::global(cx)
29 .read(cx)
30 .llm_api_token
31 .clone()
32}
33
34struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
35
36impl Global for GlobalRefreshLlmTokenListener {}
37
38pub struct LlmTokenRefreshedEvent;
39
40pub struct RefreshLlmTokenListener {
41 client: Arc<Client>,
42 user_store: Entity<UserStore>,
43 llm_api_token: LlmApiToken,
44 _subscription: Subscription,
45}
46
47impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
48
49impl RefreshLlmTokenListener {
50 pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
51 let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
52 cx.set_global(GlobalRefreshLlmTokenListener(listener));
53 }
54
55 pub fn global(cx: &App) -> Entity<Self> {
56 GlobalRefreshLlmTokenListener::global(cx).0.clone()
57 }
58
59 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
60 client.add_message_to_client_handler({
61 let this = cx.weak_entity();
62 move |message, cx| {
63 if let Some(this) = this.upgrade() {
64 Self::handle_refresh_llm_token(this, message, cx);
65 }
66 }
67 });
68
69 let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
70 if matches!(event, super::user::Event::OrganizationChanged) {
71 this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
72 }
73 });
74
75 Self {
76 client,
77 user_store,
78 llm_api_token: LlmApiToken::default(),
79 _subscription: subscription,
80 }
81 }
82
83 fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
84 let client = self.client.clone();
85 let llm_api_token = self.llm_api_token.clone();
86 let organization_id = self
87 .user_store
88 .read(cx)
89 .current_organization()
90 .map(|organization| organization.id.clone());
91 cx.spawn(async move |this, cx| {
92 match mode {
93 TokenRefreshMode::Refresh => {
94 client
95 .refresh_llm_token(&llm_api_token, organization_id)
96 .await?;
97 }
98 TokenRefreshMode::ClearAndRefresh => {
99 client
100 .clear_and_refresh_llm_token(&llm_api_token, organization_id)
101 .await?;
102 }
103 }
104 this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
105 })
106 .detach_and_log_err(cx);
107 }
108
109 fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
110 match message {
111 MessageToClient::UserUpdated => {
112 this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
113 }
114 }
115 }
116}