1use std::sync::Arc;
2
3use collections::{HashSet, IndexMap};
4use feature_flags::{Assistant2FeatureFlag, ZedPro};
5use gpui::{
6 Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
7 Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
8};
9use language_model::{
10 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
11};
12use picker::{Picker, PickerDelegate};
13use proto::Plan;
14use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
15
16action_with_deprecated_aliases!(
17 assistant,
18 ToggleModelSelector,
19 ["assistant2::ToggleModelSelector"]
20);
21
22const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
23
24type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
25
26pub struct LanguageModelSelector {
27 picker: Entity<Picker<LanguageModelPickerDelegate>>,
28 _authenticate_all_providers_task: Task<()>,
29 _subscriptions: Vec<Subscription>,
30}
31
32impl LanguageModelSelector {
33 pub fn new(
34 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
35 window: &mut Window,
36 cx: &mut Context<Self>,
37 ) -> Self {
38 let on_model_changed = Arc::new(on_model_changed);
39
40 let all_models = Self::all_models(cx);
41 let entries = all_models.entries();
42
43 let delegate = LanguageModelPickerDelegate {
44 language_model_selector: cx.entity().downgrade(),
45 on_model_changed: on_model_changed.clone(),
46 all_models: Arc::new(all_models),
47 selected_index: Self::get_active_model_index(&entries, cx),
48 filtered_entries: entries,
49 };
50
51 let picker = cx.new(|cx| {
52 Picker::list(delegate, window, cx)
53 .show_scrollbar(true)
54 .width(rems(20.))
55 .max_height(Some(rems(20.).into()))
56 });
57
58 let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
59
60 LanguageModelSelector {
61 picker,
62 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
63 _subscriptions: vec![
64 cx.subscribe_in(
65 &LanguageModelRegistry::global(cx),
66 window,
67 Self::handle_language_model_registry_event,
68 ),
69 subscription,
70 ],
71 }
72 }
73
74 fn handle_language_model_registry_event(
75 &mut self,
76 _registry: &Entity<LanguageModelRegistry>,
77 event: &language_model::Event,
78 window: &mut Window,
79 cx: &mut Context<Self>,
80 ) {
81 match event {
82 language_model::Event::ProviderStateChanged
83 | language_model::Event::AddedProvider(_)
84 | language_model::Event::RemovedProvider(_) => {
85 self.picker.update(cx, |this, cx| {
86 let query = this.query(cx);
87 this.delegate.all_models = Arc::new(Self::all_models(cx));
88 // Update matches will automatically drop the previous task
89 // if we get a provider event again
90 this.update_matches(query, window, cx)
91 });
92 }
93 _ => {}
94 }
95 }
96
97 /// Authenticates all providers in the [`LanguageModelRegistry`].
98 ///
99 /// We do this so that we can populate the language selector with all of the
100 /// models from the configured providers.
101 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
102 let authenticate_all_providers = LanguageModelRegistry::global(cx)
103 .read(cx)
104 .providers()
105 .iter()
106 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
107 .collect::<Vec<_>>();
108
109 cx.spawn(async move |_cx| {
110 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
111 if let Err(err) = authenticate_task.await {
112 if matches!(err, AuthenticateError::CredentialsNotFound) {
113 // Since we're authenticating these providers in the
114 // background for the purposes of populating the
115 // language selector, we don't care about providers
116 // where the credentials are not found.
117 } else {
118 // Some providers have noisy failure states that we
119 // don't want to spam the logs with every time the
120 // language model selector is initialized.
121 //
122 // Ideally these should have more clear failure modes
123 // that we know are safe to ignore here, like what we do
124 // with `CredentialsNotFound` above.
125 match provider_id.0.as_ref() {
126 "lmstudio" | "ollama" => {
127 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
128 //
129 // These fail noisily, so we don't log them.
130 }
131 "copilot_chat" => {
132 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
133 }
134 _ => {
135 log::error!(
136 "Failed to authenticate provider: {}: {err}",
137 provider_name.0
138 );
139 }
140 }
141 }
142 }
143 }
144 })
145 }
146
147 fn all_models(cx: &App) -> GroupedModels {
148 let mut recommended = Vec::new();
149 let mut recommended_set = HashSet::default();
150 for provider in LanguageModelRegistry::global(cx)
151 .read(cx)
152 .providers()
153 .iter()
154 {
155 let models = provider.recommended_models(cx);
156 recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
157 recommended.extend(
158 provider
159 .recommended_models(cx)
160 .into_iter()
161 .map(move |model| ModelInfo {
162 model: model.clone(),
163 icon: provider.icon(),
164 }),
165 );
166 }
167
168 let other_models = LanguageModelRegistry::global(cx)
169 .read(cx)
170 .providers()
171 .iter()
172 .map(|provider| {
173 (
174 provider.id(),
175 provider
176 .provided_models(cx)
177 .into_iter()
178 .filter_map(|model| {
179 let not_included =
180 !recommended_set.contains(&(model.provider_id(), model.id()));
181 not_included.then(|| ModelInfo {
182 model: model.clone(),
183 icon: provider.icon(),
184 })
185 })
186 .collect::<Vec<_>>(),
187 )
188 })
189 .collect::<IndexMap<_, _>>();
190
191 GroupedModels {
192 recommended,
193 other: other_models,
194 }
195 }
196
197 fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize {
198 let active_model = LanguageModelRegistry::read_global(cx).default_model();
199 entries
200 .iter()
201 .position(|entry| {
202 if let LanguageModelPickerEntry::Model(model) = entry {
203 active_model
204 .as_ref()
205 .map(|active_model| {
206 active_model.model.id() == model.model.id()
207 && active_model.model.provider_id() == model.model.provider_id()
208 })
209 .unwrap_or_default()
210 } else {
211 false
212 }
213 })
214 .unwrap_or(0)
215 }
216}
217
218impl EventEmitter<DismissEvent> for LanguageModelSelector {}
219
220impl Focusable for LanguageModelSelector {
221 fn focus_handle(&self, cx: &App) -> FocusHandle {
222 self.picker.focus_handle(cx)
223 }
224}
225
226impl Render for LanguageModelSelector {
227 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
228 self.picker.clone()
229 }
230}
231
232#[derive(IntoElement)]
233pub struct LanguageModelSelectorPopoverMenu<T, TT>
234where
235 T: PopoverTrigger + ButtonCommon,
236 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
237{
238 language_model_selector: Entity<LanguageModelSelector>,
239 trigger: T,
240 tooltip: TT,
241 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
242 anchor: Corner,
243}
244
245impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
246where
247 T: PopoverTrigger + ButtonCommon,
248 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
249{
250 pub fn new(
251 language_model_selector: Entity<LanguageModelSelector>,
252 trigger: T,
253 tooltip: TT,
254 anchor: Corner,
255 ) -> Self {
256 Self {
257 language_model_selector,
258 trigger,
259 tooltip,
260 handle: None,
261 anchor,
262 }
263 }
264
265 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
266 self.handle = Some(handle);
267 self
268 }
269}
270
271impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
272where
273 T: PopoverTrigger + ButtonCommon,
274 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
275{
276 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
277 let language_model_selector = self.language_model_selector.clone();
278
279 PopoverMenu::new("model-switcher")
280 .menu(move |_window, _cx| Some(language_model_selector.clone()))
281 .trigger_with_tooltip(self.trigger, self.tooltip)
282 .anchor(self.anchor)
283 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
284 .offset(gpui::Point {
285 x: px(0.0),
286 y: px(-2.0),
287 })
288 }
289}
290
291#[derive(Clone)]
292struct ModelInfo {
293 model: Arc<dyn LanguageModel>,
294 icon: IconName,
295}
296
297pub struct LanguageModelPickerDelegate {
298 language_model_selector: WeakEntity<LanguageModelSelector>,
299 on_model_changed: OnModelChanged,
300 all_models: Arc<GroupedModels>,
301 filtered_entries: Vec<LanguageModelPickerEntry>,
302 selected_index: usize,
303}
304
305struct GroupedModels {
306 recommended: Vec<ModelInfo>,
307 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
308}
309
310impl GroupedModels {
311 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
312 let mut entries = Vec::new();
313
314 if !self.recommended.is_empty() {
315 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
316 entries.extend(
317 self.recommended
318 .iter()
319 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
320 );
321 }
322
323 for models in self.other.values() {
324 if models.is_empty() {
325 continue;
326 }
327 entries.push(LanguageModelPickerEntry::Separator(
328 models[0].model.provider_name().0,
329 ));
330 entries.extend(
331 models
332 .iter()
333 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
334 );
335 }
336 entries
337 }
338}
339
340enum LanguageModelPickerEntry {
341 Model(ModelInfo),
342 Separator(SharedString),
343}
344
345impl PickerDelegate for LanguageModelPickerDelegate {
346 type ListItem = AnyElement;
347
348 fn match_count(&self) -> usize {
349 self.filtered_entries.len()
350 }
351
352 fn selected_index(&self) -> usize {
353 self.selected_index
354 }
355
356 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
357 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
358 cx.notify();
359 }
360
361 fn can_select(
362 &mut self,
363 ix: usize,
364 _window: &mut Window,
365 _cx: &mut Context<Picker<Self>>,
366 ) -> bool {
367 match self.filtered_entries.get(ix) {
368 Some(LanguageModelPickerEntry::Model(_)) => true,
369 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
370 }
371 }
372
373 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
374 "Select a model…".into()
375 }
376
377 fn update_matches(
378 &mut self,
379 query: String,
380 window: &mut Window,
381 cx: &mut Context<Picker<Self>>,
382 ) -> Task<()> {
383 let all_models = self.all_models.clone();
384 let current_index = self.selected_index;
385
386 let language_model_registry = LanguageModelRegistry::global(cx);
387
388 let configured_providers = language_model_registry
389 .read(cx)
390 .providers()
391 .iter()
392 .filter(|provider| provider.is_authenticated(cx))
393 .map(|provider| provider.id())
394 .collect::<Vec<_>>();
395
396 cx.spawn_in(window, async move |this, cx| {
397 let filtered_models = cx
398 .background_spawn(async move {
399 let filter_models = |model_infos: &[ModelInfo]| {
400 model_infos
401 .iter()
402 .filter(|model_info| {
403 model_info
404 .model
405 .name()
406 .0
407 .to_lowercase()
408 .contains(&query.to_lowercase())
409 })
410 .cloned()
411 .collect::<Vec<_>>()
412 };
413
414 let recommended_models = filter_models(&all_models.recommended);
415 let mut other_models = IndexMap::default();
416 for (provider_id, models) in &all_models.other {
417 if configured_providers.contains(&provider_id) {
418 other_models.insert(provider_id.clone(), filter_models(models));
419 }
420 }
421 GroupedModels {
422 recommended: recommended_models,
423 other: other_models,
424 }
425 })
426 .await;
427
428 this.update_in(cx, |this, window, cx| {
429 this.delegate.filtered_entries = filtered_models.entries();
430 // Preserve selection focus
431 let new_index = if current_index >= this.delegate.filtered_entries.len() {
432 0
433 } else {
434 current_index
435 };
436 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
437 cx.notify();
438 })
439 .ok();
440 })
441 }
442
443 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
444 if let Some(LanguageModelPickerEntry::Model(model_info)) =
445 self.filtered_entries.get(self.selected_index)
446 {
447 let model = model_info.model.clone();
448 (self.on_model_changed)(model.clone(), cx);
449
450 let current_index = self.selected_index;
451 self.set_selected_index(current_index, window, cx);
452
453 cx.emit(DismissEvent);
454 }
455 }
456
457 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
458 self.language_model_selector
459 .update(cx, |_this, cx| cx.emit(DismissEvent))
460 .ok();
461 }
462
463 fn render_match(
464 &self,
465 ix: usize,
466 selected: bool,
467 _: &mut Window,
468 cx: &mut Context<Picker<Self>>,
469 ) -> Option<Self::ListItem> {
470 match self.filtered_entries.get(ix)? {
471 LanguageModelPickerEntry::Separator(title) => Some(
472 div()
473 .px_2()
474 .pb_1()
475 .when(ix > 1, |this| {
476 this.mt_1()
477 .pt_2()
478 .border_t_1()
479 .border_color(cx.theme().colors().border_variant)
480 })
481 .child(
482 Label::new(title)
483 .size(LabelSize::XSmall)
484 .color(Color::Muted),
485 )
486 .into_any_element(),
487 ),
488 LanguageModelPickerEntry::Model(model_info) => {
489 let active_model = LanguageModelRegistry::read_global(cx).default_model();
490
491 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
492 let active_model_id = active_model.map(|m| m.model.id());
493
494 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
495 && Some(model_info.model.id()) == active_model_id;
496
497 let model_icon_color = if is_selected {
498 Color::Accent
499 } else {
500 Color::Muted
501 };
502
503 Some(
504 ListItem::new(ix)
505 .inset(true)
506 .spacing(ListItemSpacing::Sparse)
507 .toggle_state(selected)
508 .start_slot(
509 Icon::new(model_info.icon)
510 .color(model_icon_color)
511 .size(IconSize::Small),
512 )
513 .child(
514 h_flex()
515 .w_full()
516 .pl_0p5()
517 .gap_1p5()
518 .w(px(240.))
519 .child(Label::new(model_info.model.name().0.clone()).truncate()),
520 )
521 .end_slot(div().pr_3().when(is_selected, |this| {
522 this.child(
523 Icon::new(IconName::Check)
524 .color(Color::Accent)
525 .size(IconSize::Small),
526 )
527 }))
528 .into_any_element(),
529 )
530 }
531 }
532 }
533
534 fn render_footer(
535 &self,
536 _: &mut Window,
537 cx: &mut Context<Picker<Self>>,
538 ) -> Option<gpui::AnyElement> {
539 use feature_flags::FeatureFlagAppExt;
540
541 let plan = proto::Plan::ZedPro;
542 let is_trial = false;
543
544 Some(
545 h_flex()
546 .w_full()
547 .border_t_1()
548 .border_color(cx.theme().colors().border_variant)
549 .p_1()
550 .gap_4()
551 .justify_between()
552 .when(cx.has_flag::<ZedPro>(), |this| {
553 this.child(match plan {
554 // Already a Zed Pro subscriber
555 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
556 .icon(IconName::ZedAssistant)
557 .icon_size(IconSize::Small)
558 .icon_color(Color::Muted)
559 .icon_position(IconPosition::Start)
560 .on_click(|_, window, cx| {
561 window
562 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
563 }),
564 // Free user
565 Plan::Free => Button::new(
566 "try-pro",
567 if is_trial {
568 "Upgrade to Pro"
569 } else {
570 "Try Pro"
571 },
572 )
573 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
574 })
575 })
576 .child(
577 Button::new("configure", "Configure")
578 .icon(IconName::Settings)
579 .icon_size(IconSize::Small)
580 .icon_color(Color::Muted)
581 .icon_position(IconPosition::Start)
582 .on_click(|_, window, cx| {
583 let configure_action = if cx.has_flag::<Assistant2FeatureFlag>() {
584 zed_actions::agent::OpenConfiguration.boxed_clone()
585 } else {
586 zed_actions::assistant::ShowConfiguration.boxed_clone()
587 };
588
589 window.dispatch_action(configure_action, cx);
590 }),
591 )
592 .into_any(),
593 )
594 }
595}