language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::{HashSet, IndexMap};
  4use feature_flags::ZedProFeatureFlag;
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{
  7    Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task,
  8    action_with_deprecated_aliases,
  9};
 10use language_model::{
 11    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 12    LanguageModelRegistry,
 13};
 14use ordered_float::OrderedFloat;
 15use picker::{Picker, PickerDelegate};
 16use proto::Plan;
 17use ui::{ListItem, ListItemSpacing, prelude::*};
 18
 19action_with_deprecated_aliases!(
 20    agent,
 21    ToggleModelSelector,
 22    [
 23        "assistant::ToggleModelSelector",
 24        "assistant2::ToggleModelSelector"
 25    ]
 26);
 27
 28const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 29
 30type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 31type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 32
 33pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 34
 35pub fn language_model_selector(
 36    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 37    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 38    window: &mut Window,
 39    cx: &mut Context<LanguageModelSelector>,
 40) -> LanguageModelSelector {
 41    let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
 42    Picker::list(delegate, window, cx)
 43        .show_scrollbar(true)
 44        .width(rems(20.))
 45        .max_height(Some(rems(20.).into()))
 46}
 47
 48fn all_models(cx: &App) -> GroupedModels {
 49    let providers = LanguageModelRegistry::global(cx).read(cx).providers();
 50
 51    let recommended = providers
 52        .iter()
 53        .flat_map(|provider| {
 54            provider
 55                .recommended_models(cx)
 56                .into_iter()
 57                .map(|model| ModelInfo {
 58                    model,
 59                    icon: provider.icon(),
 60                })
 61        })
 62        .collect();
 63
 64    let other = providers
 65        .iter()
 66        .flat_map(|provider| {
 67            provider
 68                .provided_models(cx)
 69                .into_iter()
 70                .map(|model| ModelInfo {
 71                    model,
 72                    icon: provider.icon(),
 73                })
 74        })
 75        .collect();
 76
 77    GroupedModels::new(other, recommended)
 78}
 79
 80#[derive(Clone)]
 81struct ModelInfo {
 82    model: Arc<dyn LanguageModel>,
 83    icon: IconName,
 84}
 85
 86pub struct LanguageModelPickerDelegate {
 87    on_model_changed: OnModelChanged,
 88    get_active_model: GetActiveModel,
 89    all_models: Arc<GroupedModels>,
 90    filtered_entries: Vec<LanguageModelPickerEntry>,
 91    selected_index: usize,
 92    _authenticate_all_providers_task: Task<()>,
 93    _subscriptions: Vec<Subscription>,
 94}
 95
 96impl LanguageModelPickerDelegate {
 97    fn new(
 98        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 99        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
100        window: &mut Window,
101        cx: &mut Context<Picker<Self>>,
102    ) -> Self {
103        let on_model_changed = Arc::new(on_model_changed);
104        let models = all_models(cx);
105        let entries = models.entries();
106
107        Self {
108            on_model_changed: on_model_changed.clone(),
109            all_models: Arc::new(models),
110            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
111            filtered_entries: entries,
112            get_active_model: Arc::new(get_active_model),
113            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
114            _subscriptions: vec![cx.subscribe_in(
115                &LanguageModelRegistry::global(cx),
116                window,
117                |picker, _, event, window, cx| {
118                    match event {
119                        language_model::Event::ProviderStateChanged
120                        | language_model::Event::AddedProvider(_)
121                        | language_model::Event::RemovedProvider(_) => {
122                            let query = picker.query(cx);
123                            picker.delegate.all_models = Arc::new(all_models(cx));
124                            // Update matches will automatically drop the previous task
125                            // if we get a provider event again
126                            picker.update_matches(query, window, cx)
127                        }
128                        _ => {}
129                    }
130                },
131            )],
132        }
133    }
134
135    fn get_active_model_index(
136        entries: &[LanguageModelPickerEntry],
137        active_model: Option<ConfiguredModel>,
138    ) -> usize {
139        entries
140            .iter()
141            .position(|entry| {
142                if let LanguageModelPickerEntry::Model(model) = entry {
143                    active_model
144                        .as_ref()
145                        .map(|active_model| {
146                            active_model.model.id() == model.model.id()
147                                && active_model.provider.id() == model.model.provider_id()
148                        })
149                        .unwrap_or_default()
150                } else {
151                    false
152                }
153            })
154            .unwrap_or(0)
155    }
156
157    /// Authenticates all providers in the [`LanguageModelRegistry`].
158    ///
159    /// We do this so that we can populate the language selector with all of the
160    /// models from the configured providers.
161    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
162        let authenticate_all_providers = LanguageModelRegistry::global(cx)
163            .read(cx)
164            .providers()
165            .iter()
166            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
167            .collect::<Vec<_>>();
168
169        cx.spawn(async move |_cx| {
170            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
171                if let Err(err) = authenticate_task.await {
172                    if matches!(err, AuthenticateError::CredentialsNotFound) {
173                        // Since we're authenticating these providers in the
174                        // background for the purposes of populating the
175                        // language selector, we don't care about providers
176                        // where the credentials are not found.
177                    } else {
178                        // Some providers have noisy failure states that we
179                        // don't want to spam the logs with every time the
180                        // language model selector is initialized.
181                        //
182                        // Ideally these should have more clear failure modes
183                        // that we know are safe to ignore here, like what we do
184                        // with `CredentialsNotFound` above.
185                        match provider_id.0.as_ref() {
186                            "lmstudio" | "ollama" => {
187                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
188                                //
189                                // These fail noisily, so we don't log them.
190                            }
191                            "copilot_chat" => {
192                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
193                            }
194                            _ => {
195                                log::error!(
196                                    "Failed to authenticate provider: {}: {err}",
197                                    provider_name.0
198                                );
199                            }
200                        }
201                    }
202                }
203            }
204        })
205    }
206
207    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
208        (self.get_active_model)(cx)
209    }
210}
211
212struct GroupedModels {
213    recommended: Vec<ModelInfo>,
214    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
215}
216
217impl GroupedModels {
218    pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
219        let recommended_ids = recommended
220            .iter()
221            .map(|info| (info.model.provider_id(), info.model.id()))
222            .collect::<HashSet<_>>();
223
224        let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
225        for model in other {
226            if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
227                continue;
228            }
229
230            let provider = model.model.provider_id();
231            if let Some(models) = other_by_provider.get_mut(&provider) {
232                models.push(model);
233            } else {
234                other_by_provider.insert(provider, vec![model]);
235            }
236        }
237
238        Self {
239            recommended,
240            other: other_by_provider,
241        }
242    }
243
244    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
245        let mut entries = Vec::new();
246
247        if !self.recommended.is_empty() {
248            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
249            entries.extend(
250                self.recommended
251                    .iter()
252                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
253            );
254        }
255
256        for models in self.other.values() {
257            if models.is_empty() {
258                continue;
259            }
260            entries.push(LanguageModelPickerEntry::Separator(
261                models[0].model.provider_name().0,
262            ));
263            entries.extend(
264                models
265                    .iter()
266                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
267            );
268        }
269        entries
270    }
271
272    fn model_infos(&self) -> Vec<ModelInfo> {
273        let other = self
274            .other
275            .values()
276            .flat_map(|model| model.iter())
277            .cloned()
278            .collect::<Vec<_>>();
279        self.recommended
280            .iter()
281            .chain(&other)
282            .cloned()
283            .collect::<Vec<_>>()
284    }
285}
286
287enum LanguageModelPickerEntry {
288    Model(ModelInfo),
289    Separator(SharedString),
290}
291
292struct ModelMatcher {
293    models: Vec<ModelInfo>,
294    bg_executor: BackgroundExecutor,
295    candidates: Vec<StringMatchCandidate>,
296}
297
298impl ModelMatcher {
299    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
300        let candidates = Self::make_match_candidates(&models);
301        Self {
302            models,
303            bg_executor,
304            candidates,
305        }
306    }
307
308    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
309        let mut matches = self.bg_executor.block(match_strings(
310            &self.candidates,
311            &query,
312            false,
313            true,
314            100,
315            &Default::default(),
316            self.bg_executor.clone(),
317        ));
318
319        let sorting_key = |mat: &StringMatch| {
320            let candidate = &self.candidates[mat.candidate_id];
321            (Reverse(OrderedFloat(mat.score)), candidate.id)
322        };
323        matches.sort_unstable_by_key(sorting_key);
324
325        let matched_models: Vec<_> = matches
326            .into_iter()
327            .map(|mat| self.models[mat.candidate_id].clone())
328            .collect();
329
330        matched_models
331    }
332
333    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
334        self.models
335            .iter()
336            .filter(|m| {
337                m.model
338                    .name()
339                    .0
340                    .to_lowercase()
341                    .contains(&query.to_lowercase())
342            })
343            .cloned()
344            .collect::<Vec<_>>()
345    }
346
347    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
348        model_infos
349            .iter()
350            .enumerate()
351            .map(|(index, model)| {
352                StringMatchCandidate::new(
353                    index,
354                    &format!(
355                        "{}/{}",
356                        &model.model.provider_name().0,
357                        &model.model.name().0
358                    ),
359                )
360            })
361            .collect::<Vec<_>>()
362    }
363}
364
365impl PickerDelegate for LanguageModelPickerDelegate {
366    type ListItem = AnyElement;
367
368    fn match_count(&self) -> usize {
369        self.filtered_entries.len()
370    }
371
372    fn selected_index(&self) -> usize {
373        self.selected_index
374    }
375
376    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
377        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
378        cx.notify();
379    }
380
381    fn can_select(
382        &mut self,
383        ix: usize,
384        _window: &mut Window,
385        _cx: &mut Context<Picker<Self>>,
386    ) -> bool {
387        match self.filtered_entries.get(ix) {
388            Some(LanguageModelPickerEntry::Model(_)) => true,
389            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
390        }
391    }
392
393    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
394        "Select a model…".into()
395    }
396
397    fn update_matches(
398        &mut self,
399        query: String,
400        window: &mut Window,
401        cx: &mut Context<Picker<Self>>,
402    ) -> Task<()> {
403        let all_models = self.all_models.clone();
404        let current_index = self.selected_index;
405        let bg_executor = cx.background_executor();
406
407        let language_model_registry = LanguageModelRegistry::global(cx);
408
409        let configured_providers = language_model_registry
410            .read(cx)
411            .providers()
412            .into_iter()
413            .filter(|provider| provider.is_authenticated(cx))
414            .collect::<Vec<_>>();
415
416        let configured_provider_ids = configured_providers
417            .iter()
418            .map(|provider| provider.id())
419            .collect::<Vec<_>>();
420
421        let recommended_models = all_models
422            .recommended
423            .iter()
424            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
425            .cloned()
426            .collect::<Vec<_>>();
427
428        let available_models = all_models
429            .model_infos()
430            .iter()
431            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
432            .cloned()
433            .collect::<Vec<_>>();
434
435        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
436        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
437
438        let recommended = matcher_rec.exact_search(&query);
439        let all = matcher_all.fuzzy_search(&query);
440
441        let filtered_models = GroupedModels::new(all, recommended);
442
443        cx.spawn_in(window, async move |this, cx| {
444            this.update_in(cx, |this, window, cx| {
445                this.delegate.filtered_entries = filtered_models.entries();
446                // Preserve selection focus
447                let new_index = if current_index >= this.delegate.filtered_entries.len() {
448                    0
449                } else {
450                    current_index
451                };
452                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
453                cx.notify();
454            })
455            .ok();
456        })
457    }
458
459    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
460        if let Some(LanguageModelPickerEntry::Model(model_info)) =
461            self.filtered_entries.get(self.selected_index)
462        {
463            let model = model_info.model.clone();
464            (self.on_model_changed)(model.clone(), cx);
465
466            let current_index = self.selected_index;
467            self.set_selected_index(current_index, window, cx);
468
469            cx.emit(DismissEvent);
470        }
471    }
472
473    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
474        cx.emit(DismissEvent);
475    }
476
477    fn render_match(
478        &self,
479        ix: usize,
480        selected: bool,
481        _: &mut Window,
482        cx: &mut Context<Picker<Self>>,
483    ) -> Option<Self::ListItem> {
484        match self.filtered_entries.get(ix)? {
485            LanguageModelPickerEntry::Separator(title) => Some(
486                div()
487                    .px_2()
488                    .pb_1()
489                    .when(ix > 1, |this| {
490                        this.mt_1()
491                            .pt_2()
492                            .border_t_1()
493                            .border_color(cx.theme().colors().border_variant)
494                    })
495                    .child(
496                        Label::new(title)
497                            .size(LabelSize::XSmall)
498                            .color(Color::Muted),
499                    )
500                    .into_any_element(),
501            ),
502            LanguageModelPickerEntry::Model(model_info) => {
503                let active_model = (self.get_active_model)(cx);
504                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
505                let active_model_id = active_model.map(|m| m.model.id());
506
507                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
508                    && Some(model_info.model.id()) == active_model_id;
509
510                let model_icon_color = if is_selected {
511                    Color::Accent
512                } else {
513                    Color::Muted
514                };
515
516                Some(
517                    ListItem::new(ix)
518                        .inset(true)
519                        .spacing(ListItemSpacing::Sparse)
520                        .toggle_state(selected)
521                        .start_slot(
522                            Icon::new(model_info.icon)
523                                .color(model_icon_color)
524                                .size(IconSize::Small),
525                        )
526                        .child(
527                            h_flex()
528                                .w_full()
529                                .pl_0p5()
530                                .gap_1p5()
531                                .w(px(240.))
532                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
533                        )
534                        .end_slot(div().pr_3().when(is_selected, |this| {
535                            this.child(
536                                Icon::new(IconName::Check)
537                                    .color(Color::Accent)
538                                    .size(IconSize::Small),
539                            )
540                        }))
541                        .into_any_element(),
542                )
543            }
544        }
545    }
546
547    fn render_footer(
548        &self,
549        _: &mut Window,
550        cx: &mut Context<Picker<Self>>,
551    ) -> Option<gpui::AnyElement> {
552        use feature_flags::FeatureFlagAppExt;
553
554        let plan = proto::Plan::ZedPro;
555
556        Some(
557            h_flex()
558                .w_full()
559                .border_t_1()
560                .border_color(cx.theme().colors().border_variant)
561                .p_1()
562                .gap_4()
563                .justify_between()
564                .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
565                    this.child(match plan {
566                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
567                            .icon(IconName::ZedAssistant)
568                            .icon_size(IconSize::Small)
569                            .icon_color(Color::Muted)
570                            .icon_position(IconPosition::Start)
571                            .on_click(|_, window, cx| {
572                                window
573                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
574                            }),
575                        Plan::Free | Plan::ZedProTrial => Button::new(
576                            "try-pro",
577                            if plan == Plan::ZedProTrial {
578                                "Upgrade to Pro"
579                            } else {
580                                "Try Pro"
581                            },
582                        )
583                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
584                    })
585                })
586                .child(
587                    Button::new("configure", "Configure")
588                        .icon(IconName::Settings)
589                        .icon_size(IconSize::Small)
590                        .icon_color(Color::Muted)
591                        .icon_position(IconPosition::Start)
592                        .on_click(|_, window, cx| {
593                            window.dispatch_action(
594                                zed_actions::agent::OpenConfiguration.boxed_clone(),
595                                cx,
596                            );
597                        }),
598                )
599                .into_any(),
600        )
601    }
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607    use futures::{future::BoxFuture, stream::BoxStream};
608    use gpui::{AsyncApp, TestAppContext, http_client};
609    use language_model::{
610        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
611        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
612        LanguageModelRequest, LanguageModelToolChoice,
613    };
614    use ui::IconName;
615
616    #[derive(Clone)]
617    struct TestLanguageModel {
618        name: LanguageModelName,
619        id: LanguageModelId,
620        provider_id: LanguageModelProviderId,
621        provider_name: LanguageModelProviderName,
622    }
623
624    impl TestLanguageModel {
625        fn new(name: &str, provider: &str) -> Self {
626            Self {
627                name: LanguageModelName::from(name.to_string()),
628                id: LanguageModelId::from(name.to_string()),
629                provider_id: LanguageModelProviderId::from(provider.to_string()),
630                provider_name: LanguageModelProviderName::from(provider.to_string()),
631            }
632        }
633    }
634
635    impl LanguageModel for TestLanguageModel {
636        fn id(&self) -> LanguageModelId {
637            self.id.clone()
638        }
639
640        fn name(&self) -> LanguageModelName {
641            self.name.clone()
642        }
643
644        fn provider_id(&self) -> LanguageModelProviderId {
645            self.provider_id.clone()
646        }
647
648        fn provider_name(&self) -> LanguageModelProviderName {
649            self.provider_name.clone()
650        }
651
652        fn supports_tools(&self) -> bool {
653            false
654        }
655
656        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
657            false
658        }
659
660        fn supports_images(&self) -> bool {
661            false
662        }
663
664        fn telemetry_id(&self) -> String {
665            format!("{}/{}", self.provider_id.0, self.name.0)
666        }
667
668        fn max_token_count(&self) -> u64 {
669            1000
670        }
671
672        fn count_tokens(
673            &self,
674            _: LanguageModelRequest,
675            _: &App,
676        ) -> BoxFuture<'static, http_client::Result<u64>> {
677            unimplemented!()
678        }
679
680        fn stream_completion(
681            &self,
682            _: LanguageModelRequest,
683            _: &AsyncApp,
684        ) -> BoxFuture<
685            'static,
686            Result<
687                BoxStream<
688                    'static,
689                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
690                >,
691                LanguageModelCompletionError,
692            >,
693        > {
694            unimplemented!()
695        }
696    }
697
698    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
699        model_specs
700            .into_iter()
701            .map(|(provider, name)| ModelInfo {
702                model: Arc::new(TestLanguageModel::new(name, provider)),
703                icon: IconName::Ai,
704            })
705            .collect()
706    }
707
708    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
709        assert_eq!(
710            result.len(),
711            expected.len(),
712            "Number of models doesn't match"
713        );
714
715        for (i, expected_name) in expected.iter().enumerate() {
716            assert_eq!(
717                result[i].model.telemetry_id(),
718                *expected_name,
719                "Model at position {} doesn't match expected model",
720                i
721            );
722        }
723    }
724
725    #[gpui::test]
726    fn test_exact_match(cx: &mut TestAppContext) {
727        let models = create_models(vec![
728            ("zed", "Claude 3.7 Sonnet"),
729            ("zed", "Claude 3.7 Sonnet Thinking"),
730            ("zed", "gpt-4.1"),
731            ("zed", "gpt-4.1-nano"),
732            ("openai", "gpt-3.5-turbo"),
733            ("openai", "gpt-4.1"),
734            ("openai", "gpt-4.1-nano"),
735            ("ollama", "mistral"),
736            ("ollama", "deepseek"),
737        ]);
738        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
739
740        // The order of models should be maintained, case doesn't matter
741        let results = matcher.exact_search("GPT-4.1");
742        assert_models_eq(
743            results,
744            vec![
745                "zed/gpt-4.1",
746                "zed/gpt-4.1-nano",
747                "openai/gpt-4.1",
748                "openai/gpt-4.1-nano",
749            ],
750        );
751    }
752
753    #[gpui::test]
754    fn test_fuzzy_match(cx: &mut TestAppContext) {
755        let models = create_models(vec![
756            ("zed", "Claude 3.7 Sonnet"),
757            ("zed", "Claude 3.7 Sonnet Thinking"),
758            ("zed", "gpt-4.1"),
759            ("zed", "gpt-4.1-nano"),
760            ("openai", "gpt-3.5-turbo"),
761            ("openai", "gpt-4.1"),
762            ("openai", "gpt-4.1-nano"),
763            ("ollama", "mistral"),
764            ("ollama", "deepseek"),
765        ]);
766        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
767
768        // Results should preserve models order whenever possible.
769        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
770        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
771        // so it should appear first in the results.
772        let results = matcher.fuzzy_search("41");
773        assert_models_eq(
774            results,
775            vec![
776                "zed/gpt-4.1",
777                "openai/gpt-4.1",
778                "zed/gpt-4.1-nano",
779                "openai/gpt-4.1-nano",
780            ],
781        );
782
783        // Model provider should be searchable as well
784        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
785        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
786
787        // Fuzzy search
788        let results = matcher.fuzzy_search("z4n");
789        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
790    }
791
792    #[gpui::test]
793    fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
794        let recommended_models = create_models(vec![("zed", "claude")]);
795        let all_models = create_models(vec![
796            ("zed", "claude"), // Should be filtered out from "other"
797            ("zed", "gemini"),
798            ("copilot", "o3"),
799        ]);
800
801        let grouped_models = GroupedModels::new(all_models, recommended_models);
802
803        let actual_other_models = grouped_models
804            .other
805            .values()
806            .flatten()
807            .cloned()
808            .collect::<Vec<_>>();
809
810        // Recommended models should not appear in "other"
811        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
812    }
813
814    #[gpui::test]
815    fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
816        let recommended_models = create_models(vec![("zed", "claude")]);
817        let all_models = create_models(vec![
818            ("zed", "claude"), // Should be filtered out from "other"
819            ("zed", "gemini"),
820            ("copilot", "claude"), // Should not be filtered out from "other"
821        ]);
822
823        let grouped_models = GroupedModels::new(all_models, recommended_models);
824
825        let actual_other_models = grouped_models
826            .other
827            .values()
828            .flatten()
829            .cloned()
830            .collect::<Vec<_>>();
831
832        // Recommended models should not appear in "other"
833        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
834    }
835}