cloud.rs

  1use ai_onboarding::YoungAccountBanner;
  2use anyhow::Result;
  3use client::{Client, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls};
  4use cloud_api_client::LlmApiToken;
  5use cloud_api_types::OrganizationId;
  6use cloud_api_types::Plan;
  7use futures::StreamExt;
  8use futures::future::BoxFuture;
  9use gpui::{AnyElement, AnyView, App, AppContext, Context, Entity, Subscription, Task};
 10use language_model::{
 11    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelProviderId,
 12    LanguageModelProviderName, LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID,
 13    ZED_CLOUD_PROVIDER_NAME,
 14};
 15use language_models_cloud::{CloudLlmTokenProvider, CloudModelProvider};
 16use release_channel::AppVersion;
 17
 18use settings::SettingsStore;
 19pub use settings::ZedDotDevAvailableModel as AvailableModel;
 20pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
 21use std::sync::Arc;
 22use ui::{TintColor, prelude::*};
 23
 24const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
 25const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
 26
 27struct ClientTokenProvider {
 28    client: Arc<Client>,
 29    llm_api_token: LlmApiToken,
 30    user_store: Entity<UserStore>,
 31}
 32
 33impl CloudLlmTokenProvider for ClientTokenProvider {
 34    type AuthContext = Option<OrganizationId>;
 35
 36    fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext {
 37        self.user_store.read_with(cx, |user_store, _| {
 38            user_store
 39                .current_organization()
 40                .map(|organization| organization.id.clone())
 41        })
 42    }
 43
 44    fn acquire_token(
 45        &self,
 46        organization_id: Self::AuthContext,
 47    ) -> BoxFuture<'static, Result<String>> {
 48        let client = self.client.clone();
 49        let llm_api_token = self.llm_api_token.clone();
 50        Box::pin(async move {
 51            client
 52                .acquire_llm_token(&llm_api_token, organization_id)
 53                .await
 54        })
 55    }
 56
 57    fn refresh_token(
 58        &self,
 59        organization_id: Self::AuthContext,
 60    ) -> BoxFuture<'static, Result<String>> {
 61        let client = self.client.clone();
 62        let llm_api_token = self.llm_api_token.clone();
 63        Box::pin(async move {
 64            client
 65                .refresh_llm_token(&llm_api_token, organization_id)
 66                .await
 67        })
 68    }
 69}
 70
 71#[derive(Default, Clone, Debug, PartialEq)]
 72pub struct ZedDotDevSettings {
 73    pub available_models: Vec<AvailableModel>,
 74}
 75
 76pub struct CloudLanguageModelProvider {
 77    state: Entity<State>,
 78    _maintain_client_status: Task<()>,
 79}
 80
 81pub struct State {
 82    client: Arc<Client>,
 83    user_store: Entity<UserStore>,
 84    status: client::Status,
 85    provider: Entity<CloudModelProvider<ClientTokenProvider>>,
 86    _user_store_subscription: Subscription,
 87    _settings_subscription: Subscription,
 88    _llm_token_subscription: Subscription,
 89    _provider_subscription: Subscription,
 90}
 91
 92impl State {
 93    fn new(
 94        client: Arc<Client>,
 95        user_store: Entity<UserStore>,
 96        status: client::Status,
 97        cx: &mut Context<Self>,
 98    ) -> Self {
 99        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
100        let token_provider = Arc::new(ClientTokenProvider {
101            client: client.clone(),
102            llm_api_token: global_llm_token(cx),
103            user_store: user_store.clone(),
104        });
105
106        let provider = cx.new(|cx| {
107            CloudModelProvider::new(
108                token_provider.clone(),
109                client.http_client(),
110                Some(AppVersion::global(cx)),
111            )
112        });
113
114        Self {
115            client: client.clone(),
116            user_store: user_store.clone(),
117            status,
118            _provider_subscription: cx.observe(&provider, |_, _, cx| cx.notify()),
119            provider,
120            _user_store_subscription: cx.subscribe(
121                &user_store,
122                move |this, _user_store, event, cx| match event {
123                    client::user::Event::PrivateUserInfoUpdated => {
124                        let status = *client.status().borrow();
125                        if status.is_signed_out() {
126                            return;
127                        }
128
129                        this.refresh_models(cx);
130                    }
131                    _ => {}
132                },
133            ),
134            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
135                cx.notify();
136            }),
137            _llm_token_subscription: cx.subscribe(
138                &refresh_llm_token_listener,
139                move |this, _listener, _event, cx| {
140                    this.refresh_models(cx);
141                },
142            ),
143        }
144    }
145
146    fn is_signed_out(&self, cx: &App) -> bool {
147        self.user_store.read(cx).current_user().is_none()
148    }
149
150    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
151        let client = self.client.clone();
152        cx.spawn(async move |state, cx| {
153            client.sign_in_with_optional_connect(true, cx).await?;
154            state.update(cx, |_, cx| cx.notify())
155        })
156    }
157
158    fn refresh_models(&mut self, cx: &mut Context<Self>) {
159        self.provider.update(cx, |provider, cx| {
160            provider.refresh_models(cx).detach_and_log_err(cx);
161        });
162    }
163}
164
165impl CloudLanguageModelProvider {
166    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
167        let mut status_rx = client.status();
168        let status = *status_rx.borrow();
169
170        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
171
172        let state_ref = state.downgrade();
173        let maintain_client_status = cx.spawn(async move |cx| {
174            while let Some(status) = status_rx.next().await {
175                if let Some(this) = state_ref.upgrade() {
176                    _ = this.update(cx, |this, cx| {
177                        if this.status != status {
178                            this.status = status;
179                            cx.notify();
180                        }
181                    });
182                } else {
183                    break;
184                }
185            }
186        });
187
188        Self {
189            state,
190            _maintain_client_status: maintain_client_status,
191        }
192    }
193}
194
195impl LanguageModelProviderState for CloudLanguageModelProvider {
196    type ObservableEntity = State;
197
198    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
199        Some(self.state.clone())
200    }
201}
202
203impl LanguageModelProvider for CloudLanguageModelProvider {
204    fn id(&self) -> LanguageModelProviderId {
205        PROVIDER_ID
206    }
207
208    fn name(&self) -> LanguageModelProviderName {
209        PROVIDER_NAME
210    }
211
212    fn icon(&self) -> IconOrSvg {
213        IconOrSvg::Icon(IconName::AiZed)
214    }
215
216    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
217        let state = self.state.read(cx);
218        let provider = state.provider.read(cx);
219        let model = provider.default_model()?;
220        Some(provider.create_model(model))
221    }
222
223    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
224        let state = self.state.read(cx);
225        let provider = state.provider.read(cx);
226        let model = provider.default_fast_model()?;
227        Some(provider.create_model(model))
228    }
229
230    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
231        let state = self.state.read(cx);
232        let provider = state.provider.read(cx);
233        provider
234            .recommended_models()
235            .iter()
236            .map(|model| provider.create_model(model))
237            .collect()
238    }
239
240    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
241        let state = self.state.read(cx);
242        let provider = state.provider.read(cx);
243        provider
244            .models()
245            .iter()
246            .map(|model| provider.create_model(model))
247            .collect()
248    }
249
250    fn is_authenticated(&self, cx: &App) -> bool {
251        let state = self.state.read(cx);
252        !state.is_signed_out(cx)
253    }
254
255    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
256        Task::ready(Ok(()))
257    }
258
259    fn configuration_view(
260        &self,
261        _target_agent: language_model::ConfigurationViewTargetAgent,
262        _: &mut Window,
263        cx: &mut App,
264    ) -> AnyView {
265        cx.new(|_| ConfigurationView::new(self.state.clone()))
266            .into()
267    }
268
269    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
270        Task::ready(Ok(()))
271    }
272}
273
274#[derive(IntoElement, RegisterComponent)]
275struct ZedAiConfiguration {
276    is_connected: bool,
277    plan: Option<Plan>,
278    eligible_for_trial: bool,
279    account_too_young: bool,
280    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
281}
282
283impl RenderOnce for ZedAiConfiguration {
284    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
285        let (subscription_text, has_paid_plan) = match self.plan {
286            Some(Plan::ZedPro) => (
287                "You have access to Zed's hosted models through your Pro subscription.",
288                true,
289            ),
290            Some(Plan::ZedProTrial) => (
291                "You have access to Zed's hosted models through your Pro trial.",
292                false,
293            ),
294            Some(Plan::ZedStudent) => (
295                "You have access to Zed's hosted models through your Student subscription.",
296                true,
297            ),
298            Some(Plan::ZedBusiness) => (
299                "You have access to Zed's hosted models through your Organization.",
300                true,
301            ),
302            Some(Plan::ZedFree) | None => (
303                if self.eligible_for_trial {
304                    "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
305                } else {
306                    "Subscribe for access to Zed's hosted models."
307                },
308                false,
309            ),
310        };
311
312        let manage_subscription_buttons = if has_paid_plan {
313            Button::new("manage_settings", "Manage Subscription")
314                .full_width()
315                .label_size(LabelSize::Small)
316                .style(ButtonStyle::Tinted(TintColor::Accent))
317                .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
318                .into_any_element()
319        } else if self.plan.is_none() || self.eligible_for_trial {
320            Button::new("start_trial", "Start 14-day Free Pro Trial")
321                .full_width()
322                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
323                .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
324                .into_any_element()
325        } else {
326            Button::new("upgrade", "Upgrade to Pro")
327                .full_width()
328                .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
329                .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
330                .into_any_element()
331        };
332
333        if !self.is_connected {
334            return v_flex()
335                .gap_2()
336                .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
337                .child(
338                    Button::new("sign_in", "Sign In to use Zed AI")
339                        .start_icon(Icon::new(IconName::Github).size(IconSize::Small).color(Color::Muted))
340                        .full_width()
341                        .on_click({
342                            let callback = self.sign_in_callback.clone();
343                            move |_, window, cx| (callback)(window, cx)
344                        }),
345                );
346        }
347
348        v_flex().gap_2().w_full().map(|this| {
349            if self.account_too_young {
350                this.child(YoungAccountBanner).child(
351                    Button::new("upgrade", "Upgrade to Pro")
352                        .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
353                        .full_width()
354                        .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
355                )
356            } else {
357                this.text_sm()
358                    .child(subscription_text)
359                    .child(manage_subscription_buttons)
360            }
361        })
362    }
363}
364
365struct ConfigurationView {
366    state: Entity<State>,
367    sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
368}
369
370impl ConfigurationView {
371    fn new(state: Entity<State>) -> Self {
372        let sign_in_callback = Arc::new({
373            let state = state.clone();
374            move |_window: &mut Window, cx: &mut App| {
375                state.update(cx, |state, cx| {
376                    state.authenticate(cx).detach_and_log_err(cx);
377                });
378            }
379        });
380
381        Self {
382            state,
383            sign_in_callback,
384        }
385    }
386}
387
388impl Render for ConfigurationView {
389    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
390        let state = self.state.read(cx);
391        let user_store = state.user_store.read(cx);
392
393        ZedAiConfiguration {
394            is_connected: !state.is_signed_out(cx),
395            plan: user_store.plan(),
396            eligible_for_trial: user_store.trial_started_at().is_none(),
397            account_too_young: user_store.account_too_young(),
398            sign_in_callback: self.sign_in_callback.clone(),
399        }
400    }
401}
402
403impl Component for ZedAiConfiguration {
404    fn name() -> &'static str {
405        "AI Configuration Content"
406    }
407
408    fn sort_name() -> &'static str {
409        "AI Configuration Content"
410    }
411
412    fn scope() -> ComponentScope {
413        ComponentScope::Onboarding
414    }
415
416    fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
417        fn configuration(
418            is_connected: bool,
419            plan: Option<Plan>,
420            eligible_for_trial: bool,
421            account_too_young: bool,
422        ) -> AnyElement {
423            ZedAiConfiguration {
424                is_connected,
425                plan,
426                eligible_for_trial,
427                account_too_young,
428                sign_in_callback: Arc::new(|_, _| {}),
429            }
430            .into_any_element()
431        }
432
433        Some(
434            v_flex()
435                .p_4()
436                .gap_4()
437                .children(vec![
438                    single_example("Not connected", configuration(false, None, false, false)),
439                    single_example(
440                        "Accept Terms of Service",
441                        configuration(true, None, true, false),
442                    ),
443                    single_example(
444                        "No Plan - Not eligible for trial",
445                        configuration(true, None, false, false),
446                    ),
447                    single_example(
448                        "No Plan - Eligible for trial",
449                        configuration(true, None, true, false),
450                    ),
451                    single_example(
452                        "Free Plan",
453                        configuration(true, Some(Plan::ZedFree), true, false),
454                    ),
455                    single_example(
456                        "Zed Pro Trial Plan",
457                        configuration(true, Some(Plan::ZedProTrial), true, false),
458                    ),
459                    single_example(
460                        "Zed Pro Plan",
461                        configuration(true, Some(Plan::ZedPro), true, false),
462                    ),
463                ])
464                .into_any_element(),
465        )
466    }
467}