1use std::{cmp::Reverse, sync::Arc};
2
3use collections::IndexMap;
4use futures::{StreamExt, channel::mpsc};
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
7use language_model::{
8 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
9 LanguageModelProviderId, LanguageModelRegistry,
10};
11use ordered_float::OrderedFloat;
12use picker::{Picker, PickerDelegate};
13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
14use zed_actions::agent::OpenSettings;
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
18
19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
20
21pub fn language_model_selector(
22 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
23 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
24 popover_styles: bool,
25 focus_handle: FocusHandle,
26 window: &mut Window,
27 cx: &mut Context<LanguageModelSelector>,
28) -> LanguageModelSelector {
29 let delegate = LanguageModelPickerDelegate::new(
30 get_active_model,
31 on_model_changed,
32 popover_styles,
33 focus_handle,
34 window,
35 cx,
36 );
37
38 if popover_styles {
39 Picker::list(delegate, window, cx)
40 .show_scrollbar(true)
41 .width(rems(20.))
42 .max_height(Some(rems(20.).into()))
43 } else {
44 Picker::list(delegate, window, cx).show_scrollbar(true)
45 }
46}
47
48fn all_models(cx: &App) -> GroupedModels {
49 let providers = LanguageModelRegistry::global(cx)
50 .read(cx)
51 .visible_providers();
52
53 let recommended = providers
54 .iter()
55 .flat_map(|provider| {
56 provider
57 .recommended_models(cx)
58 .into_iter()
59 .map(|model| ModelInfo {
60 model,
61 icon: ProviderIcon::from_provider(provider.as_ref()),
62 })
63 })
64 .collect();
65
66 let all: Vec<ModelInfo> = providers
67 .iter()
68 .flat_map(|provider| {
69 provider
70 .provided_models(cx)
71 .into_iter()
72 .map(|model| ModelInfo {
73 model,
74 icon: ProviderIcon::from_provider(provider.as_ref()),
75 })
76 })
77 .collect();
78
79 GroupedModels::new(all, recommended)
80}
81
82#[derive(Clone)]
83enum ProviderIcon {
84 Name(IconName),
85 Path(SharedString),
86}
87
88impl ProviderIcon {
89 fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
90 if let Some(path) = provider.icon_path() {
91 Self::Path(path)
92 } else {
93 Self::Name(provider.icon())
94 }
95 }
96}
97
98#[derive(Clone)]
99struct ModelInfo {
100 model: Arc<dyn LanguageModel>,
101 icon: ProviderIcon,
102}
103
104pub struct LanguageModelPickerDelegate {
105 on_model_changed: OnModelChanged,
106 get_active_model: GetActiveModel,
107 all_models: Arc<GroupedModels>,
108 filtered_entries: Vec<LanguageModelPickerEntry>,
109 selected_index: usize,
110 _authenticate_all_providers_task: Task<()>,
111 _refresh_models_task: Task<()>,
112 popover_styles: bool,
113 focus_handle: FocusHandle,
114}
115
116impl LanguageModelPickerDelegate {
117 fn new(
118 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
119 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
120 popover_styles: bool,
121 focus_handle: FocusHandle,
122 window: &mut Window,
123 cx: &mut Context<Picker<Self>>,
124 ) -> Self {
125 let on_model_changed = Arc::new(on_model_changed);
126 let models = all_models(cx);
127 let entries = models.entries();
128
129 Self {
130 on_model_changed,
131 all_models: Arc::new(models),
132 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
133 filtered_entries: entries,
134 get_active_model: Arc::new(get_active_model),
135 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
136 _refresh_models_task: {
137 // Create a channel to signal when models need refreshing
138 let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
139
140 // Subscribe to registry events and send refresh signals through the channel
141 let registry = LanguageModelRegistry::global(cx);
142 cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
143 language_model::Event::ProviderStateChanged(_) => {
144 refresh_tx.unbounded_send(()).ok();
145 }
146 language_model::Event::AddedProvider(_) => {
147 refresh_tx.unbounded_send(()).ok();
148 }
149 language_model::Event::RemovedProvider(_) => {
150 refresh_tx.unbounded_send(()).ok();
151 }
152 _ => {}
153 })
154 .detach();
155
156 // Spawn a task that listens for refresh signals and updates the picker
157 cx.spawn_in(window, async move |this, cx| {
158 while let Some(()) = refresh_rx.next().await {
159 let result = this.update_in(cx, |picker, window, cx| {
160 picker.delegate.all_models = Arc::new(all_models(cx));
161 picker.refresh(window, cx);
162 });
163 if result.is_err() {
164 // Picker was dropped, exit the loop
165 break;
166 }
167 }
168 })
169 },
170 popover_styles,
171 focus_handle,
172 }
173 }
174
175 fn get_active_model_index(
176 entries: &[LanguageModelPickerEntry],
177 active_model: Option<ConfiguredModel>,
178 ) -> usize {
179 entries
180 .iter()
181 .position(|entry| {
182 if let LanguageModelPickerEntry::Model(model) = entry {
183 active_model
184 .as_ref()
185 .map(|active_model| {
186 active_model.model.id() == model.model.id()
187 && active_model.provider.id() == model.model.provider_id()
188 })
189 .unwrap_or_default()
190 } else {
191 false
192 }
193 })
194 .unwrap_or(0)
195 }
196
197 /// Authenticates all providers in the [`LanguageModelRegistry`].
198 ///
199 /// We do this so that we can populate the language selector with all of the
200 /// models from the configured providers.
201 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
202 let authenticate_all_providers = LanguageModelRegistry::global(cx)
203 .read(cx)
204 .providers()
205 .iter()
206 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
207 .collect::<Vec<_>>();
208
209 cx.spawn(async move |_cx| {
210 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
211 if let Err(err) = authenticate_task.await {
212 if matches!(err, AuthenticateError::CredentialsNotFound) {
213 // Since we're authenticating these providers in the
214 // background for the purposes of populating the
215 // language selector, we don't care about providers
216 // where the credentials are not found.
217 } else {
218 // Some providers have noisy failure states that we
219 // don't want to spam the logs with every time the
220 // language model selector is initialized.
221 //
222 // Ideally these should have more clear failure modes
223 // that we know are safe to ignore here, like what we do
224 // with `CredentialsNotFound` above.
225 match provider_id.0.as_ref() {
226 "lmstudio" | "ollama" => {
227 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
228 //
229 // These fail noisily, so we don't log them.
230 }
231 "copilot_chat" => {
232 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
233 }
234 _ => {
235 log::error!(
236 "Failed to authenticate provider: {}: {err:#}",
237 provider_name.0
238 );
239 }
240 }
241 }
242 }
243 }
244 })
245 }
246
247 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
248 (self.get_active_model)(cx)
249 }
250}
251
252struct GroupedModels {
253 recommended: Vec<ModelInfo>,
254 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
255}
256
257impl GroupedModels {
258 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
259 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
260 for model in all {
261 let provider = model.model.provider_id();
262 if let Some(models) = all_by_provider.get_mut(&provider) {
263 models.push(model);
264 } else {
265 all_by_provider.insert(provider, vec![model]);
266 }
267 }
268
269 Self {
270 recommended,
271 all: all_by_provider,
272 }
273 }
274
275 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
276 let mut entries = Vec::new();
277
278 if !self.recommended.is_empty() {
279 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
280 entries.extend(
281 self.recommended
282 .iter()
283 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
284 );
285 }
286
287 for models in self.all.values() {
288 if models.is_empty() {
289 continue;
290 }
291 entries.push(LanguageModelPickerEntry::Separator(
292 models[0].model.provider_name().0,
293 ));
294 entries.extend(
295 models
296 .iter()
297 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
298 );
299 }
300 entries
301 }
302}
303
304enum LanguageModelPickerEntry {
305 Model(ModelInfo),
306 Separator(SharedString),
307}
308
309struct ModelMatcher {
310 models: Vec<ModelInfo>,
311 bg_executor: BackgroundExecutor,
312 candidates: Vec<StringMatchCandidate>,
313}
314
315impl ModelMatcher {
316 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
317 let candidates = Self::make_match_candidates(&models);
318 Self {
319 models,
320 bg_executor,
321 candidates,
322 }
323 }
324
325 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
326 let mut matches = self.bg_executor.block(match_strings(
327 &self.candidates,
328 query,
329 false,
330 true,
331 100,
332 &Default::default(),
333 self.bg_executor.clone(),
334 ));
335
336 let sorting_key = |mat: &StringMatch| {
337 let candidate = &self.candidates[mat.candidate_id];
338 (Reverse(OrderedFloat(mat.score)), candidate.id)
339 };
340 matches.sort_unstable_by_key(sorting_key);
341
342 let matched_models: Vec<_> = matches
343 .into_iter()
344 .map(|mat| self.models[mat.candidate_id].clone())
345 .collect();
346
347 matched_models
348 }
349
350 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
351 self.models
352 .iter()
353 .filter(|m| {
354 m.model
355 .name()
356 .0
357 .to_lowercase()
358 .contains(&query.to_lowercase())
359 })
360 .cloned()
361 .collect::<Vec<_>>()
362 }
363
364 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
365 model_infos
366 .iter()
367 .enumerate()
368 .map(|(index, model)| {
369 StringMatchCandidate::new(
370 index,
371 &format!(
372 "{}/{}",
373 &model.model.provider_name().0,
374 &model.model.name().0
375 ),
376 )
377 })
378 .collect::<Vec<_>>()
379 }
380}
381
382impl PickerDelegate for LanguageModelPickerDelegate {
383 type ListItem = AnyElement;
384
385 fn match_count(&self) -> usize {
386 self.filtered_entries.len()
387 }
388
389 fn selected_index(&self) -> usize {
390 self.selected_index
391 }
392
393 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
394 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
395 cx.notify();
396 }
397
398 fn can_select(
399 &mut self,
400 ix: usize,
401 _window: &mut Window,
402 _cx: &mut Context<Picker<Self>>,
403 ) -> bool {
404 match self.filtered_entries.get(ix) {
405 Some(LanguageModelPickerEntry::Model(_)) => true,
406 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
407 }
408 }
409
410 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
411 "Select a model…".into()
412 }
413
414 fn update_matches(
415 &mut self,
416 query: String,
417 window: &mut Window,
418 cx: &mut Context<Picker<Self>>,
419 ) -> Task<()> {
420 let all_models = self.all_models.clone();
421 let active_model = (self.get_active_model)(cx);
422 let bg_executor = cx.background_executor();
423
424 let language_model_registry = LanguageModelRegistry::global(cx);
425
426 let configured_providers = language_model_registry
427 .read(cx)
428 .visible_providers()
429 .into_iter()
430 .filter(|provider| provider.is_authenticated(cx))
431 .collect::<Vec<_>>();
432
433 let configured_provider_ids = configured_providers
434 .iter()
435 .map(|provider| provider.id())
436 .collect::<Vec<_>>();
437
438 let recommended_models = all_models
439 .recommended
440 .iter()
441 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
442 .cloned()
443 .collect::<Vec<_>>();
444
445 let available_models = all_models
446 .all
447 .values()
448 .flat_map(|models| models.iter())
449 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
450 .cloned()
451 .collect::<Vec<_>>();
452
453 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
454 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
455
456 let recommended = matcher_rec.exact_search(&query);
457 let all = matcher_all.fuzzy_search(&query);
458
459 let filtered_models = GroupedModels::new(all, recommended);
460
461 cx.spawn_in(window, async move |this, cx| {
462 this.update_in(cx, |this, window, cx| {
463 this.delegate.filtered_entries = filtered_models.entries();
464 // Finds the currently selected model in the list
465 let new_index =
466 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
467 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
468 cx.notify();
469 })
470 .ok();
471 })
472 }
473
474 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
475 if let Some(LanguageModelPickerEntry::Model(model_info)) =
476 self.filtered_entries.get(self.selected_index)
477 {
478 let model = model_info.model.clone();
479 (self.on_model_changed)(model.clone(), cx);
480
481 let current_index = self.selected_index;
482 self.set_selected_index(current_index, window, cx);
483
484 cx.emit(DismissEvent);
485 }
486 }
487
488 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
489 cx.emit(DismissEvent);
490 }
491
492 fn render_match(
493 &self,
494 ix: usize,
495 selected: bool,
496 _: &mut Window,
497 cx: &mut Context<Picker<Self>>,
498 ) -> Option<Self::ListItem> {
499 match self.filtered_entries.get(ix)? {
500 LanguageModelPickerEntry::Separator(title) => Some(
501 div()
502 .px_2()
503 .pb_1()
504 .when(ix > 1, |this| {
505 this.mt_1()
506 .pt_2()
507 .border_t_1()
508 .border_color(cx.theme().colors().border_variant)
509 })
510 .child(
511 Label::new(title)
512 .size(LabelSize::XSmall)
513 .color(Color::Muted),
514 )
515 .into_any_element(),
516 ),
517 LanguageModelPickerEntry::Model(model_info) => {
518 let active_model = (self.get_active_model)(cx);
519 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
520 let active_model_id = active_model.map(|m| m.model.id());
521
522 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
523 && Some(model_info.model.id()) == active_model_id;
524
525 let model_icon_color = if is_selected {
526 Color::Accent
527 } else {
528 Color::Muted
529 };
530
531 Some(
532 ListItem::new(ix)
533 .inset(true)
534 .spacing(ListItemSpacing::Sparse)
535 .toggle_state(selected)
536 .child(
537 h_flex()
538 .w_full()
539 .gap_1p5()
540 .child(match &model_info.icon {
541 ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
542 .color(model_icon_color)
543 .size(IconSize::Small),
544 ProviderIcon::Path(icon_path) => {
545 Icon::from_external_svg(icon_path.clone())
546 .color(model_icon_color)
547 .size(IconSize::Small)
548 }
549 })
550 .child(Label::new(model_info.model.name().0).truncate()),
551 )
552 .end_slot(div().pr_3().when(is_selected, |this| {
553 this.child(
554 Icon::new(IconName::Check)
555 .color(Color::Accent)
556 .size(IconSize::Small),
557 )
558 }))
559 .into_any_element(),
560 )
561 }
562 }
563 }
564
565 fn render_footer(
566 &self,
567 _window: &mut Window,
568 cx: &mut Context<Picker<Self>>,
569 ) -> Option<gpui::AnyElement> {
570 let focus_handle = self.focus_handle.clone();
571
572 if !self.popover_styles {
573 return None;
574 }
575
576 Some(
577 h_flex()
578 .w_full()
579 .p_1p5()
580 .border_t_1()
581 .border_color(cx.theme().colors().border_variant)
582 .child(
583 Button::new("configure", "Configure")
584 .full_width()
585 .style(ButtonStyle::Outlined)
586 .key_binding(
587 KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
588 .map(|kb| kb.size(rems_from_px(12.))),
589 )
590 .on_click(|_, window, cx| {
591 window.dispatch_action(OpenSettings.boxed_clone(), cx);
592 }),
593 )
594 .into_any(),
595 )
596 }
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602 use futures::{future::BoxFuture, stream::BoxStream};
603 use gpui::{AsyncApp, TestAppContext, http_client};
604 use language_model::{
605 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
606 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
607 LanguageModelRequest, LanguageModelToolChoice,
608 };
609 use ui::IconName;
610
611 #[derive(Clone)]
612 struct TestLanguageModel {
613 name: LanguageModelName,
614 id: LanguageModelId,
615 provider_id: LanguageModelProviderId,
616 provider_name: LanguageModelProviderName,
617 }
618
619 impl TestLanguageModel {
620 fn new(name: &str, provider: &str) -> Self {
621 Self {
622 name: LanguageModelName::from(name.to_string()),
623 id: LanguageModelId::from(name.to_string()),
624 provider_id: LanguageModelProviderId::from(provider.to_string()),
625 provider_name: LanguageModelProviderName::from(provider.to_string()),
626 }
627 }
628 }
629
630 impl LanguageModel for TestLanguageModel {
631 fn id(&self) -> LanguageModelId {
632 self.id.clone()
633 }
634
635 fn name(&self) -> LanguageModelName {
636 self.name.clone()
637 }
638
639 fn provider_id(&self) -> LanguageModelProviderId {
640 self.provider_id.clone()
641 }
642
643 fn provider_name(&self) -> LanguageModelProviderName {
644 self.provider_name.clone()
645 }
646
647 fn supports_tools(&self) -> bool {
648 false
649 }
650
651 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
652 false
653 }
654
655 fn supports_images(&self) -> bool {
656 false
657 }
658
659 fn telemetry_id(&self) -> String {
660 format!("{}/{}", self.provider_id.0, self.name.0)
661 }
662
663 fn max_token_count(&self) -> u64 {
664 1000
665 }
666
667 fn count_tokens(
668 &self,
669 _: LanguageModelRequest,
670 _: &App,
671 ) -> BoxFuture<'static, http_client::Result<u64>> {
672 unimplemented!()
673 }
674
675 fn stream_completion(
676 &self,
677 _: LanguageModelRequest,
678 _: &AsyncApp,
679 ) -> BoxFuture<
680 'static,
681 Result<
682 BoxStream<
683 'static,
684 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
685 >,
686 LanguageModelCompletionError,
687 >,
688 > {
689 unimplemented!()
690 }
691 }
692
693 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
694 model_specs
695 .into_iter()
696 .map(|(provider, name)| ModelInfo {
697 model: Arc::new(TestLanguageModel::new(name, provider)),
698 icon: ProviderIcon::Name(IconName::Ai),
699 })
700 .collect()
701 }
702
703 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
704 assert_eq!(
705 result.len(),
706 expected.len(),
707 "Number of models doesn't match"
708 );
709
710 for (i, expected_name) in expected.iter().enumerate() {
711 assert_eq!(
712 result[i].model.telemetry_id(),
713 *expected_name,
714 "Model at position {} doesn't match expected model",
715 i
716 );
717 }
718 }
719
720 #[gpui::test]
721 fn test_exact_match(cx: &mut TestAppContext) {
722 let models = create_models(vec![
723 ("zed", "Claude 3.7 Sonnet"),
724 ("zed", "Claude 3.7 Sonnet Thinking"),
725 ("zed", "gpt-4.1"),
726 ("zed", "gpt-4.1-nano"),
727 ("openai", "gpt-3.5-turbo"),
728 ("openai", "gpt-4.1"),
729 ("openai", "gpt-4.1-nano"),
730 ("ollama", "mistral"),
731 ("ollama", "deepseek"),
732 ]);
733 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
734
735 // The order of models should be maintained, case doesn't matter
736 let results = matcher.exact_search("GPT-4.1");
737 assert_models_eq(
738 results,
739 vec![
740 "zed/gpt-4.1",
741 "zed/gpt-4.1-nano",
742 "openai/gpt-4.1",
743 "openai/gpt-4.1-nano",
744 ],
745 );
746 }
747
748 #[gpui::test]
749 fn test_fuzzy_match(cx: &mut TestAppContext) {
750 let models = create_models(vec![
751 ("zed", "Claude 3.7 Sonnet"),
752 ("zed", "Claude 3.7 Sonnet Thinking"),
753 ("zed", "gpt-4.1"),
754 ("zed", "gpt-4.1-nano"),
755 ("openai", "gpt-3.5-turbo"),
756 ("openai", "gpt-4.1"),
757 ("openai", "gpt-4.1-nano"),
758 ("ollama", "mistral"),
759 ("ollama", "deepseek"),
760 ]);
761 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
762
763 // Results should preserve models order whenever possible.
764 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
765 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
766 // so it should appear first in the results.
767 let results = matcher.fuzzy_search("41");
768 assert_models_eq(
769 results,
770 vec![
771 "zed/gpt-4.1",
772 "openai/gpt-4.1",
773 "zed/gpt-4.1-nano",
774 "openai/gpt-4.1-nano",
775 ],
776 );
777
778 // Model provider should be searchable as well
779 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
780 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
781
782 // Fuzzy search
783 let results = matcher.fuzzy_search("z4n");
784 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
785 }
786
787 #[gpui::test]
788 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
789 let recommended_models = create_models(vec![("zed", "claude")]);
790 let all_models = create_models(vec![
791 ("zed", "claude"), // Should also appear in "other"
792 ("zed", "gemini"),
793 ("copilot", "o3"),
794 ]);
795
796 let grouped_models = GroupedModels::new(all_models, recommended_models);
797
798 let actual_all_models = grouped_models
799 .all
800 .values()
801 .flatten()
802 .cloned()
803 .collect::<Vec<_>>();
804
805 // Recommended models should also appear in "all"
806 assert_models_eq(
807 actual_all_models,
808 vec!["zed/claude", "zed/gemini", "copilot/o3"],
809 );
810 }
811
812 #[gpui::test]
813 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
814 let recommended_models = create_models(vec![("zed", "claude")]);
815 let all_models = create_models(vec![
816 ("zed", "claude"), // Should also appear in "other"
817 ("zed", "gemini"),
818 ("copilot", "claude"), // Different provider, should appear in "other"
819 ]);
820
821 let grouped_models = GroupedModels::new(all_models, recommended_models);
822
823 let actual_all_models = grouped_models
824 .all
825 .values()
826 .flatten()
827 .cloned()
828 .collect::<Vec<_>>();
829
830 // All models should appear in "all" regardless of recommended status
831 assert_models_eq(
832 actual_all_models,
833 vec!["zed/claude", "zed/gemini", "copilot/claude"],
834 );
835 }
836}