language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::{HashSet, IndexMap};
  4use feature_flags::ZedProFeatureFlag;
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{
  7    Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task,
  8    action_with_deprecated_aliases,
  9};
 10use language_model::{
 11    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 12    LanguageModelRegistry,
 13};
 14use ordered_float::OrderedFloat;
 15use picker::{Picker, PickerDelegate};
 16use proto::Plan;
 17use ui::{ListItem, ListItemSpacing, prelude::*};
 18
 19action_with_deprecated_aliases!(
 20    agent,
 21    ToggleModelSelector,
 22    [
 23        "assistant::ToggleModelSelector",
 24        "assistant2::ToggleModelSelector"
 25    ]
 26);
 27
 28const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 29
 30type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 31type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 32
 33pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 34
 35pub fn language_model_selector(
 36    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 37    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 38    window: &mut Window,
 39    cx: &mut Context<LanguageModelSelector>,
 40) -> LanguageModelSelector {
 41    let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
 42    Picker::list(delegate, window, cx)
 43        .show_scrollbar(true)
 44        .width(rems(20.))
 45        .max_height(Some(rems(20.).into()))
 46}
 47
 48fn all_models(cx: &App) -> GroupedModels {
 49    let mut recommended = Vec::new();
 50    let mut recommended_set = HashSet::default();
 51    for provider in LanguageModelRegistry::global(cx)
 52        .read(cx)
 53        .providers()
 54        .iter()
 55    {
 56        let models = provider.recommended_models(cx);
 57        recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
 58        recommended.extend(
 59            provider
 60                .recommended_models(cx)
 61                .into_iter()
 62                .map(move |model| ModelInfo {
 63                    model: model.clone(),
 64                    icon: provider.icon(),
 65                }),
 66        );
 67    }
 68
 69    let other_models = LanguageModelRegistry::global(cx)
 70        .read(cx)
 71        .providers()
 72        .iter()
 73        .map(|provider| {
 74            (
 75                provider.id(),
 76                provider
 77                    .provided_models(cx)
 78                    .into_iter()
 79                    .filter_map(|model| {
 80                        let not_included =
 81                            !recommended_set.contains(&(model.provider_id(), model.id()));
 82                        not_included.then(|| ModelInfo {
 83                            model: model.clone(),
 84                            icon: provider.icon(),
 85                        })
 86                    })
 87                    .collect::<Vec<_>>(),
 88            )
 89        })
 90        .collect::<IndexMap<_, _>>();
 91
 92    GroupedModels {
 93        recommended,
 94        other: other_models,
 95    }
 96}
 97
 98#[derive(Clone)]
 99struct ModelInfo {
100    model: Arc<dyn LanguageModel>,
101    icon: IconName,
102}
103
104pub struct LanguageModelPickerDelegate {
105    on_model_changed: OnModelChanged,
106    get_active_model: GetActiveModel,
107    all_models: Arc<GroupedModels>,
108    filtered_entries: Vec<LanguageModelPickerEntry>,
109    selected_index: usize,
110    _authenticate_all_providers_task: Task<()>,
111    _subscriptions: Vec<Subscription>,
112}
113
114impl LanguageModelPickerDelegate {
115    fn new(
116        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
117        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
118        window: &mut Window,
119        cx: &mut Context<Picker<Self>>,
120    ) -> Self {
121        let on_model_changed = Arc::new(on_model_changed);
122        let models = all_models(cx);
123        let entries = models.entries();
124
125        Self {
126            on_model_changed: on_model_changed.clone(),
127            all_models: Arc::new(models),
128            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
129            filtered_entries: entries,
130            get_active_model: Arc::new(get_active_model),
131            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
132            _subscriptions: vec![cx.subscribe_in(
133                &LanguageModelRegistry::global(cx),
134                window,
135                |picker, _, event, window, cx| {
136                    match event {
137                        language_model::Event::ProviderStateChanged
138                        | language_model::Event::AddedProvider(_)
139                        | language_model::Event::RemovedProvider(_) => {
140                            let query = picker.query(cx);
141                            picker.delegate.all_models = Arc::new(all_models(cx));
142                            // Update matches will automatically drop the previous task
143                            // if we get a provider event again
144                            picker.update_matches(query, window, cx)
145                        }
146                        _ => {}
147                    }
148                },
149            )],
150        }
151    }
152
153    fn get_active_model_index(
154        entries: &[LanguageModelPickerEntry],
155        active_model: Option<ConfiguredModel>,
156    ) -> usize {
157        entries
158            .iter()
159            .position(|entry| {
160                if let LanguageModelPickerEntry::Model(model) = entry {
161                    active_model
162                        .as_ref()
163                        .map(|active_model| {
164                            active_model.model.id() == model.model.id()
165                                && active_model.provider.id() == model.model.provider_id()
166                        })
167                        .unwrap_or_default()
168                } else {
169                    false
170                }
171            })
172            .unwrap_or(0)
173    }
174
175    /// Authenticates all providers in the [`LanguageModelRegistry`].
176    ///
177    /// We do this so that we can populate the language selector with all of the
178    /// models from the configured providers.
179    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
180        let authenticate_all_providers = LanguageModelRegistry::global(cx)
181            .read(cx)
182            .providers()
183            .iter()
184            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
185            .collect::<Vec<_>>();
186
187        cx.spawn(async move |_cx| {
188            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
189                if let Err(err) = authenticate_task.await {
190                    if matches!(err, AuthenticateError::CredentialsNotFound) {
191                        // Since we're authenticating these providers in the
192                        // background for the purposes of populating the
193                        // language selector, we don't care about providers
194                        // where the credentials are not found.
195                    } else {
196                        // Some providers have noisy failure states that we
197                        // don't want to spam the logs with every time the
198                        // language model selector is initialized.
199                        //
200                        // Ideally these should have more clear failure modes
201                        // that we know are safe to ignore here, like what we do
202                        // with `CredentialsNotFound` above.
203                        match provider_id.0.as_ref() {
204                            "lmstudio" | "ollama" => {
205                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
206                                //
207                                // These fail noisily, so we don't log them.
208                            }
209                            "copilot_chat" => {
210                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
211                            }
212                            _ => {
213                                log::error!(
214                                    "Failed to authenticate provider: {}: {err}",
215                                    provider_name.0
216                                );
217                            }
218                        }
219                    }
220                }
221            }
222        })
223    }
224
225    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
226        (self.get_active_model)(cx)
227    }
228}
229
230struct GroupedModels {
231    recommended: Vec<ModelInfo>,
232    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
233}
234
235impl GroupedModels {
236    pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
237        let recommended_ids: HashSet<_> = recommended.iter().map(|info| info.model.id()).collect();
238
239        let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
240        for model in other {
241            if recommended_ids.contains(&model.model.id()) {
242                continue;
243            }
244
245            let provider = model.model.provider_id();
246            if let Some(models) = other_by_provider.get_mut(&provider) {
247                models.push(model);
248            } else {
249                other_by_provider.insert(provider, vec![model]);
250            }
251        }
252
253        Self {
254            recommended,
255            other: other_by_provider,
256        }
257    }
258
259    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
260        let mut entries = Vec::new();
261
262        if !self.recommended.is_empty() {
263            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
264            entries.extend(
265                self.recommended
266                    .iter()
267                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
268            );
269        }
270
271        for models in self.other.values() {
272            if models.is_empty() {
273                continue;
274            }
275            entries.push(LanguageModelPickerEntry::Separator(
276                models[0].model.provider_name().0,
277            ));
278            entries.extend(
279                models
280                    .iter()
281                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
282            );
283        }
284        entries
285    }
286
287    fn model_infos(&self) -> Vec<ModelInfo> {
288        let other = self
289            .other
290            .values()
291            .flat_map(|model| model.iter())
292            .cloned()
293            .collect::<Vec<_>>();
294        self.recommended
295            .iter()
296            .chain(&other)
297            .cloned()
298            .collect::<Vec<_>>()
299    }
300}
301
302enum LanguageModelPickerEntry {
303    Model(ModelInfo),
304    Separator(SharedString),
305}
306
307struct ModelMatcher {
308    models: Vec<ModelInfo>,
309    bg_executor: BackgroundExecutor,
310    candidates: Vec<StringMatchCandidate>,
311}
312
313impl ModelMatcher {
314    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
315        let candidates = Self::make_match_candidates(&models);
316        Self {
317            models,
318            bg_executor,
319            candidates,
320        }
321    }
322
323    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
324        let mut matches = self.bg_executor.block(match_strings(
325            &self.candidates,
326            &query,
327            false,
328            100,
329            &Default::default(),
330            self.bg_executor.clone(),
331        ));
332
333        let sorting_key = |mat: &StringMatch| {
334            let candidate = &self.candidates[mat.candidate_id];
335            (Reverse(OrderedFloat(mat.score)), candidate.id)
336        };
337        matches.sort_unstable_by_key(sorting_key);
338
339        let matched_models: Vec<_> = matches
340            .into_iter()
341            .map(|mat| self.models[mat.candidate_id].clone())
342            .collect();
343
344        matched_models
345    }
346
347    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
348        self.models
349            .iter()
350            .filter(|m| {
351                m.model
352                    .name()
353                    .0
354                    .to_lowercase()
355                    .contains(&query.to_lowercase())
356            })
357            .cloned()
358            .collect::<Vec<_>>()
359    }
360
361    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
362        model_infos
363            .iter()
364            .enumerate()
365            .map(|(index, model)| {
366                StringMatchCandidate::new(
367                    index,
368                    &format!(
369                        "{}/{}",
370                        &model.model.provider_name().0,
371                        &model.model.name().0
372                    ),
373                )
374            })
375            .collect::<Vec<_>>()
376    }
377}
378
379impl PickerDelegate for LanguageModelPickerDelegate {
380    type ListItem = AnyElement;
381
382    fn match_count(&self) -> usize {
383        self.filtered_entries.len()
384    }
385
386    fn selected_index(&self) -> usize {
387        self.selected_index
388    }
389
390    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
391        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
392        cx.notify();
393    }
394
395    fn can_select(
396        &mut self,
397        ix: usize,
398        _window: &mut Window,
399        _cx: &mut Context<Picker<Self>>,
400    ) -> bool {
401        match self.filtered_entries.get(ix) {
402            Some(LanguageModelPickerEntry::Model(_)) => true,
403            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
404        }
405    }
406
407    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
408        "Select a model…".into()
409    }
410
411    fn update_matches(
412        &mut self,
413        query: String,
414        window: &mut Window,
415        cx: &mut Context<Picker<Self>>,
416    ) -> Task<()> {
417        let all_models = self.all_models.clone();
418        let current_index = self.selected_index;
419        let bg_executor = cx.background_executor();
420
421        let language_model_registry = LanguageModelRegistry::global(cx);
422
423        let configured_providers = language_model_registry
424            .read(cx)
425            .providers()
426            .into_iter()
427            .filter(|provider| provider.is_authenticated(cx))
428            .collect::<Vec<_>>();
429
430        let configured_provider_ids = configured_providers
431            .iter()
432            .map(|provider| provider.id())
433            .collect::<Vec<_>>();
434
435        let recommended_models = all_models
436            .recommended
437            .iter()
438            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
439            .cloned()
440            .collect::<Vec<_>>();
441
442        let available_models = all_models
443            .model_infos()
444            .iter()
445            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
446            .cloned()
447            .collect::<Vec<_>>();
448
449        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
450        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
451
452        let recommended = matcher_rec.exact_search(&query);
453        let all = matcher_all.fuzzy_search(&query);
454
455        let filtered_models = GroupedModels::new(all, recommended);
456
457        cx.spawn_in(window, async move |this, cx| {
458            this.update_in(cx, |this, window, cx| {
459                this.delegate.filtered_entries = filtered_models.entries();
460                // Preserve selection focus
461                let new_index = if current_index >= this.delegate.filtered_entries.len() {
462                    0
463                } else {
464                    current_index
465                };
466                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
467                cx.notify();
468            })
469            .ok();
470        })
471    }
472
473    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
474        if let Some(LanguageModelPickerEntry::Model(model_info)) =
475            self.filtered_entries.get(self.selected_index)
476        {
477            let model = model_info.model.clone();
478            (self.on_model_changed)(model.clone(), cx);
479
480            let current_index = self.selected_index;
481            self.set_selected_index(current_index, window, cx);
482
483            cx.emit(DismissEvent);
484        }
485    }
486
487    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
488        cx.emit(DismissEvent);
489    }
490
491    fn render_match(
492        &self,
493        ix: usize,
494        selected: bool,
495        _: &mut Window,
496        cx: &mut Context<Picker<Self>>,
497    ) -> Option<Self::ListItem> {
498        match self.filtered_entries.get(ix)? {
499            LanguageModelPickerEntry::Separator(title) => Some(
500                div()
501                    .px_2()
502                    .pb_1()
503                    .when(ix > 1, |this| {
504                        this.mt_1()
505                            .pt_2()
506                            .border_t_1()
507                            .border_color(cx.theme().colors().border_variant)
508                    })
509                    .child(
510                        Label::new(title)
511                            .size(LabelSize::XSmall)
512                            .color(Color::Muted),
513                    )
514                    .into_any_element(),
515            ),
516            LanguageModelPickerEntry::Model(model_info) => {
517                let active_model = (self.get_active_model)(cx);
518                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
519                let active_model_id = active_model.map(|m| m.model.id());
520
521                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
522                    && Some(model_info.model.id()) == active_model_id;
523
524                let model_icon_color = if is_selected {
525                    Color::Accent
526                } else {
527                    Color::Muted
528                };
529
530                Some(
531                    ListItem::new(ix)
532                        .inset(true)
533                        .spacing(ListItemSpacing::Sparse)
534                        .toggle_state(selected)
535                        .start_slot(
536                            Icon::new(model_info.icon)
537                                .color(model_icon_color)
538                                .size(IconSize::Small),
539                        )
540                        .child(
541                            h_flex()
542                                .w_full()
543                                .pl_0p5()
544                                .gap_1p5()
545                                .w(px(240.))
546                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
547                        )
548                        .end_slot(div().pr_3().when(is_selected, |this| {
549                            this.child(
550                                Icon::new(IconName::Check)
551                                    .color(Color::Accent)
552                                    .size(IconSize::Small),
553                            )
554                        }))
555                        .into_any_element(),
556                )
557            }
558        }
559    }
560
561    fn render_footer(
562        &self,
563        _: &mut Window,
564        cx: &mut Context<Picker<Self>>,
565    ) -> Option<gpui::AnyElement> {
566        use feature_flags::FeatureFlagAppExt;
567
568        let plan = proto::Plan::ZedPro;
569
570        Some(
571            h_flex()
572                .w_full()
573                .border_t_1()
574                .border_color(cx.theme().colors().border_variant)
575                .p_1()
576                .gap_4()
577                .justify_between()
578                .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
579                    this.child(match plan {
580                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
581                            .icon(IconName::ZedAssistant)
582                            .icon_size(IconSize::Small)
583                            .icon_color(Color::Muted)
584                            .icon_position(IconPosition::Start)
585                            .on_click(|_, window, cx| {
586                                window
587                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
588                            }),
589                        Plan::Free | Plan::ZedProTrial => Button::new(
590                            "try-pro",
591                            if plan == Plan::ZedProTrial {
592                                "Upgrade to Pro"
593                            } else {
594                                "Try Pro"
595                            },
596                        )
597                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
598                    })
599                })
600                .child(
601                    Button::new("configure", "Configure")
602                        .icon(IconName::Settings)
603                        .icon_size(IconSize::Small)
604                        .icon_color(Color::Muted)
605                        .icon_position(IconPosition::Start)
606                        .on_click(|_, window, cx| {
607                            window.dispatch_action(
608                                zed_actions::agent::OpenConfiguration.boxed_clone(),
609                                cx,
610                            );
611                        }),
612                )
613                .into_any(),
614        )
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621    use futures::{future::BoxFuture, stream::BoxStream};
622    use gpui::{AsyncApp, TestAppContext, http_client};
623    use language_model::{
624        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
625        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
626        LanguageModelRequest, LanguageModelToolChoice,
627    };
628    use ui::IconName;
629
630    #[derive(Clone)]
631    struct TestLanguageModel {
632        name: LanguageModelName,
633        id: LanguageModelId,
634        provider_id: LanguageModelProviderId,
635        provider_name: LanguageModelProviderName,
636    }
637
638    impl TestLanguageModel {
639        fn new(name: &str, provider: &str) -> Self {
640            Self {
641                name: LanguageModelName::from(name.to_string()),
642                id: LanguageModelId::from(name.to_string()),
643                provider_id: LanguageModelProviderId::from(provider.to_string()),
644                provider_name: LanguageModelProviderName::from(provider.to_string()),
645            }
646        }
647    }
648
649    impl LanguageModel for TestLanguageModel {
650        fn id(&self) -> LanguageModelId {
651            self.id.clone()
652        }
653
654        fn name(&self) -> LanguageModelName {
655            self.name.clone()
656        }
657
658        fn provider_id(&self) -> LanguageModelProviderId {
659            self.provider_id.clone()
660        }
661
662        fn provider_name(&self) -> LanguageModelProviderName {
663            self.provider_name.clone()
664        }
665
666        fn supports_tools(&self) -> bool {
667            false
668        }
669
670        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
671            false
672        }
673
674        fn supports_images(&self) -> bool {
675            false
676        }
677
678        fn telemetry_id(&self) -> String {
679            format!("{}/{}", self.provider_id.0, self.name.0)
680        }
681
682        fn max_token_count(&self) -> usize {
683            1000
684        }
685
686        fn count_tokens(
687            &self,
688            _: LanguageModelRequest,
689            _: &App,
690        ) -> BoxFuture<'static, http_client::Result<usize>> {
691            unimplemented!()
692        }
693
694        fn stream_completion(
695            &self,
696            _: LanguageModelRequest,
697            _: &AsyncApp,
698        ) -> BoxFuture<
699            'static,
700            http_client::Result<
701                BoxStream<
702                    'static,
703                    http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
704                >,
705            >,
706        > {
707            unimplemented!()
708        }
709    }
710
711    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
712        model_specs
713            .into_iter()
714            .map(|(provider, name)| ModelInfo {
715                model: Arc::new(TestLanguageModel::new(name, provider)),
716                icon: IconName::Ai,
717            })
718            .collect()
719    }
720
721    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
722        assert_eq!(
723            result.len(),
724            expected.len(),
725            "Number of models doesn't match"
726        );
727
728        for (i, expected_name) in expected.iter().enumerate() {
729            assert_eq!(
730                result[i].model.telemetry_id(),
731                *expected_name,
732                "Model at position {} doesn't match expected model",
733                i
734            );
735        }
736    }
737
738    #[gpui::test]
739    fn test_exact_match(cx: &mut TestAppContext) {
740        let models = create_models(vec![
741            ("zed", "Claude 3.7 Sonnet"),
742            ("zed", "Claude 3.7 Sonnet Thinking"),
743            ("zed", "gpt-4.1"),
744            ("zed", "gpt-4.1-nano"),
745            ("openai", "gpt-3.5-turbo"),
746            ("openai", "gpt-4.1"),
747            ("openai", "gpt-4.1-nano"),
748            ("ollama", "mistral"),
749            ("ollama", "deepseek"),
750        ]);
751        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
752
753        // The order of models should be maintained, case doesn't matter
754        let results = matcher.exact_search("GPT-4.1");
755        assert_models_eq(
756            results,
757            vec![
758                "zed/gpt-4.1",
759                "zed/gpt-4.1-nano",
760                "openai/gpt-4.1",
761                "openai/gpt-4.1-nano",
762            ],
763        );
764    }
765
766    #[gpui::test]
767    fn test_fuzzy_match(cx: &mut TestAppContext) {
768        let models = create_models(vec![
769            ("zed", "Claude 3.7 Sonnet"),
770            ("zed", "Claude 3.7 Sonnet Thinking"),
771            ("zed", "gpt-4.1"),
772            ("zed", "gpt-4.1-nano"),
773            ("openai", "gpt-3.5-turbo"),
774            ("openai", "gpt-4.1"),
775            ("openai", "gpt-4.1-nano"),
776            ("ollama", "mistral"),
777            ("ollama", "deepseek"),
778        ]);
779        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
780
781        // Results should preserve models order whenever possible.
782        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
783        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
784        // so it should appear first in the results.
785        let results = matcher.fuzzy_search("41");
786        assert_models_eq(
787            results,
788            vec![
789                "zed/gpt-4.1",
790                "openai/gpt-4.1",
791                "zed/gpt-4.1-nano",
792                "openai/gpt-4.1-nano",
793            ],
794        );
795
796        // Model provider should be searchable as well
797        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
798        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
799
800        // Fuzzy search
801        let results = matcher.fuzzy_search("z4n");
802        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
803    }
804
805    #[gpui::test]
806    fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
807        let recommended_models = create_models(vec![("zed", "claude")]);
808        let all_models = create_models(vec![
809            ("zed", "claude"), // Should be filtered out from "other"
810            ("zed", "gemini"),
811            ("copilot", "o3"),
812        ]);
813
814        let grouped_models = GroupedModels::new(all_models, recommended_models);
815
816        let actual_other_models = grouped_models
817            .other
818            .values()
819            .flatten()
820            .cloned()
821            .collect::<Vec<_>>();
822
823        // Recommended models should not appear in "other"
824        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
825    }
826}