1use std::sync::Arc;
2
3use collections::{HashSet, IndexMap};
4use feature_flags::{Assistant2FeatureFlag, ZedProFeatureFlag};
5use gpui::{
6 Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
7 Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
8};
9use language_model::{
10 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
11 LanguageModelRegistry,
12};
13use picker::{Picker, PickerDelegate};
14use proto::Plan;
15use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
16
17action_with_deprecated_aliases!(
18 agent,
19 ToggleModelSelector,
20 [
21 "assistant::ToggleModelSelector",
22 "assistant2::ToggleModelSelector"
23 ]
24);
25
26const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
27
28type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
29type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
30
31pub struct LanguageModelSelector {
32 picker: Entity<Picker<LanguageModelPickerDelegate>>,
33 _authenticate_all_providers_task: Task<()>,
34 _subscriptions: Vec<Subscription>,
35}
36
37impl LanguageModelSelector {
38 pub fn new(
39 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
40 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
41 window: &mut Window,
42 cx: &mut Context<Self>,
43 ) -> Self {
44 let on_model_changed = Arc::new(on_model_changed);
45
46 let all_models = Self::all_models(cx);
47 let entries = all_models.entries();
48
49 let delegate = LanguageModelPickerDelegate {
50 language_model_selector: cx.entity().downgrade(),
51 on_model_changed: on_model_changed.clone(),
52 all_models: Arc::new(all_models),
53 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
54 filtered_entries: entries,
55 get_active_model: Arc::new(get_active_model),
56 };
57
58 let picker = cx.new(|cx| {
59 Picker::list(delegate, window, cx)
60 .show_scrollbar(true)
61 .width(rems(20.))
62 .max_height(Some(rems(20.).into()))
63 });
64
65 let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
66
67 LanguageModelSelector {
68 picker,
69 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
70 _subscriptions: vec![
71 cx.subscribe_in(
72 &LanguageModelRegistry::global(cx),
73 window,
74 Self::handle_language_model_registry_event,
75 ),
76 subscription,
77 ],
78 }
79 }
80
81 fn handle_language_model_registry_event(
82 &mut self,
83 _registry: &Entity<LanguageModelRegistry>,
84 event: &language_model::Event,
85 window: &mut Window,
86 cx: &mut Context<Self>,
87 ) {
88 match event {
89 language_model::Event::ProviderStateChanged
90 | language_model::Event::AddedProvider(_)
91 | language_model::Event::RemovedProvider(_) => {
92 self.picker.update(cx, |this, cx| {
93 let query = this.query(cx);
94 this.delegate.all_models = Arc::new(Self::all_models(cx));
95 // Update matches will automatically drop the previous task
96 // if we get a provider event again
97 this.update_matches(query, window, cx)
98 });
99 }
100 _ => {}
101 }
102 }
103
104 /// Authenticates all providers in the [`LanguageModelRegistry`].
105 ///
106 /// We do this so that we can populate the language selector with all of the
107 /// models from the configured providers.
108 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
109 let authenticate_all_providers = LanguageModelRegistry::global(cx)
110 .read(cx)
111 .providers()
112 .iter()
113 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
114 .collect::<Vec<_>>();
115
116 cx.spawn(async move |_cx| {
117 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
118 if let Err(err) = authenticate_task.await {
119 if matches!(err, AuthenticateError::CredentialsNotFound) {
120 // Since we're authenticating these providers in the
121 // background for the purposes of populating the
122 // language selector, we don't care about providers
123 // where the credentials are not found.
124 } else {
125 // Some providers have noisy failure states that we
126 // don't want to spam the logs with every time the
127 // language model selector is initialized.
128 //
129 // Ideally these should have more clear failure modes
130 // that we know are safe to ignore here, like what we do
131 // with `CredentialsNotFound` above.
132 match provider_id.0.as_ref() {
133 "lmstudio" | "ollama" => {
134 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
135 //
136 // These fail noisily, so we don't log them.
137 }
138 "copilot_chat" => {
139 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
140 }
141 _ => {
142 log::error!(
143 "Failed to authenticate provider: {}: {err}",
144 provider_name.0
145 );
146 }
147 }
148 }
149 }
150 }
151 })
152 }
153
154 fn all_models(cx: &App) -> GroupedModels {
155 let mut recommended = Vec::new();
156 let mut recommended_set = HashSet::default();
157 for provider in LanguageModelRegistry::global(cx)
158 .read(cx)
159 .providers()
160 .iter()
161 {
162 let models = provider.recommended_models(cx);
163 recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
164 recommended.extend(
165 provider
166 .recommended_models(cx)
167 .into_iter()
168 .map(move |model| ModelInfo {
169 model: model.clone(),
170 icon: provider.icon(),
171 }),
172 );
173 }
174
175 let other_models = LanguageModelRegistry::global(cx)
176 .read(cx)
177 .providers()
178 .iter()
179 .map(|provider| {
180 (
181 provider.id(),
182 provider
183 .provided_models(cx)
184 .into_iter()
185 .filter_map(|model| {
186 let not_included =
187 !recommended_set.contains(&(model.provider_id(), model.id()));
188 not_included.then(|| ModelInfo {
189 model: model.clone(),
190 icon: provider.icon(),
191 })
192 })
193 .collect::<Vec<_>>(),
194 )
195 })
196 .collect::<IndexMap<_, _>>();
197
198 GroupedModels {
199 recommended,
200 other: other_models,
201 }
202 }
203
204 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
205 (self.picker.read(cx).delegate.get_active_model)(cx)
206 }
207
208 fn get_active_model_index(
209 entries: &[LanguageModelPickerEntry],
210 active_model: Option<ConfiguredModel>,
211 ) -> usize {
212 entries
213 .iter()
214 .position(|entry| {
215 if let LanguageModelPickerEntry::Model(model) = entry {
216 active_model
217 .as_ref()
218 .map(|active_model| {
219 active_model.model.id() == model.model.id()
220 && active_model.provider.id() == model.model.provider_id()
221 })
222 .unwrap_or_default()
223 } else {
224 false
225 }
226 })
227 .unwrap_or(0)
228 }
229}
230
231impl EventEmitter<DismissEvent> for LanguageModelSelector {}
232
233impl Focusable for LanguageModelSelector {
234 fn focus_handle(&self, cx: &App) -> FocusHandle {
235 self.picker.focus_handle(cx)
236 }
237}
238
239impl Render for LanguageModelSelector {
240 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
241 self.picker.clone()
242 }
243}
244
245#[derive(IntoElement)]
246pub struct LanguageModelSelectorPopoverMenu<T, TT>
247where
248 T: PopoverTrigger + ButtonCommon,
249 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
250{
251 language_model_selector: Entity<LanguageModelSelector>,
252 trigger: T,
253 tooltip: TT,
254 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
255 anchor: Corner,
256}
257
258impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
259where
260 T: PopoverTrigger + ButtonCommon,
261 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
262{
263 pub fn new(
264 language_model_selector: Entity<LanguageModelSelector>,
265 trigger: T,
266 tooltip: TT,
267 anchor: Corner,
268 ) -> Self {
269 Self {
270 language_model_selector,
271 trigger,
272 tooltip,
273 handle: None,
274 anchor,
275 }
276 }
277
278 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
279 self.handle = Some(handle);
280 self
281 }
282}
283
284impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
285where
286 T: PopoverTrigger + ButtonCommon,
287 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
288{
289 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
290 let language_model_selector = self.language_model_selector.clone();
291
292 PopoverMenu::new("model-switcher")
293 .menu(move |_window, _cx| Some(language_model_selector.clone()))
294 .trigger_with_tooltip(self.trigger, self.tooltip)
295 .anchor(self.anchor)
296 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
297 .offset(gpui::Point {
298 x: px(0.0),
299 y: px(-2.0),
300 })
301 }
302}
303
304#[derive(Clone)]
305struct ModelInfo {
306 model: Arc<dyn LanguageModel>,
307 icon: IconName,
308}
309
310pub struct LanguageModelPickerDelegate {
311 language_model_selector: WeakEntity<LanguageModelSelector>,
312 on_model_changed: OnModelChanged,
313 get_active_model: GetActiveModel,
314 all_models: Arc<GroupedModels>,
315 filtered_entries: Vec<LanguageModelPickerEntry>,
316 selected_index: usize,
317}
318
319struct GroupedModels {
320 recommended: Vec<ModelInfo>,
321 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
322}
323
324impl GroupedModels {
325 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
326 let mut entries = Vec::new();
327
328 if !self.recommended.is_empty() {
329 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
330 entries.extend(
331 self.recommended
332 .iter()
333 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
334 );
335 }
336
337 for models in self.other.values() {
338 if models.is_empty() {
339 continue;
340 }
341 entries.push(LanguageModelPickerEntry::Separator(
342 models[0].model.provider_name().0,
343 ));
344 entries.extend(
345 models
346 .iter()
347 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
348 );
349 }
350 entries
351 }
352}
353
354enum LanguageModelPickerEntry {
355 Model(ModelInfo),
356 Separator(SharedString),
357}
358
359impl PickerDelegate for LanguageModelPickerDelegate {
360 type ListItem = AnyElement;
361
362 fn match_count(&self) -> usize {
363 self.filtered_entries.len()
364 }
365
366 fn selected_index(&self) -> usize {
367 self.selected_index
368 }
369
370 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
371 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
372 cx.notify();
373 }
374
375 fn can_select(
376 &mut self,
377 ix: usize,
378 _window: &mut Window,
379 _cx: &mut Context<Picker<Self>>,
380 ) -> bool {
381 match self.filtered_entries.get(ix) {
382 Some(LanguageModelPickerEntry::Model(_)) => true,
383 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
384 }
385 }
386
387 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
388 "Select a model…".into()
389 }
390
391 fn update_matches(
392 &mut self,
393 query: String,
394 window: &mut Window,
395 cx: &mut Context<Picker<Self>>,
396 ) -> Task<()> {
397 let all_models = self.all_models.clone();
398 let current_index = self.selected_index;
399
400 let language_model_registry = LanguageModelRegistry::global(cx);
401
402 let configured_providers = language_model_registry
403 .read(cx)
404 .providers()
405 .iter()
406 .filter(|provider| provider.is_authenticated(cx))
407 .map(|provider| provider.id())
408 .collect::<Vec<_>>();
409
410 cx.spawn_in(window, async move |this, cx| {
411 let filtered_models = cx
412 .background_spawn(async move {
413 let matches = |info: &ModelInfo| {
414 info.model
415 .name()
416 .0
417 .to_lowercase()
418 .contains(&query.to_lowercase())
419 };
420
421 let recommended_models = all_models
422 .recommended
423 .iter()
424 .filter(|r| {
425 configured_providers.contains(&r.model.provider_id()) && matches(r)
426 })
427 .cloned()
428 .collect();
429 let mut other_models = IndexMap::default();
430 for (provider_id, models) in &all_models.other {
431 if configured_providers.contains(&provider_id) {
432 other_models.insert(
433 provider_id.clone(),
434 models
435 .iter()
436 .filter(|m| matches(m))
437 .cloned()
438 .collect::<Vec<_>>(),
439 );
440 }
441 }
442 GroupedModels {
443 recommended: recommended_models,
444 other: other_models,
445 }
446 })
447 .await;
448
449 this.update_in(cx, |this, window, cx| {
450 this.delegate.filtered_entries = filtered_models.entries();
451 // Preserve selection focus
452 let new_index = if current_index >= this.delegate.filtered_entries.len() {
453 0
454 } else {
455 current_index
456 };
457 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
458 cx.notify();
459 })
460 .ok();
461 })
462 }
463
464 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
465 if let Some(LanguageModelPickerEntry::Model(model_info)) =
466 self.filtered_entries.get(self.selected_index)
467 {
468 let model = model_info.model.clone();
469 (self.on_model_changed)(model.clone(), cx);
470
471 let current_index = self.selected_index;
472 self.set_selected_index(current_index, window, cx);
473
474 cx.emit(DismissEvent);
475 }
476 }
477
478 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
479 self.language_model_selector
480 .update(cx, |_this, cx| cx.emit(DismissEvent))
481 .ok();
482 }
483
484 fn render_match(
485 &self,
486 ix: usize,
487 selected: bool,
488 _: &mut Window,
489 cx: &mut Context<Picker<Self>>,
490 ) -> Option<Self::ListItem> {
491 match self.filtered_entries.get(ix)? {
492 LanguageModelPickerEntry::Separator(title) => Some(
493 div()
494 .px_2()
495 .pb_1()
496 .when(ix > 1, |this| {
497 this.mt_1()
498 .pt_2()
499 .border_t_1()
500 .border_color(cx.theme().colors().border_variant)
501 })
502 .child(
503 Label::new(title)
504 .size(LabelSize::XSmall)
505 .color(Color::Muted),
506 )
507 .into_any_element(),
508 ),
509 LanguageModelPickerEntry::Model(model_info) => {
510 let active_model = (self.get_active_model)(cx);
511 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
512 let active_model_id = active_model.map(|m| m.model.id());
513
514 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
515 && Some(model_info.model.id()) == active_model_id;
516
517 let model_icon_color = if is_selected {
518 Color::Accent
519 } else {
520 Color::Muted
521 };
522
523 Some(
524 ListItem::new(ix)
525 .inset(true)
526 .spacing(ListItemSpacing::Sparse)
527 .toggle_state(selected)
528 .start_slot(
529 Icon::new(model_info.icon)
530 .color(model_icon_color)
531 .size(IconSize::Small),
532 )
533 .child(
534 h_flex()
535 .w_full()
536 .pl_0p5()
537 .gap_1p5()
538 .w(px(240.))
539 .child(Label::new(model_info.model.name().0.clone()).truncate()),
540 )
541 .end_slot(div().pr_3().when(is_selected, |this| {
542 this.child(
543 Icon::new(IconName::Check)
544 .color(Color::Accent)
545 .size(IconSize::Small),
546 )
547 }))
548 .into_any_element(),
549 )
550 }
551 }
552 }
553
554 fn render_footer(
555 &self,
556 _: &mut Window,
557 cx: &mut Context<Picker<Self>>,
558 ) -> Option<gpui::AnyElement> {
559 use feature_flags::FeatureFlagAppExt;
560
561 let plan = proto::Plan::ZedPro;
562
563 Some(
564 h_flex()
565 .w_full()
566 .border_t_1()
567 .border_color(cx.theme().colors().border_variant)
568 .p_1()
569 .gap_4()
570 .justify_between()
571 .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
572 this.child(match plan {
573 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
574 .icon(IconName::ZedAssistant)
575 .icon_size(IconSize::Small)
576 .icon_color(Color::Muted)
577 .icon_position(IconPosition::Start)
578 .on_click(|_, window, cx| {
579 window
580 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
581 }),
582 Plan::Free | Plan::ZedProTrial => Button::new(
583 "try-pro",
584 if plan == Plan::ZedProTrial {
585 "Upgrade to Pro"
586 } else {
587 "Try Pro"
588 },
589 )
590 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
591 })
592 })
593 .child(
594 Button::new("configure", "Configure")
595 .icon(IconName::Settings)
596 .icon_size(IconSize::Small)
597 .icon_color(Color::Muted)
598 .icon_position(IconPosition::Start)
599 .on_click(|_, window, cx| {
600 let configure_action = if cx.has_flag::<Assistant2FeatureFlag>() {
601 zed_actions::agent::OpenConfiguration.boxed_clone()
602 } else {
603 zed_actions::assistant::ShowConfiguration.boxed_clone()
604 };
605
606 window.dispatch_action(configure_action, cx);
607 }),
608 )
609 .into_any(),
610 )
611 }
612}