language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::{HashSet, IndexMap};
  4use feature_flags::ZedProFeatureFlag;
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{
  7    Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task, actions,
  8};
  9use language_model::{
 10    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 11    LanguageModelRegistry,
 12};
 13use ordered_float::OrderedFloat;
 14use picker::{Picker, PickerDelegate};
 15use proto::Plan;
 16use ui::{ListItem, ListItemSpacing, prelude::*};
 17
 18actions!(
 19    agent,
 20    [
 21        #[action(deprecated_aliases = ["assistant::ToggleModelSelector", "assistant2::ToggleModelSelector"])]
 22        ToggleModelSelector
 23    ]
 24);
 25
 26const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 27
 28type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 29type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 30
 31pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 32
 33pub fn language_model_selector(
 34    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 35    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 36    window: &mut Window,
 37    cx: &mut Context<LanguageModelSelector>,
 38) -> LanguageModelSelector {
 39    let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
 40    Picker::list(delegate, window, cx)
 41        .show_scrollbar(true)
 42        .width(rems(20.))
 43        .max_height(Some(rems(20.).into()))
 44}
 45
 46fn all_models(cx: &App) -> GroupedModels {
 47    let providers = LanguageModelRegistry::global(cx).read(cx).providers();
 48
 49    let recommended = providers
 50        .iter()
 51        .flat_map(|provider| {
 52            provider
 53                .recommended_models(cx)
 54                .into_iter()
 55                .map(|model| ModelInfo {
 56                    model,
 57                    icon: provider.icon(),
 58                })
 59        })
 60        .collect();
 61
 62    let other = providers
 63        .iter()
 64        .flat_map(|provider| {
 65            provider
 66                .provided_models(cx)
 67                .into_iter()
 68                .map(|model| ModelInfo {
 69                    model,
 70                    icon: provider.icon(),
 71                })
 72        })
 73        .collect();
 74
 75    GroupedModels::new(other, recommended)
 76}
 77
 78#[derive(Clone)]
 79struct ModelInfo {
 80    model: Arc<dyn LanguageModel>,
 81    icon: IconName,
 82}
 83
 84pub struct LanguageModelPickerDelegate {
 85    on_model_changed: OnModelChanged,
 86    get_active_model: GetActiveModel,
 87    all_models: Arc<GroupedModels>,
 88    filtered_entries: Vec<LanguageModelPickerEntry>,
 89    selected_index: usize,
 90    _authenticate_all_providers_task: Task<()>,
 91    _subscriptions: Vec<Subscription>,
 92}
 93
 94impl LanguageModelPickerDelegate {
 95    fn new(
 96        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 97        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 98        window: &mut Window,
 99        cx: &mut Context<Picker<Self>>,
100    ) -> Self {
101        let on_model_changed = Arc::new(on_model_changed);
102        let models = all_models(cx);
103        let entries = models.entries();
104
105        Self {
106            on_model_changed: on_model_changed.clone(),
107            all_models: Arc::new(models),
108            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
109            filtered_entries: entries,
110            get_active_model: Arc::new(get_active_model),
111            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
112            _subscriptions: vec![cx.subscribe_in(
113                &LanguageModelRegistry::global(cx),
114                window,
115                |picker, _, event, window, cx| {
116                    match event {
117                        language_model::Event::ProviderStateChanged
118                        | language_model::Event::AddedProvider(_)
119                        | language_model::Event::RemovedProvider(_) => {
120                            let query = picker.query(cx);
121                            picker.delegate.all_models = Arc::new(all_models(cx));
122                            // Update matches will automatically drop the previous task
123                            // if we get a provider event again
124                            picker.update_matches(query, window, cx)
125                        }
126                        _ => {}
127                    }
128                },
129            )],
130        }
131    }
132
133    fn get_active_model_index(
134        entries: &[LanguageModelPickerEntry],
135        active_model: Option<ConfiguredModel>,
136    ) -> usize {
137        entries
138            .iter()
139            .position(|entry| {
140                if let LanguageModelPickerEntry::Model(model) = entry {
141                    active_model
142                        .as_ref()
143                        .map(|active_model| {
144                            active_model.model.id() == model.model.id()
145                                && active_model.provider.id() == model.model.provider_id()
146                        })
147                        .unwrap_or_default()
148                } else {
149                    false
150                }
151            })
152            .unwrap_or(0)
153    }
154
155    /// Authenticates all providers in the [`LanguageModelRegistry`].
156    ///
157    /// We do this so that we can populate the language selector with all of the
158    /// models from the configured providers.
159    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
160        let authenticate_all_providers = LanguageModelRegistry::global(cx)
161            .read(cx)
162            .providers()
163            .iter()
164            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
165            .collect::<Vec<_>>();
166
167        cx.spawn(async move |_cx| {
168            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
169                if let Err(err) = authenticate_task.await {
170                    if matches!(err, AuthenticateError::CredentialsNotFound) {
171                        // Since we're authenticating these providers in the
172                        // background for the purposes of populating the
173                        // language selector, we don't care about providers
174                        // where the credentials are not found.
175                    } else {
176                        // Some providers have noisy failure states that we
177                        // don't want to spam the logs with every time the
178                        // language model selector is initialized.
179                        //
180                        // Ideally these should have more clear failure modes
181                        // that we know are safe to ignore here, like what we do
182                        // with `CredentialsNotFound` above.
183                        match provider_id.0.as_ref() {
184                            "lmstudio" | "ollama" => {
185                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
186                                //
187                                // These fail noisily, so we don't log them.
188                            }
189                            "copilot_chat" => {
190                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
191                            }
192                            _ => {
193                                log::error!(
194                                    "Failed to authenticate provider: {}: {err}",
195                                    provider_name.0
196                                );
197                            }
198                        }
199                    }
200                }
201            }
202        })
203    }
204
205    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
206        (self.get_active_model)(cx)
207    }
208}
209
210struct GroupedModels {
211    recommended: Vec<ModelInfo>,
212    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
213}
214
215impl GroupedModels {
216    pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
217        let recommended_ids = recommended
218            .iter()
219            .map(|info| (info.model.provider_id(), info.model.id()))
220            .collect::<HashSet<_>>();
221
222        let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
223        for model in other {
224            if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
225                continue;
226            }
227
228            let provider = model.model.provider_id();
229            if let Some(models) = other_by_provider.get_mut(&provider) {
230                models.push(model);
231            } else {
232                other_by_provider.insert(provider, vec![model]);
233            }
234        }
235
236        Self {
237            recommended,
238            other: other_by_provider,
239        }
240    }
241
242    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
243        let mut entries = Vec::new();
244
245        if !self.recommended.is_empty() {
246            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
247            entries.extend(
248                self.recommended
249                    .iter()
250                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
251            );
252        }
253
254        for models in self.other.values() {
255            if models.is_empty() {
256                continue;
257            }
258            entries.push(LanguageModelPickerEntry::Separator(
259                models[0].model.provider_name().0,
260            ));
261            entries.extend(
262                models
263                    .iter()
264                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
265            );
266        }
267        entries
268    }
269
270    fn model_infos(&self) -> Vec<ModelInfo> {
271        let other = self
272            .other
273            .values()
274            .flat_map(|model| model.iter())
275            .cloned()
276            .collect::<Vec<_>>();
277        self.recommended
278            .iter()
279            .chain(&other)
280            .cloned()
281            .collect::<Vec<_>>()
282    }
283}
284
285enum LanguageModelPickerEntry {
286    Model(ModelInfo),
287    Separator(SharedString),
288}
289
290struct ModelMatcher {
291    models: Vec<ModelInfo>,
292    bg_executor: BackgroundExecutor,
293    candidates: Vec<StringMatchCandidate>,
294}
295
296impl ModelMatcher {
297    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
298        let candidates = Self::make_match_candidates(&models);
299        Self {
300            models,
301            bg_executor,
302            candidates,
303        }
304    }
305
306    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
307        let mut matches = self.bg_executor.block(match_strings(
308            &self.candidates,
309            &query,
310            false,
311            true,
312            100,
313            &Default::default(),
314            self.bg_executor.clone(),
315        ));
316
317        let sorting_key = |mat: &StringMatch| {
318            let candidate = &self.candidates[mat.candidate_id];
319            (Reverse(OrderedFloat(mat.score)), candidate.id)
320        };
321        matches.sort_unstable_by_key(sorting_key);
322
323        let matched_models: Vec<_> = matches
324            .into_iter()
325            .map(|mat| self.models[mat.candidate_id].clone())
326            .collect();
327
328        matched_models
329    }
330
331    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
332        self.models
333            .iter()
334            .filter(|m| {
335                m.model
336                    .name()
337                    .0
338                    .to_lowercase()
339                    .contains(&query.to_lowercase())
340            })
341            .cloned()
342            .collect::<Vec<_>>()
343    }
344
345    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
346        model_infos
347            .iter()
348            .enumerate()
349            .map(|(index, model)| {
350                StringMatchCandidate::new(
351                    index,
352                    &format!(
353                        "{}/{}",
354                        &model.model.provider_name().0,
355                        &model.model.name().0
356                    ),
357                )
358            })
359            .collect::<Vec<_>>()
360    }
361}
362
363impl PickerDelegate for LanguageModelPickerDelegate {
364    type ListItem = AnyElement;
365
366    fn match_count(&self) -> usize {
367        self.filtered_entries.len()
368    }
369
370    fn selected_index(&self) -> usize {
371        self.selected_index
372    }
373
374    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
375        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
376        cx.notify();
377    }
378
379    fn can_select(
380        &mut self,
381        ix: usize,
382        _window: &mut Window,
383        _cx: &mut Context<Picker<Self>>,
384    ) -> bool {
385        match self.filtered_entries.get(ix) {
386            Some(LanguageModelPickerEntry::Model(_)) => true,
387            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
388        }
389    }
390
391    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
392        "Select a model…".into()
393    }
394
395    fn update_matches(
396        &mut self,
397        query: String,
398        window: &mut Window,
399        cx: &mut Context<Picker<Self>>,
400    ) -> Task<()> {
401        let all_models = self.all_models.clone();
402        let current_index = self.selected_index;
403        let bg_executor = cx.background_executor();
404
405        let language_model_registry = LanguageModelRegistry::global(cx);
406
407        let configured_providers = language_model_registry
408            .read(cx)
409            .providers()
410            .into_iter()
411            .filter(|provider| provider.is_authenticated(cx))
412            .collect::<Vec<_>>();
413
414        let configured_provider_ids = configured_providers
415            .iter()
416            .map(|provider| provider.id())
417            .collect::<Vec<_>>();
418
419        let recommended_models = all_models
420            .recommended
421            .iter()
422            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
423            .cloned()
424            .collect::<Vec<_>>();
425
426        let available_models = all_models
427            .model_infos()
428            .iter()
429            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
430            .cloned()
431            .collect::<Vec<_>>();
432
433        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
434        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
435
436        let recommended = matcher_rec.exact_search(&query);
437        let all = matcher_all.fuzzy_search(&query);
438
439        let filtered_models = GroupedModels::new(all, recommended);
440
441        cx.spawn_in(window, async move |this, cx| {
442            this.update_in(cx, |this, window, cx| {
443                this.delegate.filtered_entries = filtered_models.entries();
444                // Preserve selection focus
445                let new_index = if current_index >= this.delegate.filtered_entries.len() {
446                    0
447                } else {
448                    current_index
449                };
450                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
451                cx.notify();
452            })
453            .ok();
454        })
455    }
456
457    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
458        if let Some(LanguageModelPickerEntry::Model(model_info)) =
459            self.filtered_entries.get(self.selected_index)
460        {
461            let model = model_info.model.clone();
462            (self.on_model_changed)(model.clone(), cx);
463
464            let current_index = self.selected_index;
465            self.set_selected_index(current_index, window, cx);
466
467            cx.emit(DismissEvent);
468        }
469    }
470
471    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
472        cx.emit(DismissEvent);
473    }
474
475    fn render_match(
476        &self,
477        ix: usize,
478        selected: bool,
479        _: &mut Window,
480        cx: &mut Context<Picker<Self>>,
481    ) -> Option<Self::ListItem> {
482        match self.filtered_entries.get(ix)? {
483            LanguageModelPickerEntry::Separator(title) => Some(
484                div()
485                    .px_2()
486                    .pb_1()
487                    .when(ix > 1, |this| {
488                        this.mt_1()
489                            .pt_2()
490                            .border_t_1()
491                            .border_color(cx.theme().colors().border_variant)
492                    })
493                    .child(
494                        Label::new(title)
495                            .size(LabelSize::XSmall)
496                            .color(Color::Muted),
497                    )
498                    .into_any_element(),
499            ),
500            LanguageModelPickerEntry::Model(model_info) => {
501                let active_model = (self.get_active_model)(cx);
502                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
503                let active_model_id = active_model.map(|m| m.model.id());
504
505                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
506                    && Some(model_info.model.id()) == active_model_id;
507
508                let model_icon_color = if is_selected {
509                    Color::Accent
510                } else {
511                    Color::Muted
512                };
513
514                Some(
515                    ListItem::new(ix)
516                        .inset(true)
517                        .spacing(ListItemSpacing::Sparse)
518                        .toggle_state(selected)
519                        .start_slot(
520                            Icon::new(model_info.icon)
521                                .color(model_icon_color)
522                                .size(IconSize::Small),
523                        )
524                        .child(
525                            h_flex()
526                                .w_full()
527                                .pl_0p5()
528                                .gap_1p5()
529                                .w(px(240.))
530                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
531                        )
532                        .end_slot(div().pr_3().when(is_selected, |this| {
533                            this.child(
534                                Icon::new(IconName::Check)
535                                    .color(Color::Accent)
536                                    .size(IconSize::Small),
537                            )
538                        }))
539                        .into_any_element(),
540                )
541            }
542        }
543    }
544
545    fn render_footer(
546        &self,
547        _: &mut Window,
548        cx: &mut Context<Picker<Self>>,
549    ) -> Option<gpui::AnyElement> {
550        use feature_flags::FeatureFlagAppExt;
551
552        let plan = proto::Plan::ZedPro;
553
554        Some(
555            h_flex()
556                .w_full()
557                .border_t_1()
558                .border_color(cx.theme().colors().border_variant)
559                .p_1()
560                .gap_4()
561                .justify_between()
562                .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
563                    this.child(match plan {
564                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
565                            .icon(IconName::ZedAssistant)
566                            .icon_size(IconSize::Small)
567                            .icon_color(Color::Muted)
568                            .icon_position(IconPosition::Start)
569                            .on_click(|_, window, cx| {
570                                window
571                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
572                            }),
573                        Plan::Free | Plan::ZedProTrial => Button::new(
574                            "try-pro",
575                            if plan == Plan::ZedProTrial {
576                                "Upgrade to Pro"
577                            } else {
578                                "Try Pro"
579                            },
580                        )
581                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
582                    })
583                })
584                .child(
585                    Button::new("configure", "Configure")
586                        .icon(IconName::Settings)
587                        .icon_size(IconSize::Small)
588                        .icon_color(Color::Muted)
589                        .icon_position(IconPosition::Start)
590                        .on_click(|_, window, cx| {
591                            window.dispatch_action(
592                                zed_actions::agent::OpenConfiguration.boxed_clone(),
593                                cx,
594                            );
595                        }),
596                )
597                .into_any(),
598        )
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605    use futures::{future::BoxFuture, stream::BoxStream};
606    use gpui::{AsyncApp, TestAppContext, http_client};
607    use language_model::{
608        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
609        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
610        LanguageModelRequest, LanguageModelToolChoice,
611    };
612    use ui::IconName;
613
614    #[derive(Clone)]
615    struct TestLanguageModel {
616        name: LanguageModelName,
617        id: LanguageModelId,
618        provider_id: LanguageModelProviderId,
619        provider_name: LanguageModelProviderName,
620    }
621
622    impl TestLanguageModel {
623        fn new(name: &str, provider: &str) -> Self {
624            Self {
625                name: LanguageModelName::from(name.to_string()),
626                id: LanguageModelId::from(name.to_string()),
627                provider_id: LanguageModelProviderId::from(provider.to_string()),
628                provider_name: LanguageModelProviderName::from(provider.to_string()),
629            }
630        }
631    }
632
633    impl LanguageModel for TestLanguageModel {
634        fn id(&self) -> LanguageModelId {
635            self.id.clone()
636        }
637
638        fn name(&self) -> LanguageModelName {
639            self.name.clone()
640        }
641
642        fn provider_id(&self) -> LanguageModelProviderId {
643            self.provider_id.clone()
644        }
645
646        fn provider_name(&self) -> LanguageModelProviderName {
647            self.provider_name.clone()
648        }
649
650        fn supports_tools(&self) -> bool {
651            false
652        }
653
654        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
655            false
656        }
657
658        fn supports_images(&self) -> bool {
659            false
660        }
661
662        fn telemetry_id(&self) -> String {
663            format!("{}/{}", self.provider_id.0, self.name.0)
664        }
665
666        fn max_token_count(&self) -> u64 {
667            1000
668        }
669
670        fn count_tokens(
671            &self,
672            _: LanguageModelRequest,
673            _: &App,
674        ) -> BoxFuture<'static, http_client::Result<u64>> {
675            unimplemented!()
676        }
677
678        fn stream_completion(
679            &self,
680            _: LanguageModelRequest,
681            _: &AsyncApp,
682        ) -> BoxFuture<
683            'static,
684            Result<
685                BoxStream<
686                    'static,
687                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
688                >,
689                LanguageModelCompletionError,
690            >,
691        > {
692            unimplemented!()
693        }
694    }
695
696    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
697        model_specs
698            .into_iter()
699            .map(|(provider, name)| ModelInfo {
700                model: Arc::new(TestLanguageModel::new(name, provider)),
701                icon: IconName::Ai,
702            })
703            .collect()
704    }
705
706    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
707        assert_eq!(
708            result.len(),
709            expected.len(),
710            "Number of models doesn't match"
711        );
712
713        for (i, expected_name) in expected.iter().enumerate() {
714            assert_eq!(
715                result[i].model.telemetry_id(),
716                *expected_name,
717                "Model at position {} doesn't match expected model",
718                i
719            );
720        }
721    }
722
723    #[gpui::test]
724    fn test_exact_match(cx: &mut TestAppContext) {
725        let models = create_models(vec![
726            ("zed", "Claude 3.7 Sonnet"),
727            ("zed", "Claude 3.7 Sonnet Thinking"),
728            ("zed", "gpt-4.1"),
729            ("zed", "gpt-4.1-nano"),
730            ("openai", "gpt-3.5-turbo"),
731            ("openai", "gpt-4.1"),
732            ("openai", "gpt-4.1-nano"),
733            ("ollama", "mistral"),
734            ("ollama", "deepseek"),
735        ]);
736        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
737
738        // The order of models should be maintained, case doesn't matter
739        let results = matcher.exact_search("GPT-4.1");
740        assert_models_eq(
741            results,
742            vec![
743                "zed/gpt-4.1",
744                "zed/gpt-4.1-nano",
745                "openai/gpt-4.1",
746                "openai/gpt-4.1-nano",
747            ],
748        );
749    }
750
751    #[gpui::test]
752    fn test_fuzzy_match(cx: &mut TestAppContext) {
753        let models = create_models(vec![
754            ("zed", "Claude 3.7 Sonnet"),
755            ("zed", "Claude 3.7 Sonnet Thinking"),
756            ("zed", "gpt-4.1"),
757            ("zed", "gpt-4.1-nano"),
758            ("openai", "gpt-3.5-turbo"),
759            ("openai", "gpt-4.1"),
760            ("openai", "gpt-4.1-nano"),
761            ("ollama", "mistral"),
762            ("ollama", "deepseek"),
763        ]);
764        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
765
766        // Results should preserve models order whenever possible.
767        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
768        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
769        // so it should appear first in the results.
770        let results = matcher.fuzzy_search("41");
771        assert_models_eq(
772            results,
773            vec![
774                "zed/gpt-4.1",
775                "openai/gpt-4.1",
776                "zed/gpt-4.1-nano",
777                "openai/gpt-4.1-nano",
778            ],
779        );
780
781        // Model provider should be searchable as well
782        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
783        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
784
785        // Fuzzy search
786        let results = matcher.fuzzy_search("z4n");
787        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
788    }
789
790    #[gpui::test]
791    fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
792        let recommended_models = create_models(vec![("zed", "claude")]);
793        let all_models = create_models(vec![
794            ("zed", "claude"), // Should be filtered out from "other"
795            ("zed", "gemini"),
796            ("copilot", "o3"),
797        ]);
798
799        let grouped_models = GroupedModels::new(all_models, recommended_models);
800
801        let actual_other_models = grouped_models
802            .other
803            .values()
804            .flatten()
805            .cloned()
806            .collect::<Vec<_>>();
807
808        // Recommended models should not appear in "other"
809        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
810    }
811
812    #[gpui::test]
813    fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
814        let recommended_models = create_models(vec![("zed", "claude")]);
815        let all_models = create_models(vec![
816            ("zed", "claude"), // Should be filtered out from "other"
817            ("zed", "gemini"),
818            ("copilot", "claude"), // Should not be filtered out from "other"
819        ]);
820
821        let grouped_models = GroupedModels::new(all_models, recommended_models);
822
823        let actual_other_models = grouped_models
824            .other
825            .values()
826            .flatten()
827            .cloned()
828            .collect::<Vec<_>>();
829
830        // Recommended models should not appear in "other"
831        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
832    }
833}