language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::IndexMap;
  4use futures::{StreamExt, channel::mpsc};
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
  7use language_model::{
  8    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
  9    LanguageModelRegistry,
 10};
 11use ordered_float::OrderedFloat;
 12use picker::{Picker, PickerDelegate};
 13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
 14use zed_actions::agent::OpenSettings;
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 18
 19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 20
 21pub fn language_model_selector(
 22    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 23    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 24    popover_styles: bool,
 25    focus_handle: FocusHandle,
 26    window: &mut Window,
 27    cx: &mut Context<LanguageModelSelector>,
 28) -> LanguageModelSelector {
 29    let delegate = LanguageModelPickerDelegate::new(
 30        get_active_model,
 31        on_model_changed,
 32        popover_styles,
 33        focus_handle,
 34        window,
 35        cx,
 36    );
 37
 38    if popover_styles {
 39        Picker::list(delegate, window, cx)
 40            .show_scrollbar(true)
 41            .width(rems(20.))
 42            .max_height(Some(rems(20.).into()))
 43    } else {
 44        Picker::list(delegate, window, cx).show_scrollbar(true)
 45    }
 46}
 47
 48fn all_models(cx: &App) -> GroupedModels {
 49    let providers = LanguageModelRegistry::global(cx).read(cx).providers();
 50
 51    let recommended = providers
 52        .iter()
 53        .flat_map(|provider| {
 54            provider
 55                .recommended_models(cx)
 56                .into_iter()
 57                .map(|model| ModelInfo {
 58                    model,
 59                    icon: provider.icon(),
 60                })
 61        })
 62        .collect();
 63
 64    let all: Vec<ModelInfo> = providers
 65        .iter()
 66        .flat_map(|provider| {
 67            provider
 68                .provided_models(cx)
 69                .into_iter()
 70                .map(|model| ModelInfo {
 71                    model,
 72                    icon: provider.icon(),
 73                })
 74        })
 75        .collect();
 76
 77    GroupedModels::new(all, recommended)
 78}
 79
 80#[derive(Clone)]
 81struct ModelInfo {
 82    model: Arc<dyn LanguageModel>,
 83    icon: IconName,
 84}
 85
 86pub struct LanguageModelPickerDelegate {
 87    on_model_changed: OnModelChanged,
 88    get_active_model: GetActiveModel,
 89    all_models: Arc<GroupedModels>,
 90    filtered_entries: Vec<LanguageModelPickerEntry>,
 91    selected_index: usize,
 92    _authenticate_all_providers_task: Task<()>,
 93    _refresh_models_task: Task<()>,
 94    popover_styles: bool,
 95    focus_handle: FocusHandle,
 96}
 97
 98impl LanguageModelPickerDelegate {
 99    fn new(
100        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
101        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
102        popover_styles: bool,
103        focus_handle: FocusHandle,
104        window: &mut Window,
105        cx: &mut Context<Picker<Self>>,
106    ) -> Self {
107        let on_model_changed = Arc::new(on_model_changed);
108        let models = all_models(cx);
109        let entries = models.entries();
110
111        Self {
112            on_model_changed,
113            all_models: Arc::new(models),
114            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
115            filtered_entries: entries,
116            get_active_model: Arc::new(get_active_model),
117            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
118            _refresh_models_task: {
119                // Create a channel to signal when models need refreshing
120                let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
121
122                // Subscribe to registry events and send refresh signals through the channel
123                let registry = LanguageModelRegistry::global(cx);
124                cx.subscribe(&registry, move |_picker, _, event, _cx| match event {
125                    language_model::Event::ProviderStateChanged(_) => {
126                        refresh_tx.unbounded_send(()).ok();
127                    }
128                    language_model::Event::AddedProvider(_) => {
129                        refresh_tx.unbounded_send(()).ok();
130                    }
131                    language_model::Event::RemovedProvider(_) => {
132                        refresh_tx.unbounded_send(()).ok();
133                    }
134                    _ => {}
135                })
136                .detach();
137
138                // Spawn a task that listens for refresh signals and updates the picker
139                cx.spawn_in(window, async move |this, cx| {
140                    while let Some(()) = refresh_rx.next().await {
141                        let result = this.update_in(cx, |picker, window, cx| {
142                            picker.delegate.all_models = Arc::new(all_models(cx));
143                            picker.refresh(window, cx);
144                        });
145                        if result.is_err() {
146                            // Picker was dropped, exit the loop
147                            break;
148                        }
149                    }
150                })
151            },
152            popover_styles,
153            focus_handle,
154        }
155    }
156
157    fn get_active_model_index(
158        entries: &[LanguageModelPickerEntry],
159        active_model: Option<ConfiguredModel>,
160    ) -> usize {
161        entries
162            .iter()
163            .position(|entry| {
164                if let LanguageModelPickerEntry::Model(model) = entry {
165                    active_model
166                        .as_ref()
167                        .map(|active_model| {
168                            active_model.model.id() == model.model.id()
169                                && active_model.provider.id() == model.model.provider_id()
170                        })
171                        .unwrap_or_default()
172                } else {
173                    false
174                }
175            })
176            .unwrap_or(0)
177    }
178
179    /// Authenticates all providers in the [`LanguageModelRegistry`].
180    ///
181    /// We do this so that we can populate the language selector with all of the
182    /// models from the configured providers.
183    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
184        let authenticate_all_providers = LanguageModelRegistry::global(cx)
185            .read(cx)
186            .providers()
187            .iter()
188            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
189            .collect::<Vec<_>>();
190
191        cx.spawn(async move |_cx| {
192            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
193                if let Err(err) = authenticate_task.await {
194                    if matches!(err, AuthenticateError::CredentialsNotFound) {
195                        // Since we're authenticating these providers in the
196                        // background for the purposes of populating the
197                        // language selector, we don't care about providers
198                        // where the credentials are not found.
199                    } else {
200                        // Some providers have noisy failure states that we
201                        // don't want to spam the logs with every time the
202                        // language model selector is initialized.
203                        //
204                        // Ideally these should have more clear failure modes
205                        // that we know are safe to ignore here, like what we do
206                        // with `CredentialsNotFound` above.
207                        match provider_id.0.as_ref() {
208                            "lmstudio" | "ollama" => {
209                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
210                                //
211                                // These fail noisily, so we don't log them.
212                            }
213                            "copilot_chat" => {
214                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
215                            }
216                            _ => {
217                                log::error!(
218                                    "Failed to authenticate provider: {}: {err:#}",
219                                    provider_name.0
220                                );
221                            }
222                        }
223                    }
224                }
225            }
226        })
227    }
228
229    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
230        (self.get_active_model)(cx)
231    }
232}
233
234struct GroupedModels {
235    recommended: Vec<ModelInfo>,
236    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
237}
238
239impl GroupedModels {
240    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
241        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
242        for model in all {
243            let provider = model.model.provider_id();
244            if let Some(models) = all_by_provider.get_mut(&provider) {
245                models.push(model);
246            } else {
247                all_by_provider.insert(provider, vec![model]);
248            }
249        }
250
251        Self {
252            recommended,
253            all: all_by_provider,
254        }
255    }
256
257    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
258        let mut entries = Vec::new();
259
260        if !self.recommended.is_empty() {
261            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
262            entries.extend(
263                self.recommended
264                    .iter()
265                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
266            );
267        }
268
269        for models in self.all.values() {
270            if models.is_empty() {
271                continue;
272            }
273            entries.push(LanguageModelPickerEntry::Separator(
274                models[0].model.provider_name().0,
275            ));
276            entries.extend(
277                models
278                    .iter()
279                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
280            );
281        }
282        entries
283    }
284}
285
286enum LanguageModelPickerEntry {
287    Model(ModelInfo),
288    Separator(SharedString),
289}
290
291struct ModelMatcher {
292    models: Vec<ModelInfo>,
293    bg_executor: BackgroundExecutor,
294    candidates: Vec<StringMatchCandidate>,
295}
296
297impl ModelMatcher {
298    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
299        let candidates = Self::make_match_candidates(&models);
300        Self {
301            models,
302            bg_executor,
303            candidates,
304        }
305    }
306
307    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
308        let mut matches = self.bg_executor.block(match_strings(
309            &self.candidates,
310            query,
311            false,
312            true,
313            100,
314            &Default::default(),
315            self.bg_executor.clone(),
316        ));
317
318        let sorting_key = |mat: &StringMatch| {
319            let candidate = &self.candidates[mat.candidate_id];
320            (Reverse(OrderedFloat(mat.score)), candidate.id)
321        };
322        matches.sort_unstable_by_key(sorting_key);
323
324        let matched_models: Vec<_> = matches
325            .into_iter()
326            .map(|mat| self.models[mat.candidate_id].clone())
327            .collect();
328
329        matched_models
330    }
331
332    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
333        self.models
334            .iter()
335            .filter(|m| {
336                m.model
337                    .name()
338                    .0
339                    .to_lowercase()
340                    .contains(&query.to_lowercase())
341            })
342            .cloned()
343            .collect::<Vec<_>>()
344    }
345
346    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
347        model_infos
348            .iter()
349            .enumerate()
350            .map(|(index, model)| {
351                StringMatchCandidate::new(
352                    index,
353                    &format!(
354                        "{}/{}",
355                        &model.model.provider_name().0,
356                        &model.model.name().0
357                    ),
358                )
359            })
360            .collect::<Vec<_>>()
361    }
362}
363
364impl PickerDelegate for LanguageModelPickerDelegate {
365    type ListItem = AnyElement;
366
367    fn match_count(&self) -> usize {
368        self.filtered_entries.len()
369    }
370
371    fn selected_index(&self) -> usize {
372        self.selected_index
373    }
374
375    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
376        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
377        cx.notify();
378    }
379
380    fn can_select(
381        &mut self,
382        ix: usize,
383        _window: &mut Window,
384        _cx: &mut Context<Picker<Self>>,
385    ) -> bool {
386        match self.filtered_entries.get(ix) {
387            Some(LanguageModelPickerEntry::Model(_)) => true,
388            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
389        }
390    }
391
392    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
393        "Select a model…".into()
394    }
395
396    fn update_matches(
397        &mut self,
398        query: String,
399        window: &mut Window,
400        cx: &mut Context<Picker<Self>>,
401    ) -> Task<()> {
402        let all_models = self.all_models.clone();
403        let active_model = (self.get_active_model)(cx);
404        let bg_executor = cx.background_executor();
405
406        let language_model_registry = LanguageModelRegistry::global(cx);
407
408        let configured_providers = language_model_registry
409            .read(cx)
410            .providers()
411            .into_iter()
412            .filter(|provider| provider.is_authenticated(cx))
413            .collect::<Vec<_>>();
414
415        let configured_provider_ids = configured_providers
416            .iter()
417            .map(|provider| provider.id())
418            .collect::<Vec<_>>();
419
420        let recommended_models = all_models
421            .recommended
422            .iter()
423            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
424            .cloned()
425            .collect::<Vec<_>>();
426
427        let available_models = all_models
428            .all
429            .values()
430            .flat_map(|models| models.iter())
431            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
432            .cloned()
433            .collect::<Vec<_>>();
434
435        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
436        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
437
438        let recommended = matcher_rec.exact_search(&query);
439        let all = matcher_all.fuzzy_search(&query);
440
441        let filtered_models = GroupedModels::new(all, recommended);
442
443        cx.spawn_in(window, async move |this, cx| {
444            this.update_in(cx, |this, window, cx| {
445                this.delegate.filtered_entries = filtered_models.entries();
446                // Finds the currently selected model in the list
447                let new_index =
448                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
449                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
450                cx.notify();
451            })
452            .ok();
453        })
454    }
455
456    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
457        if let Some(LanguageModelPickerEntry::Model(model_info)) =
458            self.filtered_entries.get(self.selected_index)
459        {
460            let model = model_info.model.clone();
461            (self.on_model_changed)(model.clone(), cx);
462
463            let current_index = self.selected_index;
464            self.set_selected_index(current_index, window, cx);
465
466            cx.emit(DismissEvent);
467        }
468    }
469
470    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
471        cx.emit(DismissEvent);
472    }
473
474    fn render_match(
475        &self,
476        ix: usize,
477        selected: bool,
478        _: &mut Window,
479        cx: &mut Context<Picker<Self>>,
480    ) -> Option<Self::ListItem> {
481        match self.filtered_entries.get(ix)? {
482            LanguageModelPickerEntry::Separator(title) => Some(
483                div()
484                    .px_2()
485                    .pb_1()
486                    .when(ix > 1, |this| {
487                        this.mt_1()
488                            .pt_2()
489                            .border_t_1()
490                            .border_color(cx.theme().colors().border_variant)
491                    })
492                    .child(
493                        Label::new(title)
494                            .size(LabelSize::XSmall)
495                            .color(Color::Muted),
496                    )
497                    .into_any_element(),
498            ),
499            LanguageModelPickerEntry::Model(model_info) => {
500                let active_model = (self.get_active_model)(cx);
501                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
502                let active_model_id = active_model.map(|m| m.model.id());
503
504                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
505                    && Some(model_info.model.id()) == active_model_id;
506
507                let model_icon_color = if is_selected {
508                    Color::Accent
509                } else {
510                    Color::Muted
511                };
512
513                Some(
514                    ListItem::new(ix)
515                        .inset(true)
516                        .spacing(ListItemSpacing::Sparse)
517                        .toggle_state(selected)
518                        .child(
519                            h_flex()
520                                .w_full()
521                                .gap_1p5()
522                                .child(
523                                    Icon::new(model_info.icon)
524                                        .color(model_icon_color)
525                                        .size(IconSize::Small),
526                                )
527                                .child(Label::new(model_info.model.name().0).truncate()),
528                        )
529                        .end_slot(div().pr_3().when(is_selected, |this| {
530                            this.child(
531                                Icon::new(IconName::Check)
532                                    .color(Color::Accent)
533                                    .size(IconSize::Small),
534                            )
535                        }))
536                        .into_any_element(),
537                )
538            }
539        }
540    }
541
542    fn render_footer(
543        &self,
544        _window: &mut Window,
545        cx: &mut Context<Picker<Self>>,
546    ) -> Option<gpui::AnyElement> {
547        let focus_handle = self.focus_handle.clone();
548
549        if !self.popover_styles {
550            return None;
551        }
552
553        Some(
554            h_flex()
555                .w_full()
556                .p_1p5()
557                .border_t_1()
558                .border_color(cx.theme().colors().border_variant)
559                .child(
560                    Button::new("configure", "Configure")
561                        .full_width()
562                        .style(ButtonStyle::Outlined)
563                        .key_binding(
564                            KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
565                                .map(|kb| kb.size(rems_from_px(12.))),
566                        )
567                        .on_click(|_, window, cx| {
568                            window.dispatch_action(OpenSettings.boxed_clone(), cx);
569                        }),
570                )
571                .into_any(),
572        )
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579    use futures::{future::BoxFuture, stream::BoxStream};
580    use gpui::{AsyncApp, TestAppContext, http_client};
581    use language_model::{
582        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
583        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
584        LanguageModelRequest, LanguageModelToolChoice,
585    };
586    use ui::IconName;
587
588    #[derive(Clone)]
589    struct TestLanguageModel {
590        name: LanguageModelName,
591        id: LanguageModelId,
592        provider_id: LanguageModelProviderId,
593        provider_name: LanguageModelProviderName,
594    }
595
596    impl TestLanguageModel {
597        fn new(name: &str, provider: &str) -> Self {
598            Self {
599                name: LanguageModelName::from(name.to_string()),
600                id: LanguageModelId::from(name.to_string()),
601                provider_id: LanguageModelProviderId::from(provider.to_string()),
602                provider_name: LanguageModelProviderName::from(provider.to_string()),
603            }
604        }
605    }
606
607    impl LanguageModel for TestLanguageModel {
608        fn id(&self) -> LanguageModelId {
609            self.id.clone()
610        }
611
612        fn name(&self) -> LanguageModelName {
613            self.name.clone()
614        }
615
616        fn provider_id(&self) -> LanguageModelProviderId {
617            self.provider_id.clone()
618        }
619
620        fn provider_name(&self) -> LanguageModelProviderName {
621            self.provider_name.clone()
622        }
623
624        fn supports_tools(&self) -> bool {
625            false
626        }
627
628        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
629            false
630        }
631
632        fn supports_images(&self) -> bool {
633            false
634        }
635
636        fn telemetry_id(&self) -> String {
637            format!("{}/{}", self.provider_id.0, self.name.0)
638        }
639
640        fn max_token_count(&self) -> u64 {
641            1000
642        }
643
644        fn count_tokens(
645            &self,
646            _: LanguageModelRequest,
647            _: &App,
648        ) -> BoxFuture<'static, http_client::Result<u64>> {
649            unimplemented!()
650        }
651
652        fn stream_completion(
653            &self,
654            _: LanguageModelRequest,
655            _: &AsyncApp,
656        ) -> BoxFuture<
657            'static,
658            Result<
659                BoxStream<
660                    'static,
661                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
662                >,
663                LanguageModelCompletionError,
664            >,
665        > {
666            unimplemented!()
667        }
668    }
669
670    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
671        model_specs
672            .into_iter()
673            .map(|(provider, name)| ModelInfo {
674                model: Arc::new(TestLanguageModel::new(name, provider)),
675                icon: IconName::Ai,
676            })
677            .collect()
678    }
679
680    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
681        assert_eq!(
682            result.len(),
683            expected.len(),
684            "Number of models doesn't match"
685        );
686
687        for (i, expected_name) in expected.iter().enumerate() {
688            assert_eq!(
689                result[i].model.telemetry_id(),
690                *expected_name,
691                "Model at position {} doesn't match expected model",
692                i
693            );
694        }
695    }
696
697    #[gpui::test]
698    fn test_exact_match(cx: &mut TestAppContext) {
699        let models = create_models(vec![
700            ("zed", "Claude 3.7 Sonnet"),
701            ("zed", "Claude 3.7 Sonnet Thinking"),
702            ("zed", "gpt-4.1"),
703            ("zed", "gpt-4.1-nano"),
704            ("openai", "gpt-3.5-turbo"),
705            ("openai", "gpt-4.1"),
706            ("openai", "gpt-4.1-nano"),
707            ("ollama", "mistral"),
708            ("ollama", "deepseek"),
709        ]);
710        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
711
712        // The order of models should be maintained, case doesn't matter
713        let results = matcher.exact_search("GPT-4.1");
714        assert_models_eq(
715            results,
716            vec![
717                "zed/gpt-4.1",
718                "zed/gpt-4.1-nano",
719                "openai/gpt-4.1",
720                "openai/gpt-4.1-nano",
721            ],
722        );
723    }
724
725    #[gpui::test]
726    fn test_fuzzy_match(cx: &mut TestAppContext) {
727        let models = create_models(vec![
728            ("zed", "Claude 3.7 Sonnet"),
729            ("zed", "Claude 3.7 Sonnet Thinking"),
730            ("zed", "gpt-4.1"),
731            ("zed", "gpt-4.1-nano"),
732            ("openai", "gpt-3.5-turbo"),
733            ("openai", "gpt-4.1"),
734            ("openai", "gpt-4.1-nano"),
735            ("ollama", "mistral"),
736            ("ollama", "deepseek"),
737        ]);
738        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
739
740        // Results should preserve models order whenever possible.
741        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
742        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
743        // so it should appear first in the results.
744        let results = matcher.fuzzy_search("41");
745        assert_models_eq(
746            results,
747            vec![
748                "zed/gpt-4.1",
749                "openai/gpt-4.1",
750                "zed/gpt-4.1-nano",
751                "openai/gpt-4.1-nano",
752            ],
753        );
754
755        // Model provider should be searchable as well
756        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
757        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
758
759        // Fuzzy search
760        let results = matcher.fuzzy_search("z4n");
761        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
762    }
763
764    #[gpui::test]
765    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
766        let recommended_models = create_models(vec![("zed", "claude")]);
767        let all_models = create_models(vec![
768            ("zed", "claude"), // Should also appear in "other"
769            ("zed", "gemini"),
770            ("copilot", "o3"),
771        ]);
772
773        let grouped_models = GroupedModels::new(all_models, recommended_models);
774
775        let actual_all_models = grouped_models
776            .all
777            .values()
778            .flatten()
779            .cloned()
780            .collect::<Vec<_>>();
781
782        // Recommended models should also appear in "all"
783        assert_models_eq(
784            actual_all_models,
785            vec!["zed/claude", "zed/gemini", "copilot/o3"],
786        );
787    }
788
789    #[gpui::test]
790    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
791        let recommended_models = create_models(vec![("zed", "claude")]);
792        let all_models = create_models(vec![
793            ("zed", "claude"), // Should also appear in "other"
794            ("zed", "gemini"),
795            ("copilot", "claude"), // Different provider, should appear in "other"
796        ]);
797
798        let grouped_models = GroupedModels::new(all_models, recommended_models);
799
800        let actual_all_models = grouped_models
801            .all
802            .values()
803            .flatten()
804            .cloned()
805            .collect::<Vec<_>>();
806
807        // All models should appear in "all" regardless of recommended status
808        assert_models_eq(
809            actual_all_models,
810            vec!["zed/claude", "zed/gemini", "copilot/claude"],
811        );
812    }
813}