1use std::{cmp::Reverse, sync::Arc};
2
3use collections::IndexMap;
4use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
5use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
6use language_model::{
7 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
8 LanguageModelRegistry,
9};
10use ordered_float::OrderedFloat;
11use picker::{Picker, PickerDelegate};
12use ui::{ListItem, ListItemSpacing, prelude::*};
13
14type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
15type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
16
17pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
18
19pub fn language_model_selector(
20 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
21 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
22 popover_styles: bool,
23 window: &mut Window,
24 cx: &mut Context<LanguageModelSelector>,
25) -> LanguageModelSelector {
26 let delegate = LanguageModelPickerDelegate::new(
27 get_active_model,
28 on_model_changed,
29 popover_styles,
30 window,
31 cx,
32 );
33
34 if popover_styles {
35 Picker::list(delegate, window, cx)
36 .show_scrollbar(true)
37 .width(rems(20.))
38 .max_height(Some(rems(20.).into()))
39 } else {
40 Picker::list(delegate, window, cx).show_scrollbar(true)
41 }
42}
43
44fn all_models(cx: &App) -> GroupedModels {
45 let providers = LanguageModelRegistry::global(cx).read(cx).providers();
46
47 let recommended = providers
48 .iter()
49 .flat_map(|provider| {
50 provider
51 .recommended_models(cx)
52 .into_iter()
53 .map(|model| ModelInfo {
54 model,
55 icon: provider.icon(),
56 })
57 })
58 .collect();
59
60 let all = providers
61 .iter()
62 .flat_map(|provider| {
63 provider
64 .provided_models(cx)
65 .into_iter()
66 .map(|model| ModelInfo {
67 model,
68 icon: provider.icon(),
69 })
70 })
71 .collect();
72
73 GroupedModels::new(all, recommended)
74}
75
76#[derive(Clone)]
77struct ModelInfo {
78 model: Arc<dyn LanguageModel>,
79 icon: IconName,
80}
81
82pub struct LanguageModelPickerDelegate {
83 on_model_changed: OnModelChanged,
84 get_active_model: GetActiveModel,
85 all_models: Arc<GroupedModels>,
86 filtered_entries: Vec<LanguageModelPickerEntry>,
87 selected_index: usize,
88 _authenticate_all_providers_task: Task<()>,
89 _subscriptions: Vec<Subscription>,
90 popover_styles: bool,
91}
92
93impl LanguageModelPickerDelegate {
94 fn new(
95 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
96 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
97 popover_styles: bool,
98 window: &mut Window,
99 cx: &mut Context<Picker<Self>>,
100 ) -> Self {
101 let on_model_changed = Arc::new(on_model_changed);
102 let models = all_models(cx);
103 let entries = models.entries();
104
105 Self {
106 on_model_changed,
107 all_models: Arc::new(models),
108 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
109 filtered_entries: entries,
110 get_active_model: Arc::new(get_active_model),
111 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
112 _subscriptions: vec![cx.subscribe_in(
113 &LanguageModelRegistry::global(cx),
114 window,
115 |picker, _, event, window, cx| {
116 match event {
117 language_model::Event::ProviderStateChanged(_)
118 | language_model::Event::AddedProvider(_)
119 | language_model::Event::RemovedProvider(_) => {
120 let query = picker.query(cx);
121 picker.delegate.all_models = Arc::new(all_models(cx));
122 // Update matches will automatically drop the previous task
123 // if we get a provider event again
124 picker.update_matches(query, window, cx)
125 }
126 _ => {}
127 }
128 },
129 )],
130 popover_styles,
131 }
132 }
133
134 fn get_active_model_index(
135 entries: &[LanguageModelPickerEntry],
136 active_model: Option<ConfiguredModel>,
137 ) -> usize {
138 entries
139 .iter()
140 .position(|entry| {
141 if let LanguageModelPickerEntry::Model(model) = entry {
142 active_model
143 .as_ref()
144 .map(|active_model| {
145 active_model.model.id() == model.model.id()
146 && active_model.provider.id() == model.model.provider_id()
147 })
148 .unwrap_or_default()
149 } else {
150 false
151 }
152 })
153 .unwrap_or(0)
154 }
155
156 /// Authenticates all providers in the [`LanguageModelRegistry`].
157 ///
158 /// We do this so that we can populate the language selector with all of the
159 /// models from the configured providers.
160 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
161 let authenticate_all_providers = LanguageModelRegistry::global(cx)
162 .read(cx)
163 .providers()
164 .iter()
165 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
166 .collect::<Vec<_>>();
167
168 cx.spawn(async move |_cx| {
169 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
170 if let Err(err) = authenticate_task.await {
171 if matches!(err, AuthenticateError::CredentialsNotFound) {
172 // Since we're authenticating these providers in the
173 // background for the purposes of populating the
174 // language selector, we don't care about providers
175 // where the credentials are not found.
176 } else {
177 // Some providers have noisy failure states that we
178 // don't want to spam the logs with every time the
179 // language model selector is initialized.
180 //
181 // Ideally these should have more clear failure modes
182 // that we know are safe to ignore here, like what we do
183 // with `CredentialsNotFound` above.
184 match provider_id.0.as_ref() {
185 "lmstudio" | "ollama" => {
186 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
187 //
188 // These fail noisily, so we don't log them.
189 }
190 "copilot_chat" => {
191 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
192 }
193 _ => {
194 log::error!(
195 "Failed to authenticate provider: {}: {err:#}",
196 provider_name.0
197 );
198 }
199 }
200 }
201 }
202 }
203 })
204 }
205
206 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
207 (self.get_active_model)(cx)
208 }
209}
210
211struct GroupedModels {
212 recommended: Vec<ModelInfo>,
213 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
214}
215
216impl GroupedModels {
217 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
218 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
219 for model in all {
220 let provider = model.model.provider_id();
221 if let Some(models) = all_by_provider.get_mut(&provider) {
222 models.push(model);
223 } else {
224 all_by_provider.insert(provider, vec![model]);
225 }
226 }
227
228 Self {
229 recommended,
230 all: all_by_provider,
231 }
232 }
233
234 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
235 let mut entries = Vec::new();
236
237 if !self.recommended.is_empty() {
238 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
239 entries.extend(
240 self.recommended
241 .iter()
242 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
243 );
244 }
245
246 for models in self.all.values() {
247 if models.is_empty() {
248 continue;
249 }
250 entries.push(LanguageModelPickerEntry::Separator(
251 models[0].model.provider_name().0,
252 ));
253 entries.extend(
254 models
255 .iter()
256 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
257 );
258 }
259 entries
260 }
261}
262
263enum LanguageModelPickerEntry {
264 Model(ModelInfo),
265 Separator(SharedString),
266}
267
268struct ModelMatcher {
269 models: Vec<ModelInfo>,
270 bg_executor: BackgroundExecutor,
271 candidates: Vec<StringMatchCandidate>,
272}
273
274impl ModelMatcher {
275 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
276 let candidates = Self::make_match_candidates(&models);
277 Self {
278 models,
279 bg_executor,
280 candidates,
281 }
282 }
283
284 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
285 let mut matches = self.bg_executor.block(match_strings(
286 &self.candidates,
287 query,
288 false,
289 true,
290 100,
291 &Default::default(),
292 self.bg_executor.clone(),
293 ));
294
295 let sorting_key = |mat: &StringMatch| {
296 let candidate = &self.candidates[mat.candidate_id];
297 (Reverse(OrderedFloat(mat.score)), candidate.id)
298 };
299 matches.sort_unstable_by_key(sorting_key);
300
301 let matched_models: Vec<_> = matches
302 .into_iter()
303 .map(|mat| self.models[mat.candidate_id].clone())
304 .collect();
305
306 matched_models
307 }
308
309 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
310 self.models
311 .iter()
312 .filter(|m| {
313 m.model
314 .name()
315 .0
316 .to_lowercase()
317 .contains(&query.to_lowercase())
318 })
319 .cloned()
320 .collect::<Vec<_>>()
321 }
322
323 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
324 model_infos
325 .iter()
326 .enumerate()
327 .map(|(index, model)| {
328 StringMatchCandidate::new(
329 index,
330 &format!(
331 "{}/{}",
332 &model.model.provider_name().0,
333 &model.model.name().0
334 ),
335 )
336 })
337 .collect::<Vec<_>>()
338 }
339}
340
341impl PickerDelegate for LanguageModelPickerDelegate {
342 type ListItem = AnyElement;
343
344 fn match_count(&self) -> usize {
345 self.filtered_entries.len()
346 }
347
348 fn selected_index(&self) -> usize {
349 self.selected_index
350 }
351
352 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
353 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
354 cx.notify();
355 }
356
357 fn can_select(
358 &mut self,
359 ix: usize,
360 _window: &mut Window,
361 _cx: &mut Context<Picker<Self>>,
362 ) -> bool {
363 match self.filtered_entries.get(ix) {
364 Some(LanguageModelPickerEntry::Model(_)) => true,
365 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
366 }
367 }
368
369 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
370 "Select a model…".into()
371 }
372
373 fn update_matches(
374 &mut self,
375 query: String,
376 window: &mut Window,
377 cx: &mut Context<Picker<Self>>,
378 ) -> Task<()> {
379 let all_models = self.all_models.clone();
380 let active_model = (self.get_active_model)(cx);
381 let bg_executor = cx.background_executor();
382
383 let language_model_registry = LanguageModelRegistry::global(cx);
384
385 let configured_providers = language_model_registry
386 .read(cx)
387 .providers()
388 .into_iter()
389 .filter(|provider| provider.is_authenticated(cx))
390 .collect::<Vec<_>>();
391
392 let configured_provider_ids = configured_providers
393 .iter()
394 .map(|provider| provider.id())
395 .collect::<Vec<_>>();
396
397 let recommended_models = all_models
398 .recommended
399 .iter()
400 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
401 .cloned()
402 .collect::<Vec<_>>();
403
404 let available_models = all_models
405 .all
406 .values()
407 .flat_map(|models| models.iter())
408 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
409 .cloned()
410 .collect::<Vec<_>>();
411
412 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
413 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
414
415 let recommended = matcher_rec.exact_search(&query);
416 let all = matcher_all.fuzzy_search(&query);
417
418 let filtered_models = GroupedModels::new(all, recommended);
419
420 cx.spawn_in(window, async move |this, cx| {
421 this.update_in(cx, |this, window, cx| {
422 this.delegate.filtered_entries = filtered_models.entries();
423 // Finds the currently selected model in the list
424 let new_index =
425 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
426 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
427 cx.notify();
428 })
429 .ok();
430 })
431 }
432
433 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
434 if let Some(LanguageModelPickerEntry::Model(model_info)) =
435 self.filtered_entries.get(self.selected_index)
436 {
437 let model = model_info.model.clone();
438 (self.on_model_changed)(model.clone(), cx);
439
440 let current_index = self.selected_index;
441 self.set_selected_index(current_index, window, cx);
442
443 cx.emit(DismissEvent);
444 }
445 }
446
447 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
448 cx.emit(DismissEvent);
449 }
450
451 fn render_match(
452 &self,
453 ix: usize,
454 selected: bool,
455 _: &mut Window,
456 cx: &mut Context<Picker<Self>>,
457 ) -> Option<Self::ListItem> {
458 match self.filtered_entries.get(ix)? {
459 LanguageModelPickerEntry::Separator(title) => Some(
460 div()
461 .px_2()
462 .pb_1()
463 .when(ix > 1, |this| {
464 this.mt_1()
465 .pt_2()
466 .border_t_1()
467 .border_color(cx.theme().colors().border_variant)
468 })
469 .child(
470 Label::new(title)
471 .size(LabelSize::XSmall)
472 .color(Color::Muted),
473 )
474 .into_any_element(),
475 ),
476 LanguageModelPickerEntry::Model(model_info) => {
477 let active_model = (self.get_active_model)(cx);
478 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
479 let active_model_id = active_model.map(|m| m.model.id());
480
481 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
482 && Some(model_info.model.id()) == active_model_id;
483
484 let model_icon_color = if is_selected {
485 Color::Accent
486 } else {
487 Color::Muted
488 };
489
490 Some(
491 ListItem::new(ix)
492 .inset(true)
493 .spacing(ListItemSpacing::Sparse)
494 .toggle_state(selected)
495 .child(
496 h_flex()
497 .w_full()
498 .gap_1p5()
499 .child(
500 Icon::new(model_info.icon)
501 .color(model_icon_color)
502 .size(IconSize::Small),
503 )
504 .child(Label::new(model_info.model.name().0).truncate()),
505 )
506 .end_slot(div().pr_3().when(is_selected, |this| {
507 this.child(
508 Icon::new(IconName::Check)
509 .color(Color::Accent)
510 .size(IconSize::Small),
511 )
512 }))
513 .into_any_element(),
514 )
515 }
516 }
517 }
518
519 fn render_footer(
520 &self,
521 _window: &mut Window,
522 cx: &mut Context<Picker<Self>>,
523 ) -> Option<gpui::AnyElement> {
524 if !self.popover_styles {
525 return None;
526 }
527
528 Some(
529 h_flex()
530 .w_full()
531 .border_t_1()
532 .border_color(cx.theme().colors().border_variant)
533 .p_1()
534 .gap_4()
535 .justify_between()
536 .child(
537 Button::new("configure", "Configure")
538 .icon(IconName::Settings)
539 .icon_size(IconSize::Small)
540 .icon_color(Color::Muted)
541 .icon_position(IconPosition::Start)
542 .on_click(|_, window, cx| {
543 window.dispatch_action(
544 zed_actions::agent::OpenSettings.boxed_clone(),
545 cx,
546 );
547 }),
548 )
549 .into_any(),
550 )
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557 use futures::{future::BoxFuture, stream::BoxStream};
558 use gpui::{AsyncApp, TestAppContext, http_client};
559 use language_model::{
560 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
561 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
562 LanguageModelRequest, LanguageModelToolChoice,
563 };
564 use ui::IconName;
565
566 #[derive(Clone)]
567 struct TestLanguageModel {
568 name: LanguageModelName,
569 id: LanguageModelId,
570 provider_id: LanguageModelProviderId,
571 provider_name: LanguageModelProviderName,
572 }
573
574 impl TestLanguageModel {
575 fn new(name: &str, provider: &str) -> Self {
576 Self {
577 name: LanguageModelName::from(name.to_string()),
578 id: LanguageModelId::from(name.to_string()),
579 provider_id: LanguageModelProviderId::from(provider.to_string()),
580 provider_name: LanguageModelProviderName::from(provider.to_string()),
581 }
582 }
583 }
584
585 impl LanguageModel for TestLanguageModel {
586 fn id(&self) -> LanguageModelId {
587 self.id.clone()
588 }
589
590 fn name(&self) -> LanguageModelName {
591 self.name.clone()
592 }
593
594 fn provider_id(&self) -> LanguageModelProviderId {
595 self.provider_id.clone()
596 }
597
598 fn provider_name(&self) -> LanguageModelProviderName {
599 self.provider_name.clone()
600 }
601
602 fn supports_tools(&self) -> bool {
603 false
604 }
605
606 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
607 false
608 }
609
610 fn supports_images(&self) -> bool {
611 false
612 }
613
614 fn telemetry_id(&self) -> String {
615 format!("{}/{}", self.provider_id.0, self.name.0)
616 }
617
618 fn max_token_count(&self) -> u64 {
619 1000
620 }
621
622 fn count_tokens(
623 &self,
624 _: LanguageModelRequest,
625 _: &App,
626 ) -> BoxFuture<'static, http_client::Result<u64>> {
627 unimplemented!()
628 }
629
630 fn stream_completion(
631 &self,
632 _: LanguageModelRequest,
633 _: &AsyncApp,
634 ) -> BoxFuture<
635 'static,
636 Result<
637 BoxStream<
638 'static,
639 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
640 >,
641 LanguageModelCompletionError,
642 >,
643 > {
644 unimplemented!()
645 }
646 }
647
648 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
649 model_specs
650 .into_iter()
651 .map(|(provider, name)| ModelInfo {
652 model: Arc::new(TestLanguageModel::new(name, provider)),
653 icon: IconName::Ai,
654 })
655 .collect()
656 }
657
658 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
659 assert_eq!(
660 result.len(),
661 expected.len(),
662 "Number of models doesn't match"
663 );
664
665 for (i, expected_name) in expected.iter().enumerate() {
666 assert_eq!(
667 result[i].model.telemetry_id(),
668 *expected_name,
669 "Model at position {} doesn't match expected model",
670 i
671 );
672 }
673 }
674
675 #[gpui::test]
676 fn test_exact_match(cx: &mut TestAppContext) {
677 let models = create_models(vec![
678 ("zed", "Claude 3.7 Sonnet"),
679 ("zed", "Claude 3.7 Sonnet Thinking"),
680 ("zed", "gpt-4.1"),
681 ("zed", "gpt-4.1-nano"),
682 ("openai", "gpt-3.5-turbo"),
683 ("openai", "gpt-4.1"),
684 ("openai", "gpt-4.1-nano"),
685 ("ollama", "mistral"),
686 ("ollama", "deepseek"),
687 ]);
688 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
689
690 // The order of models should be maintained, case doesn't matter
691 let results = matcher.exact_search("GPT-4.1");
692 assert_models_eq(
693 results,
694 vec![
695 "zed/gpt-4.1",
696 "zed/gpt-4.1-nano",
697 "openai/gpt-4.1",
698 "openai/gpt-4.1-nano",
699 ],
700 );
701 }
702
703 #[gpui::test]
704 fn test_fuzzy_match(cx: &mut TestAppContext) {
705 let models = create_models(vec![
706 ("zed", "Claude 3.7 Sonnet"),
707 ("zed", "Claude 3.7 Sonnet Thinking"),
708 ("zed", "gpt-4.1"),
709 ("zed", "gpt-4.1-nano"),
710 ("openai", "gpt-3.5-turbo"),
711 ("openai", "gpt-4.1"),
712 ("openai", "gpt-4.1-nano"),
713 ("ollama", "mistral"),
714 ("ollama", "deepseek"),
715 ]);
716 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
717
718 // Results should preserve models order whenever possible.
719 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
720 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
721 // so it should appear first in the results.
722 let results = matcher.fuzzy_search("41");
723 assert_models_eq(
724 results,
725 vec![
726 "zed/gpt-4.1",
727 "openai/gpt-4.1",
728 "zed/gpt-4.1-nano",
729 "openai/gpt-4.1-nano",
730 ],
731 );
732
733 // Model provider should be searchable as well
734 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
735 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
736
737 // Fuzzy search
738 let results = matcher.fuzzy_search("z4n");
739 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
740 }
741
742 #[gpui::test]
743 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
744 let recommended_models = create_models(vec![("zed", "claude")]);
745 let all_models = create_models(vec![
746 ("zed", "claude"), // Should also appear in "other"
747 ("zed", "gemini"),
748 ("copilot", "o3"),
749 ]);
750
751 let grouped_models = GroupedModels::new(all_models, recommended_models);
752
753 let actual_all_models = grouped_models
754 .all
755 .values()
756 .flatten()
757 .cloned()
758 .collect::<Vec<_>>();
759
760 // Recommended models should also appear in "all"
761 assert_models_eq(
762 actual_all_models,
763 vec!["zed/claude", "zed/gemini", "copilot/o3"],
764 );
765 }
766
767 #[gpui::test]
768 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
769 let recommended_models = create_models(vec![("zed", "claude")]);
770 let all_models = create_models(vec![
771 ("zed", "claude"), // Should also appear in "other"
772 ("zed", "gemini"),
773 ("copilot", "claude"), // Different provider, should appear in "other"
774 ]);
775
776 let grouped_models = GroupedModels::new(all_models, recommended_models);
777
778 let actual_all_models = grouped_models
779 .all
780 .values()
781 .flatten()
782 .cloned()
783 .collect::<Vec<_>>();
784
785 // All models should appear in "all" regardless of recommended status
786 assert_models_eq(
787 actual_all_models,
788 vec!["zed/claude", "zed/gemini", "copilot/claude"],
789 );
790 }
791}