1use std::sync::Arc;
2
3use collections::{HashSet, IndexMap};
4use feature_flags::{Assistant2FeatureFlag, ZedProFeatureFlag};
5use gpui::{
6 Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
7 Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
8};
9use language_model::{
10 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
11 LanguageModelRegistry,
12};
13use picker::{Picker, PickerDelegate};
14use proto::Plan;
15use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
16
17action_with_deprecated_aliases!(
18 assistant,
19 ToggleModelSelector,
20 ["assistant2::ToggleModelSelector"]
21);
22
23const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
24
25type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
26
27pub struct LanguageModelSelector {
28 picker: Entity<Picker<LanguageModelPickerDelegate>>,
29 _authenticate_all_providers_task: Task<()>,
30 _subscriptions: Vec<Subscription>,
31}
32
33#[derive(Clone, Copy)]
34pub enum ModelType {
35 Default,
36 InlineAssistant,
37}
38
39impl LanguageModelSelector {
40 pub fn new(
41 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
42 model_type: ModelType,
43 window: &mut Window,
44 cx: &mut Context<Self>,
45 ) -> Self {
46 let on_model_changed = Arc::new(on_model_changed);
47
48 let all_models = Self::all_models(cx);
49 let entries = all_models.entries();
50
51 let delegate = LanguageModelPickerDelegate {
52 language_model_selector: cx.entity().downgrade(),
53 on_model_changed: on_model_changed.clone(),
54 all_models: Arc::new(all_models),
55 selected_index: Self::get_active_model_index(&entries, model_type, cx),
56 filtered_entries: entries,
57 model_type,
58 };
59
60 let picker = cx.new(|cx| {
61 Picker::list(delegate, window, cx)
62 .show_scrollbar(true)
63 .width(rems(20.))
64 .max_height(Some(rems(20.).into()))
65 });
66
67 let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
68
69 LanguageModelSelector {
70 picker,
71 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
72 _subscriptions: vec![
73 cx.subscribe_in(
74 &LanguageModelRegistry::global(cx),
75 window,
76 Self::handle_language_model_registry_event,
77 ),
78 subscription,
79 ],
80 }
81 }
82
83 fn handle_language_model_registry_event(
84 &mut self,
85 _registry: &Entity<LanguageModelRegistry>,
86 event: &language_model::Event,
87 window: &mut Window,
88 cx: &mut Context<Self>,
89 ) {
90 match event {
91 language_model::Event::ProviderStateChanged
92 | language_model::Event::AddedProvider(_)
93 | language_model::Event::RemovedProvider(_) => {
94 self.picker.update(cx, |this, cx| {
95 let query = this.query(cx);
96 this.delegate.all_models = Arc::new(Self::all_models(cx));
97 // Update matches will automatically drop the previous task
98 // if we get a provider event again
99 this.update_matches(query, window, cx)
100 });
101 }
102 _ => {}
103 }
104 }
105
106 /// Authenticates all providers in the [`LanguageModelRegistry`].
107 ///
108 /// We do this so that we can populate the language selector with all of the
109 /// models from the configured providers.
110 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
111 let authenticate_all_providers = LanguageModelRegistry::global(cx)
112 .read(cx)
113 .providers()
114 .iter()
115 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
116 .collect::<Vec<_>>();
117
118 cx.spawn(async move |_cx| {
119 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
120 if let Err(err) = authenticate_task.await {
121 if matches!(err, AuthenticateError::CredentialsNotFound) {
122 // Since we're authenticating these providers in the
123 // background for the purposes of populating the
124 // language selector, we don't care about providers
125 // where the credentials are not found.
126 } else {
127 // Some providers have noisy failure states that we
128 // don't want to spam the logs with every time the
129 // language model selector is initialized.
130 //
131 // Ideally these should have more clear failure modes
132 // that we know are safe to ignore here, like what we do
133 // with `CredentialsNotFound` above.
134 match provider_id.0.as_ref() {
135 "lmstudio" | "ollama" => {
136 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
137 //
138 // These fail noisily, so we don't log them.
139 }
140 "copilot_chat" => {
141 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
142 }
143 _ => {
144 log::error!(
145 "Failed to authenticate provider: {}: {err}",
146 provider_name.0
147 );
148 }
149 }
150 }
151 }
152 }
153 })
154 }
155
156 fn all_models(cx: &App) -> GroupedModels {
157 let mut recommended = Vec::new();
158 let mut recommended_set = HashSet::default();
159 for provider in LanguageModelRegistry::global(cx)
160 .read(cx)
161 .providers()
162 .iter()
163 {
164 let models = provider.recommended_models(cx);
165 recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
166 recommended.extend(
167 provider
168 .recommended_models(cx)
169 .into_iter()
170 .map(move |model| ModelInfo {
171 model: model.clone(),
172 icon: provider.icon(),
173 }),
174 );
175 }
176
177 let other_models = LanguageModelRegistry::global(cx)
178 .read(cx)
179 .providers()
180 .iter()
181 .map(|provider| {
182 (
183 provider.id(),
184 provider
185 .provided_models(cx)
186 .into_iter()
187 .filter_map(|model| {
188 let not_included =
189 !recommended_set.contains(&(model.provider_id(), model.id()));
190 not_included.then(|| ModelInfo {
191 model: model.clone(),
192 icon: provider.icon(),
193 })
194 })
195 .collect::<Vec<_>>(),
196 )
197 })
198 .collect::<IndexMap<_, _>>();
199
200 GroupedModels {
201 recommended,
202 other: other_models,
203 }
204 }
205
206 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
207 let model_type = self.picker.read(cx).delegate.model_type;
208 Self::active_model_by_type(model_type, cx)
209 }
210
211 fn active_model_by_type(model_type: ModelType, cx: &App) -> Option<ConfiguredModel> {
212 match model_type {
213 ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(),
214 ModelType::InlineAssistant => {
215 LanguageModelRegistry::read_global(cx).inline_assistant_model()
216 }
217 }
218 }
219
220 fn get_active_model_index(
221 entries: &[LanguageModelPickerEntry],
222 model_type: ModelType,
223 cx: &App,
224 ) -> usize {
225 let active_model = Self::active_model_by_type(model_type, cx);
226
227 entries
228 .iter()
229 .position(|entry| {
230 if let LanguageModelPickerEntry::Model(model) = entry {
231 active_model
232 .as_ref()
233 .map(|active_model| {
234 active_model.model.id() == model.model.id()
235 && active_model.model.provider_id() == model.model.provider_id()
236 })
237 .unwrap_or_default()
238 } else {
239 false
240 }
241 })
242 .unwrap_or(0)
243 }
244}
245
246impl EventEmitter<DismissEvent> for LanguageModelSelector {}
247
248impl Focusable for LanguageModelSelector {
249 fn focus_handle(&self, cx: &App) -> FocusHandle {
250 self.picker.focus_handle(cx)
251 }
252}
253
254impl Render for LanguageModelSelector {
255 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
256 self.picker.clone()
257 }
258}
259
260#[derive(IntoElement)]
261pub struct LanguageModelSelectorPopoverMenu<T, TT>
262where
263 T: PopoverTrigger + ButtonCommon,
264 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
265{
266 language_model_selector: Entity<LanguageModelSelector>,
267 trigger: T,
268 tooltip: TT,
269 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
270 anchor: Corner,
271}
272
273impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
274where
275 T: PopoverTrigger + ButtonCommon,
276 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
277{
278 pub fn new(
279 language_model_selector: Entity<LanguageModelSelector>,
280 trigger: T,
281 tooltip: TT,
282 anchor: Corner,
283 ) -> Self {
284 Self {
285 language_model_selector,
286 trigger,
287 tooltip,
288 handle: None,
289 anchor,
290 }
291 }
292
293 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
294 self.handle = Some(handle);
295 self
296 }
297}
298
299impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
300where
301 T: PopoverTrigger + ButtonCommon,
302 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
303{
304 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
305 let language_model_selector = self.language_model_selector.clone();
306
307 PopoverMenu::new("model-switcher")
308 .menu(move |_window, _cx| Some(language_model_selector.clone()))
309 .trigger_with_tooltip(self.trigger, self.tooltip)
310 .anchor(self.anchor)
311 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
312 .offset(gpui::Point {
313 x: px(0.0),
314 y: px(-2.0),
315 })
316 }
317}
318
319#[derive(Clone)]
320struct ModelInfo {
321 model: Arc<dyn LanguageModel>,
322 icon: IconName,
323}
324
325pub struct LanguageModelPickerDelegate {
326 language_model_selector: WeakEntity<LanguageModelSelector>,
327 on_model_changed: OnModelChanged,
328 all_models: Arc<GroupedModels>,
329 filtered_entries: Vec<LanguageModelPickerEntry>,
330 selected_index: usize,
331 model_type: ModelType,
332}
333
334struct GroupedModels {
335 recommended: Vec<ModelInfo>,
336 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
337}
338
339impl GroupedModels {
340 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
341 let mut entries = Vec::new();
342
343 if !self.recommended.is_empty() {
344 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
345 entries.extend(
346 self.recommended
347 .iter()
348 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
349 );
350 }
351
352 for models in self.other.values() {
353 if models.is_empty() {
354 continue;
355 }
356 entries.push(LanguageModelPickerEntry::Separator(
357 models[0].model.provider_name().0,
358 ));
359 entries.extend(
360 models
361 .iter()
362 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
363 );
364 }
365 entries
366 }
367}
368
369enum LanguageModelPickerEntry {
370 Model(ModelInfo),
371 Separator(SharedString),
372}
373
374impl PickerDelegate for LanguageModelPickerDelegate {
375 type ListItem = AnyElement;
376
377 fn match_count(&self) -> usize {
378 self.filtered_entries.len()
379 }
380
381 fn selected_index(&self) -> usize {
382 self.selected_index
383 }
384
385 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
386 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
387 cx.notify();
388 }
389
390 fn can_select(
391 &mut self,
392 ix: usize,
393 _window: &mut Window,
394 _cx: &mut Context<Picker<Self>>,
395 ) -> bool {
396 match self.filtered_entries.get(ix) {
397 Some(LanguageModelPickerEntry::Model(_)) => true,
398 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
399 }
400 }
401
402 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
403 "Select a model…".into()
404 }
405
406 fn update_matches(
407 &mut self,
408 query: String,
409 window: &mut Window,
410 cx: &mut Context<Picker<Self>>,
411 ) -> Task<()> {
412 let all_models = self.all_models.clone();
413 let current_index = self.selected_index;
414
415 let language_model_registry = LanguageModelRegistry::global(cx);
416
417 let configured_providers = language_model_registry
418 .read(cx)
419 .providers()
420 .iter()
421 .filter(|provider| provider.is_authenticated(cx))
422 .map(|provider| provider.id())
423 .collect::<Vec<_>>();
424
425 cx.spawn_in(window, async move |this, cx| {
426 let filtered_models = cx
427 .background_spawn(async move {
428 let matches = |info: &ModelInfo| {
429 info.model
430 .name()
431 .0
432 .to_lowercase()
433 .contains(&query.to_lowercase())
434 };
435
436 let recommended_models = all_models
437 .recommended
438 .iter()
439 .filter(|r| {
440 configured_providers.contains(&r.model.provider_id()) && matches(r)
441 })
442 .cloned()
443 .collect();
444 let mut other_models = IndexMap::default();
445 for (provider_id, models) in &all_models.other {
446 if configured_providers.contains(&provider_id) {
447 other_models.insert(
448 provider_id.clone(),
449 models
450 .iter()
451 .filter(|m| matches(m))
452 .cloned()
453 .collect::<Vec<_>>(),
454 );
455 }
456 }
457 GroupedModels {
458 recommended: recommended_models,
459 other: other_models,
460 }
461 })
462 .await;
463
464 this.update_in(cx, |this, window, cx| {
465 this.delegate.filtered_entries = filtered_models.entries();
466 // Preserve selection focus
467 let new_index = if current_index >= this.delegate.filtered_entries.len() {
468 0
469 } else {
470 current_index
471 };
472 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
473 cx.notify();
474 })
475 .ok();
476 })
477 }
478
479 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
480 if let Some(LanguageModelPickerEntry::Model(model_info)) =
481 self.filtered_entries.get(self.selected_index)
482 {
483 let model = model_info.model.clone();
484 (self.on_model_changed)(model.clone(), cx);
485
486 let current_index = self.selected_index;
487 self.set_selected_index(current_index, window, cx);
488
489 cx.emit(DismissEvent);
490 }
491 }
492
493 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
494 self.language_model_selector
495 .update(cx, |_this, cx| cx.emit(DismissEvent))
496 .ok();
497 }
498
499 fn render_match(
500 &self,
501 ix: usize,
502 selected: bool,
503 _: &mut Window,
504 cx: &mut Context<Picker<Self>>,
505 ) -> Option<Self::ListItem> {
506 match self.filtered_entries.get(ix)? {
507 LanguageModelPickerEntry::Separator(title) => Some(
508 div()
509 .px_2()
510 .pb_1()
511 .when(ix > 1, |this| {
512 this.mt_1()
513 .pt_2()
514 .border_t_1()
515 .border_color(cx.theme().colors().border_variant)
516 })
517 .child(
518 Label::new(title)
519 .size(LabelSize::XSmall)
520 .color(Color::Muted),
521 )
522 .into_any_element(),
523 ),
524 LanguageModelPickerEntry::Model(model_info) => {
525 let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx);
526
527 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
528 let active_model_id = active_model.map(|m| m.model.id());
529
530 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
531 && Some(model_info.model.id()) == active_model_id;
532
533 let model_icon_color = if is_selected {
534 Color::Accent
535 } else {
536 Color::Muted
537 };
538
539 Some(
540 ListItem::new(ix)
541 .inset(true)
542 .spacing(ListItemSpacing::Sparse)
543 .toggle_state(selected)
544 .start_slot(
545 Icon::new(model_info.icon)
546 .color(model_icon_color)
547 .size(IconSize::Small),
548 )
549 .child(
550 h_flex()
551 .w_full()
552 .pl_0p5()
553 .gap_1p5()
554 .w(px(240.))
555 .child(Label::new(model_info.model.name().0.clone()).truncate()),
556 )
557 .end_slot(div().pr_3().when(is_selected, |this| {
558 this.child(
559 Icon::new(IconName::Check)
560 .color(Color::Accent)
561 .size(IconSize::Small),
562 )
563 }))
564 .into_any_element(),
565 )
566 }
567 }
568 }
569
570 fn render_footer(
571 &self,
572 _: &mut Window,
573 cx: &mut Context<Picker<Self>>,
574 ) -> Option<gpui::AnyElement> {
575 use feature_flags::FeatureFlagAppExt;
576
577 let plan = proto::Plan::ZedPro;
578
579 Some(
580 h_flex()
581 .w_full()
582 .border_t_1()
583 .border_color(cx.theme().colors().border_variant)
584 .p_1()
585 .gap_4()
586 .justify_between()
587 .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
588 this.child(match plan {
589 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
590 .icon(IconName::ZedAssistant)
591 .icon_size(IconSize::Small)
592 .icon_color(Color::Muted)
593 .icon_position(IconPosition::Start)
594 .on_click(|_, window, cx| {
595 window
596 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
597 }),
598 Plan::Free | Plan::ZedProTrial => Button::new(
599 "try-pro",
600 if plan == Plan::ZedProTrial {
601 "Upgrade to Pro"
602 } else {
603 "Try Pro"
604 },
605 )
606 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
607 })
608 })
609 .child(
610 Button::new("configure", "Configure")
611 .icon(IconName::Settings)
612 .icon_size(IconSize::Small)
613 .icon_color(Color::Muted)
614 .icon_position(IconPosition::Start)
615 .on_click(|_, window, cx| {
616 let configure_action = if cx.has_flag::<Assistant2FeatureFlag>() {
617 zed_actions::agent::OpenConfiguration.boxed_clone()
618 } else {
619 zed_actions::assistant::ShowConfiguration.boxed_clone()
620 };
621
622 window.dispatch_action(configure_action, cx);
623 }),
624 )
625 .into_any(),
626 )
627 }
628}