cloud.rs

  1use ai_onboarding::YoungAccountBanner;
  2use anyhow::Result;
  3use client::Status;
  4use client::{Client, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls};
  5use cloud_api_client::LlmApiToken;
  6use cloud_api_types::OrganizationId;
  7use cloud_api_types::Plan;
  8use futures::StreamExt;
  9use futures::future::BoxFuture;
 10use gpui::{AnyElement, AnyView, App, AppContext, Context, Entity, Subscription, Task};
 11use language_model::{
 12    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelProviderId,
 13    LanguageModelProviderName, LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID,
 14    ZED_CLOUD_PROVIDER_NAME,
 15};
 16use language_models_cloud::{CloudLlmTokenProvider, CloudModelProvider};
 17use release_channel::AppVersion;
 18
 19use settings::SettingsStore;
 20pub use settings::ZedDotDevAvailableModel as AvailableModel;
 21pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
 22use std::sync::Arc;
 23use ui::{TintColor, prelude::*};
 24
 25const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
 26const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
 27
 28struct ClientTokenProvider {
 29    client: Arc<Client>,
 30    llm_api_token: LlmApiToken,
 31    user_store: Entity<UserStore>,
 32}
 33
 34impl CloudLlmTokenProvider for ClientTokenProvider {
 35    type AuthContext = Option<OrganizationId>;
 36
 37    fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext {
 38        self.user_store.read_with(cx, |user_store, _| {
 39            user_store
 40                .current_organization()
 41                .map(|organization| organization.id.clone())
 42        })
 43    }
 44
 45    fn acquire_token(
 46        &self,
 47        organization_id: Self::AuthContext,
 48    ) -> BoxFuture<'static, Result<String>> {
 49        let client = self.client.clone();
 50        let llm_api_token = self.llm_api_token.clone();
 51        Box::pin(async move {
 52            client
 53                .acquire_llm_token(&llm_api_token, organization_id)
 54                .await
 55        })
 56    }
 57
 58    fn refresh_token(
 59        &self,
 60        organization_id: Self::AuthContext,
 61    ) -> BoxFuture<'static, Result<String>> {
 62        let client = self.client.clone();
 63        let llm_api_token = self.llm_api_token.clone();
 64        Box::pin(async move {
 65            client
 66                .refresh_llm_token(&llm_api_token, organization_id)
 67                .await
 68        })
 69    }
 70}
 71
 72#[derive(Default, Clone, Debug, PartialEq)]
 73pub struct ZedDotDevSettings {
 74    pub available_models: Vec<AvailableModel>,
 75}
 76
 77pub struct CloudLanguageModelProvider {
 78    state: Entity<State>,
 79    _maintain_client_status: Task<()>,
 80}
 81
 82pub struct State {
 83    client: Arc<Client>,
 84    user_store: Entity<UserStore>,
 85    status: client::Status,
 86    provider: Entity<CloudModelProvider<ClientTokenProvider>>,
 87    _user_store_subscription: Subscription,
 88    _settings_subscription: Subscription,
 89    _llm_token_subscription: Subscription,
 90    _provider_subscription: Subscription,
 91}
 92
 93impl State {
 94    fn new(
 95        client: Arc<Client>,
 96        user_store: Entity<UserStore>,
 97        status: client::Status,
 98        cx: &mut Context<Self>,
 99    ) -> Self {
100        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
101        let token_provider = Arc::new(ClientTokenProvider {
102            client: client.clone(),
103            llm_api_token: global_llm_token(cx),
104            user_store: user_store.clone(),
105        });
106
107        let provider = cx.new(|cx| {
108            CloudModelProvider::new(
109                token_provider.clone(),
110                client.http_client(),
111                Some(AppVersion::global(cx)),
112            )
113        });
114
115        Self {
116            client: client.clone(),
117            user_store: user_store.clone(),
118            status,
119            _provider_subscription: cx.observe(&provider, |_, _, cx| cx.notify()),
120            provider,
121            _user_store_subscription: cx.subscribe(
122                &user_store,
123                move |this, _user_store, event, cx| match event {
124                    client::user::Event::PrivateUserInfoUpdated => {
125                        let status = *client.status().borrow();
126                        if status.is_signed_out() {
127                            return;
128                        }
129
130                        this.refresh_models(cx);
131                    }
132                    _ => {}
133                },
134            ),
135            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
136                cx.notify();
137            }),
138            _llm_token_subscription: cx.subscribe(
139                &refresh_llm_token_listener,
140                move |this, _listener, _event, cx| {
141                    this.refresh_models(cx);
142                },
143            ),
144        }
145    }
146
147    fn is_signed_out(&self, cx: &App) -> bool {
148        self.user_store.read(cx).current_user().is_none()
149    }
150
151    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
152        let client = self.client.clone();
153        cx.spawn(async move |state, cx| {
154            client.sign_in_with_optional_connect(true, cx).await?;
155            state.update(cx, |_, cx| cx.notify())
156        })
157    }
158
159    fn refresh_models(&mut self, cx: &mut Context<Self>) {
160        self.provider.update(cx, |provider, cx| {
161            provider.refresh_models(cx).detach_and_log_err(cx);
162        });
163    }
164}
165
166impl CloudLanguageModelProvider {
167    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
168        let mut status_rx = client.status();
169        let status = *status_rx.borrow();
170
171        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
172
173        let state_ref = state.downgrade();
174        let maintain_client_status = cx.spawn(async move |cx| {
175            while let Some(status) = status_rx.next().await {
176                if let Some(this) = state_ref.upgrade() {
177                    _ = this.update(cx, |this, cx| {
178                        if this.status != status {
179                            this.status = status;
180                            cx.notify();
181                        }
182                    });
183                } else {
184                    break;
185                }
186            }
187        });
188
189        Self {
190            state,
191            _maintain_client_status: maintain_client_status,
192        }
193    }
194}
195
196impl LanguageModelProviderState for CloudLanguageModelProvider {
197    type ObservableEntity = State;
198
199    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
200        Some(self.state.clone())
201    }
202}
203
204impl LanguageModelProvider for CloudLanguageModelProvider {
205    fn id(&self) -> LanguageModelProviderId {
206        PROVIDER_ID
207    }
208
209    fn name(&self) -> LanguageModelProviderName {
210        PROVIDER_NAME
211    }
212
213    fn icon(&self) -> IconOrSvg {
214        IconOrSvg::Icon(IconName::AiZed)
215    }
216
217    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
218        let state = self.state.read(cx);
219        let provider = state.provider.read(cx);
220        let model = provider.default_model()?;
221        Some(provider.create_model(model))
222    }
223
224    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
225        let state = self.state.read(cx);
226        let provider = state.provider.read(cx);
227        let model = provider.default_fast_model()?;
228        Some(provider.create_model(model))
229    }
230
231    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
232        let state = self.state.read(cx);
233        let provider = state.provider.read(cx);
234        provider
235            .recommended_models()
236            .iter()
237            .map(|model| provider.create_model(model))
238            .collect()
239    }
240
241    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
242        let state = self.state.read(cx);
243        let provider = state.provider.read(cx);
244        provider
245            .models()
246            .iter()
247            .map(|model| provider.create_model(model))
248            .collect()
249    }
250
251    fn is_authenticated(&self, cx: &App) -> bool {
252        let state = self.state.read(cx);
253        let status = *state.client.status().borrow();
254        matches!(status, Status::Authenticated | Status::Connected { .. })
255    }
256
257    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
258        let mut status = self.state.read(cx).client.status();
259        if !status.borrow().is_signing_in() {
260            return Task::ready(Ok(()));
261        }
262        cx.background_spawn(async move {
263            while status.borrow().is_signing_in() {
264                status.next().await;
265            }
266            Ok(())
267        })
268    }
269
270    fn configuration_view(
271        &self,
272        _target_agent: language_model::ConfigurationViewTargetAgent,
273        _: &mut Window,
274        cx: &mut App,
275    ) -> AnyView {
276        cx.new(|_| ConfigurationView::new(self.state.clone()))
277            .into()
278    }
279
280    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
281        Task::ready(Ok(()))
282    }
283}
284
285#[derive(IntoElement, RegisterComponent)]
286struct ZedAiConfiguration {
287    is_connected: bool,
288    plan: Option<Plan>,
289    is_zed_model_provider_enabled: bool,
290    eligible_for_trial: bool,
291    account_too_young: bool,
292    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
293}
294
295impl RenderOnce for ZedAiConfiguration {
296    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
297        let (subscription_text, has_paid_plan) = match self.plan {
298            Some(Plan::ZedPro) => (
299                "You have access to Zed's hosted models through your Pro subscription.",
300                true,
301            ),
302            Some(Plan::ZedProTrial) => (
303                "You have access to Zed's hosted models through your Pro trial.",
304                false,
305            ),
306            Some(Plan::ZedStudent) => (
307                "You have access to Zed's hosted models through your Student subscription.",
308                true,
309            ),
310            Some(Plan::ZedBusiness) => (
311                if self.is_zed_model_provider_enabled {
312                    "You have access to Zed's hosted models through your organization."
313                } else {
314                    "Zed's hosted models are disabled by your organization's configuration."
315                },
316                true,
317            ),
318            Some(Plan::ZedFree) | None => (
319                if self.eligible_for_trial {
320                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
321                } else {
322                    "Subscribe for access to Zed's hosted models."
323                },
324                false,
325            ),
326        };
327
328        let manage_subscription_buttons = if has_paid_plan {
329            Button::new("manage_settings", "Manage Subscription")
330                .full_width()
331                .label_size(LabelSize::Small)
332                .style(ButtonStyle::Tinted(TintColor::Accent))
333                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
334                .into_any_element()
335        } else if self.plan.is_none() || self.eligible_for_trial {
336            Button::new("start_trial", "Start 14-day Free Pro Trial")
337                .full_width()
338                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
339                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
340                .into_any_element()
341        } else {
342            Button::new("upgrade", "Upgrade to Pro")
343                .full_width()
344                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
345                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
346                .into_any_element()
347        };
348
349        if !self.is_connected {
350            return v_flex()
351                .gap_2()
352                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
353                .child(
354                    Button::new("sign_in", "Sign In to use Zed AI")
355                        .start_icon(Icon::new(IconName::Github).size(IconSize::Small).color(Color::Muted))
356                        .full_width()
357                        .on_click({
358                            let callback = self.sign_in_callback.clone();
359                            move |_, window, cx| (callback)(window, cx)
360                        }),
361                );
362        }
363
364        v_flex().gap_2().w_full().map(|this| {
365            if self.account_too_young {
366                this.child(YoungAccountBanner).child(
367                    Button::new("upgrade", "Upgrade to Pro")
368                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
369                        .full_width()
370                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
371                )
372            } else {
373                this.text_sm()
374                    .child(subscription_text)
375                    .child(manage_subscription_buttons)
376            }
377        })
378    }
379}
380
381struct ConfigurationView {
382    state: Entity<State>,
383    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
384}
385
386impl ConfigurationView {
387    fn new(state: Entity<State>) -> Self {
388        let sign_in_callback = Arc::new({
389            let state = state.clone();
390            move |_window: &mut Window, cx: &mut App| {
391                state.update(cx, |state, cx| {
392                    state.authenticate(cx).detach_and_log_err(cx);
393                });
394            }
395        });
396
397        Self {
398            state,
399            sign_in_callback,
400        }
401    }
402}
403
404impl Render for ConfigurationView {
405    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
406        let state = self.state.read(cx);
407        let user_store = state.user_store.read(cx);
408
409        let is_zed_model_provider_enabled = user_store
410            .current_organization_configuration()
411            .map_or(true, |config| config.is_zed_model_provider_enabled);
412
413        ZedAiConfiguration {
414            is_connected: !state.is_signed_out(cx),
415            plan: user_store.plan(),
416            is_zed_model_provider_enabled,
417            eligible_for_trial: user_store.trial_started_at().is_none(),
418            account_too_young: user_store.account_too_young(),
419            sign_in_callback: self.sign_in_callback.clone(),
420        }
421    }
422}
423
424impl Component for ZedAiConfiguration {
425    fn name() -> &'static str {
426        "AI Configuration Content"
427    }
428
429    fn sort_name() -> &'static str {
430        "AI Configuration Content"
431    }
432
433    fn scope() -> ComponentScope {
434        ComponentScope::Onboarding
435    }
436
437    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
438        struct PreviewConfiguration {
439            plan: Option<Plan>,
440            is_connected: bool,
441            is_zed_model_provider_enabled: bool,
442            eligible_for_trial: bool,
443        }
444
445        let configuration = |config: PreviewConfiguration| -> AnyElement {
446            ZedAiConfiguration {
447                is_connected: config.is_connected,
448                plan: config.plan,
449                is_zed_model_provider_enabled: config.is_zed_model_provider_enabled,
450                eligible_for_trial: config.eligible_for_trial,
451                account_too_young: false,
452                sign_in_callback: Arc::new(|_, _| {}),
453            }
454            .into_any_element()
455        };
456
457        Some(
458            v_flex()
459                .p_4()
460                .gap_4()
461                .children(vec![
462                    single_example(
463                        "Not connected",
464                        configuration(PreviewConfiguration {
465                            plan: None,
466                            is_connected: false,
467                            is_zed_model_provider_enabled: true,
468                            eligible_for_trial: false,
469                        }),
470                    ),
471                    single_example(
472                        "Accept Terms of Service",
473                        configuration(PreviewConfiguration {
474                            plan: None,
475                            is_connected: true,
476                            is_zed_model_provider_enabled: true,
477                            eligible_for_trial: true,
478                        }),
479                    ),
480                    single_example(
481                        "No Plan - Not eligible for trial",
482                        configuration(PreviewConfiguration {
483                            plan: None,
484                            is_connected: true,
485                            is_zed_model_provider_enabled: true,
486                            eligible_for_trial: false,
487                        }),
488                    ),
489                    single_example(
490                        "No Plan - Eligible for trial",
491                        configuration(PreviewConfiguration {
492                            plan: None,
493                            is_connected: true,
494                            is_zed_model_provider_enabled: true,
495                            eligible_for_trial: true,
496                        }),
497                    ),
498                    single_example(
499                        "Free Plan",
500                        configuration(PreviewConfiguration {
501                            plan: Some(Plan::ZedFree),
502                            is_connected: true,
503                            is_zed_model_provider_enabled: true,
504                            eligible_for_trial: true,
505                        }),
506                    ),
507                    single_example(
508                        "Zed Pro Trial Plan",
509                        configuration(PreviewConfiguration {
510                            plan: Some(Plan::ZedProTrial),
511                            is_connected: true,
512                            is_zed_model_provider_enabled: true,
513                            eligible_for_trial: true,
514                        }),
515                    ),
516                    single_example(
517                        "Zed Pro Plan",
518                        configuration(PreviewConfiguration {
519                            plan: Some(Plan::ZedPro),
520                            is_connected: true,
521                            is_zed_model_provider_enabled: true,
522                            eligible_for_trial: true,
523                        }),
524                    ),
525                    single_example(
526                        "Business Plan - Zed models enabled",
527                        configuration(PreviewConfiguration {
528                            plan: Some(Plan::ZedBusiness),
529                            is_connected: true,
530                            is_zed_model_provider_enabled: true,
531                            eligible_for_trial: false,
532                        }),
533                    ),
534                    single_example(
535                        "Business Plan - Zed models disabled",
536                        configuration(PreviewConfiguration {
537                            plan: Some(Plan::ZedBusiness),
538                            is_connected: true,
539                            is_zed_model_provider_enabled: false,
540                            eligible_for_trial: false,
541                        }),
542                    ),
543                ])
544                .into_any_element(),
545        )
546    }
547}