1use ai_onboarding::YoungAccountBanner;
2use anyhow::Result;
3use client::{Client, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls};
4use cloud_api_client::LlmApiToken;
5use cloud_api_types::OrganizationId;
6use cloud_api_types::Plan;
7use futures::StreamExt;
8use futures::future::BoxFuture;
9use gpui::{AnyElement, AnyView, App, AppContext, Context, Entity, Subscription, Task};
10use language_model::{
11 AuthenticateError, IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelProviderId,
12 LanguageModelProviderName, LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID,
13 ZED_CLOUD_PROVIDER_NAME,
14};
15use language_models_cloud::{CloudLlmTokenProvider, CloudModelProvider};
16use release_channel::AppVersion;
17
18use settings::SettingsStore;
19pub use settings::ZedDotDevAvailableModel as AvailableModel;
20pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
21use std::sync::Arc;
22use ui::{TintColor, prelude::*};
23
24const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
25const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
26
27struct ClientTokenProvider {
28 client: Arc<Client>,
29 llm_api_token: LlmApiToken,
30 user_store: Entity<UserStore>,
31}
32
33impl CloudLlmTokenProvider for ClientTokenProvider {
34 type AuthContext = Option<OrganizationId>;
35
36 fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext {
37 self.user_store.read_with(cx, |user_store, _| {
38 user_store
39 .current_organization()
40 .map(|organization| organization.id.clone())
41 })
42 }
43
44 fn acquire_token(
45 &self,
46 organization_id: Self::AuthContext,
47 ) -> BoxFuture<'static, Result<String>> {
48 let client = self.client.clone();
49 let llm_api_token = self.llm_api_token.clone();
50 Box::pin(async move {
51 client
52 .acquire_llm_token(&llm_api_token, organization_id)
53 .await
54 })
55 }
56
57 fn refresh_token(
58 &self,
59 organization_id: Self::AuthContext,
60 ) -> BoxFuture<'static, Result<String>> {
61 let client = self.client.clone();
62 let llm_api_token = self.llm_api_token.clone();
63 Box::pin(async move {
64 client
65 .refresh_llm_token(&llm_api_token, organization_id)
66 .await
67 })
68 }
69}
70
71#[derive(Default, Clone, Debug, PartialEq)]
72pub struct ZedDotDevSettings {
73 pub available_models: Vec<AvailableModel>,
74}
75
76pub struct CloudLanguageModelProvider {
77 state: Entity<State>,
78 _maintain_client_status: Task<()>,
79}
80
81pub struct State {
82 client: Arc<Client>,
83 user_store: Entity<UserStore>,
84 status: client::Status,
85 provider: Entity<CloudModelProvider<ClientTokenProvider>>,
86 _user_store_subscription: Subscription,
87 _settings_subscription: Subscription,
88 _llm_token_subscription: Subscription,
89 _provider_subscription: Subscription,
90}
91
92impl State {
93 fn new(
94 client: Arc<Client>,
95 user_store: Entity<UserStore>,
96 status: client::Status,
97 cx: &mut Context<Self>,
98 ) -> Self {
99 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
100 let token_provider = Arc::new(ClientTokenProvider {
101 client: client.clone(),
102 llm_api_token: global_llm_token(cx),
103 user_store: user_store.clone(),
104 });
105
106 let provider = cx.new(|cx| {
107 CloudModelProvider::new(
108 token_provider.clone(),
109 client.http_client(),
110 Some(AppVersion::global(cx)),
111 )
112 });
113
114 Self {
115 client: client.clone(),
116 user_store: user_store.clone(),
117 status,
118 _provider_subscription: cx.observe(&provider, |_, _, cx| cx.notify()),
119 provider,
120 _user_store_subscription: cx.subscribe(
121 &user_store,
122 move |this, _user_store, event, cx| match event {
123 client::user::Event::PrivateUserInfoUpdated => {
124 let status = *client.status().borrow();
125 if status.is_signed_out() {
126 return;
127 }
128
129 this.refresh_models(cx);
130 }
131 _ => {}
132 },
133 ),
134 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
135 cx.notify();
136 }),
137 _llm_token_subscription: cx.subscribe(
138 &refresh_llm_token_listener,
139 move |this, _listener, _event, cx| {
140 this.refresh_models(cx);
141 },
142 ),
143 }
144 }
145
146 fn is_signed_out(&self, cx: &App) -> bool {
147 self.user_store.read(cx).current_user().is_none()
148 }
149
150 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
151 let client = self.client.clone();
152 cx.spawn(async move |state, cx| {
153 client.sign_in_with_optional_connect(true, cx).await?;
154 state.update(cx, |_, cx| cx.notify())
155 })
156 }
157
158 fn refresh_models(&mut self, cx: &mut Context<Self>) {
159 self.provider.update(cx, |provider, cx| {
160 provider.refresh_models(cx).detach_and_log_err(cx);
161 });
162 }
163}
164
165impl CloudLanguageModelProvider {
166 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
167 let mut status_rx = client.status();
168 let status = *status_rx.borrow();
169
170 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
171
172 let state_ref = state.downgrade();
173 let maintain_client_status = cx.spawn(async move |cx| {
174 while let Some(status) = status_rx.next().await {
175 if let Some(this) = state_ref.upgrade() {
176 _ = this.update(cx, |this, cx| {
177 if this.status != status {
178 this.status = status;
179 cx.notify();
180 }
181 });
182 } else {
183 break;
184 }
185 }
186 });
187
188 Self {
189 state,
190 _maintain_client_status: maintain_client_status,
191 }
192 }
193}
194
195impl LanguageModelProviderState for CloudLanguageModelProvider {
196 type ObservableEntity = State;
197
198 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
199 Some(self.state.clone())
200 }
201}
202
203impl LanguageModelProvider for CloudLanguageModelProvider {
204 fn id(&self) -> LanguageModelProviderId {
205 PROVIDER_ID
206 }
207
208 fn name(&self) -> LanguageModelProviderName {
209 PROVIDER_NAME
210 }
211
212 fn icon(&self) -> IconOrSvg {
213 IconOrSvg::Icon(IconName::AiZed)
214 }
215
216 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
217 let state = self.state.read(cx);
218 let provider = state.provider.read(cx);
219 let model = provider.default_model()?;
220 Some(provider.create_model(model))
221 }
222
223 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
224 let state = self.state.read(cx);
225 let provider = state.provider.read(cx);
226 let model = provider.default_fast_model()?;
227 Some(provider.create_model(model))
228 }
229
230 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
231 let state = self.state.read(cx);
232 let provider = state.provider.read(cx);
233 provider
234 .recommended_models()
235 .iter()
236 .map(|model| provider.create_model(model))
237 .collect()
238 }
239
240 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
241 let state = self.state.read(cx);
242 let provider = state.provider.read(cx);
243 provider
244 .models()
245 .iter()
246 .map(|model| provider.create_model(model))
247 .collect()
248 }
249
250 fn is_authenticated(&self, cx: &App) -> bool {
251 let state = self.state.read(cx);
252 !state.is_signed_out(cx)
253 }
254
255 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
256 Task::ready(Ok(()))
257 }
258
259 fn configuration_view(
260 &self,
261 _target_agent: language_model::ConfigurationViewTargetAgent,
262 _: &mut Window,
263 cx: &mut App,
264 ) -> AnyView {
265 cx.new(|_| ConfigurationView::new(self.state.clone()))
266 .into()
267 }
268
269 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
270 Task::ready(Ok(()))
271 }
272}
273
274#[derive(IntoElement, RegisterComponent)]
275struct ZedAiConfiguration {
276 is_connected: bool,
277 plan: Option<Plan>,
278 eligible_for_trial: bool,
279 account_too_young: bool,
280 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
281}
282
283impl RenderOnce for ZedAiConfiguration {
284 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
285 let (subscription_text, has_paid_plan) = match self.plan {
286 Some(Plan::ZedPro) => (
287 "You have access to Zed's hosted models through your Pro subscription.",
288 true,
289 ),
290 Some(Plan::ZedProTrial) => (
291 "You have access to Zed's hosted models through your Pro trial.",
292 false,
293 ),
294 Some(Plan::ZedStudent) => (
295 "You have access to Zed's hosted models through your Student subscription.",
296 true,
297 ),
298 Some(Plan::ZedBusiness) => (
299 "You have access to Zed's hosted models through your Organization.",
300 true,
301 ),
302 Some(Plan::ZedFree) | None => (
303 if self.eligible_for_trial {
304 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
305 } else {
306 "Subscribe for access to Zed's hosted models."
307 },
308 false,
309 ),
310 };
311
312 let manage_subscription_buttons = if has_paid_plan {
313 Button::new("manage_settings", "Manage Subscription")
314 .full_width()
315 .label_size(LabelSize::Small)
316 .style(ButtonStyle::Tinted(TintColor::Accent))
317 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
318 .into_any_element()
319 } else if self.plan.is_none() || self.eligible_for_trial {
320 Button::new("start_trial", "Start 14-day Free Pro Trial")
321 .full_width()
322 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
323 .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
324 .into_any_element()
325 } else {
326 Button::new("upgrade", "Upgrade to Pro")
327 .full_width()
328 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
329 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
330 .into_any_element()
331 };
332
333 if !self.is_connected {
334 return v_flex()
335 .gap_2()
336 .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
337 .child(
338 Button::new("sign_in", "Sign In to use Zed AI")
339 .start_icon(Icon::new(IconName::Github).size(IconSize::Small).color(Color::Muted))
340 .full_width()
341 .on_click({
342 let callback = self.sign_in_callback.clone();
343 move |_, window, cx| (callback)(window, cx)
344 }),
345 );
346 }
347
348 v_flex().gap_2().w_full().map(|this| {
349 if self.account_too_young {
350 this.child(YoungAccountBanner).child(
351 Button::new("upgrade", "Upgrade to Pro")
352 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
353 .full_width()
354 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
355 )
356 } else {
357 this.text_sm()
358 .child(subscription_text)
359 .child(manage_subscription_buttons)
360 }
361 })
362 }
363}
364
365struct ConfigurationView {
366 state: Entity<State>,
367 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
368}
369
370impl ConfigurationView {
371 fn new(state: Entity<State>) -> Self {
372 let sign_in_callback = Arc::new({
373 let state = state.clone();
374 move |_window: &mut Window, cx: &mut App| {
375 state.update(cx, |state, cx| {
376 state.authenticate(cx).detach_and_log_err(cx);
377 });
378 }
379 });
380
381 Self {
382 state,
383 sign_in_callback,
384 }
385 }
386}
387
388impl Render for ConfigurationView {
389 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
390 let state = self.state.read(cx);
391 let user_store = state.user_store.read(cx);
392
393 ZedAiConfiguration {
394 is_connected: !state.is_signed_out(cx),
395 plan: user_store.plan(),
396 eligible_for_trial: user_store.trial_started_at().is_none(),
397 account_too_young: user_store.account_too_young(),
398 sign_in_callback: self.sign_in_callback.clone(),
399 }
400 }
401}
402
403impl Component for ZedAiConfiguration {
404 fn name() -> &'static str {
405 "AI Configuration Content"
406 }
407
408 fn sort_name() -> &'static str {
409 "AI Configuration Content"
410 }
411
412 fn scope() -> ComponentScope {
413 ComponentScope::Onboarding
414 }
415
416 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
417 fn configuration(
418 is_connected: bool,
419 plan: Option<Plan>,
420 eligible_for_trial: bool,
421 account_too_young: bool,
422 ) -> AnyElement {
423 ZedAiConfiguration {
424 is_connected,
425 plan,
426 eligible_for_trial,
427 account_too_young,
428 sign_in_callback: Arc::new(|_, _| {}),
429 }
430 .into_any_element()
431 }
432
433 Some(
434 v_flex()
435 .p_4()
436 .gap_4()
437 .children(vec![
438 single_example("Not connected", configuration(false, None, false, false)),
439 single_example(
440 "Accept Terms of Service",
441 configuration(true, None, true, false),
442 ),
443 single_example(
444 "No Plan - Not eligible for trial",
445 configuration(true, None, false, false),
446 ),
447 single_example(
448 "No Plan - Eligible for trial",
449 configuration(true, None, true, false),
450 ),
451 single_example(
452 "Free Plan",
453 configuration(true, Some(Plan::ZedFree), true, false),
454 ),
455 single_example(
456 "Zed Pro Trial Plan",
457 configuration(true, Some(Plan::ZedProTrial), true, false),
458 ),
459 single_example(
460 "Zed Pro Plan",
461 configuration(true, Some(Plan::ZedPro), true, false),
462 ),
463 ])
464 .into_any_element(),
465 )
466 }
467}