1use std::{cmp::Reverse, sync::Arc};
2
3use collections::IndexMap;
4use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
5use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
6use language_model::{
7 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
8 LanguageModelRegistry,
9};
10use ordered_float::OrderedFloat;
11use picker::{Picker, PickerDelegate};
12use ui::{ListItem, ListItemSpacing, prelude::*};
13
14type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
15type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
16
17pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
18
19pub fn language_model_selector(
20 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
21 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
22 popover_styles: bool,
23 window: &mut Window,
24 cx: &mut Context<LanguageModelSelector>,
25) -> LanguageModelSelector {
26 let delegate = LanguageModelPickerDelegate::new(
27 get_active_model,
28 on_model_changed,
29 popover_styles,
30 window,
31 cx,
32 );
33
34 if popover_styles {
35 Picker::list(delegate, window, cx)
36 .show_scrollbar(true)
37 .width(rems(20.))
38 .max_height(Some(rems(20.).into()))
39 } else {
40 Picker::list(delegate, window, cx).show_scrollbar(true)
41 }
42}
43
44fn all_models(cx: &App) -> GroupedModels {
45 let providers = LanguageModelRegistry::global(cx).read(cx).providers();
46
47 let recommended = providers
48 .iter()
49 .flat_map(|provider| {
50 provider
51 .recommended_models(cx)
52 .into_iter()
53 .map(|model| ModelInfo {
54 model,
55 icon: provider.icon(),
56 })
57 })
58 .collect();
59
60 let all = providers
61 .iter()
62 .flat_map(|provider| {
63 provider
64 .provided_models(cx)
65 .into_iter()
66 .map(|model| ModelInfo {
67 model,
68 icon: provider.icon(),
69 })
70 })
71 .collect();
72
73 GroupedModels::new(all, recommended)
74}
75
76#[derive(Clone)]
77struct ModelInfo {
78 model: Arc<dyn LanguageModel>,
79 icon: IconName,
80}
81
82pub struct LanguageModelPickerDelegate {
83 on_model_changed: OnModelChanged,
84 get_active_model: GetActiveModel,
85 all_models: Arc<GroupedModels>,
86 filtered_entries: Vec<LanguageModelPickerEntry>,
87 selected_index: usize,
88 _authenticate_all_providers_task: Task<()>,
89 _subscriptions: Vec<Subscription>,
90 popover_styles: bool,
91}
92
93impl LanguageModelPickerDelegate {
94 fn new(
95 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
96 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
97 popover_styles: bool,
98 window: &mut Window,
99 cx: &mut Context<Picker<Self>>,
100 ) -> Self {
101 let on_model_changed = Arc::new(on_model_changed);
102 let models = all_models(cx);
103 let entries = models.entries();
104
105 Self {
106 on_model_changed,
107 all_models: Arc::new(models),
108 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
109 filtered_entries: entries,
110 get_active_model: Arc::new(get_active_model),
111 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
112 _subscriptions: vec![cx.subscribe_in(
113 &LanguageModelRegistry::global(cx),
114 window,
115 |picker, _, event, window, cx| {
116 match event {
117 language_model::Event::ProviderStateChanged(_)
118 | language_model::Event::AddedProvider(_)
119 | language_model::Event::RemovedProvider(_) => {
120 let query = picker.query(cx);
121 picker.delegate.all_models = Arc::new(all_models(cx));
122 // Update matches will automatically drop the previous task
123 // if we get a provider event again
124 picker.update_matches(query, window, cx)
125 }
126 _ => {}
127 }
128 },
129 )],
130 popover_styles,
131 }
132 }
133
134 fn get_active_model_index(
135 entries: &[LanguageModelPickerEntry],
136 active_model: Option<ConfiguredModel>,
137 ) -> usize {
138 entries
139 .iter()
140 .position(|entry| {
141 if let LanguageModelPickerEntry::Model(model) = entry {
142 active_model
143 .as_ref()
144 .map(|active_model| {
145 active_model.model.id() == model.model.id()
146 && active_model.provider.id() == model.model.provider_id()
147 })
148 .unwrap_or_default()
149 } else {
150 false
151 }
152 })
153 .unwrap_or(0)
154 }
155
156 /// Authenticates all providers in the [`LanguageModelRegistry`].
157 ///
158 /// We do this so that we can populate the language selector with all of the
159 /// models from the configured providers.
160 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
161 let authenticate_all_providers = LanguageModelRegistry::global(cx)
162 .read(cx)
163 .providers()
164 .iter()
165 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
166 .collect::<Vec<_>>();
167
168 cx.spawn(async move |_cx| {
169 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
170 if let Err(err) = authenticate_task.await {
171 if matches!(err, AuthenticateError::CredentialsNotFound) {
172 // Since we're authenticating these providers in the
173 // background for the purposes of populating the
174 // language selector, we don't care about providers
175 // where the credentials are not found.
176 } else {
177 // Some providers have noisy failure states that we
178 // don't want to spam the logs with every time the
179 // language model selector is initialized.
180 //
181 // Ideally these should have more clear failure modes
182 // that we know are safe to ignore here, like what we do
183 // with `CredentialsNotFound` above.
184 match provider_id.0.as_ref() {
185 "lmstudio" | "ollama" => {
186 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
187 //
188 // These fail noisily, so we don't log them.
189 }
190 "copilot_chat" => {
191 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
192 }
193 _ => {
194 log::error!(
195 "Failed to authenticate provider: {}: {err:#}",
196 provider_name.0
197 );
198 }
199 }
200 }
201 }
202 }
203 })
204 }
205
206 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
207 (self.get_active_model)(cx)
208 }
209}
210
211struct GroupedModels {
212 recommended: Vec<ModelInfo>,
213 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
214}
215
216impl GroupedModels {
217 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
218 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
219 for model in all {
220 let provider = model.model.provider_id();
221 if let Some(models) = all_by_provider.get_mut(&provider) {
222 models.push(model);
223 } else {
224 all_by_provider.insert(provider, vec![model]);
225 }
226 }
227
228 Self {
229 recommended,
230 all: all_by_provider,
231 }
232 }
233
234 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
235 let mut entries = Vec::new();
236
237 if !self.recommended.is_empty() {
238 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
239 entries.extend(
240 self.recommended
241 .iter()
242 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
243 );
244 }
245
246 for models in self.all.values() {
247 if models.is_empty() {
248 continue;
249 }
250 entries.push(LanguageModelPickerEntry::Separator(
251 models[0].model.provider_name().0,
252 ));
253 entries.extend(
254 models
255 .iter()
256 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
257 );
258 }
259 entries
260 }
261}
262
263enum LanguageModelPickerEntry {
264 Model(ModelInfo),
265 Separator(SharedString),
266}
267
268struct ModelMatcher {
269 models: Vec<ModelInfo>,
270 bg_executor: BackgroundExecutor,
271 candidates: Vec<StringMatchCandidate>,
272}
273
274impl ModelMatcher {
275 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
276 let candidates = Self::make_match_candidates(&models);
277 Self {
278 models,
279 bg_executor,
280 candidates,
281 }
282 }
283
284 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
285 let mut matches = self.bg_executor.block(match_strings(
286 &self.candidates,
287 query,
288 false,
289 true,
290 100,
291 &Default::default(),
292 self.bg_executor.clone(),
293 ));
294
295 let sorting_key = |mat: &StringMatch| {
296 let candidate = &self.candidates[mat.candidate_id];
297 (Reverse(OrderedFloat(mat.score)), candidate.id)
298 };
299 matches.sort_unstable_by_key(sorting_key);
300
301 let matched_models: Vec<_> = matches
302 .into_iter()
303 .map(|mat| self.models[mat.candidate_id].clone())
304 .collect();
305
306 matched_models
307 }
308
309 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
310 self.models
311 .iter()
312 .filter(|m| {
313 m.model
314 .name()
315 .0
316 .to_lowercase()
317 .contains(&query.to_lowercase())
318 })
319 .cloned()
320 .collect::<Vec<_>>()
321 }
322
323 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
324 model_infos
325 .iter()
326 .enumerate()
327 .map(|(index, model)| {
328 StringMatchCandidate::new(
329 index,
330 &format!(
331 "{}/{}",
332 &model.model.provider_name().0,
333 &model.model.name().0
334 ),
335 )
336 })
337 .collect::<Vec<_>>()
338 }
339}
340
341impl PickerDelegate for LanguageModelPickerDelegate {
342 type ListItem = AnyElement;
343
344 fn match_count(&self) -> usize {
345 self.filtered_entries.len()
346 }
347
348 fn selected_index(&self) -> usize {
349 self.selected_index
350 }
351
352 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
353 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
354 cx.notify();
355 }
356
357 fn can_select(
358 &mut self,
359 ix: usize,
360 _window: &mut Window,
361 _cx: &mut Context<Picker<Self>>,
362 ) -> bool {
363 match self.filtered_entries.get(ix) {
364 Some(LanguageModelPickerEntry::Model(_)) => true,
365 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
366 }
367 }
368
369 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
370 "Select a model…".into()
371 }
372
373 fn update_matches(
374 &mut self,
375 query: String,
376 window: &mut Window,
377 cx: &mut Context<Picker<Self>>,
378 ) -> Task<()> {
379 let all_models = self.all_models.clone();
380 let active_model = (self.get_active_model)(cx);
381 let bg_executor = cx.background_executor();
382
383 let language_model_registry = LanguageModelRegistry::global(cx);
384
385 let configured_providers = language_model_registry
386 .read(cx)
387 .providers()
388 .into_iter()
389 .filter(|provider| provider.is_authenticated(cx))
390 .collect::<Vec<_>>();
391
392 let configured_provider_ids = configured_providers
393 .iter()
394 .map(|provider| provider.id())
395 .collect::<Vec<_>>();
396
397 let recommended_models = all_models
398 .recommended
399 .iter()
400 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
401 .cloned()
402 .collect::<Vec<_>>();
403
404 let available_models = all_models
405 .all
406 .values()
407 .flat_map(|models| models.iter())
408 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
409 .cloned()
410 .collect::<Vec<_>>();
411
412 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
413 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
414
415 let recommended = matcher_rec.exact_search(&query);
416 let all = matcher_all.fuzzy_search(&query);
417
418 let filtered_models = GroupedModels::new(all, recommended);
419
420 cx.spawn_in(window, async move |this, cx| {
421 this.update_in(cx, |this, window, cx| {
422 this.delegate.filtered_entries = filtered_models.entries();
423 // Finds the currently selected model in the list
424 let new_index =
425 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
426 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
427 cx.notify();
428 })
429 .ok();
430 })
431 }
432
433 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
434 if let Some(LanguageModelPickerEntry::Model(model_info)) =
435 self.filtered_entries.get(self.selected_index)
436 {
437 let model = model_info.model.clone();
438 (self.on_model_changed)(model.clone(), cx);
439
440 let current_index = self.selected_index;
441 self.set_selected_index(current_index, window, cx);
442
443 cx.emit(DismissEvent);
444 }
445 }
446
447 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
448 cx.emit(DismissEvent);
449 }
450
451 fn render_match(
452 &self,
453 ix: usize,
454 selected: bool,
455 _: &mut Window,
456 cx: &mut Context<Picker<Self>>,
457 ) -> Option<Self::ListItem> {
458 match self.filtered_entries.get(ix)? {
459 LanguageModelPickerEntry::Separator(title) => Some(
460 div()
461 .px_2()
462 .pb_1()
463 .when(ix > 1, |this| {
464 this.mt_1()
465 .pt_2()
466 .border_t_1()
467 .border_color(cx.theme().colors().border_variant)
468 })
469 .child(
470 Label::new(title)
471 .size(LabelSize::XSmall)
472 .color(Color::Muted),
473 )
474 .into_any_element(),
475 ),
476 LanguageModelPickerEntry::Model(model_info) => {
477 let active_model = (self.get_active_model)(cx);
478 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
479 let active_model_id = active_model.map(|m| m.model.id());
480
481 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
482 && Some(model_info.model.id()) == active_model_id;
483
484 let model_icon_color = if is_selected {
485 Color::Accent
486 } else {
487 Color::Muted
488 };
489
490 Some(
491 ListItem::new(ix)
492 .inset(true)
493 .spacing(ListItemSpacing::Sparse)
494 .toggle_state(selected)
495 .start_slot(
496 Icon::new(model_info.icon)
497 .color(model_icon_color)
498 .size(IconSize::Small),
499 )
500 .child(
501 h_flex()
502 .w_full()
503 .pl_0p5()
504 .gap_1p5()
505 .w(px(240.))
506 .child(Label::new(model_info.model.name().0).truncate()),
507 )
508 .end_slot(div().pr_3().when(is_selected, |this| {
509 this.child(
510 Icon::new(IconName::Check)
511 .color(Color::Accent)
512 .size(IconSize::Small),
513 )
514 }))
515 .into_any_element(),
516 )
517 }
518 }
519 }
520
521 fn render_footer(
522 &self,
523 _window: &mut Window,
524 cx: &mut Context<Picker<Self>>,
525 ) -> Option<gpui::AnyElement> {
526 if !self.popover_styles {
527 return None;
528 }
529
530 Some(
531 h_flex()
532 .w_full()
533 .border_t_1()
534 .border_color(cx.theme().colors().border_variant)
535 .p_1()
536 .gap_4()
537 .justify_between()
538 .child(
539 Button::new("configure", "Configure")
540 .icon(IconName::Settings)
541 .icon_size(IconSize::Small)
542 .icon_color(Color::Muted)
543 .icon_position(IconPosition::Start)
544 .on_click(|_, window, cx| {
545 window.dispatch_action(
546 zed_actions::agent::OpenSettings.boxed_clone(),
547 cx,
548 );
549 }),
550 )
551 .into_any(),
552 )
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use futures::{future::BoxFuture, stream::BoxStream};
560 use gpui::{AsyncApp, TestAppContext, http_client};
561 use language_model::{
562 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
563 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
564 LanguageModelRequest, LanguageModelToolChoice,
565 };
566 use ui::IconName;
567
568 #[derive(Clone)]
569 struct TestLanguageModel {
570 name: LanguageModelName,
571 id: LanguageModelId,
572 provider_id: LanguageModelProviderId,
573 provider_name: LanguageModelProviderName,
574 }
575
576 impl TestLanguageModel {
577 fn new(name: &str, provider: &str) -> Self {
578 Self {
579 name: LanguageModelName::from(name.to_string()),
580 id: LanguageModelId::from(name.to_string()),
581 provider_id: LanguageModelProviderId::from(provider.to_string()),
582 provider_name: LanguageModelProviderName::from(provider.to_string()),
583 }
584 }
585 }
586
587 impl LanguageModel for TestLanguageModel {
588 fn id(&self) -> LanguageModelId {
589 self.id.clone()
590 }
591
592 fn name(&self) -> LanguageModelName {
593 self.name.clone()
594 }
595
596 fn provider_id(&self) -> LanguageModelProviderId {
597 self.provider_id.clone()
598 }
599
600 fn provider_name(&self) -> LanguageModelProviderName {
601 self.provider_name.clone()
602 }
603
604 fn supports_tools(&self) -> bool {
605 false
606 }
607
608 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
609 false
610 }
611
612 fn supports_images(&self) -> bool {
613 false
614 }
615
616 fn telemetry_id(&self) -> String {
617 format!("{}/{}", self.provider_id.0, self.name.0)
618 }
619
620 fn max_token_count(&self) -> u64 {
621 1000
622 }
623
624 fn count_tokens(
625 &self,
626 _: LanguageModelRequest,
627 _: &App,
628 ) -> BoxFuture<'static, http_client::Result<u64>> {
629 unimplemented!()
630 }
631
632 fn stream_completion(
633 &self,
634 _: LanguageModelRequest,
635 _: &AsyncApp,
636 ) -> BoxFuture<
637 'static,
638 Result<
639 BoxStream<
640 'static,
641 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
642 >,
643 LanguageModelCompletionError,
644 >,
645 > {
646 unimplemented!()
647 }
648 }
649
650 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
651 model_specs
652 .into_iter()
653 .map(|(provider, name)| ModelInfo {
654 model: Arc::new(TestLanguageModel::new(name, provider)),
655 icon: IconName::Ai,
656 })
657 .collect()
658 }
659
660 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
661 assert_eq!(
662 result.len(),
663 expected.len(),
664 "Number of models doesn't match"
665 );
666
667 for (i, expected_name) in expected.iter().enumerate() {
668 assert_eq!(
669 result[i].model.telemetry_id(),
670 *expected_name,
671 "Model at position {} doesn't match expected model",
672 i
673 );
674 }
675 }
676
677 #[gpui::test]
678 fn test_exact_match(cx: &mut TestAppContext) {
679 let models = create_models(vec![
680 ("zed", "Claude 3.7 Sonnet"),
681 ("zed", "Claude 3.7 Sonnet Thinking"),
682 ("zed", "gpt-4.1"),
683 ("zed", "gpt-4.1-nano"),
684 ("openai", "gpt-3.5-turbo"),
685 ("openai", "gpt-4.1"),
686 ("openai", "gpt-4.1-nano"),
687 ("ollama", "mistral"),
688 ("ollama", "deepseek"),
689 ]);
690 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
691
692 // The order of models should be maintained, case doesn't matter
693 let results = matcher.exact_search("GPT-4.1");
694 assert_models_eq(
695 results,
696 vec![
697 "zed/gpt-4.1",
698 "zed/gpt-4.1-nano",
699 "openai/gpt-4.1",
700 "openai/gpt-4.1-nano",
701 ],
702 );
703 }
704
705 #[gpui::test]
706 fn test_fuzzy_match(cx: &mut TestAppContext) {
707 let models = create_models(vec![
708 ("zed", "Claude 3.7 Sonnet"),
709 ("zed", "Claude 3.7 Sonnet Thinking"),
710 ("zed", "gpt-4.1"),
711 ("zed", "gpt-4.1-nano"),
712 ("openai", "gpt-3.5-turbo"),
713 ("openai", "gpt-4.1"),
714 ("openai", "gpt-4.1-nano"),
715 ("ollama", "mistral"),
716 ("ollama", "deepseek"),
717 ]);
718 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
719
720 // Results should preserve models order whenever possible.
721 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
722 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
723 // so it should appear first in the results.
724 let results = matcher.fuzzy_search("41");
725 assert_models_eq(
726 results,
727 vec![
728 "zed/gpt-4.1",
729 "openai/gpt-4.1",
730 "zed/gpt-4.1-nano",
731 "openai/gpt-4.1-nano",
732 ],
733 );
734
735 // Model provider should be searchable as well
736 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
737 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
738
739 // Fuzzy search
740 let results = matcher.fuzzy_search("z4n");
741 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
742 }
743
744 #[gpui::test]
745 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
746 let recommended_models = create_models(vec![("zed", "claude")]);
747 let all_models = create_models(vec![
748 ("zed", "claude"), // Should also appear in "other"
749 ("zed", "gemini"),
750 ("copilot", "o3"),
751 ]);
752
753 let grouped_models = GroupedModels::new(all_models, recommended_models);
754
755 let actual_all_models = grouped_models
756 .all
757 .values()
758 .flatten()
759 .cloned()
760 .collect::<Vec<_>>();
761
762 // Recommended models should also appear in "all"
763 assert_models_eq(
764 actual_all_models,
765 vec!["zed/claude", "zed/gemini", "copilot/o3"],
766 );
767 }
768
769 #[gpui::test]
770 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
771 let recommended_models = create_models(vec![("zed", "claude")]);
772 let all_models = create_models(vec![
773 ("zed", "claude"), // Should also appear in "other"
774 ("zed", "gemini"),
775 ("copilot", "claude"), // Different provider, should appear in "other"
776 ]);
777
778 let grouped_models = GroupedModels::new(all_models, recommended_models);
779
780 let actual_all_models = grouped_models
781 .all
782 .values()
783 .flatten()
784 .cloned()
785 .collect::<Vec<_>>();
786
787 // All models should appear in "all" regardless of recommended status
788 assert_models_eq(
789 actual_all_models,
790 vec!["zed/claude", "zed/gemini", "copilot/claude"],
791 );
792 }
793}