1use std::{cmp::Reverse, sync::Arc};
2
3use collections::IndexMap;
4use futures::{StreamExt, channel::mpsc};
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
7use language_model::{
8 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
9 LanguageModelProviderId, LanguageModelRegistry,
10};
11use ordered_float::OrderedFloat;
12use picker::{Picker, PickerDelegate};
13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
14use zed_actions::agent::OpenSettings;
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
18
19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
20
21pub fn language_model_selector(
22 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
23 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
24 popover_styles: bool,
25 focus_handle: FocusHandle,
26 window: &mut Window,
27 cx: &mut Context<LanguageModelSelector>,
28) -> LanguageModelSelector {
29 let delegate = LanguageModelPickerDelegate::new(
30 get_active_model,
31 on_model_changed,
32 popover_styles,
33 focus_handle,
34 window,
35 cx,
36 );
37
38 if popover_styles {
39 Picker::list(delegate, window, cx)
40 .show_scrollbar(true)
41 .width(rems(20.))
42 .max_height(Some(rems(20.).into()))
43 } else {
44 Picker::list(delegate, window, cx).show_scrollbar(true)
45 }
46}
47
48fn all_models(cx: &App) -> GroupedModels {
49 let providers = LanguageModelRegistry::global(cx)
50 .read(cx)
51 .visible_providers();
52
53 let recommended = providers
54 .iter()
55 .flat_map(|provider| {
56 provider
57 .recommended_models(cx)
58 .into_iter()
59 .map(|model| ModelInfo {
60 model,
61 icon: ProviderIcon::from_provider(provider.as_ref()),
62 })
63 })
64 .collect();
65
66 let all: Vec<ModelInfo> = providers
67 .iter()
68 .flat_map(|provider| {
69 provider
70 .provided_models(cx)
71 .into_iter()
72 .map(|model| ModelInfo {
73 model,
74 icon: ProviderIcon::from_provider(provider.as_ref()),
75 })
76 })
77 .collect();
78
79 GroupedModels::new(all, recommended)
80}
81
82#[derive(Clone)]
83enum ProviderIcon {
84 Name(IconName),
85 Path(SharedString),
86}
87
88impl ProviderIcon {
89 fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
90 if let Some(path) = provider.icon_path() {
91 Self::Path(path)
92 } else {
93 Self::Name(provider.icon())
94 }
95 }
96}
97
98#[derive(Clone)]
99struct ModelInfo {
100 model: Arc<dyn LanguageModel>,
101 icon: ProviderIcon,
102}
103
104pub struct LanguageModelPickerDelegate {
105 on_model_changed: OnModelChanged,
106 get_active_model: GetActiveModel,
107 all_models: Arc<GroupedModels>,
108 filtered_entries: Vec<LanguageModelPickerEntry>,
109 selected_index: usize,
110 _authenticate_all_providers_task: Task<()>,
111 _refresh_models_task: Task<()>,
112 popover_styles: bool,
113 focus_handle: FocusHandle,
114}
115
116impl LanguageModelPickerDelegate {
117 fn new(
118 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
119 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
120 popover_styles: bool,
121 focus_handle: FocusHandle,
122 window: &mut Window,
123 cx: &mut Context<Picker<Self>>,
124 ) -> Self {
125 let on_model_changed = Arc::new(on_model_changed);
126 let models = all_models(cx);
127 let entries = models.entries();
128
129 Self {
130 on_model_changed,
131 all_models: Arc::new(models),
132 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
133 filtered_entries: entries,
134 get_active_model: Arc::new(get_active_model),
135 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
136 _refresh_models_task: {
137 // Create a channel to signal when models need refreshing
138 let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
139
140 // Subscribe to registry events and send refresh signals through the channel
141 let registry = LanguageModelRegistry::global(cx);
142 cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
143 language_model::Event::ProviderStateChanged(_)
144 | language_model::Event::AddedProvider(_)
145 | language_model::Event::RemovedProvider(_) => {
146 refresh_tx.unbounded_send(()).ok();
147 }
148 _ => {}
149 })
150 .detach();
151
152 // Spawn a task that listens for refresh signals and updates the picker
153 cx.spawn_in(window, async move |this, cx| {
154 while let Some(()) = refresh_rx.next().await {
155 let result = this.update_in(cx, |picker, window, cx| {
156 picker.delegate.all_models = Arc::new(all_models(cx));
157 picker.refresh(window, cx);
158 });
159 if result.is_err() {
160 // Picker was dropped, exit the loop
161 break;
162 }
163 }
164 })
165 },
166 popover_styles,
167 focus_handle,
168 }
169 }
170
171 fn get_active_model_index(
172 entries: &[LanguageModelPickerEntry],
173 active_model: Option<ConfiguredModel>,
174 ) -> usize {
175 entries
176 .iter()
177 .position(|entry| {
178 if let LanguageModelPickerEntry::Model(model) = entry {
179 active_model
180 .as_ref()
181 .map(|active_model| {
182 active_model.model.id() == model.model.id()
183 && active_model.provider.id() == model.model.provider_id()
184 })
185 .unwrap_or_default()
186 } else {
187 false
188 }
189 })
190 .unwrap_or(0)
191 }
192
193 /// Authenticates all providers in the [`LanguageModelRegistry`].
194 ///
195 /// We do this so that we can populate the language selector with all of the
196 /// models from the configured providers.
197 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
198 let authenticate_all_providers = LanguageModelRegistry::global(cx)
199 .read(cx)
200 .providers()
201 .iter()
202 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
203 .collect::<Vec<_>>();
204
205 cx.spawn(async move |_cx| {
206 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
207 if let Err(err) = authenticate_task.await {
208 if matches!(err, AuthenticateError::CredentialsNotFound) {
209 // Since we're authenticating these providers in the
210 // background for the purposes of populating the
211 // language selector, we don't care about providers
212 // where the credentials are not found.
213 } else {
214 // Some providers have noisy failure states that we
215 // don't want to spam the logs with every time the
216 // language model selector is initialized.
217 //
218 // Ideally these should have more clear failure modes
219 // that we know are safe to ignore here, like what we do
220 // with `CredentialsNotFound` above.
221 match provider_id.0.as_ref() {
222 "lmstudio" | "ollama" => {
223 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
224 //
225 // These fail noisily, so we don't log them.
226 }
227 "copilot_chat" => {
228 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
229 }
230 _ => {
231 log::error!(
232 "Failed to authenticate provider: {}: {err:#}",
233 provider_name.0
234 );
235 }
236 }
237 }
238 }
239 }
240 })
241 }
242
243 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
244 (self.get_active_model)(cx)
245 }
246}
247
248struct GroupedModels {
249 recommended: Vec<ModelInfo>,
250 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
251}
252
253impl GroupedModels {
254 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
255 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
256 for model in all {
257 let provider = model.model.provider_id();
258 if let Some(models) = all_by_provider.get_mut(&provider) {
259 models.push(model);
260 } else {
261 all_by_provider.insert(provider, vec![model]);
262 }
263 }
264
265 Self {
266 recommended,
267 all: all_by_provider,
268 }
269 }
270
271 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
272 let mut entries = Vec::new();
273
274 if !self.recommended.is_empty() {
275 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
276 entries.extend(
277 self.recommended
278 .iter()
279 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
280 );
281 }
282
283 for models in self.all.values() {
284 if models.is_empty() {
285 continue;
286 }
287 entries.push(LanguageModelPickerEntry::Separator(
288 models[0].model.provider_name().0,
289 ));
290 entries.extend(
291 models
292 .iter()
293 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
294 );
295 }
296 entries
297 }
298}
299
300enum LanguageModelPickerEntry {
301 Model(ModelInfo),
302 Separator(SharedString),
303}
304
305struct ModelMatcher {
306 models: Vec<ModelInfo>,
307 bg_executor: BackgroundExecutor,
308 candidates: Vec<StringMatchCandidate>,
309}
310
311impl ModelMatcher {
312 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
313 let candidates = Self::make_match_candidates(&models);
314 Self {
315 models,
316 bg_executor,
317 candidates,
318 }
319 }
320
321 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
322 let mut matches = self.bg_executor.block(match_strings(
323 &self.candidates,
324 query,
325 false,
326 true,
327 100,
328 &Default::default(),
329 self.bg_executor.clone(),
330 ));
331
332 let sorting_key = |mat: &StringMatch| {
333 let candidate = &self.candidates[mat.candidate_id];
334 (Reverse(OrderedFloat(mat.score)), candidate.id)
335 };
336 matches.sort_unstable_by_key(sorting_key);
337
338 let matched_models: Vec<_> = matches
339 .into_iter()
340 .map(|mat| self.models[mat.candidate_id].clone())
341 .collect();
342
343 matched_models
344 }
345
346 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
347 self.models
348 .iter()
349 .filter(|m| {
350 m.model
351 .name()
352 .0
353 .to_lowercase()
354 .contains(&query.to_lowercase())
355 })
356 .cloned()
357 .collect::<Vec<_>>()
358 }
359
360 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
361 model_infos
362 .iter()
363 .enumerate()
364 .map(|(index, model)| {
365 StringMatchCandidate::new(
366 index,
367 &format!(
368 "{}/{}",
369 &model.model.provider_name().0,
370 &model.model.name().0
371 ),
372 )
373 })
374 .collect::<Vec<_>>()
375 }
376}
377
378impl PickerDelegate for LanguageModelPickerDelegate {
379 type ListItem = AnyElement;
380
381 fn match_count(&self) -> usize {
382 self.filtered_entries.len()
383 }
384
385 fn selected_index(&self) -> usize {
386 self.selected_index
387 }
388
389 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
390 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
391 cx.notify();
392 }
393
394 fn can_select(
395 &mut self,
396 ix: usize,
397 _window: &mut Window,
398 _cx: &mut Context<Picker<Self>>,
399 ) -> bool {
400 match self.filtered_entries.get(ix) {
401 Some(LanguageModelPickerEntry::Model(_)) => true,
402 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
403 }
404 }
405
406 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
407 "Select a model…".into()
408 }
409
410 fn update_matches(
411 &mut self,
412 query: String,
413 window: &mut Window,
414 cx: &mut Context<Picker<Self>>,
415 ) -> Task<()> {
416 let all_models = self.all_models.clone();
417 let active_model = (self.get_active_model)(cx);
418 let bg_executor = cx.background_executor();
419
420 let language_model_registry = LanguageModelRegistry::global(cx);
421
422 let configured_providers = language_model_registry
423 .read(cx)
424 .visible_providers()
425 .into_iter()
426 .filter(|provider| provider.is_authenticated(cx))
427 .collect::<Vec<_>>();
428
429 let configured_provider_ids = configured_providers
430 .iter()
431 .map(|provider| provider.id())
432 .collect::<Vec<_>>();
433
434 let recommended_models = all_models
435 .recommended
436 .iter()
437 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
438 .cloned()
439 .collect::<Vec<_>>();
440
441 let available_models = all_models
442 .all
443 .values()
444 .flat_map(|models| models.iter())
445 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
446 .cloned()
447 .collect::<Vec<_>>();
448
449 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
450 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
451
452 let recommended = matcher_rec.exact_search(&query);
453 let all = matcher_all.fuzzy_search(&query);
454
455 let filtered_models = GroupedModels::new(all, recommended);
456
457 cx.spawn_in(window, async move |this, cx| {
458 this.update_in(cx, |this, window, cx| {
459 this.delegate.filtered_entries = filtered_models.entries();
460 // Finds the currently selected model in the list
461 let new_index =
462 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
463 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
464 cx.notify();
465 })
466 .ok();
467 })
468 }
469
470 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
471 if let Some(LanguageModelPickerEntry::Model(model_info)) =
472 self.filtered_entries.get(self.selected_index)
473 {
474 let model = model_info.model.clone();
475 (self.on_model_changed)(model.clone(), cx);
476
477 let current_index = self.selected_index;
478 self.set_selected_index(current_index, window, cx);
479
480 cx.emit(DismissEvent);
481 }
482 }
483
484 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
485 cx.emit(DismissEvent);
486 }
487
488 fn render_match(
489 &self,
490 ix: usize,
491 selected: bool,
492 _: &mut Window,
493 cx: &mut Context<Picker<Self>>,
494 ) -> Option<Self::ListItem> {
495 match self.filtered_entries.get(ix)? {
496 LanguageModelPickerEntry::Separator(title) => Some(
497 div()
498 .px_2()
499 .pb_1()
500 .when(ix > 1, |this| {
501 this.mt_1()
502 .pt_2()
503 .border_t_1()
504 .border_color(cx.theme().colors().border_variant)
505 })
506 .child(
507 Label::new(title)
508 .size(LabelSize::XSmall)
509 .color(Color::Muted),
510 )
511 .into_any_element(),
512 ),
513 LanguageModelPickerEntry::Model(model_info) => {
514 let active_model = (self.get_active_model)(cx);
515 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
516 let active_model_id = active_model.map(|m| m.model.id());
517
518 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
519 && Some(model_info.model.id()) == active_model_id;
520
521 let model_icon_color = if is_selected {
522 Color::Accent
523 } else {
524 Color::Muted
525 };
526
527 Some(
528 ListItem::new(ix)
529 .inset(true)
530 .spacing(ListItemSpacing::Sparse)
531 .toggle_state(selected)
532 .child(
533 h_flex()
534 .w_full()
535 .gap_1p5()
536 .child(match &model_info.icon {
537 ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
538 .color(model_icon_color)
539 .size(IconSize::Small),
540 ProviderIcon::Path(icon_path) => {
541 Icon::from_external_svg(icon_path.clone())
542 .color(model_icon_color)
543 .size(IconSize::Small)
544 }
545 })
546 .child(Label::new(model_info.model.name().0).truncate()),
547 )
548 .end_slot(div().pr_3().when(is_selected, |this| {
549 this.child(
550 Icon::new(IconName::Check)
551 .color(Color::Accent)
552 .size(IconSize::Small),
553 )
554 }))
555 .into_any_element(),
556 )
557 }
558 }
559 }
560
561 fn render_footer(
562 &self,
563 _window: &mut Window,
564 cx: &mut Context<Picker<Self>>,
565 ) -> Option<gpui::AnyElement> {
566 let focus_handle = self.focus_handle.clone();
567
568 if !self.popover_styles {
569 return None;
570 }
571
572 Some(
573 h_flex()
574 .w_full()
575 .p_1p5()
576 .border_t_1()
577 .border_color(cx.theme().colors().border_variant)
578 .child(
579 Button::new("configure", "Configure")
580 .full_width()
581 .style(ButtonStyle::Outlined)
582 .key_binding(
583 KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
584 .map(|kb| kb.size(rems_from_px(12.))),
585 )
586 .on_click(|_, window, cx| {
587 window.dispatch_action(OpenSettings.boxed_clone(), cx);
588 }),
589 )
590 .into_any(),
591 )
592 }
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598 use futures::{future::BoxFuture, stream::BoxStream};
599 use gpui::{AsyncApp, TestAppContext, http_client};
600 use language_model::{
601 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
602 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
603 LanguageModelRequest, LanguageModelToolChoice,
604 };
605 use ui::IconName;
606
607 #[derive(Clone)]
608 struct TestLanguageModel {
609 name: LanguageModelName,
610 id: LanguageModelId,
611 provider_id: LanguageModelProviderId,
612 provider_name: LanguageModelProviderName,
613 }
614
615 impl TestLanguageModel {
616 fn new(name: &str, provider: &str) -> Self {
617 Self {
618 name: LanguageModelName::from(name.to_string()),
619 id: LanguageModelId::from(name.to_string()),
620 provider_id: LanguageModelProviderId::from(provider.to_string()),
621 provider_name: LanguageModelProviderName::from(provider.to_string()),
622 }
623 }
624 }
625
626 impl LanguageModel for TestLanguageModel {
627 fn id(&self) -> LanguageModelId {
628 self.id.clone()
629 }
630
631 fn name(&self) -> LanguageModelName {
632 self.name.clone()
633 }
634
635 fn provider_id(&self) -> LanguageModelProviderId {
636 self.provider_id.clone()
637 }
638
639 fn provider_name(&self) -> LanguageModelProviderName {
640 self.provider_name.clone()
641 }
642
643 fn supports_tools(&self) -> bool {
644 false
645 }
646
647 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
648 false
649 }
650
651 fn supports_images(&self) -> bool {
652 false
653 }
654
655 fn telemetry_id(&self) -> String {
656 format!("{}/{}", self.provider_id.0, self.name.0)
657 }
658
659 fn max_token_count(&self) -> u64 {
660 1000
661 }
662
663 fn count_tokens(
664 &self,
665 _: LanguageModelRequest,
666 _: &App,
667 ) -> BoxFuture<'static, http_client::Result<u64>> {
668 unimplemented!()
669 }
670
671 fn stream_completion(
672 &self,
673 _: LanguageModelRequest,
674 _: &AsyncApp,
675 ) -> BoxFuture<
676 'static,
677 Result<
678 BoxStream<
679 'static,
680 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
681 >,
682 LanguageModelCompletionError,
683 >,
684 > {
685 unimplemented!()
686 }
687 }
688
689 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
690 model_specs
691 .into_iter()
692 .map(|(provider, name)| ModelInfo {
693 model: Arc::new(TestLanguageModel::new(name, provider)),
694 icon: ProviderIcon::Name(IconName::Ai),
695 })
696 .collect()
697 }
698
699 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
700 assert_eq!(
701 result.len(),
702 expected.len(),
703 "Number of models doesn't match"
704 );
705
706 for (i, expected_name) in expected.iter().enumerate() {
707 assert_eq!(
708 result[i].model.telemetry_id(),
709 *expected_name,
710 "Model at position {} doesn't match expected model",
711 i
712 );
713 }
714 }
715
716 #[gpui::test]
717 fn test_exact_match(cx: &mut TestAppContext) {
718 let models = create_models(vec![
719 ("zed", "Claude 3.7 Sonnet"),
720 ("zed", "Claude 3.7 Sonnet Thinking"),
721 ("zed", "gpt-4.1"),
722 ("zed", "gpt-4.1-nano"),
723 ("openai", "gpt-3.5-turbo"),
724 ("openai", "gpt-4.1"),
725 ("openai", "gpt-4.1-nano"),
726 ("ollama", "mistral"),
727 ("ollama", "deepseek"),
728 ]);
729 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
730
731 // The order of models should be maintained, case doesn't matter
732 let results = matcher.exact_search("GPT-4.1");
733 assert_models_eq(
734 results,
735 vec![
736 "zed/gpt-4.1",
737 "zed/gpt-4.1-nano",
738 "openai/gpt-4.1",
739 "openai/gpt-4.1-nano",
740 ],
741 );
742 }
743
744 #[gpui::test]
745 fn test_fuzzy_match(cx: &mut TestAppContext) {
746 let models = create_models(vec![
747 ("zed", "Claude 3.7 Sonnet"),
748 ("zed", "Claude 3.7 Sonnet Thinking"),
749 ("zed", "gpt-4.1"),
750 ("zed", "gpt-4.1-nano"),
751 ("openai", "gpt-3.5-turbo"),
752 ("openai", "gpt-4.1"),
753 ("openai", "gpt-4.1-nano"),
754 ("ollama", "mistral"),
755 ("ollama", "deepseek"),
756 ]);
757 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
758
759 // Results should preserve models order whenever possible.
760 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
761 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
762 // so it should appear first in the results.
763 let results = matcher.fuzzy_search("41");
764 assert_models_eq(
765 results,
766 vec![
767 "zed/gpt-4.1",
768 "openai/gpt-4.1",
769 "zed/gpt-4.1-nano",
770 "openai/gpt-4.1-nano",
771 ],
772 );
773
774 // Model provider should be searchable as well
775 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
776 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
777
778 // Fuzzy search
779 let results = matcher.fuzzy_search("z4n");
780 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
781 }
782
783 #[gpui::test]
784 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
785 let recommended_models = create_models(vec![("zed", "claude")]);
786 let all_models = create_models(vec![
787 ("zed", "claude"), // Should also appear in "other"
788 ("zed", "gemini"),
789 ("copilot", "o3"),
790 ]);
791
792 let grouped_models = GroupedModels::new(all_models, recommended_models);
793
794 let actual_all_models = grouped_models
795 .all
796 .values()
797 .flatten()
798 .cloned()
799 .collect::<Vec<_>>();
800
801 // Recommended models should also appear in "all"
802 assert_models_eq(
803 actual_all_models,
804 vec!["zed/claude", "zed/gemini", "copilot/o3"],
805 );
806 }
807
808 #[gpui::test]
809 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
810 let recommended_models = create_models(vec![("zed", "claude")]);
811 let all_models = create_models(vec![
812 ("zed", "claude"), // Should also appear in "other"
813 ("zed", "gemini"),
814 ("copilot", "claude"), // Different provider, should appear in "other"
815 ]);
816
817 let grouped_models = GroupedModels::new(all_models, recommended_models);
818
819 let actual_all_models = grouped_models
820 .all
821 .values()
822 .flatten()
823 .cloned()
824 .collect::<Vec<_>>();
825
826 // All models should appear in "all" regardless of recommended status
827 assert_models_eq(
828 actual_all_models,
829 vec!["zed/claude", "zed/gemini", "copilot/claude"],
830 );
831 }
832}