1use std::{rc::Rc, sync::Arc};
2
3use feature_flags::ZedPro;
4use gpui::{
5 action_with_deprecated_aliases, Action, AnyElement, App, Corner, DismissEvent, Entity,
6 EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
7};
8use language_model::{
9 AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
10};
11use picker::{Picker, PickerDelegate};
12use proto::Plan;
13use ui::{
14 prelude::*, ButtonLike, IconButtonShape, ListItem, ListItemSpacing, PopoverMenu,
15 PopoverMenuHandle, Tooltip,
16};
17use workspace::ShowConfiguration;
18
19action_with_deprecated_aliases!(
20 assistant,
21 ToggleModelSelector,
22 ["assistant2::ToggleModelSelector"]
23);
24
25const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
26
27type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
28
29pub struct LanguageModelSelector {
30 picker: Entity<Picker<LanguageModelPickerDelegate>>,
31 /// The task used to update the picker's matches when there is a change to
32 /// the language model registry.
33 update_matches_task: Option<Task<()>>,
34 popover_menu_handle: PopoverMenuHandle<LanguageModelSelector>,
35 _authenticate_all_providers_task: Task<()>,
36 _subscriptions: Vec<Subscription>,
37}
38
39impl LanguageModelSelector {
40 pub fn new(
41 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
42 window: &mut Window,
43 cx: &mut Context<Self>,
44 ) -> Self {
45 let on_model_changed = Arc::new(on_model_changed);
46
47 let all_models = Self::all_models(cx);
48 let delegate = LanguageModelPickerDelegate {
49 language_model_selector: cx.entity().downgrade(),
50 on_model_changed: on_model_changed.clone(),
51 all_models: all_models.clone(),
52 filtered_models: all_models,
53 selected_index: Self::get_active_model_index(cx),
54 };
55
56 let picker = cx.new(|cx| {
57 Picker::uniform_list(delegate, window, cx)
58 .show_scrollbar(true)
59 .width(rems(20.))
60 .max_height(Some(rems(20.).into()))
61 });
62
63 LanguageModelSelector {
64 picker,
65 update_matches_task: None,
66 popover_menu_handle: PopoverMenuHandle::default(),
67 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
68 _subscriptions: vec![cx.subscribe_in(
69 &LanguageModelRegistry::global(cx),
70 window,
71 Self::handle_language_model_registry_event,
72 )],
73 }
74 }
75
76 pub fn toggle_model_selector(
77 &mut self,
78 _: &ToggleModelSelector,
79 window: &mut Window,
80 cx: &mut Context<Self>,
81 ) {
82 self.popover_menu_handle.toggle(window, cx);
83 }
84
85 fn handle_language_model_registry_event(
86 &mut self,
87 _registry: &Entity<LanguageModelRegistry>,
88 event: &language_model::Event,
89 window: &mut Window,
90 cx: &mut Context<Self>,
91 ) {
92 match event {
93 language_model::Event::ProviderStateChanged
94 | language_model::Event::AddedProvider(_)
95 | language_model::Event::RemovedProvider(_) => {
96 let task = self.picker.update(cx, |this, cx| {
97 let query = this.query(cx);
98 this.delegate.all_models = Self::all_models(cx);
99 this.delegate.update_matches(query, window, cx)
100 });
101 self.update_matches_task = Some(task);
102 }
103 _ => {}
104 }
105 }
106
107 /// Authenticates all providers in the [`LanguageModelRegistry`].
108 ///
109 /// We do this so that we can populate the language selector with all of the
110 /// models from the configured providers.
111 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
112 let authenticate_all_providers = LanguageModelRegistry::global(cx)
113 .read(cx)
114 .providers()
115 .iter()
116 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
117 .collect::<Vec<_>>();
118
119 cx.spawn(|_cx| async move {
120 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
121 if let Err(err) = authenticate_task.await {
122 if matches!(err, AuthenticateError::CredentialsNotFound) {
123 // Since we're authenticating these providers in the
124 // background for the purposes of populating the
125 // language selector, we don't care about providers
126 // where the credentials are not found.
127 } else {
128 // Some providers have noisy failure states that we
129 // don't want to spam the logs with every time the
130 // language model selector is initialized.
131 //
132 // Ideally these should have more clear failure modes
133 // that we know are safe to ignore here, like what we do
134 // with `CredentialsNotFound` above.
135 match provider_id.0.as_ref() {
136 "lmstudio" | "ollama" => {
137 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
138 //
139 // These fail noisily, so we don't log them.
140 }
141 "copilot_chat" => {
142 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
143 }
144 _ => {
145 log::error!(
146 "Failed to authenticate provider: {}: {err}",
147 provider_name.0
148 );
149 }
150 }
151 }
152 }
153 }
154 })
155 }
156
157 fn all_models(cx: &App) -> Vec<ModelInfo> {
158 LanguageModelRegistry::global(cx)
159 .read(cx)
160 .providers()
161 .iter()
162 .flat_map(|provider| {
163 let icon = provider.icon();
164
165 provider.provided_models(cx).into_iter().map(move |model| {
166 let model = model.clone();
167 let icon = model.icon().unwrap_or(icon);
168
169 ModelInfo {
170 model: model.clone(),
171 icon,
172 availability: model.availability(),
173 }
174 })
175 })
176 .collect::<Vec<_>>()
177 }
178
179 fn get_active_model_index(cx: &App) -> usize {
180 let active_model = LanguageModelRegistry::read_global(cx).active_model();
181 Self::all_models(cx)
182 .iter()
183 .position(|model_info| {
184 Some(model_info.model.id()) == active_model.as_ref().map(|model| model.id())
185 })
186 .unwrap_or(0)
187 }
188}
189
190impl EventEmitter<DismissEvent> for LanguageModelSelector {}
191
192impl Focusable for LanguageModelSelector {
193 fn focus_handle(&self, cx: &App) -> FocusHandle {
194 self.picker.focus_handle(cx)
195 }
196}
197
198impl Render for LanguageModelSelector {
199 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
200 self.picker.clone()
201 }
202}
203
204#[derive(Clone)]
205struct ModelInfo {
206 model: Arc<dyn LanguageModel>,
207 icon: IconName,
208 availability: LanguageModelAvailability,
209}
210
211pub struct LanguageModelPickerDelegate {
212 language_model_selector: WeakEntity<LanguageModelSelector>,
213 on_model_changed: OnModelChanged,
214 all_models: Vec<ModelInfo>,
215 filtered_models: Vec<ModelInfo>,
216 selected_index: usize,
217}
218
219impl PickerDelegate for LanguageModelPickerDelegate {
220 type ListItem = ListItem;
221
222 fn match_count(&self) -> usize {
223 self.filtered_models.len()
224 }
225
226 fn selected_index(&self) -> usize {
227 self.selected_index
228 }
229
230 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
231 self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
232 cx.notify();
233 }
234
235 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
236 "Select a model...".into()
237 }
238
239 fn update_matches(
240 &mut self,
241 query: String,
242 window: &mut Window,
243 cx: &mut Context<Picker<Self>>,
244 ) -> Task<()> {
245 let all_models = self.all_models.clone();
246 let current_index = self.selected_index;
247
248 let language_model_registry = LanguageModelRegistry::global(cx);
249
250 let configured_providers = language_model_registry
251 .read(cx)
252 .providers()
253 .iter()
254 .filter(|provider| provider.is_authenticated(cx))
255 .map(|provider| provider.id())
256 .collect::<Vec<_>>();
257
258 cx.spawn_in(window, |this, mut cx| async move {
259 let filtered_models = cx
260 .background_spawn(async move {
261 let displayed_models = if configured_providers.is_empty() {
262 all_models
263 } else {
264 all_models
265 .into_iter()
266 .filter(|model_info| {
267 configured_providers.contains(&model_info.model.provider_id())
268 })
269 .collect::<Vec<_>>()
270 };
271
272 if query.is_empty() {
273 displayed_models
274 } else {
275 displayed_models
276 .into_iter()
277 .filter(|model_info| {
278 model_info
279 .model
280 .name()
281 .0
282 .to_lowercase()
283 .contains(&query.to_lowercase())
284 })
285 .collect()
286 }
287 })
288 .await;
289
290 this.update_in(&mut cx, |this, window, cx| {
291 this.delegate.filtered_models = filtered_models;
292 // Preserve selection focus
293 let new_index = if current_index >= this.delegate.filtered_models.len() {
294 0
295 } else {
296 current_index
297 };
298 this.delegate.set_selected_index(new_index, window, cx);
299 cx.notify();
300 })
301 .ok();
302 })
303 }
304
305 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
306 if let Some(model_info) = self.filtered_models.get(self.selected_index) {
307 let model = model_info.model.clone();
308 (self.on_model_changed)(model.clone(), cx);
309
310 let current_index = self.selected_index;
311 self.set_selected_index(current_index, window, cx);
312
313 cx.emit(DismissEvent);
314 }
315 }
316
317 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
318 self.language_model_selector
319 .update(cx, |_this, cx| cx.emit(DismissEvent))
320 .ok();
321 }
322
323 fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
324 let configured_models_count = LanguageModelRegistry::global(cx)
325 .read(cx)
326 .providers()
327 .iter()
328 .filter(|provider| provider.is_authenticated(cx))
329 .count();
330
331 if configured_models_count > 0 {
332 Some(
333 Label::new("Configured Models")
334 .size(LabelSize::Small)
335 .color(Color::Muted)
336 .mt_1()
337 .mb_0p5()
338 .ml_2()
339 .into_any_element(),
340 )
341 } else {
342 None
343 }
344 }
345
346 fn render_match(
347 &self,
348 ix: usize,
349 selected: bool,
350 _: &mut Window,
351 cx: &mut Context<Picker<Self>>,
352 ) -> Option<Self::ListItem> {
353 use feature_flags::FeatureFlagAppExt;
354 let show_badges = cx.has_flag::<ZedPro>();
355
356 let model_info = self.filtered_models.get(ix)?;
357 let provider_name: String = model_info.model.provider_name().0.clone().into();
358
359 let active_provider_id = LanguageModelRegistry::read_global(cx)
360 .active_provider()
361 .map(|m| m.id());
362
363 let active_model_id = LanguageModelRegistry::read_global(cx)
364 .active_model()
365 .map(|m| m.id());
366
367 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
368 && Some(model_info.model.id()) == active_model_id;
369
370 let model_icon_color = if is_selected {
371 Color::Accent
372 } else {
373 Color::Muted
374 };
375
376 Some(
377 ListItem::new(ix)
378 .inset(true)
379 .spacing(ListItemSpacing::Sparse)
380 .toggle_state(selected)
381 .start_slot(
382 Icon::new(model_info.icon)
383 .color(model_icon_color)
384 .size(IconSize::Small),
385 )
386 .child(
387 h_flex()
388 .w_full()
389 .items_center()
390 .gap_1p5()
391 .pl_0p5()
392 .w(px(240.))
393 .child(
394 div()
395 .max_w_40()
396 .child(Label::new(model_info.model.name().0.clone()).truncate()),
397 )
398 .child(
399 h_flex()
400 .gap_0p5()
401 .child(
402 Label::new(provider_name)
403 .size(LabelSize::XSmall)
404 .color(Color::Muted),
405 )
406 .children(match model_info.availability {
407 LanguageModelAvailability::Public => None,
408 LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
409 LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
410 show_badges.then(|| {
411 Label::new("Pro")
412 .size(LabelSize::XSmall)
413 .color(Color::Muted)
414 })
415 }
416 }),
417 ),
418 )
419 .end_slot(div().pr_3().when(is_selected, |this| {
420 this.child(
421 Icon::new(IconName::Check)
422 .color(Color::Accent)
423 .size(IconSize::Small),
424 )
425 })),
426 )
427 }
428
429 fn render_footer(
430 &self,
431 _: &mut Window,
432 cx: &mut Context<Picker<Self>>,
433 ) -> Option<gpui::AnyElement> {
434 use feature_flags::FeatureFlagAppExt;
435
436 let plan = proto::Plan::ZedPro;
437 let is_trial = false;
438
439 Some(
440 h_flex()
441 .w_full()
442 .border_t_1()
443 .border_color(cx.theme().colors().border_variant)
444 .p_1()
445 .gap_4()
446 .justify_between()
447 .when(cx.has_flag::<ZedPro>(), |this| {
448 this.child(match plan {
449 // Already a Zed Pro subscriber
450 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
451 .icon(IconName::ZedAssistant)
452 .icon_size(IconSize::Small)
453 .icon_color(Color::Muted)
454 .icon_position(IconPosition::Start)
455 .on_click(|_, window, cx| {
456 window
457 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
458 }),
459 // Free user
460 Plan::Free => Button::new(
461 "try-pro",
462 if is_trial {
463 "Upgrade to Pro"
464 } else {
465 "Try Pro"
466 },
467 )
468 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
469 })
470 })
471 .child(
472 Button::new("configure", "Configure")
473 .icon(IconName::Settings)
474 .icon_size(IconSize::Small)
475 .icon_color(Color::Muted)
476 .icon_position(IconPosition::Start)
477 .on_click(|_, window, cx| {
478 window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
479 }),
480 )
481 .into_any(),
482 )
483 }
484}
485
486pub fn inline_language_model_selector(
487 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
488) -> PopoverMenu<LanguageModelSelector> {
489 let on_model_changed = Rc::new(on_model_changed);
490 PopoverMenu::new("popover-button")
491 .menu(move |window, cx| {
492 Some(cx.new(|cx| {
493 LanguageModelSelector::new(
494 {
495 let on_model_changed = on_model_changed.clone();
496 move |model, cx| {
497 on_model_changed(model, cx);
498 }
499 },
500 window,
501 cx,
502 )
503 }))
504 })
505 .trigger_with_tooltip(
506 IconButton::new("context", IconName::SettingsAlt)
507 .shape(IconButtonShape::Square)
508 .icon_size(IconSize::Small)
509 .icon_color(Color::Muted),
510 move |window, cx| {
511 Tooltip::with_meta(
512 format!(
513 "Using {}",
514 LanguageModelRegistry::read_global(cx)
515 .active_model()
516 .map(|model| model.name().0)
517 .unwrap_or_else(|| "No model selected".into()),
518 ),
519 None,
520 "Change Model",
521 window,
522 cx,
523 )
524 },
525 )
526 .anchor(gpui::Corner::TopRight)
527 .offset(gpui::Point {
528 x: px(0.0),
529 y: px(-2.0),
530 })
531}
532
533pub fn assistant_language_model_selector(
534 keybinding_target: FocusHandle,
535 menu_handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
536 cx: &App,
537 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
538) -> PopoverMenu<LanguageModelSelector> {
539 let active_model = LanguageModelRegistry::read_global(cx).active_model();
540 let model_name = match active_model {
541 Some(model) => model.name().0,
542 _ => SharedString::from("No model selected"),
543 };
544
545 let on_model_changed = Rc::new(on_model_changed);
546
547 PopoverMenu::new("popover-button")
548 .menu(move |window, cx| {
549 Some(cx.new(|cx| {
550 LanguageModelSelector::new(
551 {
552 let on_model_changed = on_model_changed.clone();
553 move |model, cx| {
554 on_model_changed(model, cx);
555 }
556 },
557 window,
558 cx,
559 )
560 }))
561 })
562 .trigger_with_tooltip(
563 ButtonLike::new("active-model")
564 .style(ButtonStyle::Subtle)
565 .child(
566 h_flex()
567 .gap_0p5()
568 .child(
569 Label::new(model_name)
570 .size(LabelSize::Small)
571 .color(Color::Muted),
572 )
573 .child(
574 Icon::new(IconName::ChevronDown)
575 .color(Color::Muted)
576 .size(IconSize::XSmall),
577 ),
578 ),
579 move |window, cx| {
580 Tooltip::for_action_in(
581 "Change Model",
582 &ToggleModelSelector,
583 &keybinding_target,
584 window,
585 cx,
586 )
587 },
588 )
589 .anchor(Corner::BottomRight)
590 .when_some(menu_handle, |el, handle| el.with_handle(handle))
591 .offset(gpui::Point {
592 x: px(0.0),
593 y: px(-2.0),
594 })
595}