1use std::sync::Arc;
2
3use collections::{HashSet, IndexMap};
4use feature_flags::{Assistant2FeatureFlag, ZedPro};
5use gpui::{
6 Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
7 Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
8};
9use language_model::{
10 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
11};
12use picker::{Picker, PickerDelegate};
13use proto::Plan;
14use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
15
16action_with_deprecated_aliases!(
17 assistant,
18 ToggleModelSelector,
19 ["assistant2::ToggleModelSelector"]
20);
21
22const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
23
24type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
25
26pub struct LanguageModelSelector {
27 picker: Entity<Picker<LanguageModelPickerDelegate>>,
28 _authenticate_all_providers_task: Task<()>,
29 _subscriptions: Vec<Subscription>,
30}
31
32impl LanguageModelSelector {
33 pub fn new(
34 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
35 window: &mut Window,
36 cx: &mut Context<Self>,
37 ) -> Self {
38 let on_model_changed = Arc::new(on_model_changed);
39
40 let all_models = Self::all_models(cx);
41 let entries = all_models.entries();
42
43 let delegate = LanguageModelPickerDelegate {
44 language_model_selector: cx.entity().downgrade(),
45 on_model_changed: on_model_changed.clone(),
46 all_models: Arc::new(all_models),
47 selected_index: Self::get_active_model_index(&entries, cx),
48 filtered_entries: entries,
49 };
50
51 let picker = cx.new(|cx| {
52 Picker::list(delegate, window, cx)
53 .show_scrollbar(true)
54 .width(rems(20.))
55 .max_height(Some(rems(20.).into()))
56 });
57
58 let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
59
60 LanguageModelSelector {
61 picker,
62 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
63 _subscriptions: vec![
64 cx.subscribe_in(
65 &LanguageModelRegistry::global(cx),
66 window,
67 Self::handle_language_model_registry_event,
68 ),
69 subscription,
70 ],
71 }
72 }
73
74 fn handle_language_model_registry_event(
75 &mut self,
76 _registry: &Entity<LanguageModelRegistry>,
77 event: &language_model::Event,
78 window: &mut Window,
79 cx: &mut Context<Self>,
80 ) {
81 match event {
82 language_model::Event::ProviderStateChanged
83 | language_model::Event::AddedProvider(_)
84 | language_model::Event::RemovedProvider(_) => {
85 self.picker.update(cx, |this, cx| {
86 let query = this.query(cx);
87 this.delegate.all_models = Arc::new(Self::all_models(cx));
88 // Update matches will automatically drop the previous task
89 // if we get a provider event again
90 this.update_matches(query, window, cx)
91 });
92 }
93 _ => {}
94 }
95 }
96
97 /// Authenticates all providers in the [`LanguageModelRegistry`].
98 ///
99 /// We do this so that we can populate the language selector with all of the
100 /// models from the configured providers.
101 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
102 let authenticate_all_providers = LanguageModelRegistry::global(cx)
103 .read(cx)
104 .providers()
105 .iter()
106 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
107 .collect::<Vec<_>>();
108
109 cx.spawn(async move |_cx| {
110 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
111 if let Err(err) = authenticate_task.await {
112 if matches!(err, AuthenticateError::CredentialsNotFound) {
113 // Since we're authenticating these providers in the
114 // background for the purposes of populating the
115 // language selector, we don't care about providers
116 // where the credentials are not found.
117 } else {
118 // Some providers have noisy failure states that we
119 // don't want to spam the logs with every time the
120 // language model selector is initialized.
121 //
122 // Ideally these should have more clear failure modes
123 // that we know are safe to ignore here, like what we do
124 // with `CredentialsNotFound` above.
125 match provider_id.0.as_ref() {
126 "lmstudio" | "ollama" => {
127 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
128 //
129 // These fail noisily, so we don't log them.
130 }
131 "copilot_chat" => {
132 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
133 }
134 _ => {
135 log::error!(
136 "Failed to authenticate provider: {}: {err}",
137 provider_name.0
138 );
139 }
140 }
141 }
142 }
143 }
144 })
145 }
146
147 fn all_models(cx: &App) -> GroupedModels {
148 let mut recommended = Vec::new();
149 let mut recommended_set = HashSet::default();
150 for provider in LanguageModelRegistry::global(cx)
151 .read(cx)
152 .providers()
153 .iter()
154 {
155 let models = provider.recommended_models(cx);
156 recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
157 recommended.extend(
158 provider
159 .recommended_models(cx)
160 .into_iter()
161 .map(move |model| ModelInfo {
162 model: model.clone(),
163 icon: provider.icon(),
164 }),
165 );
166 }
167
168 let other_models = LanguageModelRegistry::global(cx)
169 .read(cx)
170 .providers()
171 .iter()
172 .map(|provider| {
173 (
174 provider.id(),
175 provider
176 .provided_models(cx)
177 .into_iter()
178 .filter_map(|model| {
179 let not_included =
180 !recommended_set.contains(&(model.provider_id(), model.id()));
181 not_included.then(|| ModelInfo {
182 model: model.clone(),
183 icon: provider.icon(),
184 })
185 })
186 .collect::<Vec<_>>(),
187 )
188 })
189 .collect::<IndexMap<_, _>>();
190
191 GroupedModels {
192 recommended,
193 other: other_models,
194 }
195 }
196
197 fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize {
198 let active_model = LanguageModelRegistry::read_global(cx).default_model();
199 entries
200 .iter()
201 .position(|entry| {
202 if let LanguageModelPickerEntry::Model(model) = entry {
203 active_model
204 .as_ref()
205 .map(|active_model| {
206 active_model.model.id() == model.model.id()
207 && active_model.model.provider_id() == model.model.provider_id()
208 })
209 .unwrap_or_default()
210 } else {
211 false
212 }
213 })
214 .unwrap_or(0)
215 }
216}
217
218impl EventEmitter<DismissEvent> for LanguageModelSelector {}
219
220impl Focusable for LanguageModelSelector {
221 fn focus_handle(&self, cx: &App) -> FocusHandle {
222 self.picker.focus_handle(cx)
223 }
224}
225
226impl Render for LanguageModelSelector {
227 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
228 self.picker.clone()
229 }
230}
231
232#[derive(IntoElement)]
233pub struct LanguageModelSelectorPopoverMenu<T, TT>
234where
235 T: PopoverTrigger + ButtonCommon,
236 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
237{
238 language_model_selector: Entity<LanguageModelSelector>,
239 trigger: T,
240 tooltip: TT,
241 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
242 anchor: Corner,
243}
244
245impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
246where
247 T: PopoverTrigger + ButtonCommon,
248 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
249{
250 pub fn new(
251 language_model_selector: Entity<LanguageModelSelector>,
252 trigger: T,
253 tooltip: TT,
254 anchor: Corner,
255 ) -> Self {
256 Self {
257 language_model_selector,
258 trigger,
259 tooltip,
260 handle: None,
261 anchor,
262 }
263 }
264
265 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
266 self.handle = Some(handle);
267 self
268 }
269}
270
271impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
272where
273 T: PopoverTrigger + ButtonCommon,
274 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
275{
276 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
277 let language_model_selector = self.language_model_selector.clone();
278
279 PopoverMenu::new("model-switcher")
280 .menu(move |_window, _cx| Some(language_model_selector.clone()))
281 .trigger_with_tooltip(self.trigger, self.tooltip)
282 .anchor(self.anchor)
283 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
284 .offset(gpui::Point {
285 x: px(0.0),
286 y: px(-2.0),
287 })
288 }
289}
290
291#[derive(Clone)]
292struct ModelInfo {
293 model: Arc<dyn LanguageModel>,
294 icon: IconName,
295}
296
297pub struct LanguageModelPickerDelegate {
298 language_model_selector: WeakEntity<LanguageModelSelector>,
299 on_model_changed: OnModelChanged,
300 all_models: Arc<GroupedModels>,
301 filtered_entries: Vec<LanguageModelPickerEntry>,
302 selected_index: usize,
303}
304
305struct GroupedModels {
306 recommended: Vec<ModelInfo>,
307 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
308}
309
310impl GroupedModels {
311 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
312 let mut entries = Vec::new();
313
314 if !self.recommended.is_empty() {
315 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
316 entries.extend(
317 self.recommended
318 .iter()
319 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
320 );
321 }
322
323 for models in self.other.values() {
324 if models.is_empty() {
325 continue;
326 }
327 entries.push(LanguageModelPickerEntry::Separator(
328 models[0].model.provider_name().0,
329 ));
330 entries.extend(
331 models
332 .iter()
333 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
334 );
335 }
336 entries
337 }
338}
339
340enum LanguageModelPickerEntry {
341 Model(ModelInfo),
342 Separator(SharedString),
343}
344
345impl PickerDelegate for LanguageModelPickerDelegate {
346 type ListItem = AnyElement;
347
348 fn match_count(&self) -> usize {
349 self.filtered_entries.len()
350 }
351
352 fn selected_index(&self) -> usize {
353 self.selected_index
354 }
355
356 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
357 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
358 cx.notify();
359 }
360
361 fn can_select(
362 &mut self,
363 ix: usize,
364 _window: &mut Window,
365 _cx: &mut Context<Picker<Self>>,
366 ) -> bool {
367 match self.filtered_entries.get(ix) {
368 Some(LanguageModelPickerEntry::Model(_)) => true,
369 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
370 }
371 }
372
373 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
374 "Select a model…".into()
375 }
376
377 fn update_matches(
378 &mut self,
379 query: String,
380 window: &mut Window,
381 cx: &mut Context<Picker<Self>>,
382 ) -> Task<()> {
383 let all_models = self.all_models.clone();
384 let current_index = self.selected_index;
385
386 let language_model_registry = LanguageModelRegistry::global(cx);
387
388 let configured_providers = language_model_registry
389 .read(cx)
390 .providers()
391 .iter()
392 .filter(|provider| provider.is_authenticated(cx))
393 .map(|provider| provider.id())
394 .collect::<Vec<_>>();
395
396 cx.spawn_in(window, async move |this, cx| {
397 let filtered_models = cx
398 .background_spawn(async move {
399 let matches = |info: &ModelInfo| {
400 info.model
401 .name()
402 .0
403 .to_lowercase()
404 .contains(&query.to_lowercase())
405 };
406
407 let recommended_models = all_models
408 .recommended
409 .iter()
410 .filter(|r| {
411 configured_providers.contains(&r.model.provider_id()) && matches(r)
412 })
413 .cloned()
414 .collect();
415 let mut other_models = IndexMap::default();
416 for (provider_id, models) in &all_models.other {
417 if configured_providers.contains(&provider_id) {
418 other_models.insert(
419 provider_id.clone(),
420 models
421 .iter()
422 .filter(|m| matches(m))
423 .cloned()
424 .collect::<Vec<_>>(),
425 );
426 }
427 }
428 GroupedModels {
429 recommended: recommended_models,
430 other: other_models,
431 }
432 })
433 .await;
434
435 this.update_in(cx, |this, window, cx| {
436 this.delegate.filtered_entries = filtered_models.entries();
437 // Preserve selection focus
438 let new_index = if current_index >= this.delegate.filtered_entries.len() {
439 0
440 } else {
441 current_index
442 };
443 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
444 cx.notify();
445 })
446 .ok();
447 })
448 }
449
450 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
451 if let Some(LanguageModelPickerEntry::Model(model_info)) =
452 self.filtered_entries.get(self.selected_index)
453 {
454 let model = model_info.model.clone();
455 (self.on_model_changed)(model.clone(), cx);
456
457 let current_index = self.selected_index;
458 self.set_selected_index(current_index, window, cx);
459
460 cx.emit(DismissEvent);
461 }
462 }
463
464 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
465 self.language_model_selector
466 .update(cx, |_this, cx| cx.emit(DismissEvent))
467 .ok();
468 }
469
470 fn render_match(
471 &self,
472 ix: usize,
473 selected: bool,
474 _: &mut Window,
475 cx: &mut Context<Picker<Self>>,
476 ) -> Option<Self::ListItem> {
477 match self.filtered_entries.get(ix)? {
478 LanguageModelPickerEntry::Separator(title) => Some(
479 div()
480 .px_2()
481 .pb_1()
482 .when(ix > 1, |this| {
483 this.mt_1()
484 .pt_2()
485 .border_t_1()
486 .border_color(cx.theme().colors().border_variant)
487 })
488 .child(
489 Label::new(title)
490 .size(LabelSize::XSmall)
491 .color(Color::Muted),
492 )
493 .into_any_element(),
494 ),
495 LanguageModelPickerEntry::Model(model_info) => {
496 let active_model = LanguageModelRegistry::read_global(cx).default_model();
497
498 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
499 let active_model_id = active_model.map(|m| m.model.id());
500
501 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
502 && Some(model_info.model.id()) == active_model_id;
503
504 let model_icon_color = if is_selected {
505 Color::Accent
506 } else {
507 Color::Muted
508 };
509
510 Some(
511 ListItem::new(ix)
512 .inset(true)
513 .spacing(ListItemSpacing::Sparse)
514 .toggle_state(selected)
515 .start_slot(
516 Icon::new(model_info.icon)
517 .color(model_icon_color)
518 .size(IconSize::Small),
519 )
520 .child(
521 h_flex()
522 .w_full()
523 .pl_0p5()
524 .gap_1p5()
525 .w(px(240.))
526 .child(Label::new(model_info.model.name().0.clone()).truncate()),
527 )
528 .end_slot(div().pr_3().when(is_selected, |this| {
529 this.child(
530 Icon::new(IconName::Check)
531 .color(Color::Accent)
532 .size(IconSize::Small),
533 )
534 }))
535 .into_any_element(),
536 )
537 }
538 }
539 }
540
541 fn render_footer(
542 &self,
543 _: &mut Window,
544 cx: &mut Context<Picker<Self>>,
545 ) -> Option<gpui::AnyElement> {
546 use feature_flags::FeatureFlagAppExt;
547
548 let plan = proto::Plan::ZedPro;
549
550 Some(
551 h_flex()
552 .w_full()
553 .border_t_1()
554 .border_color(cx.theme().colors().border_variant)
555 .p_1()
556 .gap_4()
557 .justify_between()
558 .when(cx.has_flag::<ZedPro>(), |this| {
559 this.child(match plan {
560 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
561 .icon(IconName::ZedAssistant)
562 .icon_size(IconSize::Small)
563 .icon_color(Color::Muted)
564 .icon_position(IconPosition::Start)
565 .on_click(|_, window, cx| {
566 window
567 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
568 }),
569 Plan::Free | Plan::ZedProTrial => Button::new(
570 "try-pro",
571 if plan == Plan::ZedProTrial {
572 "Upgrade to Pro"
573 } else {
574 "Try Pro"
575 },
576 )
577 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
578 })
579 })
580 .child(
581 Button::new("configure", "Configure")
582 .icon(IconName::Settings)
583 .icon_size(IconSize::Small)
584 .icon_color(Color::Muted)
585 .icon_position(IconPosition::Start)
586 .on_click(|_, window, cx| {
587 let configure_action = if cx.has_flag::<Assistant2FeatureFlag>() {
588 zed_actions::agent::OpenConfiguration.boxed_clone()
589 } else {
590 zed_actions::assistant::ShowConfiguration.boxed_clone()
591 };
592
593 window.dispatch_action(configure_action, cx);
594 }),
595 )
596 .into_any(),
597 )
598 }
599}