cloud.rs

  1use ai_onboarding::YoungAccountBanner;
  2use anyhow::Result;
  3use client::{Client, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls};
  4use cloud_api_client::LlmApiToken;
  5use cloud_api_types::OrganizationId;
  6use cloud_api_types::Plan;
  7use futures::StreamExt;
  8use futures::future::BoxFuture;
  9use gpui::AsyncApp;
 10use gpui::{AnyElement, AnyView, App, Context, Entity, Subscription, Task};
 11use language_model::{
 12    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelProviderId,
 13    LanguageModelProviderName, LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID,
 14    ZED_CLOUD_PROVIDER_NAME,
 15};
 16use language_models_cloud::{CloudLlmTokenProvider, CloudModelProvider};
 17use release_channel::AppVersion;
 18
 19use settings::SettingsStore;
 20pub use settings::ZedDotDevAvailableModel as AvailableModel;
 21pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
 22use std::sync::Arc;
 23use ui::{TintColor, prelude::*};
 24
 25const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
 26const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
 27
 28struct ClientTokenProvider {
 29    client: Arc<Client>,
 30    llm_api_token: LlmApiToken,
 31    user_store: Entity<UserStore>,
 32}
 33
 34impl CloudLlmTokenProvider for ClientTokenProvider {
 35    type AuthContext = Option<OrganizationId>;
 36
 37    fn auth_context(&self, cx: &AsyncApp) -> Self::AuthContext {
 38        self.user_store.read_with(cx, |user_store, _| {
 39            user_store
 40                .current_organization()
 41                .map(|organization| organization.id.clone())
 42        })
 43    }
 44
 45    fn acquire_token(
 46        &self,
 47        organization_id: Self::AuthContext,
 48    ) -> BoxFuture<'static, Result<String>> {
 49        let client = self.client.clone();
 50        let llm_api_token = self.llm_api_token.clone();
 51        Box::pin(async move {
 52            client
 53                .acquire_llm_token(&llm_api_token, organization_id)
 54                .await
 55        })
 56    }
 57
 58    fn refresh_token(
 59        &self,
 60        organization_id: Self::AuthContext,
 61    ) -> BoxFuture<'static, Result<String>> {
 62        let client = self.client.clone();
 63        let llm_api_token = self.llm_api_token.clone();
 64        Box::pin(async move {
 65            client
 66                .refresh_llm_token(&llm_api_token, organization_id)
 67                .await
 68        })
 69    }
 70}
 71
 72#[derive(Default, Clone, Debug, PartialEq)]
 73pub struct ZedDotDevSettings {
 74    pub available_models: Vec<AvailableModel>,
 75}
 76
 77pub struct CloudLanguageModelProvider {
 78    state: Entity<State>,
 79    _maintain_client_status: Task<()>,
 80}
 81
 82pub struct State {
 83    client: Arc<Client>,
 84    user_store: Entity<UserStore>,
 85    status: client::Status,
 86    provider: Entity<CloudModelProvider<ClientTokenProvider>>,
 87    _user_store_subscription: Subscription,
 88    _settings_subscription: Subscription,
 89    _llm_token_subscription: Subscription,
 90    _provider_subscription: Subscription,
 91}
 92
 93impl State {
 94    fn new(
 95        client: Arc<Client>,
 96        user_store: Entity<UserStore>,
 97        status: client::Status,
 98        cx: &mut Context<Self>,
 99    ) -> Self {
100        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
101        let token_provider = Arc::new(ClientTokenProvider {
102            client: client.clone(),
103            llm_api_token: global_llm_token(cx),
104            user_store: user_store.clone(),
105        });
106
107        let provider = cx.new(|cx| {
108            CloudModelProvider::new(
109                token_provider.clone(),
110                client.http_client(),
111                Some(AppVersion::global(cx)),
112            )
113        });
114
115        Self {
116            client: client.clone(),
117            user_store: user_store.clone(),
118            status,
119            _provider_subscription: cx.observe(&provider, |_, _, cx| cx.notify()),
120            provider,
121            _user_store_subscription: cx.subscribe(
122                &user_store,
123                move |this, _user_store, event, cx| match event {
124                    client::user::Event::PrivateUserInfoUpdated => {
125                        let status = *client.status().borrow();
126                        if status.is_signed_out() {
127                            return;
128                        }
129
130                        this.refresh_models(cx);
131                    }
132                    _ => {}
133                },
134            ),
135            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
136                cx.notify();
137            }),
138            _llm_token_subscription: cx.subscribe(
139                &refresh_llm_token_listener,
140                move |this, _listener, _event, cx| {
141                    this.refresh_models(cx);
142                },
143            ),
144        }
145    }
146
147    fn is_signed_out(&self, cx: &App) -> bool {
148        self.user_store.read(cx).current_user().is_none()
149    }
150
151    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
152        let client = self.client.clone();
153        cx.spawn(async move |state, cx| {
154            client.sign_in_with_optional_connect(true, cx).await?;
155            state.update(cx, |_, cx| cx.notify())
156        })
157    }
158
159    fn refresh_models(&mut self, cx: &mut Context<Self>) {
160        self.provider.update(cx, |provider, cx| {
161            provider.refresh_models(cx).detach_and_log_err(cx);
162        });
163    }
164}
165
166impl CloudLanguageModelProvider {
167    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
168        let mut status_rx = client.status();
169        let status = *status_rx.borrow();
170
171        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
172
173        let state_ref = state.downgrade();
174        let maintain_client_status = cx.spawn(async move |cx| {
175            while let Some(status) = status_rx.next().await {
176                if let Some(this) = state_ref.upgrade() {
177                    _ = this.update(cx, |this, cx| {
178                        if this.status != status {
179                            this.status = status;
180                            cx.notify();
181                        }
182                    });
183                } else {
184                    break;
185                }
186            }
187        });
188
189        Self {
190            state,
191            _maintain_client_status: maintain_client_status,
192        }
193    }
194}
195
196impl LanguageModelProviderState for CloudLanguageModelProvider {
197    type ObservableEntity = State;
198
199    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
200        Some(self.state.clone())
201    }
202}
203
204impl LanguageModelProvider for CloudLanguageModelProvider {
205    fn id(&self) -> LanguageModelProviderId {
206        PROVIDER_ID
207    }
208
209    fn name(&self) -> LanguageModelProviderName {
210        PROVIDER_NAME
211    }
212
213    fn icon(&self) -> IconOrSvg {
214        IconOrSvg::Icon(IconName::AiZed)
215    }
216
217    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
218        let state = self.state.read(cx);
219        let provider = state.provider.read(cx);
220        let model = provider.default_model()?;
221        Some(provider.create_model(model))
222    }
223
224    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
225        let state = self.state.read(cx);
226        let provider = state.provider.read(cx);
227        let model = provider.default_fast_model()?;
228        Some(provider.create_model(model))
229    }
230
231    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
232        let state = self.state.read(cx);
233        let provider = state.provider.read(cx);
234        provider
235            .recommended_models()
236            .iter()
237            .map(|model| provider.create_model(model))
238            .collect()
239    }
240
241    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
242        let state = self.state.read(cx);
243        let provider = state.provider.read(cx);
244        provider
245            .models()
246            .iter()
247            .map(|model| provider.create_model(model))
248            .collect()
249    }
250
251    fn is_authenticated(&self, cx: &App) -> bool {
252        let state = self.state.read(cx);
253        !state.is_signed_out(cx)
254    }
255
256    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
257        Task::ready(Ok(()))
258    }
259
260    fn configuration_view(
261        &self,
262        _target_agent: language_model::ConfigurationViewTargetAgent,
263        _: &mut Window,
264        cx: &mut App,
265    ) -> AnyView {
266        cx.new(|_| ConfigurationView::new(self.state.clone()))
267            .into()
268    }
269
270    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
271        Task::ready(Ok(()))
272    }
273}
274
275#[derive(IntoElement, RegisterComponent)]
276struct ZedAiConfiguration {
277    is_connected: bool,
278    plan: Option<Plan>,
279    eligible_for_trial: bool,
280    account_too_young: bool,
281    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
282}
283
284impl RenderOnce for ZedAiConfiguration {
285    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
286        let (subscription_text, has_paid_plan) = match self.plan {
287            Some(Plan::ZedPro) => (
288                "You have access to Zed's hosted models through your Pro subscription.",
289                true,
290            ),
291            Some(Plan::ZedProTrial) => (
292                "You have access to Zed's hosted models through your Pro trial.",
293                false,
294            ),
295            Some(Plan::ZedStudent) => (
296                "You have access to Zed's hosted models through your Student subscription.",
297                true,
298            ),
299            Some(Plan::ZedBusiness) => (
300                "You have access to Zed's hosted models through your Organization.",
301                true,
302            ),
303            Some(Plan::ZedFree) | None => (
304                if self.eligible_for_trial {
305                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
306                } else {
307                    "Subscribe for access to Zed's hosted models."
308                },
309                false,
310            ),
311        };
312
313        let manage_subscription_buttons = if has_paid_plan {
314            Button::new("manage_settings", "Manage Subscription")
315                .full_width()
316                .label_size(LabelSize::Small)
317                .style(ButtonStyle::Tinted(TintColor::Accent))
318                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
319                .into_any_element()
320        } else if self.plan.is_none() || self.eligible_for_trial {
321            Button::new("start_trial", "Start 14-day Free Pro Trial")
322                .full_width()
323                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
324                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
325                .into_any_element()
326        } else {
327            Button::new("upgrade", "Upgrade to Pro")
328                .full_width()
329                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
330                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
331                .into_any_element()
332        };
333
334        if !self.is_connected {
335            return v_flex()
336                .gap_2()
337                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
338                .child(
339                    Button::new("sign_in", "Sign In to use Zed AI")
340                        .start_icon(Icon::new(IconName::Github).size(IconSize::Small).color(Color::Muted))
341                        .full_width()
342                        .on_click({
343                            let callback = self.sign_in_callback.clone();
344                            move |_, window, cx| (callback)(window, cx)
345                        }),
346                );
347        }
348
349        v_flex().gap_2().w_full().map(|this| {
350            if self.account_too_young {
351                this.child(YoungAccountBanner).child(
352                    Button::new("upgrade", "Upgrade to Pro")
353                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
354                        .full_width()
355                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
356                )
357            } else {
358                this.text_sm()
359                    .child(subscription_text)
360                    .child(manage_subscription_buttons)
361            }
362        })
363    }
364}
365
366struct ConfigurationView {
367    state: Entity<State>,
368    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
369}
370
371impl ConfigurationView {
372    fn new(state: Entity<State>) -> Self {
373        let sign_in_callback = Arc::new({
374            let state = state.clone();
375            move |_window: &mut Window, cx: &mut App| {
376                state.update(cx, |state, cx| {
377                    state.authenticate(cx).detach_and_log_err(cx);
378                });
379            }
380        });
381
382        Self {
383            state,
384            sign_in_callback,
385        }
386    }
387}
388
389impl Render for ConfigurationView {
390    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
391        let state = self.state.read(cx);
392        let user_store = state.user_store.read(cx);
393
394        ZedAiConfiguration {
395            is_connected: !state.is_signed_out(cx),
396            plan: user_store.plan(),
397            eligible_for_trial: user_store.trial_started_at().is_none(),
398            account_too_young: user_store.account_too_young(),
399            sign_in_callback: self.sign_in_callback.clone(),
400        }
401    }
402}
403
404impl Component for ZedAiConfiguration {
405    fn name() -> &'static str {
406        "AI Configuration Content"
407    }
408
409    fn sort_name() -> &'static str {
410        "AI Configuration Content"
411    }
412
413    fn scope() -> ComponentScope {
414        ComponentScope::Onboarding
415    }
416
417    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
418        fn configuration(
419            is_connected: bool,
420            plan: Option<Plan>,
421            eligible_for_trial: bool,
422            account_too_young: bool,
423        ) -> AnyElement {
424            ZedAiConfiguration {
425                is_connected,
426                plan,
427                eligible_for_trial,
428                account_too_young,
429                sign_in_callback: Arc::new(|_, _| {}),
430            }
431            .into_any_element()
432        }
433
434        Some(
435            v_flex()
436                .p_4()
437                .gap_4()
438                .children(vec![
439                    single_example("Not connected", configuration(false, None, false, false)),
440                    single_example(
441                        "Accept Terms of Service",
442                        configuration(true, None, true, false),
443                    ),
444                    single_example(
445                        "No Plan - Not eligible for trial",
446                        configuration(true, None, false, false),
447                    ),
448                    single_example(
449                        "No Plan - Eligible for trial",
450                        configuration(true, None, true, false),
451                    ),
452                    single_example(
453                        "Free Plan",
454                        configuration(true, Some(Plan::ZedFree), true, false),
455                    ),
456                    single_example(
457                        "Zed Pro Trial Plan",
458                        configuration(true, Some(Plan::ZedProTrial), true, false),
459                    ),
460                    single_example(
461                        "Zed Pro Plan",
462                        configuration(true, Some(Plan::ZedPro), true, false),
463                    ),
464                ])
465                .into_any_element(),
466        )
467    }
468}