1use std::{cmp::Reverse, sync::Arc};
2
3use collections::{HashSet, IndexMap};
4use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
5use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
6use language_model::{
7 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
8 LanguageModelRegistry,
9};
10use ordered_float::OrderedFloat;
11use picker::{Picker, PickerDelegate};
12use ui::{ListItem, ListItemSpacing, prelude::*};
13
14type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
15type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
16
17pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
18
19pub fn language_model_selector(
20 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
21 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
22 window: &mut Window,
23 cx: &mut Context<LanguageModelSelector>,
24) -> LanguageModelSelector {
25 let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
26 Picker::list(delegate, window, cx)
27 .show_scrollbar(true)
28 .width(rems(20.))
29 .max_height(Some(rems(20.).into()))
30}
31
32fn all_models(cx: &App) -> GroupedModels {
33 let providers = LanguageModelRegistry::global(cx).read(cx).providers();
34
35 let recommended = providers
36 .iter()
37 .flat_map(|provider| {
38 provider
39 .recommended_models(cx)
40 .into_iter()
41 .map(|model| ModelInfo {
42 model,
43 icon: provider.icon(),
44 })
45 })
46 .collect();
47
48 let other = providers
49 .iter()
50 .flat_map(|provider| {
51 provider
52 .provided_models(cx)
53 .into_iter()
54 .map(|model| ModelInfo {
55 model,
56 icon: provider.icon(),
57 })
58 })
59 .collect();
60
61 GroupedModels::new(other, recommended)
62}
63
64#[derive(Clone)]
65struct ModelInfo {
66 model: Arc<dyn LanguageModel>,
67 icon: IconName,
68}
69
70pub struct LanguageModelPickerDelegate {
71 on_model_changed: OnModelChanged,
72 get_active_model: GetActiveModel,
73 all_models: Arc<GroupedModels>,
74 filtered_entries: Vec<LanguageModelPickerEntry>,
75 selected_index: usize,
76 _authenticate_all_providers_task: Task<()>,
77 _subscriptions: Vec<Subscription>,
78}
79
80impl LanguageModelPickerDelegate {
81 fn new(
82 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
83 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
84 window: &mut Window,
85 cx: &mut Context<Picker<Self>>,
86 ) -> Self {
87 let on_model_changed = Arc::new(on_model_changed);
88 let models = all_models(cx);
89 let entries = models.entries();
90
91 Self {
92 on_model_changed,
93 all_models: Arc::new(models),
94 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
95 filtered_entries: entries,
96 get_active_model: Arc::new(get_active_model),
97 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
98 _subscriptions: vec![cx.subscribe_in(
99 &LanguageModelRegistry::global(cx),
100 window,
101 |picker, _, event, window, cx| {
102 match event {
103 language_model::Event::ProviderStateChanged(_)
104 | language_model::Event::AddedProvider(_)
105 | language_model::Event::RemovedProvider(_) => {
106 let query = picker.query(cx);
107 picker.delegate.all_models = Arc::new(all_models(cx));
108 // Update matches will automatically drop the previous task
109 // if we get a provider event again
110 picker.update_matches(query, window, cx)
111 }
112 _ => {}
113 }
114 },
115 )],
116 }
117 }
118
119 fn get_active_model_index(
120 entries: &[LanguageModelPickerEntry],
121 active_model: Option<ConfiguredModel>,
122 ) -> usize {
123 entries
124 .iter()
125 .position(|entry| {
126 if let LanguageModelPickerEntry::Model(model) = entry {
127 active_model
128 .as_ref()
129 .map(|active_model| {
130 active_model.model.id() == model.model.id()
131 && active_model.provider.id() == model.model.provider_id()
132 })
133 .unwrap_or_default()
134 } else {
135 false
136 }
137 })
138 .unwrap_or(0)
139 }
140
141 /// Authenticates all providers in the [`LanguageModelRegistry`].
142 ///
143 /// We do this so that we can populate the language selector with all of the
144 /// models from the configured providers.
145 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
146 let authenticate_all_providers = LanguageModelRegistry::global(cx)
147 .read(cx)
148 .providers()
149 .iter()
150 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
151 .collect::<Vec<_>>();
152
153 cx.spawn(async move |_cx| {
154 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
155 if let Err(err) = authenticate_task.await {
156 if matches!(err, AuthenticateError::CredentialsNotFound) {
157 // Since we're authenticating these providers in the
158 // background for the purposes of populating the
159 // language selector, we don't care about providers
160 // where the credentials are not found.
161 } else {
162 // Some providers have noisy failure states that we
163 // don't want to spam the logs with every time the
164 // language model selector is initialized.
165 //
166 // Ideally these should have more clear failure modes
167 // that we know are safe to ignore here, like what we do
168 // with `CredentialsNotFound` above.
169 match provider_id.0.as_ref() {
170 "lmstudio" | "ollama" => {
171 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
172 //
173 // These fail noisily, so we don't log them.
174 }
175 "copilot_chat" => {
176 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
177 }
178 _ => {
179 log::error!(
180 "Failed to authenticate provider: {}: {err}",
181 provider_name.0
182 );
183 }
184 }
185 }
186 }
187 }
188 })
189 }
190
191 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
192 (self.get_active_model)(cx)
193 }
194}
195
196struct GroupedModels {
197 recommended: Vec<ModelInfo>,
198 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
199}
200
201impl GroupedModels {
202 pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
203 let recommended_ids = recommended
204 .iter()
205 .map(|info| (info.model.provider_id(), info.model.id()))
206 .collect::<HashSet<_>>();
207
208 let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
209 for model in other {
210 if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
211 continue;
212 }
213
214 let provider = model.model.provider_id();
215 if let Some(models) = other_by_provider.get_mut(&provider) {
216 models.push(model);
217 } else {
218 other_by_provider.insert(provider, vec![model]);
219 }
220 }
221
222 Self {
223 recommended,
224 other: other_by_provider,
225 }
226 }
227
228 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
229 let mut entries = Vec::new();
230
231 if !self.recommended.is_empty() {
232 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
233 entries.extend(
234 self.recommended
235 .iter()
236 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
237 );
238 }
239
240 for models in self.other.values() {
241 if models.is_empty() {
242 continue;
243 }
244 entries.push(LanguageModelPickerEntry::Separator(
245 models[0].model.provider_name().0,
246 ));
247 entries.extend(
248 models
249 .iter()
250 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
251 );
252 }
253 entries
254 }
255
256 fn model_infos(&self) -> Vec<ModelInfo> {
257 let other = self
258 .other
259 .values()
260 .flat_map(|model| model.iter())
261 .cloned()
262 .collect::<Vec<_>>();
263 self.recommended
264 .iter()
265 .chain(&other)
266 .cloned()
267 .collect::<Vec<_>>()
268 }
269}
270
271enum LanguageModelPickerEntry {
272 Model(ModelInfo),
273 Separator(SharedString),
274}
275
276struct ModelMatcher {
277 models: Vec<ModelInfo>,
278 bg_executor: BackgroundExecutor,
279 candidates: Vec<StringMatchCandidate>,
280}
281
282impl ModelMatcher {
283 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
284 let candidates = Self::make_match_candidates(&models);
285 Self {
286 models,
287 bg_executor,
288 candidates,
289 }
290 }
291
292 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
293 let mut matches = self.bg_executor.block(match_strings(
294 &self.candidates,
295 query,
296 false,
297 true,
298 100,
299 &Default::default(),
300 self.bg_executor.clone(),
301 ));
302
303 let sorting_key = |mat: &StringMatch| {
304 let candidate = &self.candidates[mat.candidate_id];
305 (Reverse(OrderedFloat(mat.score)), candidate.id)
306 };
307 matches.sort_unstable_by_key(sorting_key);
308
309 let matched_models: Vec<_> = matches
310 .into_iter()
311 .map(|mat| self.models[mat.candidate_id].clone())
312 .collect();
313
314 matched_models
315 }
316
317 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
318 self.models
319 .iter()
320 .filter(|m| {
321 m.model
322 .name()
323 .0
324 .to_lowercase()
325 .contains(&query.to_lowercase())
326 })
327 .cloned()
328 .collect::<Vec<_>>()
329 }
330
331 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
332 model_infos
333 .iter()
334 .enumerate()
335 .map(|(index, model)| {
336 StringMatchCandidate::new(
337 index,
338 &format!(
339 "{}/{}",
340 &model.model.provider_name().0,
341 &model.model.name().0
342 ),
343 )
344 })
345 .collect::<Vec<_>>()
346 }
347}
348
349impl PickerDelegate for LanguageModelPickerDelegate {
350 type ListItem = AnyElement;
351
352 fn match_count(&self) -> usize {
353 self.filtered_entries.len()
354 }
355
356 fn selected_index(&self) -> usize {
357 self.selected_index
358 }
359
360 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
361 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
362 cx.notify();
363 }
364
365 fn can_select(
366 &mut self,
367 ix: usize,
368 _window: &mut Window,
369 _cx: &mut Context<Picker<Self>>,
370 ) -> bool {
371 match self.filtered_entries.get(ix) {
372 Some(LanguageModelPickerEntry::Model(_)) => true,
373 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
374 }
375 }
376
377 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
378 "Select a model…".into()
379 }
380
381 fn update_matches(
382 &mut self,
383 query: String,
384 window: &mut Window,
385 cx: &mut Context<Picker<Self>>,
386 ) -> Task<()> {
387 let all_models = self.all_models.clone();
388 let active_model = (self.get_active_model)(cx);
389 let bg_executor = cx.background_executor();
390
391 let language_model_registry = LanguageModelRegistry::global(cx);
392
393 let configured_providers = language_model_registry
394 .read(cx)
395 .providers()
396 .into_iter()
397 .filter(|provider| provider.is_authenticated(cx))
398 .collect::<Vec<_>>();
399
400 let configured_provider_ids = configured_providers
401 .iter()
402 .map(|provider| provider.id())
403 .collect::<Vec<_>>();
404
405 let recommended_models = all_models
406 .recommended
407 .iter()
408 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
409 .cloned()
410 .collect::<Vec<_>>();
411
412 let available_models = all_models
413 .model_infos()
414 .iter()
415 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
416 .cloned()
417 .collect::<Vec<_>>();
418
419 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
420 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
421
422 let recommended = matcher_rec.exact_search(&query);
423 let all = matcher_all.fuzzy_search(&query);
424
425 let filtered_models = GroupedModels::new(all, recommended);
426
427 cx.spawn_in(window, async move |this, cx| {
428 this.update_in(cx, |this, window, cx| {
429 this.delegate.filtered_entries = filtered_models.entries();
430 // Finds the currently selected model in the list
431 let new_index =
432 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
433 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
434 cx.notify();
435 })
436 .ok();
437 })
438 }
439
440 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
441 if let Some(LanguageModelPickerEntry::Model(model_info)) =
442 self.filtered_entries.get(self.selected_index)
443 {
444 let model = model_info.model.clone();
445 (self.on_model_changed)(model.clone(), cx);
446
447 let current_index = self.selected_index;
448 self.set_selected_index(current_index, window, cx);
449
450 cx.emit(DismissEvent);
451 }
452 }
453
454 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
455 cx.emit(DismissEvent);
456 }
457
458 fn render_match(
459 &self,
460 ix: usize,
461 selected: bool,
462 _: &mut Window,
463 cx: &mut Context<Picker<Self>>,
464 ) -> Option<Self::ListItem> {
465 match self.filtered_entries.get(ix)? {
466 LanguageModelPickerEntry::Separator(title) => Some(
467 div()
468 .px_2()
469 .pb_1()
470 .when(ix > 1, |this| {
471 this.mt_1()
472 .pt_2()
473 .border_t_1()
474 .border_color(cx.theme().colors().border_variant)
475 })
476 .child(
477 Label::new(title)
478 .size(LabelSize::XSmall)
479 .color(Color::Muted),
480 )
481 .into_any_element(),
482 ),
483 LanguageModelPickerEntry::Model(model_info) => {
484 let active_model = (self.get_active_model)(cx);
485 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
486 let active_model_id = active_model.map(|m| m.model.id());
487
488 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
489 && Some(model_info.model.id()) == active_model_id;
490
491 let model_icon_color = if is_selected {
492 Color::Accent
493 } else {
494 Color::Muted
495 };
496
497 Some(
498 ListItem::new(ix)
499 .inset(true)
500 .spacing(ListItemSpacing::Sparse)
501 .toggle_state(selected)
502 .start_slot(
503 Icon::new(model_info.icon)
504 .color(model_icon_color)
505 .size(IconSize::Small),
506 )
507 .child(
508 h_flex()
509 .w_full()
510 .pl_0p5()
511 .gap_1p5()
512 .w(px(240.))
513 .child(Label::new(model_info.model.name().0).truncate()),
514 )
515 .end_slot(div().pr_3().when(is_selected, |this| {
516 this.child(
517 Icon::new(IconName::Check)
518 .color(Color::Accent)
519 .size(IconSize::Small),
520 )
521 }))
522 .into_any_element(),
523 )
524 }
525 }
526 }
527
528 fn render_footer(
529 &self,
530 _window: &mut Window,
531 cx: &mut Context<Picker<Self>>,
532 ) -> Option<gpui::AnyElement> {
533 Some(
534 h_flex()
535 .w_full()
536 .border_t_1()
537 .border_color(cx.theme().colors().border_variant)
538 .p_1()
539 .gap_4()
540 .justify_between()
541 .child(
542 Button::new("configure", "Configure")
543 .icon(IconName::Settings)
544 .icon_size(IconSize::Small)
545 .icon_color(Color::Muted)
546 .icon_position(IconPosition::Start)
547 .on_click(|_, window, cx| {
548 window.dispatch_action(
549 zed_actions::agent::OpenSettings.boxed_clone(),
550 cx,
551 );
552 }),
553 )
554 .into_any(),
555 )
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use futures::{future::BoxFuture, stream::BoxStream};
563 use gpui::{AsyncApp, TestAppContext, http_client};
564 use language_model::{
565 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
566 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
567 LanguageModelRequest, LanguageModelToolChoice,
568 };
569 use ui::IconName;
570
571 #[derive(Clone)]
572 struct TestLanguageModel {
573 name: LanguageModelName,
574 id: LanguageModelId,
575 provider_id: LanguageModelProviderId,
576 provider_name: LanguageModelProviderName,
577 }
578
579 impl TestLanguageModel {
580 fn new(name: &str, provider: &str) -> Self {
581 Self {
582 name: LanguageModelName::from(name.to_string()),
583 id: LanguageModelId::from(name.to_string()),
584 provider_id: LanguageModelProviderId::from(provider.to_string()),
585 provider_name: LanguageModelProviderName::from(provider.to_string()),
586 }
587 }
588 }
589
590 impl LanguageModel for TestLanguageModel {
591 fn id(&self) -> LanguageModelId {
592 self.id.clone()
593 }
594
595 fn name(&self) -> LanguageModelName {
596 self.name.clone()
597 }
598
599 fn provider_id(&self) -> LanguageModelProviderId {
600 self.provider_id.clone()
601 }
602
603 fn provider_name(&self) -> LanguageModelProviderName {
604 self.provider_name.clone()
605 }
606
607 fn supports_tools(&self) -> bool {
608 false
609 }
610
611 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
612 false
613 }
614
615 fn supports_images(&self) -> bool {
616 false
617 }
618
619 fn telemetry_id(&self) -> String {
620 format!("{}/{}", self.provider_id.0, self.name.0)
621 }
622
623 fn max_token_count(&self) -> u64 {
624 1000
625 }
626
627 fn count_tokens(
628 &self,
629 _: LanguageModelRequest,
630 _: &App,
631 ) -> BoxFuture<'static, http_client::Result<u64>> {
632 unimplemented!()
633 }
634
635 fn stream_completion(
636 &self,
637 _: LanguageModelRequest,
638 _: &AsyncApp,
639 ) -> BoxFuture<
640 'static,
641 Result<
642 BoxStream<
643 'static,
644 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
645 >,
646 LanguageModelCompletionError,
647 >,
648 > {
649 unimplemented!()
650 }
651 }
652
653 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
654 model_specs
655 .into_iter()
656 .map(|(provider, name)| ModelInfo {
657 model: Arc::new(TestLanguageModel::new(name, provider)),
658 icon: IconName::Ai,
659 })
660 .collect()
661 }
662
663 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
664 assert_eq!(
665 result.len(),
666 expected.len(),
667 "Number of models doesn't match"
668 );
669
670 for (i, expected_name) in expected.iter().enumerate() {
671 assert_eq!(
672 result[i].model.telemetry_id(),
673 *expected_name,
674 "Model at position {} doesn't match expected model",
675 i
676 );
677 }
678 }
679
680 #[gpui::test]
681 fn test_exact_match(cx: &mut TestAppContext) {
682 let models = create_models(vec![
683 ("zed", "Claude 3.7 Sonnet"),
684 ("zed", "Claude 3.7 Sonnet Thinking"),
685 ("zed", "gpt-4.1"),
686 ("zed", "gpt-4.1-nano"),
687 ("openai", "gpt-3.5-turbo"),
688 ("openai", "gpt-4.1"),
689 ("openai", "gpt-4.1-nano"),
690 ("ollama", "mistral"),
691 ("ollama", "deepseek"),
692 ]);
693 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
694
695 // The order of models should be maintained, case doesn't matter
696 let results = matcher.exact_search("GPT-4.1");
697 assert_models_eq(
698 results,
699 vec![
700 "zed/gpt-4.1",
701 "zed/gpt-4.1-nano",
702 "openai/gpt-4.1",
703 "openai/gpt-4.1-nano",
704 ],
705 );
706 }
707
708 #[gpui::test]
709 fn test_fuzzy_match(cx: &mut TestAppContext) {
710 let models = create_models(vec![
711 ("zed", "Claude 3.7 Sonnet"),
712 ("zed", "Claude 3.7 Sonnet Thinking"),
713 ("zed", "gpt-4.1"),
714 ("zed", "gpt-4.1-nano"),
715 ("openai", "gpt-3.5-turbo"),
716 ("openai", "gpt-4.1"),
717 ("openai", "gpt-4.1-nano"),
718 ("ollama", "mistral"),
719 ("ollama", "deepseek"),
720 ]);
721 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
722
723 // Results should preserve models order whenever possible.
724 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
725 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
726 // so it should appear first in the results.
727 let results = matcher.fuzzy_search("41");
728 assert_models_eq(
729 results,
730 vec![
731 "zed/gpt-4.1",
732 "openai/gpt-4.1",
733 "zed/gpt-4.1-nano",
734 "openai/gpt-4.1-nano",
735 ],
736 );
737
738 // Model provider should be searchable as well
739 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
740 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
741
742 // Fuzzy search
743 let results = matcher.fuzzy_search("z4n");
744 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
745 }
746
747 #[gpui::test]
748 fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
749 let recommended_models = create_models(vec![("zed", "claude")]);
750 let all_models = create_models(vec![
751 ("zed", "claude"), // Should be filtered out from "other"
752 ("zed", "gemini"),
753 ("copilot", "o3"),
754 ]);
755
756 let grouped_models = GroupedModels::new(all_models, recommended_models);
757
758 let actual_other_models = grouped_models
759 .other
760 .values()
761 .flatten()
762 .cloned()
763 .collect::<Vec<_>>();
764
765 // Recommended models should not appear in "other"
766 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
767 }
768
769 #[gpui::test]
770 fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
771 let recommended_models = create_models(vec![("zed", "claude")]);
772 let all_models = create_models(vec![
773 ("zed", "claude"), // Should be filtered out from "other"
774 ("zed", "gemini"),
775 ("copilot", "claude"), // Should not be filtered out from "other"
776 ]);
777
778 let grouped_models = GroupedModels::new(all_models, recommended_models);
779
780 let actual_other_models = grouped_models
781 .other
782 .values()
783 .flatten()
784 .cloned()
785 .collect::<Vec<_>>();
786
787 // Recommended models should not appear in "other"
788 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
789 }
790}