language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::IndexMap;
  4use futures::{StreamExt, channel::mpsc};
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
  7use language_model::{
  8    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
  9    LanguageModelProviderId, LanguageModelRegistry,
 10};
 11use ordered_float::OrderedFloat;
 12use picker::{Picker, PickerDelegate};
 13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
 14use zed_actions::agent::OpenSettings;
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 18
 19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 20
 21pub fn language_model_selector(
 22    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 23    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 24    popover_styles: bool,
 25    focus_handle: FocusHandle,
 26    window: &mut Window,
 27    cx: &mut Context<LanguageModelSelector>,
 28) -> LanguageModelSelector {
 29    let delegate = LanguageModelPickerDelegate::new(
 30        get_active_model,
 31        on_model_changed,
 32        popover_styles,
 33        focus_handle,
 34        window,
 35        cx,
 36    );
 37
 38    if popover_styles {
 39        Picker::list(delegate, window, cx)
 40            .show_scrollbar(true)
 41            .width(rems(20.))
 42            .max_height(Some(rems(20.).into()))
 43    } else {
 44        Picker::list(delegate, window, cx).show_scrollbar(true)
 45    }
 46}
 47
 48fn all_models(cx: &App) -> GroupedModels {
 49    let providers = LanguageModelRegistry::global(cx)
 50        .read(cx)
 51        .visible_providers();
 52
 53    let recommended = providers
 54        .iter()
 55        .flat_map(|provider| {
 56            provider
 57                .recommended_models(cx)
 58                .into_iter()
 59                .map(|model| ModelInfo {
 60                    model,
 61                    icon: ProviderIcon::from_provider(provider.as_ref()),
 62                })
 63        })
 64        .collect();
 65
 66    let all: Vec<ModelInfo> = providers
 67        .iter()
 68        .flat_map(|provider| {
 69            provider
 70                .provided_models(cx)
 71                .into_iter()
 72                .map(|model| ModelInfo {
 73                    model,
 74                    icon: ProviderIcon::from_provider(provider.as_ref()),
 75                })
 76        })
 77        .collect();
 78
 79    GroupedModels::new(all, recommended)
 80}
 81
 82#[derive(Clone)]
 83enum ProviderIcon {
 84    Name(IconName),
 85    Path(SharedString),
 86}
 87
 88impl ProviderIcon {
 89    fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
 90        if let Some(path) = provider.icon_path() {
 91            Self::Path(path)
 92        } else {
 93            Self::Name(provider.icon())
 94        }
 95    }
 96}
 97
 98#[derive(Clone)]
 99struct ModelInfo {
100    model: Arc<dyn LanguageModel>,
101    icon: ProviderIcon,
102}
103
104pub struct LanguageModelPickerDelegate {
105    on_model_changed: OnModelChanged,
106    get_active_model: GetActiveModel,
107    all_models: Arc<GroupedModels>,
108    filtered_entries: Vec<LanguageModelPickerEntry>,
109    selected_index: usize,
110    _authenticate_all_providers_task: Task<()>,
111    _refresh_models_task: Task<()>,
112    popover_styles: bool,
113    focus_handle: FocusHandle,
114}
115
116impl LanguageModelPickerDelegate {
117    fn new(
118        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
119        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
120        popover_styles: bool,
121        focus_handle: FocusHandle,
122        window: &mut Window,
123        cx: &mut Context<Picker<Self>>,
124    ) -> Self {
125        let on_model_changed = Arc::new(on_model_changed);
126        let models = all_models(cx);
127        let entries = models.entries();
128
129        Self {
130            on_model_changed,
131            all_models: Arc::new(models),
132            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
133            filtered_entries: entries,
134            get_active_model: Arc::new(get_active_model),
135            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
136            _refresh_models_task: {
137                // Create a channel to signal when models need refreshing
138                let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
139
140                // Subscribe to registry events and send refresh signals through the channel
141                let registry = LanguageModelRegistry::global(cx);
142                cx.subscribe(&registry, move |_picker, _, event, _cx| match event {
143                    language_model::Event::ProviderStateChanged(_)
144                    | language_model::Event::AddedProvider(_)
145                    | language_model::Event::RemovedProvider(_)
146                    | language_model::Event::ProvidersChanged => {
147                        refresh_tx.unbounded_send(()).ok();
148                    }
149                    language_model::Event::DefaultModelChanged
150                    | language_model::Event::InlineAssistantModelChanged
151                    | language_model::Event::CommitMessageModelChanged
152                    | language_model::Event::ThreadSummaryModelChanged => {}
153                })
154                .detach();
155
156                // Spawn a task that listens for refresh signals and updates the picker
157                cx.spawn_in(window, async move |this, cx| {
158                    while let Some(()) = refresh_rx.next().await {
159                        if this
160                            .update_in(cx, |picker, window, cx| {
161                                picker.delegate.all_models = Arc::new(all_models(cx));
162                                picker.refresh(window, cx);
163                            })
164                            .is_err()
165                        {
166                            // Picker was dropped, exit the loop
167                            break;
168                        }
169                    }
170                })
171            },
172            popover_styles,
173            focus_handle,
174        }
175    }
176
177    fn get_active_model_index(
178        entries: &[LanguageModelPickerEntry],
179        active_model: Option<ConfiguredModel>,
180    ) -> usize {
181        entries
182            .iter()
183            .position(|entry| {
184                if let LanguageModelPickerEntry::Model(model) = entry {
185                    active_model
186                        .as_ref()
187                        .map(|active_model| {
188                            active_model.model.id() == model.model.id()
189                                && active_model.provider.id() == model.model.provider_id()
190                        })
191                        .unwrap_or_default()
192                } else {
193                    false
194                }
195            })
196            .unwrap_or(0)
197    }
198
199    /// Authenticates all providers in the [`LanguageModelRegistry`].
200    ///
201    /// We do this so that we can populate the language selector with all of the
202    /// models from the configured providers.
203    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
204        let authenticate_all_providers = LanguageModelRegistry::global(cx)
205            .read(cx)
206            .providers()
207            .iter()
208            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
209            .collect::<Vec<_>>();
210
211        cx.spawn(async move |_cx| {
212            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
213                if let Err(err) = authenticate_task.await {
214                    if matches!(err, AuthenticateError::CredentialsNotFound) {
215                        // Since we're authenticating these providers in the
216                        // background for the purposes of populating the
217                        // language selector, we don't care about providers
218                        // where the credentials are not found.
219                    } else {
220                        // Some providers have noisy failure states that we
221                        // don't want to spam the logs with every time the
222                        // language model selector is initialized.
223                        //
224                        // Ideally these should have more clear failure modes
225                        // that we know are safe to ignore here, like what we do
226                        // with `CredentialsNotFound` above.
227                        match provider_id.0.as_ref() {
228                            "lmstudio" | "ollama" => {
229                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
230                                //
231                                // These fail noisily, so we don't log them.
232                            }
233                            "copilot_chat" => {
234                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
235                            }
236                            _ => {
237                                log::error!(
238                                    "Failed to authenticate provider: {}: {err:#}",
239                                    provider_name.0
240                                );
241                            }
242                        }
243                    }
244                }
245            }
246        })
247    }
248
249    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
250        (self.get_active_model)(cx)
251    }
252}
253
254struct GroupedModels {
255    recommended: Vec<ModelInfo>,
256    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
257}
258
259impl GroupedModels {
260    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
261        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
262        for model in all {
263            let provider = model.model.provider_id();
264            if let Some(models) = all_by_provider.get_mut(&provider) {
265                models.push(model);
266            } else {
267                all_by_provider.insert(provider, vec![model]);
268            }
269        }
270
271        Self {
272            recommended,
273            all: all_by_provider,
274        }
275    }
276
277    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
278        let mut entries = Vec::new();
279
280        if !self.recommended.is_empty() {
281            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
282            entries.extend(
283                self.recommended
284                    .iter()
285                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
286            );
287        }
288
289        for models in self.all.values() {
290            if models.is_empty() {
291                continue;
292            }
293            entries.push(LanguageModelPickerEntry::Separator(
294                models[0].model.provider_name().0,
295            ));
296            entries.extend(
297                models
298                    .iter()
299                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
300            );
301        }
302        entries
303    }
304}
305
306enum LanguageModelPickerEntry {
307    Model(ModelInfo),
308    Separator(SharedString),
309}
310
311struct ModelMatcher {
312    models: Vec<ModelInfo>,
313    bg_executor: BackgroundExecutor,
314    candidates: Vec<StringMatchCandidate>,
315}
316
317impl ModelMatcher {
318    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
319        let candidates = Self::make_match_candidates(&models);
320        Self {
321            models,
322            bg_executor,
323            candidates,
324        }
325    }
326
327    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
328        let mut matches = self.bg_executor.block(match_strings(
329            &self.candidates,
330            query,
331            false,
332            true,
333            100,
334            &Default::default(),
335            self.bg_executor.clone(),
336        ));
337
338        let sorting_key = |mat: &StringMatch| {
339            let candidate = &self.candidates[mat.candidate_id];
340            (Reverse(OrderedFloat(mat.score)), candidate.id)
341        };
342        matches.sort_unstable_by_key(sorting_key);
343
344        let matched_models: Vec<_> = matches
345            .into_iter()
346            .map(|mat| self.models[mat.candidate_id].clone())
347            .collect();
348
349        matched_models
350    }
351
352    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
353        self.models
354            .iter()
355            .filter(|m| {
356                m.model
357                    .name()
358                    .0
359                    .to_lowercase()
360                    .contains(&query.to_lowercase())
361            })
362            .cloned()
363            .collect::<Vec<_>>()
364    }
365
366    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
367        model_infos
368            .iter()
369            .enumerate()
370            .map(|(index, model)| {
371                StringMatchCandidate::new(
372                    index,
373                    &format!(
374                        "{}/{}",
375                        &model.model.provider_name().0,
376                        &model.model.name().0
377                    ),
378                )
379            })
380            .collect::<Vec<_>>()
381    }
382}
383
384impl PickerDelegate for LanguageModelPickerDelegate {
385    type ListItem = AnyElement;
386
387    fn match_count(&self) -> usize {
388        self.filtered_entries.len()
389    }
390
391    fn selected_index(&self) -> usize {
392        self.selected_index
393    }
394
395    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
396        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
397        cx.notify();
398    }
399
400    fn can_select(
401        &mut self,
402        ix: usize,
403        _window: &mut Window,
404        _cx: &mut Context<Picker<Self>>,
405    ) -> bool {
406        match self.filtered_entries.get(ix) {
407            Some(LanguageModelPickerEntry::Model(_)) => true,
408            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
409        }
410    }
411
412    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
413        "Select a model…".into()
414    }
415
416    fn update_matches(
417        &mut self,
418        query: String,
419        window: &mut Window,
420        cx: &mut Context<Picker<Self>>,
421    ) -> Task<()> {
422        let all_models = self.all_models.clone();
423        let active_model = (self.get_active_model)(cx);
424        let bg_executor = cx.background_executor();
425
426        let language_model_registry = LanguageModelRegistry::global(cx);
427
428        let configured_providers = language_model_registry
429            .read(cx)
430            .visible_providers()
431            .into_iter()
432            .filter(|provider| provider.is_authenticated(cx))
433            .collect::<Vec<_>>();
434
435        let configured_provider_ids = configured_providers
436            .iter()
437            .map(|provider| provider.id())
438            .collect::<Vec<_>>();
439
440        let recommended_models = all_models
441            .recommended
442            .iter()
443            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
444            .cloned()
445            .collect::<Vec<_>>();
446
447        let available_models = all_models
448            .all
449            .values()
450            .flat_map(|models| models.iter())
451            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
452            .cloned()
453            .collect::<Vec<_>>();
454
455        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
456        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
457
458        let recommended = matcher_rec.exact_search(&query);
459        let all = matcher_all.fuzzy_search(&query);
460
461        let filtered_models = GroupedModels::new(all, recommended);
462
463        cx.spawn_in(window, async move |this, cx| {
464            this.update_in(cx, |this, window, cx| {
465                this.delegate.filtered_entries = filtered_models.entries();
466                // Finds the currently selected model in the list
467                let new_index =
468                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
469                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
470                cx.notify();
471            })
472            .ok();
473        })
474    }
475
476    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
477        if let Some(LanguageModelPickerEntry::Model(model_info)) =
478            self.filtered_entries.get(self.selected_index)
479        {
480            let model = model_info.model.clone();
481            (self.on_model_changed)(model.clone(), cx);
482
483            let current_index = self.selected_index;
484            self.set_selected_index(current_index, window, cx);
485
486            cx.emit(DismissEvent);
487        }
488    }
489
490    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
491        cx.emit(DismissEvent);
492    }
493
494    fn render_match(
495        &self,
496        ix: usize,
497        selected: bool,
498        _: &mut Window,
499        cx: &mut Context<Picker<Self>>,
500    ) -> Option<Self::ListItem> {
501        match self.filtered_entries.get(ix)? {
502            LanguageModelPickerEntry::Separator(title) => Some(
503                div()
504                    .px_2()
505                    .pb_1()
506                    .when(ix > 1, |this| {
507                        this.mt_1()
508                            .pt_2()
509                            .border_t_1()
510                            .border_color(cx.theme().colors().border_variant)
511                    })
512                    .child(
513                        Label::new(title)
514                            .size(LabelSize::XSmall)
515                            .color(Color::Muted),
516                    )
517                    .into_any_element(),
518            ),
519            LanguageModelPickerEntry::Model(model_info) => {
520                let active_model = (self.get_active_model)(cx);
521                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
522                let active_model_id = active_model.map(|m| m.model.id());
523
524                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
525                    && Some(model_info.model.id()) == active_model_id;
526
527                let model_icon_color = if is_selected {
528                    Color::Accent
529                } else {
530                    Color::Muted
531                };
532
533                Some(
534                    ListItem::new(ix)
535                        .inset(true)
536                        .spacing(ListItemSpacing::Sparse)
537                        .toggle_state(selected)
538                        .child(
539                            h_flex()
540                                .w_full()
541                                .gap_1p5()
542                                .child(match &model_info.icon {
543                                    ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
544                                        .color(model_icon_color)
545                                        .size(IconSize::Small),
546                                    ProviderIcon::Path(icon_path) => {
547                                        Icon::from_external_svg(icon_path.clone())
548                                            .color(model_icon_color)
549                                            .size(IconSize::Small)
550                                    }
551                                })
552                                .child(Label::new(model_info.model.name().0).truncate()),
553                        )
554                        .end_slot(div().pr_3().when(is_selected, |this| {
555                            this.child(
556                                Icon::new(IconName::Check)
557                                    .color(Color::Accent)
558                                    .size(IconSize::Small),
559                            )
560                        }))
561                        .into_any_element(),
562                )
563            }
564        }
565    }
566
567    fn render_footer(
568        &self,
569        _window: &mut Window,
570        cx: &mut Context<Picker<Self>>,
571    ) -> Option<gpui::AnyElement> {
572        let focus_handle = self.focus_handle.clone();
573
574        if !self.popover_styles {
575            return None;
576        }
577
578        Some(
579            h_flex()
580                .w_full()
581                .p_1p5()
582                .border_t_1()
583                .border_color(cx.theme().colors().border_variant)
584                .child(
585                    Button::new("configure", "Configure")
586                        .full_width()
587                        .style(ButtonStyle::Outlined)
588                        .key_binding(
589                            KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
590                                .map(|kb| kb.size(rems_from_px(12.))),
591                        )
592                        .on_click(|_, window, cx| {
593                            window.dispatch_action(OpenSettings.boxed_clone(), cx);
594                        }),
595                )
596                .into_any(),
597        )
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604    use futures::{future::BoxFuture, stream::BoxStream};
605    use gpui::{AsyncApp, TestAppContext, http_client};
606    use language_model::{
607        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
608        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
609        LanguageModelRequest, LanguageModelToolChoice,
610    };
611    use ui::IconName;
612
613    #[derive(Clone)]
614    struct TestLanguageModel {
615        name: LanguageModelName,
616        id: LanguageModelId,
617        provider_id: LanguageModelProviderId,
618        provider_name: LanguageModelProviderName,
619    }
620
621    impl TestLanguageModel {
622        fn new(name: &str, provider: &str) -> Self {
623            Self {
624                name: LanguageModelName::from(name.to_string()),
625                id: LanguageModelId::from(name.to_string()),
626                provider_id: LanguageModelProviderId::from(provider.to_string()),
627                provider_name: LanguageModelProviderName::from(provider.to_string()),
628            }
629        }
630    }
631
632    impl LanguageModel for TestLanguageModel {
633        fn id(&self) -> LanguageModelId {
634            self.id.clone()
635        }
636
637        fn name(&self) -> LanguageModelName {
638            self.name.clone()
639        }
640
641        fn provider_id(&self) -> LanguageModelProviderId {
642            self.provider_id.clone()
643        }
644
645        fn provider_name(&self) -> LanguageModelProviderName {
646            self.provider_name.clone()
647        }
648
649        fn supports_tools(&self) -> bool {
650            false
651        }
652
653        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
654            false
655        }
656
657        fn supports_images(&self) -> bool {
658            false
659        }
660
661        fn telemetry_id(&self) -> String {
662            format!("{}/{}", self.provider_id.0, self.name.0)
663        }
664
665        fn max_token_count(&self) -> u64 {
666            1000
667        }
668
669        fn count_tokens(
670            &self,
671            _: LanguageModelRequest,
672            _: &App,
673        ) -> BoxFuture<'static, http_client::Result<u64>> {
674            unimplemented!()
675        }
676
677        fn stream_completion(
678            &self,
679            _: LanguageModelRequest,
680            _: &AsyncApp,
681        ) -> BoxFuture<
682            'static,
683            Result<
684                BoxStream<
685                    'static,
686                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
687                >,
688                LanguageModelCompletionError,
689            >,
690        > {
691            unimplemented!()
692        }
693    }
694
695    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
696        model_specs
697            .into_iter()
698            .map(|(provider, name)| ModelInfo {
699                model: Arc::new(TestLanguageModel::new(name, provider)),
700                icon: ProviderIcon::Name(IconName::Ai),
701            })
702            .collect()
703    }
704
705    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
706        assert_eq!(
707            result.len(),
708            expected.len(),
709            "Number of models doesn't match"
710        );
711
712        for (i, expected_name) in expected.iter().enumerate() {
713            assert_eq!(
714                result[i].model.telemetry_id(),
715                *expected_name,
716                "Model at position {} doesn't match expected model",
717                i
718            );
719        }
720    }
721
722    #[gpui::test]
723    fn test_exact_match(cx: &mut TestAppContext) {
724        let models = create_models(vec![
725            ("zed", "Claude 3.7 Sonnet"),
726            ("zed", "Claude 3.7 Sonnet Thinking"),
727            ("zed", "gpt-4.1"),
728            ("zed", "gpt-4.1-nano"),
729            ("openai", "gpt-3.5-turbo"),
730            ("openai", "gpt-4.1"),
731            ("openai", "gpt-4.1-nano"),
732            ("ollama", "mistral"),
733            ("ollama", "deepseek"),
734        ]);
735        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
736
737        // The order of models should be maintained, case doesn't matter
738        let results = matcher.exact_search("GPT-4.1");
739        assert_models_eq(
740            results,
741            vec![
742                "zed/gpt-4.1",
743                "zed/gpt-4.1-nano",
744                "openai/gpt-4.1",
745                "openai/gpt-4.1-nano",
746            ],
747        );
748    }
749
750    #[gpui::test]
751    fn test_fuzzy_match(cx: &mut TestAppContext) {
752        let models = create_models(vec![
753            ("zed", "Claude 3.7 Sonnet"),
754            ("zed", "Claude 3.7 Sonnet Thinking"),
755            ("zed", "gpt-4.1"),
756            ("zed", "gpt-4.1-nano"),
757            ("openai", "gpt-3.5-turbo"),
758            ("openai", "gpt-4.1"),
759            ("openai", "gpt-4.1-nano"),
760            ("ollama", "mistral"),
761            ("ollama", "deepseek"),
762        ]);
763        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
764
765        // Results should preserve models order whenever possible.
766        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
767        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
768        // so it should appear first in the results.
769        let results = matcher.fuzzy_search("41");
770        assert_models_eq(
771            results,
772            vec![
773                "zed/gpt-4.1",
774                "openai/gpt-4.1",
775                "zed/gpt-4.1-nano",
776                "openai/gpt-4.1-nano",
777            ],
778        );
779
780        // Model provider should be searchable as well
781        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
782        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
783
784        // Fuzzy search
785        let results = matcher.fuzzy_search("z4n");
786        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
787    }
788
789    #[gpui::test]
790    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
791        let recommended_models = create_models(vec![("zed", "claude")]);
792        let all_models = create_models(vec![
793            ("zed", "claude"), // Should also appear in "other"
794            ("zed", "gemini"),
795            ("copilot", "o3"),
796        ]);
797
798        let grouped_models = GroupedModels::new(all_models, recommended_models);
799
800        let actual_all_models = grouped_models
801            .all
802            .values()
803            .flatten()
804            .cloned()
805            .collect::<Vec<_>>();
806
807        // Recommended models should also appear in "all"
808        assert_models_eq(
809            actual_all_models,
810            vec!["zed/claude", "zed/gemini", "copilot/o3"],
811        );
812    }
813
814    #[gpui::test]
815    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
816        let recommended_models = create_models(vec![("zed", "claude")]);
817        let all_models = create_models(vec![
818            ("zed", "claude"), // Should also appear in "other"
819            ("zed", "gemini"),
820            ("copilot", "claude"), // Different provider, should appear in "other"
821        ]);
822
823        let grouped_models = GroupedModels::new(all_models, recommended_models);
824
825        let actual_all_models = grouped_models
826            .all
827            .values()
828            .flatten()
829            .cloned()
830            .collect::<Vec<_>>();
831
832        // All models should appear in "all" regardless of recommended status
833        assert_models_eq(
834            actual_all_models,
835            vec!["zed/claude", "zed/gemini", "copilot/claude"],
836        );
837    }
838}