language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::IndexMap;
  4use futures::{StreamExt, channel::mpsc};
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
  7use language_model::{
  8    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
  9    LanguageModelProviderId, LanguageModelRegistry,
 10};
 11use ordered_float::OrderedFloat;
 12use picker::{Picker, PickerDelegate};
 13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
 14use zed_actions::agent::OpenSettings;
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 18
 19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 20
 21pub fn language_model_selector(
 22    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 23    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 24    popover_styles: bool,
 25    focus_handle: FocusHandle,
 26    window: &mut Window,
 27    cx: &mut Context<LanguageModelSelector>,
 28) -> LanguageModelSelector {
 29    let delegate = LanguageModelPickerDelegate::new(
 30        get_active_model,
 31        on_model_changed,
 32        popover_styles,
 33        focus_handle,
 34        window,
 35        cx,
 36    );
 37
 38    if popover_styles {
 39        Picker::list(delegate, window, cx)
 40            .show_scrollbar(true)
 41            .width(rems(20.))
 42            .max_height(Some(rems(20.).into()))
 43    } else {
 44        Picker::list(delegate, window, cx).show_scrollbar(true)
 45    }
 46}
 47
 48fn all_models(cx: &App) -> GroupedModels {
 49    let providers = LanguageModelRegistry::global(cx).read(cx).providers();
 50
 51    let recommended = providers
 52        .iter()
 53        .flat_map(|provider| {
 54            provider
 55                .recommended_models(cx)
 56                .into_iter()
 57                .map(|model| ModelInfo {
 58                    model,
 59                    icon: ProviderIcon::from_provider(provider.as_ref()),
 60                })
 61        })
 62        .collect();
 63
 64    let all: Vec<ModelInfo> = providers
 65        .iter()
 66        .flat_map(|provider| {
 67            provider
 68                .provided_models(cx)
 69                .into_iter()
 70                .map(|model| ModelInfo {
 71                    model,
 72                    icon: ProviderIcon::from_provider(provider.as_ref()),
 73                })
 74        })
 75        .collect();
 76
 77    GroupedModels::new(all, recommended)
 78}
 79
 80#[derive(Clone)]
 81enum ProviderIcon {
 82    Name(IconName),
 83    Path(SharedString),
 84}
 85
 86impl ProviderIcon {
 87    fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
 88        if let Some(path) = provider.icon_path() {
 89            Self::Path(path)
 90        } else {
 91            Self::Name(provider.icon())
 92        }
 93    }
 94}
 95
 96#[derive(Clone)]
 97struct ModelInfo {
 98    model: Arc<dyn LanguageModel>,
 99    icon: ProviderIcon,
100}
101
102pub struct LanguageModelPickerDelegate {
103    on_model_changed: OnModelChanged,
104    get_active_model: GetActiveModel,
105    all_models: Arc<GroupedModels>,
106    filtered_entries: Vec<LanguageModelPickerEntry>,
107    selected_index: usize,
108    _authenticate_all_providers_task: Task<()>,
109    _refresh_models_task: Task<()>,
110    popover_styles: bool,
111    focus_handle: FocusHandle,
112}
113
114impl LanguageModelPickerDelegate {
115    fn new(
116        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
117        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
118        popover_styles: bool,
119        focus_handle: FocusHandle,
120        window: &mut Window,
121        cx: &mut Context<Picker<Self>>,
122    ) -> Self {
123        let on_model_changed = Arc::new(on_model_changed);
124        let models = all_models(cx);
125        let entries = models.entries();
126
127        Self {
128            on_model_changed,
129            all_models: Arc::new(models),
130            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
131            filtered_entries: entries,
132            get_active_model: Arc::new(get_active_model),
133            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
134            _refresh_models_task: {
135                // Create a channel to signal when models need refreshing
136                let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
137
138                // Subscribe to registry events and send refresh signals through the channel
139                let registry = LanguageModelRegistry::global(cx);
140                cx.subscribe(&registry, move |_picker, _, event, _cx| match event {
141                    language_model::Event::ProviderStateChanged(_) => {
142                        refresh_tx.unbounded_send(()).ok();
143                    }
144                    language_model::Event::AddedProvider(_) => {
145                        refresh_tx.unbounded_send(()).ok();
146                    }
147                    language_model::Event::RemovedProvider(_) => {
148                        refresh_tx.unbounded_send(()).ok();
149                    }
150                    _ => {}
151                })
152                .detach();
153
154                // Spawn a task that listens for refresh signals and updates the picker
155                cx.spawn_in(window, async move |this, cx| {
156                    while let Some(()) = refresh_rx.next().await {
157                        let result = this.update_in(cx, |picker, window, cx| {
158                            picker.delegate.all_models = Arc::new(all_models(cx));
159                            picker.refresh(window, cx);
160                        });
161                        if result.is_err() {
162                            // Picker was dropped, exit the loop
163                            break;
164                        }
165                    }
166                })
167            },
168            popover_styles,
169            focus_handle,
170        }
171    }
172
173    fn get_active_model_index(
174        entries: &[LanguageModelPickerEntry],
175        active_model: Option<ConfiguredModel>,
176    ) -> usize {
177        entries
178            .iter()
179            .position(|entry| {
180                if let LanguageModelPickerEntry::Model(model) = entry {
181                    active_model
182                        .as_ref()
183                        .map(|active_model| {
184                            active_model.model.id() == model.model.id()
185                                && active_model.provider.id() == model.model.provider_id()
186                        })
187                        .unwrap_or_default()
188                } else {
189                    false
190                }
191            })
192            .unwrap_or(0)
193    }
194
195    /// Authenticates all providers in the [`LanguageModelRegistry`].
196    ///
197    /// We do this so that we can populate the language selector with all of the
198    /// models from the configured providers.
199    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
200        let authenticate_all_providers = LanguageModelRegistry::global(cx)
201            .read(cx)
202            .providers()
203            .iter()
204            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
205            .collect::<Vec<_>>();
206
207        cx.spawn(async move |_cx| {
208            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
209                if let Err(err) = authenticate_task.await {
210                    if matches!(err, AuthenticateError::CredentialsNotFound) {
211                        // Since we're authenticating these providers in the
212                        // background for the purposes of populating the
213                        // language selector, we don't care about providers
214                        // where the credentials are not found.
215                    } else {
216                        // Some providers have noisy failure states that we
217                        // don't want to spam the logs with every time the
218                        // language model selector is initialized.
219                        //
220                        // Ideally these should have more clear failure modes
221                        // that we know are safe to ignore here, like what we do
222                        // with `CredentialsNotFound` above.
223                        match provider_id.0.as_ref() {
224                            "lmstudio" | "ollama" => {
225                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
226                                //
227                                // These fail noisily, so we don't log them.
228                            }
229                            "copilot_chat" => {
230                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
231                            }
232                            _ => {
233                                log::error!(
234                                    "Failed to authenticate provider: {}: {err:#}",
235                                    provider_name.0
236                                );
237                            }
238                        }
239                    }
240                }
241            }
242        })
243    }
244
245    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
246        (self.get_active_model)(cx)
247    }
248}
249
250struct GroupedModels {
251    recommended: Vec<ModelInfo>,
252    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
253}
254
255impl GroupedModels {
256    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
257        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
258        for model in all {
259            let provider = model.model.provider_id();
260            if let Some(models) = all_by_provider.get_mut(&provider) {
261                models.push(model);
262            } else {
263                all_by_provider.insert(provider, vec![model]);
264            }
265        }
266
267        Self {
268            recommended,
269            all: all_by_provider,
270        }
271    }
272
273    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
274        let mut entries = Vec::new();
275
276        if !self.recommended.is_empty() {
277            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
278            entries.extend(
279                self.recommended
280                    .iter()
281                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
282            );
283        }
284
285        for models in self.all.values() {
286            if models.is_empty() {
287                continue;
288            }
289            entries.push(LanguageModelPickerEntry::Separator(
290                models[0].model.provider_name().0,
291            ));
292            entries.extend(
293                models
294                    .iter()
295                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
296            );
297        }
298        entries
299    }
300}
301
302enum LanguageModelPickerEntry {
303    Model(ModelInfo),
304    Separator(SharedString),
305}
306
307struct ModelMatcher {
308    models: Vec<ModelInfo>,
309    bg_executor: BackgroundExecutor,
310    candidates: Vec<StringMatchCandidate>,
311}
312
313impl ModelMatcher {
314    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
315        let candidates = Self::make_match_candidates(&models);
316        Self {
317            models,
318            bg_executor,
319            candidates,
320        }
321    }
322
323    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
324        let mut matches = self.bg_executor.block(match_strings(
325            &self.candidates,
326            query,
327            false,
328            true,
329            100,
330            &Default::default(),
331            self.bg_executor.clone(),
332        ));
333
334        let sorting_key = |mat: &StringMatch| {
335            let candidate = &self.candidates[mat.candidate_id];
336            (Reverse(OrderedFloat(mat.score)), candidate.id)
337        };
338        matches.sort_unstable_by_key(sorting_key);
339
340        let matched_models: Vec<_> = matches
341            .into_iter()
342            .map(|mat| self.models[mat.candidate_id].clone())
343            .collect();
344
345        matched_models
346    }
347
348    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
349        self.models
350            .iter()
351            .filter(|m| {
352                m.model
353                    .name()
354                    .0
355                    .to_lowercase()
356                    .contains(&query.to_lowercase())
357            })
358            .cloned()
359            .collect::<Vec<_>>()
360    }
361
362    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
363        model_infos
364            .iter()
365            .enumerate()
366            .map(|(index, model)| {
367                StringMatchCandidate::new(
368                    index,
369                    &format!(
370                        "{}/{}",
371                        &model.model.provider_name().0,
372                        &model.model.name().0
373                    ),
374                )
375            })
376            .collect::<Vec<_>>()
377    }
378}
379
380impl PickerDelegate for LanguageModelPickerDelegate {
381    type ListItem = AnyElement;
382
383    fn match_count(&self) -> usize {
384        self.filtered_entries.len()
385    }
386
387    fn selected_index(&self) -> usize {
388        self.selected_index
389    }
390
391    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
392        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
393        cx.notify();
394    }
395
396    fn can_select(
397        &mut self,
398        ix: usize,
399        _window: &mut Window,
400        _cx: &mut Context<Picker<Self>>,
401    ) -> bool {
402        match self.filtered_entries.get(ix) {
403            Some(LanguageModelPickerEntry::Model(_)) => true,
404            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
405        }
406    }
407
408    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
409        "Select a model…".into()
410    }
411
412    fn update_matches(
413        &mut self,
414        query: String,
415        window: &mut Window,
416        cx: &mut Context<Picker<Self>>,
417    ) -> Task<()> {
418        let all_models = self.all_models.clone();
419        let active_model = (self.get_active_model)(cx);
420        let bg_executor = cx.background_executor();
421
422        let language_model_registry = LanguageModelRegistry::global(cx);
423
424        let configured_providers = language_model_registry
425            .read(cx)
426            .providers()
427            .into_iter()
428            .filter(|provider| provider.is_authenticated(cx))
429            .collect::<Vec<_>>();
430
431        let configured_provider_ids = configured_providers
432            .iter()
433            .map(|provider| provider.id())
434            .collect::<Vec<_>>();
435
436        let recommended_models = all_models
437            .recommended
438            .iter()
439            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
440            .cloned()
441            .collect::<Vec<_>>();
442
443        let available_models = all_models
444            .all
445            .values()
446            .flat_map(|models| models.iter())
447            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
448            .cloned()
449            .collect::<Vec<_>>();
450
451        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
452        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
453
454        let recommended = matcher_rec.exact_search(&query);
455        let all = matcher_all.fuzzy_search(&query);
456
457        let filtered_models = GroupedModels::new(all, recommended);
458
459        cx.spawn_in(window, async move |this, cx| {
460            this.update_in(cx, |this, window, cx| {
461                this.delegate.filtered_entries = filtered_models.entries();
462                // Finds the currently selected model in the list
463                let new_index =
464                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
465                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
466                cx.notify();
467            })
468            .ok();
469        })
470    }
471
472    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
473        if let Some(LanguageModelPickerEntry::Model(model_info)) =
474            self.filtered_entries.get(self.selected_index)
475        {
476            let model = model_info.model.clone();
477            (self.on_model_changed)(model.clone(), cx);
478
479            let current_index = self.selected_index;
480            self.set_selected_index(current_index, window, cx);
481
482            cx.emit(DismissEvent);
483        }
484    }
485
486    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
487        cx.emit(DismissEvent);
488    }
489
490    fn render_match(
491        &self,
492        ix: usize,
493        selected: bool,
494        _: &mut Window,
495        cx: &mut Context<Picker<Self>>,
496    ) -> Option<Self::ListItem> {
497        match self.filtered_entries.get(ix)? {
498            LanguageModelPickerEntry::Separator(title) => Some(
499                div()
500                    .px_2()
501                    .pb_1()
502                    .when(ix > 1, |this| {
503                        this.mt_1()
504                            .pt_2()
505                            .border_t_1()
506                            .border_color(cx.theme().colors().border_variant)
507                    })
508                    .child(
509                        Label::new(title)
510                            .size(LabelSize::XSmall)
511                            .color(Color::Muted),
512                    )
513                    .into_any_element(),
514            ),
515            LanguageModelPickerEntry::Model(model_info) => {
516                let active_model = (self.get_active_model)(cx);
517                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
518                let active_model_id = active_model.map(|m| m.model.id());
519
520                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
521                    && Some(model_info.model.id()) == active_model_id;
522
523                let model_icon_color = if is_selected {
524                    Color::Accent
525                } else {
526                    Color::Muted
527                };
528
529                Some(
530                    ListItem::new(ix)
531                        .inset(true)
532                        .spacing(ListItemSpacing::Sparse)
533                        .toggle_state(selected)
534                        .child(
535                            h_flex()
536                                .w_full()
537                                .gap_1p5()
538                                .child(match &model_info.icon {
539                                    ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
540                                        .color(model_icon_color)
541                                        .size(IconSize::Small),
542                                    ProviderIcon::Path(icon_path) => {
543                                        Icon::from_external_svg(icon_path.clone())
544                                            .color(model_icon_color)
545                                            .size(IconSize::Small)
546                                    }
547                                })
548                                .child(Label::new(model_info.model.name().0).truncate()),
549                        )
550                        .end_slot(div().pr_3().when(is_selected, |this| {
551                            this.child(
552                                Icon::new(IconName::Check)
553                                    .color(Color::Accent)
554                                    .size(IconSize::Small),
555                            )
556                        }))
557                        .into_any_element(),
558                )
559            }
560        }
561    }
562
563    fn render_footer(
564        &self,
565        _window: &mut Window,
566        cx: &mut Context<Picker<Self>>,
567    ) -> Option<gpui::AnyElement> {
568        let focus_handle = self.focus_handle.clone();
569
570        if !self.popover_styles {
571            return None;
572        }
573
574        Some(
575            h_flex()
576                .w_full()
577                .p_1p5()
578                .border_t_1()
579                .border_color(cx.theme().colors().border_variant)
580                .child(
581                    Button::new("configure", "Configure")
582                        .full_width()
583                        .style(ButtonStyle::Outlined)
584                        .key_binding(
585                            KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
586                                .map(|kb| kb.size(rems_from_px(12.))),
587                        )
588                        .on_click(|_, window, cx| {
589                            window.dispatch_action(OpenSettings.boxed_clone(), cx);
590                        }),
591                )
592                .into_any(),
593        )
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600    use futures::{future::BoxFuture, stream::BoxStream};
601    use gpui::{AsyncApp, TestAppContext, http_client};
602    use language_model::{
603        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
604        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
605        LanguageModelRequest, LanguageModelToolChoice,
606    };
607    use ui::IconName;
608
609    #[derive(Clone)]
610    struct TestLanguageModel {
611        name: LanguageModelName,
612        id: LanguageModelId,
613        provider_id: LanguageModelProviderId,
614        provider_name: LanguageModelProviderName,
615    }
616
617    impl TestLanguageModel {
618        fn new(name: &str, provider: &str) -> Self {
619            Self {
620                name: LanguageModelName::from(name.to_string()),
621                id: LanguageModelId::from(name.to_string()),
622                provider_id: LanguageModelProviderId::from(provider.to_string()),
623                provider_name: LanguageModelProviderName::from(provider.to_string()),
624            }
625        }
626    }
627
628    impl LanguageModel for TestLanguageModel {
629        fn id(&self) -> LanguageModelId {
630            self.id.clone()
631        }
632
633        fn name(&self) -> LanguageModelName {
634            self.name.clone()
635        }
636
637        fn provider_id(&self) -> LanguageModelProviderId {
638            self.provider_id.clone()
639        }
640
641        fn provider_name(&self) -> LanguageModelProviderName {
642            self.provider_name.clone()
643        }
644
645        fn supports_tools(&self) -> bool {
646            false
647        }
648
649        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
650            false
651        }
652
653        fn supports_images(&self) -> bool {
654            false
655        }
656
657        fn telemetry_id(&self) -> String {
658            format!("{}/{}", self.provider_id.0, self.name.0)
659        }
660
661        fn max_token_count(&self) -> u64 {
662            1000
663        }
664
665        fn count_tokens(
666            &self,
667            _: LanguageModelRequest,
668            _: &App,
669        ) -> BoxFuture<'static, http_client::Result<u64>> {
670            unimplemented!()
671        }
672
673        fn stream_completion(
674            &self,
675            _: LanguageModelRequest,
676            _: &AsyncApp,
677        ) -> BoxFuture<
678            'static,
679            Result<
680                BoxStream<
681                    'static,
682                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
683                >,
684                LanguageModelCompletionError,
685            >,
686        > {
687            unimplemented!()
688        }
689    }
690
691    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
692        model_specs
693            .into_iter()
694            .map(|(provider, name)| ModelInfo {
695                model: Arc::new(TestLanguageModel::new(name, provider)),
696                icon: ProviderIcon::Name(IconName::Ai),
697            })
698            .collect()
699    }
700
701    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
702        assert_eq!(
703            result.len(),
704            expected.len(),
705            "Number of models doesn't match"
706        );
707
708        for (i, expected_name) in expected.iter().enumerate() {
709            assert_eq!(
710                result[i].model.telemetry_id(),
711                *expected_name,
712                "Model at position {} doesn't match expected model",
713                i
714            );
715        }
716    }
717
718    #[gpui::test]
719    fn test_exact_match(cx: &mut TestAppContext) {
720        let models = create_models(vec![
721            ("zed", "Claude 3.7 Sonnet"),
722            ("zed", "Claude 3.7 Sonnet Thinking"),
723            ("zed", "gpt-4.1"),
724            ("zed", "gpt-4.1-nano"),
725            ("openai", "gpt-3.5-turbo"),
726            ("openai", "gpt-4.1"),
727            ("openai", "gpt-4.1-nano"),
728            ("ollama", "mistral"),
729            ("ollama", "deepseek"),
730        ]);
731        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
732
733        // The order of models should be maintained, case doesn't matter
734        let results = matcher.exact_search("GPT-4.1");
735        assert_models_eq(
736            results,
737            vec![
738                "zed/gpt-4.1",
739                "zed/gpt-4.1-nano",
740                "openai/gpt-4.1",
741                "openai/gpt-4.1-nano",
742            ],
743        );
744    }
745
746    #[gpui::test]
747    fn test_fuzzy_match(cx: &mut TestAppContext) {
748        let models = create_models(vec![
749            ("zed", "Claude 3.7 Sonnet"),
750            ("zed", "Claude 3.7 Sonnet Thinking"),
751            ("zed", "gpt-4.1"),
752            ("zed", "gpt-4.1-nano"),
753            ("openai", "gpt-3.5-turbo"),
754            ("openai", "gpt-4.1"),
755            ("openai", "gpt-4.1-nano"),
756            ("ollama", "mistral"),
757            ("ollama", "deepseek"),
758        ]);
759        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
760
761        // Results should preserve models order whenever possible.
762        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
763        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
764        // so it should appear first in the results.
765        let results = matcher.fuzzy_search("41");
766        assert_models_eq(
767            results,
768            vec![
769                "zed/gpt-4.1",
770                "openai/gpt-4.1",
771                "zed/gpt-4.1-nano",
772                "openai/gpt-4.1-nano",
773            ],
774        );
775
776        // Model provider should be searchable as well
777        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
778        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
779
780        // Fuzzy search
781        let results = matcher.fuzzy_search("z4n");
782        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
783    }
784
785    #[gpui::test]
786    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
787        let recommended_models = create_models(vec![("zed", "claude")]);
788        let all_models = create_models(vec![
789            ("zed", "claude"), // Should also appear in "other"
790            ("zed", "gemini"),
791            ("copilot", "o3"),
792        ]);
793
794        let grouped_models = GroupedModels::new(all_models, recommended_models);
795
796        let actual_all_models = grouped_models
797            .all
798            .values()
799            .flatten()
800            .cloned()
801            .collect::<Vec<_>>();
802
803        // Recommended models should also appear in "all"
804        assert_models_eq(
805            actual_all_models,
806            vec!["zed/claude", "zed/gemini", "copilot/o3"],
807        );
808    }
809
810    #[gpui::test]
811    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
812        let recommended_models = create_models(vec![("zed", "claude")]);
813        let all_models = create_models(vec![
814            ("zed", "claude"), // Should also appear in "other"
815            ("zed", "gemini"),
816            ("copilot", "claude"), // Different provider, should appear in "other"
817        ]);
818
819        let grouped_models = GroupedModels::new(all_models, recommended_models);
820
821        let actual_all_models = grouped_models
822            .all
823            .values()
824            .flatten()
825            .cloned()
826            .collect::<Vec<_>>();
827
828        // All models should appear in "all" regardless of recommended status
829        assert_models_eq(
830            actual_all_models,
831            vec!["zed/claude", "zed/gemini", "copilot/claude"],
832        );
833    }
834}