1use std::sync::Arc;
2
3use collections::{HashSet, IndexMap};
4use feature_flags::{Assistant2FeatureFlag, ZedProFeatureFlag};
5use gpui::{
6 Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
7 Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
8};
9use language_model::{
10 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
11 LanguageModelRegistry,
12};
13use picker::{Picker, PickerDelegate};
14use proto::Plan;
15use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
16
17action_with_deprecated_aliases!(
18 assistant,
19 ToggleModelSelector,
20 ["assistant2::ToggleModelSelector"]
21);
22
23const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
24
25type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
26type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
27
28pub struct LanguageModelSelector {
29 picker: Entity<Picker<LanguageModelPickerDelegate>>,
30 _authenticate_all_providers_task: Task<()>,
31 _subscriptions: Vec<Subscription>,
32}
33
34impl LanguageModelSelector {
35 pub fn new(
36 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
37 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
38 window: &mut Window,
39 cx: &mut Context<Self>,
40 ) -> Self {
41 let on_model_changed = Arc::new(on_model_changed);
42
43 let all_models = Self::all_models(cx);
44 let entries = all_models.entries();
45
46 let delegate = LanguageModelPickerDelegate {
47 language_model_selector: cx.entity().downgrade(),
48 on_model_changed: on_model_changed.clone(),
49 all_models: Arc::new(all_models),
50 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
51 filtered_entries: entries,
52 get_active_model: Arc::new(get_active_model),
53 };
54
55 let picker = cx.new(|cx| {
56 Picker::list(delegate, window, cx)
57 .show_scrollbar(true)
58 .width(rems(20.))
59 .max_height(Some(rems(20.).into()))
60 });
61
62 let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
63
64 LanguageModelSelector {
65 picker,
66 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
67 _subscriptions: vec![
68 cx.subscribe_in(
69 &LanguageModelRegistry::global(cx),
70 window,
71 Self::handle_language_model_registry_event,
72 ),
73 subscription,
74 ],
75 }
76 }
77
78 fn handle_language_model_registry_event(
79 &mut self,
80 _registry: &Entity<LanguageModelRegistry>,
81 event: &language_model::Event,
82 window: &mut Window,
83 cx: &mut Context<Self>,
84 ) {
85 match event {
86 language_model::Event::ProviderStateChanged
87 | language_model::Event::AddedProvider(_)
88 | language_model::Event::RemovedProvider(_) => {
89 self.picker.update(cx, |this, cx| {
90 let query = this.query(cx);
91 this.delegate.all_models = Arc::new(Self::all_models(cx));
92 // Update matches will automatically drop the previous task
93 // if we get a provider event again
94 this.update_matches(query, window, cx)
95 });
96 }
97 _ => {}
98 }
99 }
100
101 /// Authenticates all providers in the [`LanguageModelRegistry`].
102 ///
103 /// We do this so that we can populate the language selector with all of the
104 /// models from the configured providers.
105 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
106 let authenticate_all_providers = LanguageModelRegistry::global(cx)
107 .read(cx)
108 .providers()
109 .iter()
110 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
111 .collect::<Vec<_>>();
112
113 cx.spawn(async move |_cx| {
114 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
115 if let Err(err) = authenticate_task.await {
116 if matches!(err, AuthenticateError::CredentialsNotFound) {
117 // Since we're authenticating these providers in the
118 // background for the purposes of populating the
119 // language selector, we don't care about providers
120 // where the credentials are not found.
121 } else {
122 // Some providers have noisy failure states that we
123 // don't want to spam the logs with every time the
124 // language model selector is initialized.
125 //
126 // Ideally these should have more clear failure modes
127 // that we know are safe to ignore here, like what we do
128 // with `CredentialsNotFound` above.
129 match provider_id.0.as_ref() {
130 "lmstudio" | "ollama" => {
131 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
132 //
133 // These fail noisily, so we don't log them.
134 }
135 "copilot_chat" => {
136 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
137 }
138 _ => {
139 log::error!(
140 "Failed to authenticate provider: {}: {err}",
141 provider_name.0
142 );
143 }
144 }
145 }
146 }
147 }
148 })
149 }
150
151 fn all_models(cx: &App) -> GroupedModels {
152 let mut recommended = Vec::new();
153 let mut recommended_set = HashSet::default();
154 for provider in LanguageModelRegistry::global(cx)
155 .read(cx)
156 .providers()
157 .iter()
158 {
159 let models = provider.recommended_models(cx);
160 recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
161 recommended.extend(
162 provider
163 .recommended_models(cx)
164 .into_iter()
165 .map(move |model| ModelInfo {
166 model: model.clone(),
167 icon: provider.icon(),
168 }),
169 );
170 }
171
172 let other_models = LanguageModelRegistry::global(cx)
173 .read(cx)
174 .providers()
175 .iter()
176 .map(|provider| {
177 (
178 provider.id(),
179 provider
180 .provided_models(cx)
181 .into_iter()
182 .filter_map(|model| {
183 let not_included =
184 !recommended_set.contains(&(model.provider_id(), model.id()));
185 not_included.then(|| ModelInfo {
186 model: model.clone(),
187 icon: provider.icon(),
188 })
189 })
190 .collect::<Vec<_>>(),
191 )
192 })
193 .collect::<IndexMap<_, _>>();
194
195 GroupedModels {
196 recommended,
197 other: other_models,
198 }
199 }
200
201 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
202 (self.picker.read(cx).delegate.get_active_model)(cx)
203 }
204
205 fn get_active_model_index(
206 entries: &[LanguageModelPickerEntry],
207 active_model: Option<ConfiguredModel>,
208 ) -> usize {
209 entries
210 .iter()
211 .position(|entry| {
212 if let LanguageModelPickerEntry::Model(model) = entry {
213 active_model
214 .as_ref()
215 .map(|active_model| {
216 active_model.model.id() == model.model.id()
217 && active_model.provider.id() == model.model.provider_id()
218 })
219 .unwrap_or_default()
220 } else {
221 false
222 }
223 })
224 .unwrap_or(0)
225 }
226}
227
228impl EventEmitter<DismissEvent> for LanguageModelSelector {}
229
230impl Focusable for LanguageModelSelector {
231 fn focus_handle(&self, cx: &App) -> FocusHandle {
232 self.picker.focus_handle(cx)
233 }
234}
235
236impl Render for LanguageModelSelector {
237 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
238 self.picker.clone()
239 }
240}
241
242#[derive(IntoElement)]
243pub struct LanguageModelSelectorPopoverMenu<T, TT>
244where
245 T: PopoverTrigger + ButtonCommon,
246 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
247{
248 language_model_selector: Entity<LanguageModelSelector>,
249 trigger: T,
250 tooltip: TT,
251 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
252 anchor: Corner,
253}
254
255impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
256where
257 T: PopoverTrigger + ButtonCommon,
258 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
259{
260 pub fn new(
261 language_model_selector: Entity<LanguageModelSelector>,
262 trigger: T,
263 tooltip: TT,
264 anchor: Corner,
265 ) -> Self {
266 Self {
267 language_model_selector,
268 trigger,
269 tooltip,
270 handle: None,
271 anchor,
272 }
273 }
274
275 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
276 self.handle = Some(handle);
277 self
278 }
279}
280
281impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
282where
283 T: PopoverTrigger + ButtonCommon,
284 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
285{
286 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
287 let language_model_selector = self.language_model_selector.clone();
288
289 PopoverMenu::new("model-switcher")
290 .menu(move |_window, _cx| Some(language_model_selector.clone()))
291 .trigger_with_tooltip(self.trigger, self.tooltip)
292 .anchor(self.anchor)
293 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
294 .offset(gpui::Point {
295 x: px(0.0),
296 y: px(-2.0),
297 })
298 }
299}
300
301#[derive(Clone)]
302struct ModelInfo {
303 model: Arc<dyn LanguageModel>,
304 icon: IconName,
305}
306
307pub struct LanguageModelPickerDelegate {
308 language_model_selector: WeakEntity<LanguageModelSelector>,
309 on_model_changed: OnModelChanged,
310 get_active_model: GetActiveModel,
311 all_models: Arc<GroupedModels>,
312 filtered_entries: Vec<LanguageModelPickerEntry>,
313 selected_index: usize,
314}
315
316struct GroupedModels {
317 recommended: Vec<ModelInfo>,
318 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
319}
320
321impl GroupedModels {
322 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
323 let mut entries = Vec::new();
324
325 if !self.recommended.is_empty() {
326 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
327 entries.extend(
328 self.recommended
329 .iter()
330 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
331 );
332 }
333
334 for models in self.other.values() {
335 if models.is_empty() {
336 continue;
337 }
338 entries.push(LanguageModelPickerEntry::Separator(
339 models[0].model.provider_name().0,
340 ));
341 entries.extend(
342 models
343 .iter()
344 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
345 );
346 }
347 entries
348 }
349}
350
351enum LanguageModelPickerEntry {
352 Model(ModelInfo),
353 Separator(SharedString),
354}
355
356impl PickerDelegate for LanguageModelPickerDelegate {
357 type ListItem = AnyElement;
358
359 fn match_count(&self) -> usize {
360 self.filtered_entries.len()
361 }
362
363 fn selected_index(&self) -> usize {
364 self.selected_index
365 }
366
367 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
368 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
369 cx.notify();
370 }
371
372 fn can_select(
373 &mut self,
374 ix: usize,
375 _window: &mut Window,
376 _cx: &mut Context<Picker<Self>>,
377 ) -> bool {
378 match self.filtered_entries.get(ix) {
379 Some(LanguageModelPickerEntry::Model(_)) => true,
380 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
381 }
382 }
383
384 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
385 "Select a model…".into()
386 }
387
388 fn update_matches(
389 &mut self,
390 query: String,
391 window: &mut Window,
392 cx: &mut Context<Picker<Self>>,
393 ) -> Task<()> {
394 let all_models = self.all_models.clone();
395 let current_index = self.selected_index;
396
397 let language_model_registry = LanguageModelRegistry::global(cx);
398
399 let configured_providers = language_model_registry
400 .read(cx)
401 .providers()
402 .iter()
403 .filter(|provider| provider.is_authenticated(cx))
404 .map(|provider| provider.id())
405 .collect::<Vec<_>>();
406
407 cx.spawn_in(window, async move |this, cx| {
408 let filtered_models = cx
409 .background_spawn(async move {
410 let matches = |info: &ModelInfo| {
411 info.model
412 .name()
413 .0
414 .to_lowercase()
415 .contains(&query.to_lowercase())
416 };
417
418 let recommended_models = all_models
419 .recommended
420 .iter()
421 .filter(|r| {
422 configured_providers.contains(&r.model.provider_id()) && matches(r)
423 })
424 .cloned()
425 .collect();
426 let mut other_models = IndexMap::default();
427 for (provider_id, models) in &all_models.other {
428 if configured_providers.contains(&provider_id) {
429 other_models.insert(
430 provider_id.clone(),
431 models
432 .iter()
433 .filter(|m| matches(m))
434 .cloned()
435 .collect::<Vec<_>>(),
436 );
437 }
438 }
439 GroupedModels {
440 recommended: recommended_models,
441 other: other_models,
442 }
443 })
444 .await;
445
446 this.update_in(cx, |this, window, cx| {
447 this.delegate.filtered_entries = filtered_models.entries();
448 // Preserve selection focus
449 let new_index = if current_index >= this.delegate.filtered_entries.len() {
450 0
451 } else {
452 current_index
453 };
454 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
455 cx.notify();
456 })
457 .ok();
458 })
459 }
460
461 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
462 if let Some(LanguageModelPickerEntry::Model(model_info)) =
463 self.filtered_entries.get(self.selected_index)
464 {
465 let model = model_info.model.clone();
466 (self.on_model_changed)(model.clone(), cx);
467
468 let current_index = self.selected_index;
469 self.set_selected_index(current_index, window, cx);
470
471 cx.emit(DismissEvent);
472 }
473 }
474
475 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
476 self.language_model_selector
477 .update(cx, |_this, cx| cx.emit(DismissEvent))
478 .ok();
479 }
480
481 fn render_match(
482 &self,
483 ix: usize,
484 selected: bool,
485 _: &mut Window,
486 cx: &mut Context<Picker<Self>>,
487 ) -> Option<Self::ListItem> {
488 match self.filtered_entries.get(ix)? {
489 LanguageModelPickerEntry::Separator(title) => Some(
490 div()
491 .px_2()
492 .pb_1()
493 .when(ix > 1, |this| {
494 this.mt_1()
495 .pt_2()
496 .border_t_1()
497 .border_color(cx.theme().colors().border_variant)
498 })
499 .child(
500 Label::new(title)
501 .size(LabelSize::XSmall)
502 .color(Color::Muted),
503 )
504 .into_any_element(),
505 ),
506 LanguageModelPickerEntry::Model(model_info) => {
507 let active_model = (self.get_active_model)(cx);
508 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
509 let active_model_id = active_model.map(|m| m.model.id());
510
511 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
512 && Some(model_info.model.id()) == active_model_id;
513
514 let model_icon_color = if is_selected {
515 Color::Accent
516 } else {
517 Color::Muted
518 };
519
520 Some(
521 ListItem::new(ix)
522 .inset(true)
523 .spacing(ListItemSpacing::Sparse)
524 .toggle_state(selected)
525 .start_slot(
526 Icon::new(model_info.icon)
527 .color(model_icon_color)
528 .size(IconSize::Small),
529 )
530 .child(
531 h_flex()
532 .w_full()
533 .pl_0p5()
534 .gap_1p5()
535 .w(px(240.))
536 .child(Label::new(model_info.model.name().0.clone()).truncate()),
537 )
538 .end_slot(div().pr_3().when(is_selected, |this| {
539 this.child(
540 Icon::new(IconName::Check)
541 .color(Color::Accent)
542 .size(IconSize::Small),
543 )
544 }))
545 .into_any_element(),
546 )
547 }
548 }
549 }
550
551 fn render_footer(
552 &self,
553 _: &mut Window,
554 cx: &mut Context<Picker<Self>>,
555 ) -> Option<gpui::AnyElement> {
556 use feature_flags::FeatureFlagAppExt;
557
558 let plan = proto::Plan::ZedPro;
559
560 Some(
561 h_flex()
562 .w_full()
563 .border_t_1()
564 .border_color(cx.theme().colors().border_variant)
565 .p_1()
566 .gap_4()
567 .justify_between()
568 .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
569 this.child(match plan {
570 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
571 .icon(IconName::ZedAssistant)
572 .icon_size(IconSize::Small)
573 .icon_color(Color::Muted)
574 .icon_position(IconPosition::Start)
575 .on_click(|_, window, cx| {
576 window
577 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
578 }),
579 Plan::Free | Plan::ZedProTrial => Button::new(
580 "try-pro",
581 if plan == Plan::ZedProTrial {
582 "Upgrade to Pro"
583 } else {
584 "Try Pro"
585 },
586 )
587 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
588 })
589 })
590 .child(
591 Button::new("configure", "Configure")
592 .icon(IconName::Settings)
593 .icon_size(IconSize::Small)
594 .icon_color(Color::Muted)
595 .icon_position(IconPosition::Start)
596 .on_click(|_, window, cx| {
597 let configure_action = if cx.has_flag::<Assistant2FeatureFlag>() {
598 zed_actions::agent::OpenConfiguration.boxed_clone()
599 } else {
600 zed_actions::assistant::ShowConfiguration.boxed_clone()
601 };
602
603 window.dispatch_action(configure_action, cx);
604 }),
605 )
606 .into_any(),
607 )
608 }
609}