language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::{HashSet, IndexMap};
  4use feature_flags::ZedProFeatureFlag;
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{
  7    Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task, actions,
  8};
  9use language_model::{
 10    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 11    LanguageModelRegistry,
 12};
 13use ordered_float::OrderedFloat;
 14use picker::{Picker, PickerDelegate};
 15use proto::Plan;
 16use ui::{ListItem, ListItemSpacing, prelude::*};
 17
 18actions!(
 19    agent,
 20    [
 21        /// Toggles the language model selector dropdown.
 22        #[action(deprecated_aliases = ["assistant::ToggleModelSelector", "assistant2::ToggleModelSelector"])]
 23        ToggleModelSelector
 24    ]
 25);
 26
 27const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 28
 29type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 30type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 31
 32pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 33
 34pub fn language_model_selector(
 35    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 36    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 37    window: &mut Window,
 38    cx: &mut Context<LanguageModelSelector>,
 39) -> LanguageModelSelector {
 40    let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
 41    Picker::list(delegate, window, cx)
 42        .show_scrollbar(true)
 43        .width(rems(20.))
 44        .max_height(Some(rems(20.).into()))
 45}
 46
 47fn all_models(cx: &App) -> GroupedModels {
 48    let providers = LanguageModelRegistry::global(cx).read(cx).providers();
 49
 50    let recommended = providers
 51        .iter()
 52        .flat_map(|provider| {
 53            provider
 54                .recommended_models(cx)
 55                .into_iter()
 56                .map(|model| ModelInfo {
 57                    model,
 58                    icon: provider.icon(),
 59                })
 60        })
 61        .collect();
 62
 63    let other = providers
 64        .iter()
 65        .flat_map(|provider| {
 66            provider
 67                .provided_models(cx)
 68                .into_iter()
 69                .map(|model| ModelInfo {
 70                    model,
 71                    icon: provider.icon(),
 72                })
 73        })
 74        .collect();
 75
 76    GroupedModels::new(other, recommended)
 77}
 78
 79#[derive(Clone)]
 80struct ModelInfo {
 81    model: Arc<dyn LanguageModel>,
 82    icon: IconName,
 83}
 84
 85pub struct LanguageModelPickerDelegate {
 86    on_model_changed: OnModelChanged,
 87    get_active_model: GetActiveModel,
 88    all_models: Arc<GroupedModels>,
 89    filtered_entries: Vec<LanguageModelPickerEntry>,
 90    selected_index: usize,
 91    _authenticate_all_providers_task: Task<()>,
 92    _subscriptions: Vec<Subscription>,
 93}
 94
 95impl LanguageModelPickerDelegate {
 96    fn new(
 97        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 98        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 99        window: &mut Window,
100        cx: &mut Context<Picker<Self>>,
101    ) -> Self {
102        let on_model_changed = Arc::new(on_model_changed);
103        let models = all_models(cx);
104        let entries = models.entries();
105
106        Self {
107            on_model_changed: on_model_changed.clone(),
108            all_models: Arc::new(models),
109            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
110            filtered_entries: entries,
111            get_active_model: Arc::new(get_active_model),
112            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
113            _subscriptions: vec![cx.subscribe_in(
114                &LanguageModelRegistry::global(cx),
115                window,
116                |picker, _, event, window, cx| {
117                    match event {
118                        language_model::Event::ProviderStateChanged
119                        | language_model::Event::AddedProvider(_)
120                        | language_model::Event::RemovedProvider(_) => {
121                            let query = picker.query(cx);
122                            picker.delegate.all_models = Arc::new(all_models(cx));
123                            // Update matches will automatically drop the previous task
124                            // if we get a provider event again
125                            picker.update_matches(query, window, cx)
126                        }
127                        _ => {}
128                    }
129                },
130            )],
131        }
132    }
133
134    fn get_active_model_index(
135        entries: &[LanguageModelPickerEntry],
136        active_model: Option<ConfiguredModel>,
137    ) -> usize {
138        entries
139            .iter()
140            .position(|entry| {
141                if let LanguageModelPickerEntry::Model(model) = entry {
142                    active_model
143                        .as_ref()
144                        .map(|active_model| {
145                            active_model.model.id() == model.model.id()
146                                && active_model.provider.id() == model.model.provider_id()
147                        })
148                        .unwrap_or_default()
149                } else {
150                    false
151                }
152            })
153            .unwrap_or(0)
154    }
155
156    /// Authenticates all providers in the [`LanguageModelRegistry`].
157    ///
158    /// We do this so that we can populate the language selector with all of the
159    /// models from the configured providers.
160    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
161        let authenticate_all_providers = LanguageModelRegistry::global(cx)
162            .read(cx)
163            .providers()
164            .iter()
165            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
166            .collect::<Vec<_>>();
167
168        cx.spawn(async move |_cx| {
169            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
170                if let Err(err) = authenticate_task.await {
171                    if matches!(err, AuthenticateError::CredentialsNotFound) {
172                        // Since we're authenticating these providers in the
173                        // background for the purposes of populating the
174                        // language selector, we don't care about providers
175                        // where the credentials are not found.
176                    } else {
177                        // Some providers have noisy failure states that we
178                        // don't want to spam the logs with every time the
179                        // language model selector is initialized.
180                        //
181                        // Ideally these should have more clear failure modes
182                        // that we know are safe to ignore here, like what we do
183                        // with `CredentialsNotFound` above.
184                        match provider_id.0.as_ref() {
185                            "lmstudio" | "ollama" => {
186                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
187                                //
188                                // These fail noisily, so we don't log them.
189                            }
190                            "copilot_chat" => {
191                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
192                            }
193                            _ => {
194                                log::error!(
195                                    "Failed to authenticate provider: {}: {err}",
196                                    provider_name.0
197                                );
198                            }
199                        }
200                    }
201                }
202            }
203        })
204    }
205
206    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
207        (self.get_active_model)(cx)
208    }
209}
210
211struct GroupedModels {
212    recommended: Vec<ModelInfo>,
213    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
214}
215
216impl GroupedModels {
217    pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
218        let recommended_ids = recommended
219            .iter()
220            .map(|info| (info.model.provider_id(), info.model.id()))
221            .collect::<HashSet<_>>();
222
223        let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
224        for model in other {
225            if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
226                continue;
227            }
228
229            let provider = model.model.provider_id();
230            if let Some(models) = other_by_provider.get_mut(&provider) {
231                models.push(model);
232            } else {
233                other_by_provider.insert(provider, vec![model]);
234            }
235        }
236
237        Self {
238            recommended,
239            other: other_by_provider,
240        }
241    }
242
243    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
244        let mut entries = Vec::new();
245
246        if !self.recommended.is_empty() {
247            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
248            entries.extend(
249                self.recommended
250                    .iter()
251                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
252            );
253        }
254
255        for models in self.other.values() {
256            if models.is_empty() {
257                continue;
258            }
259            entries.push(LanguageModelPickerEntry::Separator(
260                models[0].model.provider_name().0,
261            ));
262            entries.extend(
263                models
264                    .iter()
265                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
266            );
267        }
268        entries
269    }
270
271    fn model_infos(&self) -> Vec<ModelInfo> {
272        let other = self
273            .other
274            .values()
275            .flat_map(|model| model.iter())
276            .cloned()
277            .collect::<Vec<_>>();
278        self.recommended
279            .iter()
280            .chain(&other)
281            .cloned()
282            .collect::<Vec<_>>()
283    }
284}
285
286enum LanguageModelPickerEntry {
287    Model(ModelInfo),
288    Separator(SharedString),
289}
290
291struct ModelMatcher {
292    models: Vec<ModelInfo>,
293    bg_executor: BackgroundExecutor,
294    candidates: Vec<StringMatchCandidate>,
295}
296
297impl ModelMatcher {
298    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
299        let candidates = Self::make_match_candidates(&models);
300        Self {
301            models,
302            bg_executor,
303            candidates,
304        }
305    }
306
307    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
308        let mut matches = self.bg_executor.block(match_strings(
309            &self.candidates,
310            &query,
311            false,
312            true,
313            100,
314            &Default::default(),
315            self.bg_executor.clone(),
316        ));
317
318        let sorting_key = |mat: &StringMatch| {
319            let candidate = &self.candidates[mat.candidate_id];
320            (Reverse(OrderedFloat(mat.score)), candidate.id)
321        };
322        matches.sort_unstable_by_key(sorting_key);
323
324        let matched_models: Vec<_> = matches
325            .into_iter()
326            .map(|mat| self.models[mat.candidate_id].clone())
327            .collect();
328
329        matched_models
330    }
331
332    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
333        self.models
334            .iter()
335            .filter(|m| {
336                m.model
337                    .name()
338                    .0
339                    .to_lowercase()
340                    .contains(&query.to_lowercase())
341            })
342            .cloned()
343            .collect::<Vec<_>>()
344    }
345
346    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
347        model_infos
348            .iter()
349            .enumerate()
350            .map(|(index, model)| {
351                StringMatchCandidate::new(
352                    index,
353                    &format!(
354                        "{}/{}",
355                        &model.model.provider_name().0,
356                        &model.model.name().0
357                    ),
358                )
359            })
360            .collect::<Vec<_>>()
361    }
362}
363
364impl PickerDelegate for LanguageModelPickerDelegate {
365    type ListItem = AnyElement;
366
367    fn match_count(&self) -> usize {
368        self.filtered_entries.len()
369    }
370
371    fn selected_index(&self) -> usize {
372        self.selected_index
373    }
374
375    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
376        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
377        cx.notify();
378    }
379
380    fn can_select(
381        &mut self,
382        ix: usize,
383        _window: &mut Window,
384        _cx: &mut Context<Picker<Self>>,
385    ) -> bool {
386        match self.filtered_entries.get(ix) {
387            Some(LanguageModelPickerEntry::Model(_)) => true,
388            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
389        }
390    }
391
392    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
393        "Select a model…".into()
394    }
395
396    fn update_matches(
397        &mut self,
398        query: String,
399        window: &mut Window,
400        cx: &mut Context<Picker<Self>>,
401    ) -> Task<()> {
402        let all_models = self.all_models.clone();
403        let active_model = (self.get_active_model)(cx);
404        let bg_executor = cx.background_executor();
405
406        let language_model_registry = LanguageModelRegistry::global(cx);
407
408        let configured_providers = language_model_registry
409            .read(cx)
410            .providers()
411            .into_iter()
412            .filter(|provider| provider.is_authenticated(cx))
413            .collect::<Vec<_>>();
414
415        let configured_provider_ids = configured_providers
416            .iter()
417            .map(|provider| provider.id())
418            .collect::<Vec<_>>();
419
420        let recommended_models = all_models
421            .recommended
422            .iter()
423            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
424            .cloned()
425            .collect::<Vec<_>>();
426
427        let available_models = all_models
428            .model_infos()
429            .iter()
430            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
431            .cloned()
432            .collect::<Vec<_>>();
433
434        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
435        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
436
437        let recommended = matcher_rec.exact_search(&query);
438        let all = matcher_all.fuzzy_search(&query);
439
440        let filtered_models = GroupedModels::new(all, recommended);
441
442        cx.spawn_in(window, async move |this, cx| {
443            this.update_in(cx, |this, window, cx| {
444                this.delegate.filtered_entries = filtered_models.entries();
445                // Finds the currently selected model in the list
446                let new_index =
447                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
448                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
449                cx.notify();
450            })
451            .ok();
452        })
453    }
454
455    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
456        if let Some(LanguageModelPickerEntry::Model(model_info)) =
457            self.filtered_entries.get(self.selected_index)
458        {
459            let model = model_info.model.clone();
460            (self.on_model_changed)(model.clone(), cx);
461
462            let current_index = self.selected_index;
463            self.set_selected_index(current_index, window, cx);
464
465            cx.emit(DismissEvent);
466        }
467    }
468
469    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
470        cx.emit(DismissEvent);
471    }
472
473    fn render_match(
474        &self,
475        ix: usize,
476        selected: bool,
477        _: &mut Window,
478        cx: &mut Context<Picker<Self>>,
479    ) -> Option<Self::ListItem> {
480        match self.filtered_entries.get(ix)? {
481            LanguageModelPickerEntry::Separator(title) => Some(
482                div()
483                    .px_2()
484                    .pb_1()
485                    .when(ix > 1, |this| {
486                        this.mt_1()
487                            .pt_2()
488                            .border_t_1()
489                            .border_color(cx.theme().colors().border_variant)
490                    })
491                    .child(
492                        Label::new(title)
493                            .size(LabelSize::XSmall)
494                            .color(Color::Muted),
495                    )
496                    .into_any_element(),
497            ),
498            LanguageModelPickerEntry::Model(model_info) => {
499                let active_model = (self.get_active_model)(cx);
500                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
501                let active_model_id = active_model.map(|m| m.model.id());
502
503                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
504                    && Some(model_info.model.id()) == active_model_id;
505
506                let model_icon_color = if is_selected {
507                    Color::Accent
508                } else {
509                    Color::Muted
510                };
511
512                Some(
513                    ListItem::new(ix)
514                        .inset(true)
515                        .spacing(ListItemSpacing::Sparse)
516                        .toggle_state(selected)
517                        .start_slot(
518                            Icon::new(model_info.icon)
519                                .color(model_icon_color)
520                                .size(IconSize::Small),
521                        )
522                        .child(
523                            h_flex()
524                                .w_full()
525                                .pl_0p5()
526                                .gap_1p5()
527                                .w(px(240.))
528                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
529                        )
530                        .end_slot(div().pr_3().when(is_selected, |this| {
531                            this.child(
532                                Icon::new(IconName::Check)
533                                    .color(Color::Accent)
534                                    .size(IconSize::Small),
535                            )
536                        }))
537                        .into_any_element(),
538                )
539            }
540        }
541    }
542
543    fn render_footer(
544        &self,
545        _: &mut Window,
546        cx: &mut Context<Picker<Self>>,
547    ) -> Option<gpui::AnyElement> {
548        use feature_flags::FeatureFlagAppExt;
549
550        let plan = proto::Plan::ZedPro;
551
552        Some(
553            h_flex()
554                .w_full()
555                .border_t_1()
556                .border_color(cx.theme().colors().border_variant)
557                .p_1()
558                .gap_4()
559                .justify_between()
560                .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
561                    this.child(match plan {
562                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
563                            .icon(IconName::ZedAssistant)
564                            .icon_size(IconSize::Small)
565                            .icon_color(Color::Muted)
566                            .icon_position(IconPosition::Start)
567                            .on_click(|_, window, cx| {
568                                window
569                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
570                            }),
571                        Plan::Free | Plan::ZedProTrial => Button::new(
572                            "try-pro",
573                            if plan == Plan::ZedProTrial {
574                                "Upgrade to Pro"
575                            } else {
576                                "Try Pro"
577                            },
578                        )
579                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
580                    })
581                })
582                .child(
583                    Button::new("configure", "Configure")
584                        .icon(IconName::Settings)
585                        .icon_size(IconSize::Small)
586                        .icon_color(Color::Muted)
587                        .icon_position(IconPosition::Start)
588                        .on_click(|_, window, cx| {
589                            window.dispatch_action(
590                                zed_actions::agent::OpenConfiguration.boxed_clone(),
591                                cx,
592                            );
593                        }),
594                )
595                .into_any(),
596        )
597    }
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603    use futures::{future::BoxFuture, stream::BoxStream};
604    use gpui::{AsyncApp, TestAppContext, http_client};
605    use language_model::{
606        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
607        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
608        LanguageModelRequest, LanguageModelToolChoice,
609    };
610    use ui::IconName;
611
612    #[derive(Clone)]
613    struct TestLanguageModel {
614        name: LanguageModelName,
615        id: LanguageModelId,
616        provider_id: LanguageModelProviderId,
617        provider_name: LanguageModelProviderName,
618    }
619
620    impl TestLanguageModel {
621        fn new(name: &str, provider: &str) -> Self {
622            Self {
623                name: LanguageModelName::from(name.to_string()),
624                id: LanguageModelId::from(name.to_string()),
625                provider_id: LanguageModelProviderId::from(provider.to_string()),
626                provider_name: LanguageModelProviderName::from(provider.to_string()),
627            }
628        }
629    }
630
631    impl LanguageModel for TestLanguageModel {
632        fn id(&self) -> LanguageModelId {
633            self.id.clone()
634        }
635
636        fn name(&self) -> LanguageModelName {
637            self.name.clone()
638        }
639
640        fn provider_id(&self) -> LanguageModelProviderId {
641            self.provider_id.clone()
642        }
643
644        fn provider_name(&self) -> LanguageModelProviderName {
645            self.provider_name.clone()
646        }
647
648        fn supports_tools(&self) -> bool {
649            false
650        }
651
652        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
653            false
654        }
655
656        fn supports_images(&self) -> bool {
657            false
658        }
659
660        fn telemetry_id(&self) -> String {
661            format!("{}/{}", self.provider_id.0, self.name.0)
662        }
663
664        fn max_token_count(&self) -> u64 {
665            1000
666        }
667
668        fn count_tokens(
669            &self,
670            _: LanguageModelRequest,
671            _: &App,
672        ) -> BoxFuture<'static, http_client::Result<u64>> {
673            unimplemented!()
674        }
675
676        fn stream_completion(
677            &self,
678            _: LanguageModelRequest,
679            _: &AsyncApp,
680        ) -> BoxFuture<
681            'static,
682            Result<
683                BoxStream<
684                    'static,
685                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
686                >,
687                LanguageModelCompletionError,
688            >,
689        > {
690            unimplemented!()
691        }
692    }
693
694    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
695        model_specs
696            .into_iter()
697            .map(|(provider, name)| ModelInfo {
698                model: Arc::new(TestLanguageModel::new(name, provider)),
699                icon: IconName::Ai,
700            })
701            .collect()
702    }
703
704    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
705        assert_eq!(
706            result.len(),
707            expected.len(),
708            "Number of models doesn't match"
709        );
710
711        for (i, expected_name) in expected.iter().enumerate() {
712            assert_eq!(
713                result[i].model.telemetry_id(),
714                *expected_name,
715                "Model at position {} doesn't match expected model",
716                i
717            );
718        }
719    }
720
721    #[gpui::test]
722    fn test_exact_match(cx: &mut TestAppContext) {
723        let models = create_models(vec![
724            ("zed", "Claude 3.7 Sonnet"),
725            ("zed", "Claude 3.7 Sonnet Thinking"),
726            ("zed", "gpt-4.1"),
727            ("zed", "gpt-4.1-nano"),
728            ("openai", "gpt-3.5-turbo"),
729            ("openai", "gpt-4.1"),
730            ("openai", "gpt-4.1-nano"),
731            ("ollama", "mistral"),
732            ("ollama", "deepseek"),
733        ]);
734        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
735
736        // The order of models should be maintained, case doesn't matter
737        let results = matcher.exact_search("GPT-4.1");
738        assert_models_eq(
739            results,
740            vec![
741                "zed/gpt-4.1",
742                "zed/gpt-4.1-nano",
743                "openai/gpt-4.1",
744                "openai/gpt-4.1-nano",
745            ],
746        );
747    }
748
749    #[gpui::test]
750    fn test_fuzzy_match(cx: &mut TestAppContext) {
751        let models = create_models(vec![
752            ("zed", "Claude 3.7 Sonnet"),
753            ("zed", "Claude 3.7 Sonnet Thinking"),
754            ("zed", "gpt-4.1"),
755            ("zed", "gpt-4.1-nano"),
756            ("openai", "gpt-3.5-turbo"),
757            ("openai", "gpt-4.1"),
758            ("openai", "gpt-4.1-nano"),
759            ("ollama", "mistral"),
760            ("ollama", "deepseek"),
761        ]);
762        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
763
764        // Results should preserve models order whenever possible.
765        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
766        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
767        // so it should appear first in the results.
768        let results = matcher.fuzzy_search("41");
769        assert_models_eq(
770            results,
771            vec![
772                "zed/gpt-4.1",
773                "openai/gpt-4.1",
774                "zed/gpt-4.1-nano",
775                "openai/gpt-4.1-nano",
776            ],
777        );
778
779        // Model provider should be searchable as well
780        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
781        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
782
783        // Fuzzy search
784        let results = matcher.fuzzy_search("z4n");
785        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
786    }
787
788    #[gpui::test]
789    fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
790        let recommended_models = create_models(vec![("zed", "claude")]);
791        let all_models = create_models(vec![
792            ("zed", "claude"), // Should be filtered out from "other"
793            ("zed", "gemini"),
794            ("copilot", "o3"),
795        ]);
796
797        let grouped_models = GroupedModels::new(all_models, recommended_models);
798
799        let actual_other_models = grouped_models
800            .other
801            .values()
802            .flatten()
803            .cloned()
804            .collect::<Vec<_>>();
805
806        // Recommended models should not appear in "other"
807        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
808    }
809
810    #[gpui::test]
811    fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
812        let recommended_models = create_models(vec![("zed", "claude")]);
813        let all_models = create_models(vec![
814            ("zed", "claude"), // Should be filtered out from "other"
815            ("zed", "gemini"),
816            ("copilot", "claude"), // Should not be filtered out from "other"
817        ]);
818
819        let grouped_models = GroupedModels::new(all_models, recommended_models);
820
821        let actual_other_models = grouped_models
822            .other
823            .values()
824            .flatten()
825            .cloned()
826            .collect::<Vec<_>>();
827
828        // Recommended models should not appear in "other"
829        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
830    }
831}