language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use cloud_llm_client::Plan;
  4use collections::{HashSet, IndexMap};
  5use feature_flags::ZedProFeatureFlag;
  6use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  7use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
  8use language_model::{
  9    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 10    LanguageModelRegistry,
 11};
 12use ordered_float::OrderedFloat;
 13use picker::{Picker, PickerDelegate};
 14use ui::{ListItem, ListItemSpacing, prelude::*};
 15
 16const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 17
 18type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 19type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 20
 21pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 22
 23pub fn language_model_selector(
 24    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 25    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 26    window: &mut Window,
 27    cx: &mut Context<LanguageModelSelector>,
 28) -> LanguageModelSelector {
 29    let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
 30    Picker::list(delegate, window, cx)
 31        .show_scrollbar(true)
 32        .width(rems(20.))
 33        .max_height(Some(rems(20.).into()))
 34}
 35
 36fn all_models(cx: &App) -> GroupedModels {
 37    let providers = LanguageModelRegistry::global(cx).read(cx).providers();
 38
 39    let recommended = providers
 40        .iter()
 41        .flat_map(|provider| {
 42            provider
 43                .recommended_models(cx)
 44                .into_iter()
 45                .map(|model| ModelInfo {
 46                    model,
 47                    icon: provider.icon(),
 48                })
 49        })
 50        .collect();
 51
 52    let other = providers
 53        .iter()
 54        .flat_map(|provider| {
 55            provider
 56                .provided_models(cx)
 57                .into_iter()
 58                .map(|model| ModelInfo {
 59                    model,
 60                    icon: provider.icon(),
 61                })
 62        })
 63        .collect();
 64
 65    GroupedModels::new(other, recommended)
 66}
 67
 68#[derive(Clone)]
 69struct ModelInfo {
 70    model: Arc<dyn LanguageModel>,
 71    icon: IconName,
 72}
 73
 74pub struct LanguageModelPickerDelegate {
 75    on_model_changed: OnModelChanged,
 76    get_active_model: GetActiveModel,
 77    all_models: Arc<GroupedModels>,
 78    filtered_entries: Vec<LanguageModelPickerEntry>,
 79    selected_index: usize,
 80    _authenticate_all_providers_task: Task<()>,
 81    _subscriptions: Vec<Subscription>,
 82}
 83
 84impl LanguageModelPickerDelegate {
 85    fn new(
 86        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 87        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 88        window: &mut Window,
 89        cx: &mut Context<Picker<Self>>,
 90    ) -> Self {
 91        let on_model_changed = Arc::new(on_model_changed);
 92        let models = all_models(cx);
 93        let entries = models.entries();
 94
 95        Self {
 96            on_model_changed,
 97            all_models: Arc::new(models),
 98            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
 99            filtered_entries: entries,
100            get_active_model: Arc::new(get_active_model),
101            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
102            _subscriptions: vec![cx.subscribe_in(
103                &LanguageModelRegistry::global(cx),
104                window,
105                |picker, _, event, window, cx| {
106                    match event {
107                        language_model::Event::ProviderStateChanged(_)
108                        | language_model::Event::AddedProvider(_)
109                        | language_model::Event::RemovedProvider(_) => {
110                            let query = picker.query(cx);
111                            picker.delegate.all_models = Arc::new(all_models(cx));
112                            // Update matches will automatically drop the previous task
113                            // if we get a provider event again
114                            picker.update_matches(query, window, cx)
115                        }
116                        _ => {}
117                    }
118                },
119            )],
120        }
121    }
122
123    fn get_active_model_index(
124        entries: &[LanguageModelPickerEntry],
125        active_model: Option<ConfiguredModel>,
126    ) -> usize {
127        entries
128            .iter()
129            .position(|entry| {
130                if let LanguageModelPickerEntry::Model(model) = entry {
131                    active_model
132                        .as_ref()
133                        .map(|active_model| {
134                            active_model.model.id() == model.model.id()
135                                && active_model.provider.id() == model.model.provider_id()
136                        })
137                        .unwrap_or_default()
138                } else {
139                    false
140                }
141            })
142            .unwrap_or(0)
143    }
144
145    /// Authenticates all providers in the [`LanguageModelRegistry`].
146    ///
147    /// We do this so that we can populate the language selector with all of the
148    /// models from the configured providers.
149    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
150        let authenticate_all_providers = LanguageModelRegistry::global(cx)
151            .read(cx)
152            .providers()
153            .iter()
154            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
155            .collect::<Vec<_>>();
156
157        cx.spawn(async move |_cx| {
158            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
159                if let Err(err) = authenticate_task.await {
160                    if matches!(err, AuthenticateError::CredentialsNotFound) {
161                        // Since we're authenticating these providers in the
162                        // background for the purposes of populating the
163                        // language selector, we don't care about providers
164                        // where the credentials are not found.
165                    } else {
166                        // Some providers have noisy failure states that we
167                        // don't want to spam the logs with every time the
168                        // language model selector is initialized.
169                        //
170                        // Ideally these should have more clear failure modes
171                        // that we know are safe to ignore here, like what we do
172                        // with `CredentialsNotFound` above.
173                        match provider_id.0.as_ref() {
174                            "lmstudio" | "ollama" => {
175                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
176                                //
177                                // These fail noisily, so we don't log them.
178                            }
179                            "copilot_chat" => {
180                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
181                            }
182                            _ => {
183                                log::error!(
184                                    "Failed to authenticate provider: {}: {err}",
185                                    provider_name.0
186                                );
187                            }
188                        }
189                    }
190                }
191            }
192        })
193    }
194
195    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
196        (self.get_active_model)(cx)
197    }
198}
199
200struct GroupedModels {
201    recommended: Vec<ModelInfo>,
202    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
203}
204
205impl GroupedModels {
206    pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
207        let recommended_ids = recommended
208            .iter()
209            .map(|info| (info.model.provider_id(), info.model.id()))
210            .collect::<HashSet<_>>();
211
212        let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
213        for model in other {
214            if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
215                continue;
216            }
217
218            let provider = model.model.provider_id();
219            if let Some(models) = other_by_provider.get_mut(&provider) {
220                models.push(model);
221            } else {
222                other_by_provider.insert(provider, vec![model]);
223            }
224        }
225
226        Self {
227            recommended,
228            other: other_by_provider,
229        }
230    }
231
232    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
233        let mut entries = Vec::new();
234
235        if !self.recommended.is_empty() {
236            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
237            entries.extend(
238                self.recommended
239                    .iter()
240                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
241            );
242        }
243
244        for models in self.other.values() {
245            if models.is_empty() {
246                continue;
247            }
248            entries.push(LanguageModelPickerEntry::Separator(
249                models[0].model.provider_name().0,
250            ));
251            entries.extend(
252                models
253                    .iter()
254                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
255            );
256        }
257        entries
258    }
259
260    fn model_infos(&self) -> Vec<ModelInfo> {
261        let other = self
262            .other
263            .values()
264            .flat_map(|model| model.iter())
265            .cloned()
266            .collect::<Vec<_>>();
267        self.recommended
268            .iter()
269            .chain(&other)
270            .cloned()
271            .collect::<Vec<_>>()
272    }
273}
274
275enum LanguageModelPickerEntry {
276    Model(ModelInfo),
277    Separator(SharedString),
278}
279
280struct ModelMatcher {
281    models: Vec<ModelInfo>,
282    bg_executor: BackgroundExecutor,
283    candidates: Vec<StringMatchCandidate>,
284}
285
286impl ModelMatcher {
287    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
288        let candidates = Self::make_match_candidates(&models);
289        Self {
290            models,
291            bg_executor,
292            candidates,
293        }
294    }
295
296    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
297        let mut matches = self.bg_executor.block(match_strings(
298            &self.candidates,
299            query,
300            false,
301            true,
302            100,
303            &Default::default(),
304            self.bg_executor.clone(),
305        ));
306
307        let sorting_key = |mat: &StringMatch| {
308            let candidate = &self.candidates[mat.candidate_id];
309            (Reverse(OrderedFloat(mat.score)), candidate.id)
310        };
311        matches.sort_unstable_by_key(sorting_key);
312
313        let matched_models: Vec<_> = matches
314            .into_iter()
315            .map(|mat| self.models[mat.candidate_id].clone())
316            .collect();
317
318        matched_models
319    }
320
321    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
322        self.models
323            .iter()
324            .filter(|m| {
325                m.model
326                    .name()
327                    .0
328                    .to_lowercase()
329                    .contains(&query.to_lowercase())
330            })
331            .cloned()
332            .collect::<Vec<_>>()
333    }
334
335    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
336        model_infos
337            .iter()
338            .enumerate()
339            .map(|(index, model)| {
340                StringMatchCandidate::new(
341                    index,
342                    &format!(
343                        "{}/{}",
344                        &model.model.provider_name().0,
345                        &model.model.name().0
346                    ),
347                )
348            })
349            .collect::<Vec<_>>()
350    }
351}
352
353impl PickerDelegate for LanguageModelPickerDelegate {
354    type ListItem = AnyElement;
355
356    fn match_count(&self) -> usize {
357        self.filtered_entries.len()
358    }
359
360    fn selected_index(&self) -> usize {
361        self.selected_index
362    }
363
364    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
365        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
366        cx.notify();
367    }
368
369    fn can_select(
370        &mut self,
371        ix: usize,
372        _window: &mut Window,
373        _cx: &mut Context<Picker<Self>>,
374    ) -> bool {
375        match self.filtered_entries.get(ix) {
376            Some(LanguageModelPickerEntry::Model(_)) => true,
377            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
378        }
379    }
380
381    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
382        "Select a model…".into()
383    }
384
385    fn update_matches(
386        &mut self,
387        query: String,
388        window: &mut Window,
389        cx: &mut Context<Picker<Self>>,
390    ) -> Task<()> {
391        let all_models = self.all_models.clone();
392        let active_model = (self.get_active_model)(cx);
393        let bg_executor = cx.background_executor();
394
395        let language_model_registry = LanguageModelRegistry::global(cx);
396
397        let configured_providers = language_model_registry
398            .read(cx)
399            .providers()
400            .into_iter()
401            .filter(|provider| provider.is_authenticated(cx))
402            .collect::<Vec<_>>();
403
404        let configured_provider_ids = configured_providers
405            .iter()
406            .map(|provider| provider.id())
407            .collect::<Vec<_>>();
408
409        let recommended_models = all_models
410            .recommended
411            .iter()
412            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
413            .cloned()
414            .collect::<Vec<_>>();
415
416        let available_models = all_models
417            .model_infos()
418            .iter()
419            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
420            .cloned()
421            .collect::<Vec<_>>();
422
423        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
424        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
425
426        let recommended = matcher_rec.exact_search(&query);
427        let all = matcher_all.fuzzy_search(&query);
428
429        let filtered_models = GroupedModels::new(all, recommended);
430
431        cx.spawn_in(window, async move |this, cx| {
432            this.update_in(cx, |this, window, cx| {
433                this.delegate.filtered_entries = filtered_models.entries();
434                // Finds the currently selected model in the list
435                let new_index =
436                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
437                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
438                cx.notify();
439            })
440            .ok();
441        })
442    }
443
444    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
445        if let Some(LanguageModelPickerEntry::Model(model_info)) =
446            self.filtered_entries.get(self.selected_index)
447        {
448            let model = model_info.model.clone();
449            (self.on_model_changed)(model.clone(), cx);
450
451            let current_index = self.selected_index;
452            self.set_selected_index(current_index, window, cx);
453
454            cx.emit(DismissEvent);
455        }
456    }
457
458    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
459        cx.emit(DismissEvent);
460    }
461
462    fn render_match(
463        &self,
464        ix: usize,
465        selected: bool,
466        _: &mut Window,
467        cx: &mut Context<Picker<Self>>,
468    ) -> Option<Self::ListItem> {
469        match self.filtered_entries.get(ix)? {
470            LanguageModelPickerEntry::Separator(title) => Some(
471                div()
472                    .px_2()
473                    .pb_1()
474                    .when(ix > 1, |this| {
475                        this.mt_1()
476                            .pt_2()
477                            .border_t_1()
478                            .border_color(cx.theme().colors().border_variant)
479                    })
480                    .child(
481                        Label::new(title)
482                            .size(LabelSize::XSmall)
483                            .color(Color::Muted),
484                    )
485                    .into_any_element(),
486            ),
487            LanguageModelPickerEntry::Model(model_info) => {
488                let active_model = (self.get_active_model)(cx);
489                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
490                let active_model_id = active_model.map(|m| m.model.id());
491
492                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
493                    && Some(model_info.model.id()) == active_model_id;
494
495                let model_icon_color = if is_selected {
496                    Color::Accent
497                } else {
498                    Color::Muted
499                };
500
501                Some(
502                    ListItem::new(ix)
503                        .inset(true)
504                        .spacing(ListItemSpacing::Sparse)
505                        .toggle_state(selected)
506                        .start_slot(
507                            Icon::new(model_info.icon)
508                                .color(model_icon_color)
509                                .size(IconSize::Small),
510                        )
511                        .child(
512                            h_flex()
513                                .w_full()
514                                .pl_0p5()
515                                .gap_1p5()
516                                .w(px(240.))
517                                .child(Label::new(model_info.model.name().0).truncate()),
518                        )
519                        .end_slot(div().pr_3().when(is_selected, |this| {
520                            this.child(
521                                Icon::new(IconName::Check)
522                                    .color(Color::Accent)
523                                    .size(IconSize::Small),
524                            )
525                        }))
526                        .into_any_element(),
527                )
528            }
529        }
530    }
531
532    fn render_footer(
533        &self,
534        _: &mut Window,
535        cx: &mut Context<Picker<Self>>,
536    ) -> Option<gpui::AnyElement> {
537        use feature_flags::FeatureFlagAppExt;
538
539        let plan = Plan::ZedPro;
540
541        Some(
542            h_flex()
543                .w_full()
544                .border_t_1()
545                .border_color(cx.theme().colors().border_variant)
546                .p_1()
547                .gap_4()
548                .justify_between()
549                .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
550                    this.child(match plan {
551                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
552                            .icon(IconName::ZedAssistant)
553                            .icon_size(IconSize::Small)
554                            .icon_color(Color::Muted)
555                            .icon_position(IconPosition::Start)
556                            .on_click(|_, window, cx| {
557                                window
558                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
559                            }),
560                        Plan::ZedFree | Plan::ZedProTrial => Button::new(
561                            "try-pro",
562                            if plan == Plan::ZedProTrial {
563                                "Upgrade to Pro"
564                            } else {
565                                "Try Pro"
566                            },
567                        )
568                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
569                    })
570                })
571                .child(
572                    Button::new("configure", "Configure")
573                        .icon(IconName::Settings)
574                        .icon_size(IconSize::Small)
575                        .icon_color(Color::Muted)
576                        .icon_position(IconPosition::Start)
577                        .on_click(|_, window, cx| {
578                            window.dispatch_action(
579                                zed_actions::agent::OpenSettings.boxed_clone(),
580                                cx,
581                            );
582                        }),
583                )
584                .into_any(),
585        )
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592    use futures::{future::BoxFuture, stream::BoxStream};
593    use gpui::{AsyncApp, TestAppContext, http_client};
594    use language_model::{
595        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
596        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
597        LanguageModelRequest, LanguageModelToolChoice,
598    };
599    use ui::IconName;
600
601    #[derive(Clone)]
602    struct TestLanguageModel {
603        name: LanguageModelName,
604        id: LanguageModelId,
605        provider_id: LanguageModelProviderId,
606        provider_name: LanguageModelProviderName,
607    }
608
609    impl TestLanguageModel {
610        fn new(name: &str, provider: &str) -> Self {
611            Self {
612                name: LanguageModelName::from(name.to_string()),
613                id: LanguageModelId::from(name.to_string()),
614                provider_id: LanguageModelProviderId::from(provider.to_string()),
615                provider_name: LanguageModelProviderName::from(provider.to_string()),
616            }
617        }
618    }
619
620    impl LanguageModel for TestLanguageModel {
621        fn id(&self) -> LanguageModelId {
622            self.id.clone()
623        }
624
625        fn name(&self) -> LanguageModelName {
626            self.name.clone()
627        }
628
629        fn provider_id(&self) -> LanguageModelProviderId {
630            self.provider_id.clone()
631        }
632
633        fn provider_name(&self) -> LanguageModelProviderName {
634            self.provider_name.clone()
635        }
636
637        fn supports_tools(&self) -> bool {
638            false
639        }
640
641        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
642            false
643        }
644
645        fn supports_images(&self) -> bool {
646            false
647        }
648
649        fn telemetry_id(&self) -> String {
650            format!("{}/{}", self.provider_id.0, self.name.0)
651        }
652
653        fn max_token_count(&self) -> u64 {
654            1000
655        }
656
657        fn count_tokens(
658            &self,
659            _: LanguageModelRequest,
660            _: &App,
661        ) -> BoxFuture<'static, http_client::Result<u64>> {
662            unimplemented!()
663        }
664
665        fn stream_completion(
666            &self,
667            _: LanguageModelRequest,
668            _: &AsyncApp,
669        ) -> BoxFuture<
670            'static,
671            Result<
672                BoxStream<
673                    'static,
674                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
675                >,
676                LanguageModelCompletionError,
677            >,
678        > {
679            unimplemented!()
680        }
681    }
682
683    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
684        model_specs
685            .into_iter()
686            .map(|(provider, name)| ModelInfo {
687                model: Arc::new(TestLanguageModel::new(name, provider)),
688                icon: IconName::Ai,
689            })
690            .collect()
691    }
692
693    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
694        assert_eq!(
695            result.len(),
696            expected.len(),
697            "Number of models doesn't match"
698        );
699
700        for (i, expected_name) in expected.iter().enumerate() {
701            assert_eq!(
702                result[i].model.telemetry_id(),
703                *expected_name,
704                "Model at position {} doesn't match expected model",
705                i
706            );
707        }
708    }
709
710    #[gpui::test]
711    fn test_exact_match(cx: &mut TestAppContext) {
712        let models = create_models(vec![
713            ("zed", "Claude 3.7 Sonnet"),
714            ("zed", "Claude 3.7 Sonnet Thinking"),
715            ("zed", "gpt-4.1"),
716            ("zed", "gpt-4.1-nano"),
717            ("openai", "gpt-3.5-turbo"),
718            ("openai", "gpt-4.1"),
719            ("openai", "gpt-4.1-nano"),
720            ("ollama", "mistral"),
721            ("ollama", "deepseek"),
722        ]);
723        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
724
725        // The order of models should be maintained, case doesn't matter
726        let results = matcher.exact_search("GPT-4.1");
727        assert_models_eq(
728            results,
729            vec![
730                "zed/gpt-4.1",
731                "zed/gpt-4.1-nano",
732                "openai/gpt-4.1",
733                "openai/gpt-4.1-nano",
734            ],
735        );
736    }
737
738    #[gpui::test]
739    fn test_fuzzy_match(cx: &mut TestAppContext) {
740        let models = create_models(vec![
741            ("zed", "Claude 3.7 Sonnet"),
742            ("zed", "Claude 3.7 Sonnet Thinking"),
743            ("zed", "gpt-4.1"),
744            ("zed", "gpt-4.1-nano"),
745            ("openai", "gpt-3.5-turbo"),
746            ("openai", "gpt-4.1"),
747            ("openai", "gpt-4.1-nano"),
748            ("ollama", "mistral"),
749            ("ollama", "deepseek"),
750        ]);
751        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
752
753        // Results should preserve models order whenever possible.
754        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
755        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
756        // so it should appear first in the results.
757        let results = matcher.fuzzy_search("41");
758        assert_models_eq(
759            results,
760            vec![
761                "zed/gpt-4.1",
762                "openai/gpt-4.1",
763                "zed/gpt-4.1-nano",
764                "openai/gpt-4.1-nano",
765            ],
766        );
767
768        // Model provider should be searchable as well
769        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
770        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
771
772        // Fuzzy search
773        let results = matcher.fuzzy_search("z4n");
774        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
775    }
776
777    #[gpui::test]
778    fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
779        let recommended_models = create_models(vec![("zed", "claude")]);
780        let all_models = create_models(vec![
781            ("zed", "claude"), // Should be filtered out from "other"
782            ("zed", "gemini"),
783            ("copilot", "o3"),
784        ]);
785
786        let grouped_models = GroupedModels::new(all_models, recommended_models);
787
788        let actual_other_models = grouped_models
789            .other
790            .values()
791            .flatten()
792            .cloned()
793            .collect::<Vec<_>>();
794
795        // Recommended models should not appear in "other"
796        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
797    }
798
799    #[gpui::test]
800    fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
801        let recommended_models = create_models(vec![("zed", "claude")]);
802        let all_models = create_models(vec![
803            ("zed", "claude"), // Should be filtered out from "other"
804            ("zed", "gemini"),
805            ("copilot", "claude"), // Should not be filtered out from "other"
806        ]);
807
808        let grouped_models = GroupedModels::new(all_models, recommended_models);
809
810        let actual_other_models = grouped_models
811            .other
812            .values()
813            .flatten()
814            .cloned()
815            .collect::<Vec<_>>();
816
817        // Recommended models should not appear in "other"
818        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
819    }
820}