1use ai_onboarding::YoungAccountBanner;
2use anyhow::Result;
3use client::Status;
4use client::{Client, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls};
5use cloud_api_client::LlmApiToken;
6use cloud_api_types::OrganizationId;
7use cloud_api_types::Plan;
8use futures::StreamExt;
9use futures::future::BoxFuture;
10use gpui::{AnyElement, AnyView, App, AppContext, Context, Entity, Subscription, Task};
11use language_model::{
12 AuthenticateError, IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelProviderId,
13 LanguageModelProviderName, LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID,
14 ZED_CLOUD_PROVIDER_NAME,
15};
16use language_models_cloud::{CloudLlmTokenProvider, CloudModelProvider};
17use release_channel::AppVersion;
18
19use settings::SettingsStore;
20pub use settings::ZedDotDevAvailableModel as AvailableModel;
21pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
22use std::sync::Arc;
23use ui::{TintColor, prelude::*};
24
25const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
26const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
27
28struct ClientTokenProvider {
29 client: Arc<Client>,
30 llm_api_token: LlmApiToken,
31 user_store: Entity<UserStore>,
32}
33
34impl CloudLlmTokenProvider for ClientTokenProvider {
35 type AuthContext = Option<OrganizationId>;
36
37 fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext {
38 self.user_store.read_with(cx, |user_store, _| {
39 user_store
40 .current_organization()
41 .map(|organization| organization.id.clone())
42 })
43 }
44
45 fn acquire_token(
46 &self,
47 organization_id: Self::AuthContext,
48 ) -> BoxFuture<'static, Result<String>> {
49 let client = self.client.clone();
50 let llm_api_token = self.llm_api_token.clone();
51 Box::pin(async move {
52 client
53 .acquire_llm_token(&llm_api_token, organization_id)
54 .await
55 })
56 }
57
58 fn refresh_token(
59 &self,
60 organization_id: Self::AuthContext,
61 ) -> BoxFuture<'static, Result<String>> {
62 let client = self.client.clone();
63 let llm_api_token = self.llm_api_token.clone();
64 Box::pin(async move {
65 client
66 .refresh_llm_token(&llm_api_token, organization_id)
67 .await
68 })
69 }
70}
71
72#[derive(Default, Clone, Debug, PartialEq)]
73pub struct ZedDotDevSettings {
74 pub available_models: Vec<AvailableModel>,
75}
76
77pub struct CloudLanguageModelProvider {
78 state: Entity<State>,
79 _maintain_client_status: Task<()>,
80}
81
82pub struct State {
83 client: Arc<Client>,
84 user_store: Entity<UserStore>,
85 status: client::Status,
86 provider: Entity<CloudModelProvider<ClientTokenProvider>>,
87 _user_store_subscription: Subscription,
88 _settings_subscription: Subscription,
89 _llm_token_subscription: Subscription,
90 _provider_subscription: Subscription,
91}
92
93impl State {
94 fn new(
95 client: Arc<Client>,
96 user_store: Entity<UserStore>,
97 status: client::Status,
98 cx: &mut Context<Self>,
99 ) -> Self {
100 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
101 let token_provider = Arc::new(ClientTokenProvider {
102 client: client.clone(),
103 llm_api_token: global_llm_token(cx),
104 user_store: user_store.clone(),
105 });
106
107 let provider = cx.new(|cx| {
108 CloudModelProvider::new(
109 token_provider.clone(),
110 client.http_client(),
111 Some(AppVersion::global(cx)),
112 )
113 });
114
115 Self {
116 client: client.clone(),
117 user_store: user_store.clone(),
118 status,
119 _provider_subscription: cx.observe(&provider, |_, _, cx| cx.notify()),
120 provider,
121 _user_store_subscription: cx.subscribe(
122 &user_store,
123 move |this, _user_store, event, cx| match event {
124 client::user::Event::PrivateUserInfoUpdated => {
125 let status = *client.status().borrow();
126 if status.is_signed_out() {
127 return;
128 }
129
130 this.refresh_models(cx);
131 }
132 _ => {}
133 },
134 ),
135 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
136 cx.notify();
137 }),
138 _llm_token_subscription: cx.subscribe(
139 &refresh_llm_token_listener,
140 move |this, _listener, _event, cx| {
141 this.refresh_models(cx);
142 },
143 ),
144 }
145 }
146
147 fn is_signed_out(&self, cx: &App) -> bool {
148 self.user_store.read(cx).current_user().is_none()
149 }
150
151 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
152 let client = self.client.clone();
153 cx.spawn(async move |state, cx| {
154 client.sign_in_with_optional_connect(true, cx).await?;
155 state.update(cx, |_, cx| cx.notify())
156 })
157 }
158
159 fn refresh_models(&mut self, cx: &mut Context<Self>) {
160 self.provider.update(cx, |provider, cx| {
161 provider.refresh_models(cx).detach_and_log_err(cx);
162 });
163 }
164}
165
166impl CloudLanguageModelProvider {
167 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
168 let mut status_rx = client.status();
169 let status = *status_rx.borrow();
170
171 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
172
173 let state_ref = state.downgrade();
174 let maintain_client_status = cx.spawn(async move |cx| {
175 while let Some(status) = status_rx.next().await {
176 if let Some(this) = state_ref.upgrade() {
177 _ = this.update(cx, |this, cx| {
178 if this.status != status {
179 this.status = status;
180 cx.notify();
181 }
182 });
183 } else {
184 break;
185 }
186 }
187 });
188
189 Self {
190 state,
191 _maintain_client_status: maintain_client_status,
192 }
193 }
194}
195
196impl LanguageModelProviderState for CloudLanguageModelProvider {
197 type ObservableEntity = State;
198
199 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
200 Some(self.state.clone())
201 }
202}
203
204impl LanguageModelProvider for CloudLanguageModelProvider {
205 fn id(&self) -> LanguageModelProviderId {
206 PROVIDER_ID
207 }
208
209 fn name(&self) -> LanguageModelProviderName {
210 PROVIDER_NAME
211 }
212
213 fn icon(&self) -> IconOrSvg {
214 IconOrSvg::Icon(IconName::AiZed)
215 }
216
217 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
218 let state = self.state.read(cx);
219 let provider = state.provider.read(cx);
220 let model = provider.default_model()?;
221 Some(provider.create_model(model))
222 }
223
224 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
225 let state = self.state.read(cx);
226 let provider = state.provider.read(cx);
227 let model = provider.default_fast_model()?;
228 Some(provider.create_model(model))
229 }
230
231 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
232 let state = self.state.read(cx);
233 let provider = state.provider.read(cx);
234 provider
235 .recommended_models()
236 .iter()
237 .map(|model| provider.create_model(model))
238 .collect()
239 }
240
241 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
242 let state = self.state.read(cx);
243 let provider = state.provider.read(cx);
244 provider
245 .models()
246 .iter()
247 .map(|model| provider.create_model(model))
248 .collect()
249 }
250
251 fn is_authenticated(&self, cx: &App) -> bool {
252 let state = self.state.read(cx);
253 let status = *state.client.status().borrow();
254 matches!(status, Status::Authenticated | Status::Connected { .. })
255 }
256
257 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
258 let mut status = self.state.read(cx).client.status();
259 if !status.borrow().is_signing_in() {
260 return Task::ready(Ok(()));
261 }
262 cx.background_spawn(async move {
263 while status.borrow().is_signing_in() {
264 status.next().await;
265 }
266 Ok(())
267 })
268 }
269
270 fn configuration_view(
271 &self,
272 _target_agent: language_model::ConfigurationViewTargetAgent,
273 _: &mut Window,
274 cx: &mut App,
275 ) -> AnyView {
276 cx.new(|_| ConfigurationView::new(self.state.clone()))
277 .into()
278 }
279
280 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
281 Task::ready(Ok(()))
282 }
283}
284
285#[derive(IntoElement, RegisterComponent)]
286struct ZedAiConfiguration {
287 is_connected: bool,
288 plan: Option<Plan>,
289 is_zed_model_provider_enabled: bool,
290 eligible_for_trial: bool,
291 account_too_young: bool,
292 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
293}
294
295impl RenderOnce for ZedAiConfiguration {
296 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
297 let (subscription_text, has_paid_plan) = match self.plan {
298 Some(Plan::ZedPro) => (
299 "You have access to Zed's hosted models through your Pro subscription.",
300 true,
301 ),
302 Some(Plan::ZedProTrial) => (
303 "You have access to Zed's hosted models through your Pro trial.",
304 false,
305 ),
306 Some(Plan::ZedStudent) => (
307 "You have access to Zed's hosted models through your Student subscription.",
308 true,
309 ),
310 Some(Plan::ZedBusiness) => (
311 if self.is_zed_model_provider_enabled {
312 "You have access to Zed's hosted models through your organization."
313 } else {
314 "Zed's hosted models are disabled by your organization's configuration."
315 },
316 true,
317 ),
318 Some(Plan::ZedFree) | None => (
319 if self.eligible_for_trial {
320 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
321 } else {
322 "Subscribe for access to Zed's hosted models."
323 },
324 false,
325 ),
326 };
327
328 let manage_subscription_buttons = if has_paid_plan {
329 Button::new("manage_settings", "Manage Subscription")
330 .full_width()
331 .label_size(LabelSize::Small)
332 .style(ButtonStyle::Tinted(TintColor::Accent))
333 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
334 .into_any_element()
335 } else if self.plan.is_none() || self.eligible_for_trial {
336 Button::new("start_trial", "Start 14-day Free Pro Trial")
337 .full_width()
338 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
339 .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
340 .into_any_element()
341 } else {
342 Button::new("upgrade", "Upgrade to Pro")
343 .full_width()
344 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
345 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
346 .into_any_element()
347 };
348
349 if !self.is_connected {
350 return v_flex()
351 .gap_2()
352 .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
353 .child(
354 Button::new("sign_in", "Sign In to use Zed AI")
355 .start_icon(Icon::new(IconName::Github).size(IconSize::Small).color(Color::Muted))
356 .full_width()
357 .on_click({
358 let callback = self.sign_in_callback.clone();
359 move |_, window, cx| (callback)(window, cx)
360 }),
361 );
362 }
363
364 v_flex().gap_2().w_full().map(|this| {
365 if self.account_too_young {
366 this.child(YoungAccountBanner).child(
367 Button::new("upgrade", "Upgrade to Pro")
368 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
369 .full_width()
370 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
371 )
372 } else {
373 this.text_sm()
374 .child(subscription_text)
375 .child(manage_subscription_buttons)
376 }
377 })
378 }
379}
380
381struct ConfigurationView {
382 state: Entity<State>,
383 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
384}
385
386impl ConfigurationView {
387 fn new(state: Entity<State>) -> Self {
388 let sign_in_callback = Arc::new({
389 let state = state.clone();
390 move |_window: &mut Window, cx: &mut App| {
391 state.update(cx, |state, cx| {
392 state.authenticate(cx).detach_and_log_err(cx);
393 });
394 }
395 });
396
397 Self {
398 state,
399 sign_in_callback,
400 }
401 }
402}
403
404impl Render for ConfigurationView {
405 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
406 let state = self.state.read(cx);
407 let user_store = state.user_store.read(cx);
408
409 let is_zed_model_provider_enabled = user_store
410 .current_organization_configuration()
411 .map_or(true, |config| config.is_zed_model_provider_enabled);
412
413 ZedAiConfiguration {
414 is_connected: !state.is_signed_out(cx),
415 plan: user_store.plan(),
416 is_zed_model_provider_enabled,
417 eligible_for_trial: user_store.trial_started_at().is_none(),
418 account_too_young: user_store.account_too_young(),
419 sign_in_callback: self.sign_in_callback.clone(),
420 }
421 }
422}
423
424impl Component for ZedAiConfiguration {
425 fn name() -> &'static str {
426 "AI Configuration Content"
427 }
428
429 fn sort_name() -> &'static str {
430 "AI Configuration Content"
431 }
432
433 fn scope() -> ComponentScope {
434 ComponentScope::Onboarding
435 }
436
437 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
438 struct PreviewConfiguration {
439 plan: Option<Plan>,
440 is_connected: bool,
441 is_zed_model_provider_enabled: bool,
442 eligible_for_trial: bool,
443 }
444
445 let configuration = |config: PreviewConfiguration| -> AnyElement {
446 ZedAiConfiguration {
447 is_connected: config.is_connected,
448 plan: config.plan,
449 is_zed_model_provider_enabled: config.is_zed_model_provider_enabled,
450 eligible_for_trial: config.eligible_for_trial,
451 account_too_young: false,
452 sign_in_callback: Arc::new(|_, _| {}),
453 }
454 .into_any_element()
455 };
456
457 Some(
458 v_flex()
459 .p_4()
460 .gap_4()
461 .children(vec![
462 single_example(
463 "Not connected",
464 configuration(PreviewConfiguration {
465 plan: None,
466 is_connected: false,
467 is_zed_model_provider_enabled: true,
468 eligible_for_trial: false,
469 }),
470 ),
471 single_example(
472 "Accept Terms of Service",
473 configuration(PreviewConfiguration {
474 plan: None,
475 is_connected: true,
476 is_zed_model_provider_enabled: true,
477 eligible_for_trial: true,
478 }),
479 ),
480 single_example(
481 "No Plan - Not eligible for trial",
482 configuration(PreviewConfiguration {
483 plan: None,
484 is_connected: true,
485 is_zed_model_provider_enabled: true,
486 eligible_for_trial: false,
487 }),
488 ),
489 single_example(
490 "No Plan - Eligible for trial",
491 configuration(PreviewConfiguration {
492 plan: None,
493 is_connected: true,
494 is_zed_model_provider_enabled: true,
495 eligible_for_trial: true,
496 }),
497 ),
498 single_example(
499 "Free Plan",
500 configuration(PreviewConfiguration {
501 plan: Some(Plan::ZedFree),
502 is_connected: true,
503 is_zed_model_provider_enabled: true,
504 eligible_for_trial: true,
505 }),
506 ),
507 single_example(
508 "Zed Pro Trial Plan",
509 configuration(PreviewConfiguration {
510 plan: Some(Plan::ZedProTrial),
511 is_connected: true,
512 is_zed_model_provider_enabled: true,
513 eligible_for_trial: true,
514 }),
515 ),
516 single_example(
517 "Zed Pro Plan",
518 configuration(PreviewConfiguration {
519 plan: Some(Plan::ZedPro),
520 is_connected: true,
521 is_zed_model_provider_enabled: true,
522 eligible_for_trial: true,
523 }),
524 ),
525 single_example(
526 "Business Plan - Zed models enabled",
527 configuration(PreviewConfiguration {
528 plan: Some(Plan::ZedBusiness),
529 is_connected: true,
530 is_zed_model_provider_enabled: true,
531 eligible_for_trial: false,
532 }),
533 ),
534 single_example(
535 "Business Plan - Zed models disabled",
536 configuration(PreviewConfiguration {
537 plan: Some(Plan::ZedBusiness),
538 is_connected: true,
539 is_zed_model_provider_enabled: false,
540 eligible_for_trial: false,
541 }),
542 ),
543 ])
544 .into_any_element(),
545 )
546 }
547}