1use ai_onboarding::YoungAccountBanner;
2use anyhow::Result;
3use client::{Client, RefreshLlmTokenListener, UserStore, global_llm_token, zed_urls};
4use cloud_api_client::LlmApiToken;
5use cloud_api_types::OrganizationId;
6use cloud_api_types::Plan;
7use futures::StreamExt;
8use futures::future::BoxFuture;
9use gpui::AsyncApp;
10use gpui::{AnyElement, AnyView, App, Context, Entity, Subscription, Task};
11use language_model::{
12 AuthenticateError, IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelProviderId,
13 LanguageModelProviderName, LanguageModelProviderState, ZED_CLOUD_PROVIDER_ID,
14 ZED_CLOUD_PROVIDER_NAME,
15};
16use language_models_cloud::{CloudLlmTokenProvider, CloudModelProvider};
17use release_channel::AppVersion;
18
19use settings::SettingsStore;
20pub use settings::ZedDotDevAvailableModel as AvailableModel;
21pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
22use std::sync::Arc;
23use ui::{TintColor, prelude::*};
24
25const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
26const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
27
28struct ClientTokenProvider {
29 client: Arc<Client>,
30 llm_api_token: LlmApiToken,
31 user_store: Entity<UserStore>,
32}
33
34impl CloudLlmTokenProvider for ClientTokenProvider {
35 type AuthContext = Option<OrganizationId>;
36
37 fn auth_context(&self, cx: &AsyncApp) -> Self::AuthContext {
38 self.user_store.read_with(cx, |user_store, _| {
39 user_store
40 .current_organization()
41 .map(|organization| organization.id.clone())
42 })
43 }
44
45 fn acquire_token(
46 &self,
47 organization_id: Self::AuthContext,
48 ) -> BoxFuture<'static, Result<String>> {
49 let client = self.client.clone();
50 let llm_api_token = self.llm_api_token.clone();
51 Box::pin(async move {
52 client
53 .acquire_llm_token(&llm_api_token, organization_id)
54 .await
55 })
56 }
57
58 fn refresh_token(
59 &self,
60 organization_id: Self::AuthContext,
61 ) -> BoxFuture<'static, Result<String>> {
62 let client = self.client.clone();
63 let llm_api_token = self.llm_api_token.clone();
64 Box::pin(async move {
65 client
66 .refresh_llm_token(&llm_api_token, organization_id)
67 .await
68 })
69 }
70}
71
72#[derive(Default, Clone, Debug, PartialEq)]
73pub struct ZedDotDevSettings {
74 pub available_models: Vec<AvailableModel>,
75}
76
77pub struct CloudLanguageModelProvider {
78 state: Entity<State>,
79 _maintain_client_status: Task<()>,
80}
81
82pub struct State {
83 client: Arc<Client>,
84 user_store: Entity<UserStore>,
85 status: client::Status,
86 provider: Entity<CloudModelProvider<ClientTokenProvider>>,
87 _user_store_subscription: Subscription,
88 _settings_subscription: Subscription,
89 _llm_token_subscription: Subscription,
90 _provider_subscription: Subscription,
91}
92
93impl State {
94 fn new(
95 client: Arc<Client>,
96 user_store: Entity<UserStore>,
97 status: client::Status,
98 cx: &mut Context<Self>,
99 ) -> Self {
100 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
101 let token_provider = Arc::new(ClientTokenProvider {
102 client: client.clone(),
103 llm_api_token: global_llm_token(cx),
104 user_store: user_store.clone(),
105 });
106
107 let provider = cx.new(|cx| {
108 CloudModelProvider::new(
109 token_provider.clone(),
110 client.http_client(),
111 Some(AppVersion::global(cx)),
112 )
113 });
114
115 Self {
116 client: client.clone(),
117 user_store: user_store.clone(),
118 status,
119 _provider_subscription: cx.observe(&provider, |_, _, cx| cx.notify()),
120 provider,
121 _user_store_subscription: cx.subscribe(
122 &user_store,
123 move |this, _user_store, event, cx| match event {
124 client::user::Event::PrivateUserInfoUpdated => {
125 let status = *client.status().borrow();
126 if status.is_signed_out() {
127 return;
128 }
129
130 this.refresh_models(cx);
131 }
132 _ => {}
133 },
134 ),
135 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
136 cx.notify();
137 }),
138 _llm_token_subscription: cx.subscribe(
139 &refresh_llm_token_listener,
140 move |this, _listener, _event, cx| {
141 this.refresh_models(cx);
142 },
143 ),
144 }
145 }
146
147 fn is_signed_out(&self, cx: &App) -> bool {
148 self.user_store.read(cx).current_user().is_none()
149 }
150
151 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
152 let client = self.client.clone();
153 cx.spawn(async move |state, cx| {
154 client.sign_in_with_optional_connect(true, cx).await?;
155 state.update(cx, |_, cx| cx.notify())
156 })
157 }
158
159 fn refresh_models(&mut self, cx: &mut Context<Self>) {
160 self.provider.update(cx, |provider, cx| {
161 provider.refresh_models(cx).detach_and_log_err(cx);
162 });
163 }
164}
165
166impl CloudLanguageModelProvider {
167 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
168 let mut status_rx = client.status();
169 let status = *status_rx.borrow();
170
171 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
172
173 let state_ref = state.downgrade();
174 let maintain_client_status = cx.spawn(async move |cx| {
175 while let Some(status) = status_rx.next().await {
176 if let Some(this) = state_ref.upgrade() {
177 _ = this.update(cx, |this, cx| {
178 if this.status != status {
179 this.status = status;
180 cx.notify();
181 }
182 });
183 } else {
184 break;
185 }
186 }
187 });
188
189 Self {
190 state,
191 _maintain_client_status: maintain_client_status,
192 }
193 }
194}
195
196impl LanguageModelProviderState for CloudLanguageModelProvider {
197 type ObservableEntity = State;
198
199 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
200 Some(self.state.clone())
201 }
202}
203
204impl LanguageModelProvider for CloudLanguageModelProvider {
205 fn id(&self) -> LanguageModelProviderId {
206 PROVIDER_ID
207 }
208
209 fn name(&self) -> LanguageModelProviderName {
210 PROVIDER_NAME
211 }
212
213 fn icon(&self) -> IconOrSvg {
214 IconOrSvg::Icon(IconName::AiZed)
215 }
216
217 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
218 let state = self.state.read(cx);
219 let provider = state.provider.read(cx);
220 let model = provider.default_model()?;
221 Some(provider.create_model(model))
222 }
223
224 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
225 let state = self.state.read(cx);
226 let provider = state.provider.read(cx);
227 let model = provider.default_fast_model()?;
228 Some(provider.create_model(model))
229 }
230
231 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
232 let state = self.state.read(cx);
233 let provider = state.provider.read(cx);
234 provider
235 .recommended_models()
236 .iter()
237 .map(|model| provider.create_model(model))
238 .collect()
239 }
240
241 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
242 let state = self.state.read(cx);
243 let provider = state.provider.read(cx);
244 provider
245 .models()
246 .iter()
247 .map(|model| provider.create_model(model))
248 .collect()
249 }
250
251 fn is_authenticated(&self, cx: &App) -> bool {
252 let state = self.state.read(cx);
253 !state.is_signed_out(cx)
254 }
255
256 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
257 Task::ready(Ok(()))
258 }
259
260 fn configuration_view(
261 &self,
262 _target_agent: language_model::ConfigurationViewTargetAgent,
263 _: &mut Window,
264 cx: &mut App,
265 ) -> AnyView {
266 cx.new(|_| ConfigurationView::new(self.state.clone()))
267 .into()
268 }
269
270 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
271 Task::ready(Ok(()))
272 }
273}
274
275#[derive(IntoElement, RegisterComponent)]
276struct ZedAiConfiguration {
277 is_connected: bool,
278 plan: Option<Plan>,
279 eligible_for_trial: bool,
280 account_too_young: bool,
281 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
282}
283
284impl RenderOnce for ZedAiConfiguration {
285 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
286 let (subscription_text, has_paid_plan) = match self.plan {
287 Some(Plan::ZedPro) => (
288 "You have access to Zed's hosted models through your Pro subscription.",
289 true,
290 ),
291 Some(Plan::ZedProTrial) => (
292 "You have access to Zed's hosted models through your Pro trial.",
293 false,
294 ),
295 Some(Plan::ZedStudent) => (
296 "You have access to Zed's hosted models through your Student subscription.",
297 true,
298 ),
299 Some(Plan::ZedBusiness) => (
300 "You have access to Zed's hosted models through your Organization.",
301 true,
302 ),
303 Some(Plan::ZedFree) | None => (
304 if self.eligible_for_trial {
305 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
306 } else {
307 "Subscribe for access to Zed's hosted models."
308 },
309 false,
310 ),
311 };
312
313 let manage_subscription_buttons = if has_paid_plan {
314 Button::new("manage_settings", "Manage Subscription")
315 .full_width()
316 .label_size(LabelSize::Small)
317 .style(ButtonStyle::Tinted(TintColor::Accent))
318 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
319 .into_any_element()
320 } else if self.plan.is_none() || self.eligible_for_trial {
321 Button::new("start_trial", "Start 14-day Free Pro Trial")
322 .full_width()
323 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
324 .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
325 .into_any_element()
326 } else {
327 Button::new("upgrade", "Upgrade to Pro")
328 .full_width()
329 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
330 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
331 .into_any_element()
332 };
333
334 if !self.is_connected {
335 return v_flex()
336 .gap_2()
337 .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
338 .child(
339 Button::new("sign_in", "Sign In to use Zed AI")
340 .start_icon(Icon::new(IconName::Github).size(IconSize::Small).color(Color::Muted))
341 .full_width()
342 .on_click({
343 let callback = self.sign_in_callback.clone();
344 move |_, window, cx| (callback)(window, cx)
345 }),
346 );
347 }
348
349 v_flex().gap_2().w_full().map(|this| {
350 if self.account_too_young {
351 this.child(YoungAccountBanner).child(
352 Button::new("upgrade", "Upgrade to Pro")
353 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
354 .full_width()
355 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
356 )
357 } else {
358 this.text_sm()
359 .child(subscription_text)
360 .child(manage_subscription_buttons)
361 }
362 })
363 }
364}
365
366struct ConfigurationView {
367 state: Entity<State>,
368 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
369}
370
371impl ConfigurationView {
372 fn new(state: Entity<State>) -> Self {
373 let sign_in_callback = Arc::new({
374 let state = state.clone();
375 move |_window: &mut Window, cx: &mut App| {
376 state.update(cx, |state, cx| {
377 state.authenticate(cx).detach_and_log_err(cx);
378 });
379 }
380 });
381
382 Self {
383 state,
384 sign_in_callback,
385 }
386 }
387}
388
389impl Render for ConfigurationView {
390 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
391 let state = self.state.read(cx);
392 let user_store = state.user_store.read(cx);
393
394 ZedAiConfiguration {
395 is_connected: !state.is_signed_out(cx),
396 plan: user_store.plan(),
397 eligible_for_trial: user_store.trial_started_at().is_none(),
398 account_too_young: user_store.account_too_young(),
399 sign_in_callback: self.sign_in_callback.clone(),
400 }
401 }
402}
403
404impl Component for ZedAiConfiguration {
405 fn name() -> &'static str {
406 "AI Configuration Content"
407 }
408
409 fn sort_name() -> &'static str {
410 "AI Configuration Content"
411 }
412
413 fn scope() -> ComponentScope {
414 ComponentScope::Onboarding
415 }
416
417 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
418 fn configuration(
419 is_connected: bool,
420 plan: Option<Plan>,
421 eligible_for_trial: bool,
422 account_too_young: bool,
423 ) -> AnyElement {
424 ZedAiConfiguration {
425 is_connected,
426 plan,
427 eligible_for_trial,
428 account_too_young,
429 sign_in_callback: Arc::new(|_, _| {}),
430 }
431 .into_any_element()
432 }
433
434 Some(
435 v_flex()
436 .p_4()
437 .gap_4()
438 .children(vec![
439 single_example("Not connected", configuration(false, None, false, false)),
440 single_example(
441 "Accept Terms of Service",
442 configuration(true, None, true, false),
443 ),
444 single_example(
445 "No Plan - Not eligible for trial",
446 configuration(true, None, false, false),
447 ),
448 single_example(
449 "No Plan - Eligible for trial",
450 configuration(true, None, true, false),
451 ),
452 single_example(
453 "Free Plan",
454 configuration(true, Some(Plan::ZedFree), true, false),
455 ),
456 single_example(
457 "Zed Pro Trial Plan",
458 configuration(true, Some(Plan::ZedProTrial), true, false),
459 ),
460 single_example(
461 "Zed Pro Plan",
462 configuration(true, Some(Plan::ZedPro), true, false),
463 ),
464 ])
465 .into_any_element(),
466 )
467 }
468}