language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::{HashSet, IndexMap};
  4use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  5use gpui::{
  6    Action, AnyElement, App, BackgroundExecutor, DismissEvent, ForegroundExecutor, Subscription,
  7    Task,
  8};
  9use language_model::{
 10    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 11    LanguageModelRegistry,
 12};
 13use ordered_float::OrderedFloat;
 14use picker::{Picker, PickerDelegate};
 15use ui::{ListItem, ListItemSpacing, prelude::*};
 16
 17type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 18type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 19
 20pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 21
 22pub fn language_model_selector(
 23    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 24    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 25    window: &mut Window,
 26    cx: &mut Context<LanguageModelSelector>,
 27) -> LanguageModelSelector {
 28    let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
 29    Picker::list(delegate, window, cx)
 30        .show_scrollbar(true)
 31        .width(rems(20.))
 32        .max_height(Some(rems(20.).into()))
 33}
 34
 35fn all_models(cx: &App) -> GroupedModels {
 36    let providers = LanguageModelRegistry::global(cx).read(cx).providers();
 37
 38    let recommended = providers
 39        .iter()
 40        .flat_map(|provider| {
 41            provider
 42                .recommended_models(cx)
 43                .into_iter()
 44                .map(|model| ModelInfo {
 45                    model,
 46                    icon: provider.icon(),
 47                })
 48        })
 49        .collect();
 50
 51    let other = providers
 52        .iter()
 53        .flat_map(|provider| {
 54            provider
 55                .provided_models(cx)
 56                .into_iter()
 57                .map(|model| ModelInfo {
 58                    model,
 59                    icon: provider.icon(),
 60                })
 61        })
 62        .collect();
 63
 64    GroupedModels::new(other, recommended)
 65}
 66
 67#[derive(Clone)]
 68struct ModelInfo {
 69    model: Arc<dyn LanguageModel>,
 70    icon: IconName,
 71}
 72
 73pub struct LanguageModelPickerDelegate {
 74    on_model_changed: OnModelChanged,
 75    get_active_model: GetActiveModel,
 76    all_models: Arc<GroupedModels>,
 77    filtered_entries: Vec<LanguageModelPickerEntry>,
 78    selected_index: usize,
 79    _authenticate_all_providers_task: Task<()>,
 80    _subscriptions: Vec<Subscription>,
 81}
 82
 83impl LanguageModelPickerDelegate {
 84    fn new(
 85        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 86        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 87        window: &mut Window,
 88        cx: &mut Context<Picker<Self>>,
 89    ) -> Self {
 90        let on_model_changed = Arc::new(on_model_changed);
 91        let models = all_models(cx);
 92        let entries = models.entries();
 93
 94        Self {
 95            on_model_changed,
 96            all_models: Arc::new(models),
 97            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
 98            filtered_entries: entries,
 99            get_active_model: Arc::new(get_active_model),
100            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
101            _subscriptions: vec![cx.subscribe_in(
102                &LanguageModelRegistry::global(cx),
103                window,
104                |picker, _, event, window, cx| {
105                    match event {
106                        language_model::Event::ProviderStateChanged(_)
107                        | language_model::Event::AddedProvider(_)
108                        | language_model::Event::RemovedProvider(_) => {
109                            let query = picker.query(cx);
110                            picker.delegate.all_models = Arc::new(all_models(cx));
111                            // Update matches will automatically drop the previous task
112                            // if we get a provider event again
113                            picker.update_matches(query, window, cx)
114                        }
115                        _ => {}
116                    }
117                },
118            )],
119        }
120    }
121
122    fn get_active_model_index(
123        entries: &[LanguageModelPickerEntry],
124        active_model: Option<ConfiguredModel>,
125    ) -> usize {
126        entries
127            .iter()
128            .position(|entry| {
129                if let LanguageModelPickerEntry::Model(model) = entry {
130                    active_model
131                        .as_ref()
132                        .map(|active_model| {
133                            active_model.model.id() == model.model.id()
134                                && active_model.provider.id() == model.model.provider_id()
135                        })
136                        .unwrap_or_default()
137                } else {
138                    false
139                }
140            })
141            .unwrap_or(0)
142    }
143
144    /// Authenticates all providers in the [`LanguageModelRegistry`].
145    ///
146    /// We do this so that we can populate the language selector with all of the
147    /// models from the configured providers.
148    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
149        let authenticate_all_providers = LanguageModelRegistry::global(cx)
150            .read(cx)
151            .providers()
152            .iter()
153            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
154            .collect::<Vec<_>>();
155
156        cx.spawn(async move |_cx| {
157            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
158                if let Err(err) = authenticate_task.await {
159                    if matches!(err, AuthenticateError::CredentialsNotFound) {
160                        // Since we're authenticating these providers in the
161                        // background for the purposes of populating the
162                        // language selector, we don't care about providers
163                        // where the credentials are not found.
164                    } else {
165                        // Some providers have noisy failure states that we
166                        // don't want to spam the logs with every time the
167                        // language model selector is initialized.
168                        //
169                        // Ideally these should have more clear failure modes
170                        // that we know are safe to ignore here, like what we do
171                        // with `CredentialsNotFound` above.
172                        match provider_id.0.as_ref() {
173                            "lmstudio" | "ollama" => {
174                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
175                                //
176                                // These fail noisily, so we don't log them.
177                            }
178                            "copilot_chat" => {
179                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
180                            }
181                            _ => {
182                                log::error!(
183                                    "Failed to authenticate provider: {}: {err}",
184                                    provider_name.0
185                                );
186                            }
187                        }
188                    }
189                }
190            }
191        })
192    }
193
194    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
195        (self.get_active_model)(cx)
196    }
197}
198
199struct GroupedModels {
200    recommended: Vec<ModelInfo>,
201    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
202}
203
204impl GroupedModels {
205    pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
206        let recommended_ids = recommended
207            .iter()
208            .map(|info| (info.model.provider_id(), info.model.id()))
209            .collect::<HashSet<_>>();
210
211        let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
212        for model in other {
213            if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
214                continue;
215            }
216
217            let provider = model.model.provider_id();
218            if let Some(models) = other_by_provider.get_mut(&provider) {
219                models.push(model);
220            } else {
221                other_by_provider.insert(provider, vec![model]);
222            }
223        }
224
225        Self {
226            recommended,
227            other: other_by_provider,
228        }
229    }
230
231    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
232        let mut entries = Vec::new();
233
234        if !self.recommended.is_empty() {
235            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
236            entries.extend(
237                self.recommended
238                    .iter()
239                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
240            );
241        }
242
243        for models in self.other.values() {
244            if models.is_empty() {
245                continue;
246            }
247            entries.push(LanguageModelPickerEntry::Separator(
248                models[0].model.provider_name().0,
249            ));
250            entries.extend(
251                models
252                    .iter()
253                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
254            );
255        }
256        entries
257    }
258
259    fn model_infos(&self) -> Vec<ModelInfo> {
260        let other = self
261            .other
262            .values()
263            .flat_map(|model| model.iter())
264            .cloned()
265            .collect::<Vec<_>>();
266        self.recommended
267            .iter()
268            .chain(&other)
269            .cloned()
270            .collect::<Vec<_>>()
271    }
272}
273
274enum LanguageModelPickerEntry {
275    Model(ModelInfo),
276    Separator(SharedString),
277}
278
279struct ModelMatcher {
280    models: Vec<ModelInfo>,
281    fg_executor: ForegroundExecutor,
282    bg_executor: BackgroundExecutor,
283    candidates: Vec<StringMatchCandidate>,
284}
285
286impl ModelMatcher {
287    fn new(
288        models: Vec<ModelInfo>,
289        fg_executor: ForegroundExecutor,
290        bg_executor: BackgroundExecutor,
291    ) -> ModelMatcher {
292        let candidates = Self::make_match_candidates(&models);
293        Self {
294            models,
295            fg_executor,
296            bg_executor,
297            candidates,
298        }
299    }
300
301    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
302        let mut matches = self.fg_executor.block_on(match_strings(
303            &self.candidates,
304            query,
305            false,
306            true,
307            100,
308            &Default::default(),
309            self.bg_executor.clone(),
310        ));
311
312        let sorting_key = |mat: &StringMatch| {
313            let candidate = &self.candidates[mat.candidate_id];
314            (Reverse(OrderedFloat(mat.score)), candidate.id)
315        };
316        matches.sort_unstable_by_key(sorting_key);
317
318        let matched_models: Vec<_> = matches
319            .into_iter()
320            .map(|mat| self.models[mat.candidate_id].clone())
321            .collect();
322
323        matched_models
324    }
325
326    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
327        self.models
328            .iter()
329            .filter(|m| {
330                m.model
331                    .name()
332                    .0
333                    .to_lowercase()
334                    .contains(&query.to_lowercase())
335            })
336            .cloned()
337            .collect::<Vec<_>>()
338    }
339
340    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
341        model_infos
342            .iter()
343            .enumerate()
344            .map(|(index, model)| {
345                StringMatchCandidate::new(
346                    index,
347                    &format!(
348                        "{}/{}",
349                        &model.model.provider_name().0,
350                        &model.model.name().0
351                    ),
352                )
353            })
354            .collect::<Vec<_>>()
355    }
356}
357
358impl PickerDelegate for LanguageModelPickerDelegate {
359    type ListItem = AnyElement;
360
361    fn match_count(&self) -> usize {
362        self.filtered_entries.len()
363    }
364
365    fn selected_index(&self) -> usize {
366        self.selected_index
367    }
368
369    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
370        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
371        cx.notify();
372    }
373
374    fn can_select(
375        &mut self,
376        ix: usize,
377        _window: &mut Window,
378        _cx: &mut Context<Picker<Self>>,
379    ) -> bool {
380        match self.filtered_entries.get(ix) {
381            Some(LanguageModelPickerEntry::Model(_)) => true,
382            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
383        }
384    }
385
386    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
387        "Select a model…".into()
388    }
389
390    fn update_matches(
391        &mut self,
392        query: String,
393        window: &mut Window,
394        cx: &mut Context<Picker<Self>>,
395    ) -> Task<()> {
396        let all_models = self.all_models.clone();
397        let active_model = (self.get_active_model)(cx);
398        let fg_executor = cx.foreground_executor();
399        let bg_executor = cx.background_executor();
400
401        let language_model_registry = LanguageModelRegistry::global(cx);
402
403        let configured_providers = language_model_registry
404            .read(cx)
405            .providers()
406            .into_iter()
407            .filter(|provider| provider.is_authenticated(cx))
408            .collect::<Vec<_>>();
409
410        let configured_provider_ids = configured_providers
411            .iter()
412            .map(|provider| provider.id())
413            .collect::<Vec<_>>();
414
415        let recommended_models = all_models
416            .recommended
417            .iter()
418            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
419            .cloned()
420            .collect::<Vec<_>>();
421
422        let available_models = all_models
423            .model_infos()
424            .iter()
425            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
426            .cloned()
427            .collect::<Vec<_>>();
428
429        let matcher_rec =
430            ModelMatcher::new(recommended_models, fg_executor.clone(), bg_executor.clone());
431        let matcher_all =
432            ModelMatcher::new(available_models, fg_executor.clone(), bg_executor.clone());
433
434        let recommended = matcher_rec.exact_search(&query);
435        let all = matcher_all.fuzzy_search(&query);
436
437        let filtered_models = GroupedModels::new(all, recommended);
438
439        cx.spawn_in(window, async move |this, cx| {
440            this.update_in(cx, |this, window, cx| {
441                this.delegate.filtered_entries = filtered_models.entries();
442                // Finds the currently selected model in the list
443                let new_index =
444                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
445                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
446                cx.notify();
447            })
448            .ok();
449        })
450    }
451
452    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
453        if let Some(LanguageModelPickerEntry::Model(model_info)) =
454            self.filtered_entries.get(self.selected_index)
455        {
456            let model = model_info.model.clone();
457            (self.on_model_changed)(model.clone(), cx);
458
459            let current_index = self.selected_index;
460            self.set_selected_index(current_index, window, cx);
461
462            cx.emit(DismissEvent);
463        }
464    }
465
466    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
467        cx.emit(DismissEvent);
468    }
469
470    fn render_match(
471        &self,
472        ix: usize,
473        selected: bool,
474        _: &mut Window,
475        cx: &mut Context<Picker<Self>>,
476    ) -> Option<Self::ListItem> {
477        match self.filtered_entries.get(ix)? {
478            LanguageModelPickerEntry::Separator(title) => Some(
479                div()
480                    .px_2()
481                    .pb_1()
482                    .when(ix > 1, |this| {
483                        this.mt_1()
484                            .pt_2()
485                            .border_t_1()
486                            .border_color(cx.theme().colors().border_variant)
487                    })
488                    .child(
489                        Label::new(title)
490                            .size(LabelSize::XSmall)
491                            .color(Color::Muted),
492                    )
493                    .into_any_element(),
494            ),
495            LanguageModelPickerEntry::Model(model_info) => {
496                let active_model = (self.get_active_model)(cx);
497                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
498                let active_model_id = active_model.map(|m| m.model.id());
499
500                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
501                    && Some(model_info.model.id()) == active_model_id;
502
503                let model_icon_color = if is_selected {
504                    Color::Accent
505                } else {
506                    Color::Muted
507                };
508
509                Some(
510                    ListItem::new(ix)
511                        .inset(true)
512                        .spacing(ListItemSpacing::Sparse)
513                        .toggle_state(selected)
514                        .start_slot(
515                            Icon::new(model_info.icon)
516                                .color(model_icon_color)
517                                .size(IconSize::Small),
518                        )
519                        .child(
520                            h_flex()
521                                .w_full()
522                                .pl_0p5()
523                                .gap_1p5()
524                                .w(px(240.))
525                                .child(Label::new(model_info.model.name().0).truncate()),
526                        )
527                        .end_slot(div().pr_3().when(is_selected, |this| {
528                            this.child(
529                                Icon::new(IconName::Check)
530                                    .color(Color::Accent)
531                                    .size(IconSize::Small),
532                            )
533                        }))
534                        .into_any_element(),
535                )
536            }
537        }
538    }
539
540    fn render_footer(
541        &self,
542        _window: &mut Window,
543        cx: &mut Context<Picker<Self>>,
544    ) -> Option<gpui::AnyElement> {
545        Some(
546            h_flex()
547                .w_full()
548                .border_t_1()
549                .border_color(cx.theme().colors().border_variant)
550                .p_1()
551                .gap_4()
552                .justify_between()
553                .child(
554                    Button::new("configure", "Configure")
555                        .icon(IconName::Settings)
556                        .icon_size(IconSize::Small)
557                        .icon_color(Color::Muted)
558                        .icon_position(IconPosition::Start)
559                        .on_click(|_, window, cx| {
560                            window.dispatch_action(
561                                zed_actions::agent::OpenSettings.boxed_clone(),
562                                cx,
563                            );
564                        }),
565                )
566                .into_any(),
567        )
568    }
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574    use futures::{future::BoxFuture, stream::BoxStream};
575    use gpui::{AsyncApp, TestAppContext, http_client};
576    use language_model::{
577        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
578        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
579        LanguageModelRequest, LanguageModelToolChoice,
580    };
581    use ui::IconName;
582
583    #[derive(Clone)]
584    struct TestLanguageModel {
585        name: LanguageModelName,
586        id: LanguageModelId,
587        provider_id: LanguageModelProviderId,
588        provider_name: LanguageModelProviderName,
589    }
590
591    impl TestLanguageModel {
592        fn new(name: &str, provider: &str) -> Self {
593            Self {
594                name: LanguageModelName::from(name.to_string()),
595                id: LanguageModelId::from(name.to_string()),
596                provider_id: LanguageModelProviderId::from(provider.to_string()),
597                provider_name: LanguageModelProviderName::from(provider.to_string()),
598            }
599        }
600    }
601
602    impl LanguageModel for TestLanguageModel {
603        fn id(&self) -> LanguageModelId {
604            self.id.clone()
605        }
606
607        fn name(&self) -> LanguageModelName {
608            self.name.clone()
609        }
610
611        fn provider_id(&self) -> LanguageModelProviderId {
612            self.provider_id.clone()
613        }
614
615        fn provider_name(&self) -> LanguageModelProviderName {
616            self.provider_name.clone()
617        }
618
619        fn supports_tools(&self) -> bool {
620            false
621        }
622
623        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
624            false
625        }
626
627        fn supports_images(&self) -> bool {
628            false
629        }
630
631        fn telemetry_id(&self) -> String {
632            format!("{}/{}", self.provider_id.0, self.name.0)
633        }
634
635        fn max_token_count(&self) -> u64 {
636            1000
637        }
638
639        fn count_tokens(
640            &self,
641            _: LanguageModelRequest,
642            _: &App,
643        ) -> BoxFuture<'static, http_client::Result<u64>> {
644            unimplemented!()
645        }
646
647        fn stream_completion(
648            &self,
649            _: LanguageModelRequest,
650            _: &AsyncApp,
651        ) -> BoxFuture<
652            'static,
653            Result<
654                BoxStream<
655                    'static,
656                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
657                >,
658                LanguageModelCompletionError,
659            >,
660        > {
661            unimplemented!()
662        }
663    }
664
665    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
666        model_specs
667            .into_iter()
668            .map(|(provider, name)| ModelInfo {
669                model: Arc::new(TestLanguageModel::new(name, provider)),
670                icon: IconName::Ai,
671            })
672            .collect()
673    }
674
675    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
676        assert_eq!(
677            result.len(),
678            expected.len(),
679            "Number of models doesn't match"
680        );
681
682        for (i, expected_name) in expected.iter().enumerate() {
683            assert_eq!(
684                result[i].model.telemetry_id(),
685                *expected_name,
686                "Model at position {} doesn't match expected model",
687                i
688            );
689        }
690    }
691
692    #[gpui::test]
693    fn test_exact_match(cx: &mut TestAppContext) {
694        let models = create_models(vec![
695            ("zed", "Claude 3.7 Sonnet"),
696            ("zed", "Claude 3.7 Sonnet Thinking"),
697            ("zed", "gpt-4.1"),
698            ("zed", "gpt-4.1-nano"),
699            ("openai", "gpt-3.5-turbo"),
700            ("openai", "gpt-4.1"),
701            ("openai", "gpt-4.1-nano"),
702            ("ollama", "mistral"),
703            ("ollama", "deepseek"),
704        ]);
705        let matcher = ModelMatcher::new(models, cx.foreground_executor(), cx.executor());
706
707        // The order of models should be maintained, case doesn't matter
708        let results = matcher.exact_search("GPT-4.1");
709        assert_models_eq(
710            results,
711            vec![
712                "zed/gpt-4.1",
713                "zed/gpt-4.1-nano",
714                "openai/gpt-4.1",
715                "openai/gpt-4.1-nano",
716            ],
717        );
718    }
719
720    #[gpui::test]
721    fn test_fuzzy_match(cx: &mut TestAppContext) {
722        let models = create_models(vec![
723            ("zed", "Claude 3.7 Sonnet"),
724            ("zed", "Claude 3.7 Sonnet Thinking"),
725            ("zed", "gpt-4.1"),
726            ("zed", "gpt-4.1-nano"),
727            ("openai", "gpt-3.5-turbo"),
728            ("openai", "gpt-4.1"),
729            ("openai", "gpt-4.1-nano"),
730            ("ollama", "mistral"),
731            ("ollama", "deepseek"),
732        ]);
733        let matcher = ModelMatcher::new(models, cx.foreground_executor(), cx.executor());
734
735        // Results should preserve models order whenever possible.
736        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
737        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
738        // so it should appear first in the results.
739        let results = matcher.fuzzy_search("41");
740        assert_models_eq(
741            results,
742            vec![
743                "zed/gpt-4.1",
744                "openai/gpt-4.1",
745                "zed/gpt-4.1-nano",
746                "openai/gpt-4.1-nano",
747            ],
748        );
749
750        // Model provider should be searchable as well
751        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
752        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
753
754        // Fuzzy search
755        let results = matcher.fuzzy_search("z4n");
756        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
757    }
758
759    #[gpui::test]
760    fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
761        let recommended_models = create_models(vec![("zed", "claude")]);
762        let all_models = create_models(vec![
763            ("zed", "claude"), // Should be filtered out from "other"
764            ("zed", "gemini"),
765            ("copilot", "o3"),
766        ]);
767
768        let grouped_models = GroupedModels::new(all_models, recommended_models);
769
770        let actual_other_models = grouped_models
771            .other
772            .values()
773            .flatten()
774            .cloned()
775            .collect::<Vec<_>>();
776
777        // Recommended models should not appear in "other"
778        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
779    }
780
781    #[gpui::test]
782    fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
783        let recommended_models = create_models(vec![("zed", "claude")]);
784        let all_models = create_models(vec![
785            ("zed", "claude"), // Should be filtered out from "other"
786            ("zed", "gemini"),
787            ("copilot", "claude"), // Should not be filtered out from "other"
788        ]);
789
790        let grouped_models = GroupedModels::new(all_models, recommended_models);
791
792        let actual_other_models = grouped_models
793            .other
794            .values()
795            .flatten()
796            .cloned()
797            .collect::<Vec<_>>();
798
799        // Recommended models should not appear in "other"
800        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
801    }
802}