language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::IndexMap;
  4use futures::{StreamExt, channel::mpsc};
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
  7use language_model::{
  8    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
  9    LanguageModelProviderId, LanguageModelRegistry,
 10};
 11use ordered_float::OrderedFloat;
 12use picker::{Picker, PickerDelegate};
 13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
 14use zed_actions::agent::OpenSettings;
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 18
 19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 20
 21pub fn language_model_selector(
 22    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 23    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 24    popover_styles: bool,
 25    focus_handle: FocusHandle,
 26    window: &mut Window,
 27    cx: &mut Context<LanguageModelSelector>,
 28) -> LanguageModelSelector {
 29    let delegate = LanguageModelPickerDelegate::new(
 30        get_active_model,
 31        on_model_changed,
 32        popover_styles,
 33        focus_handle,
 34        window,
 35        cx,
 36    );
 37
 38    if popover_styles {
 39        Picker::list(delegate, window, cx)
 40            .show_scrollbar(true)
 41            .width(rems(20.))
 42            .max_height(Some(rems(20.).into()))
 43    } else {
 44        Picker::list(delegate, window, cx).show_scrollbar(true)
 45    }
 46}
 47
 48fn all_models(cx: &App) -> GroupedModels {
 49    let providers = LanguageModelRegistry::global(cx)
 50        .read(cx)
 51        .visible_providers();
 52
 53    let recommended = providers
 54        .iter()
 55        .flat_map(|provider| {
 56            provider
 57                .recommended_models(cx)
 58                .into_iter()
 59                .map(|model| ModelInfo {
 60                    model,
 61                    icon: ProviderIcon::from_provider(provider.as_ref()),
 62                })
 63        })
 64        .collect();
 65
 66    let all: Vec<ModelInfo> = providers
 67        .iter()
 68        .flat_map(|provider| {
 69            provider
 70                .provided_models(cx)
 71                .into_iter()
 72                .map(|model| ModelInfo {
 73                    model,
 74                    icon: ProviderIcon::from_provider(provider.as_ref()),
 75                })
 76        })
 77        .collect();
 78
 79    GroupedModels::new(all, recommended)
 80}
 81
 82#[derive(Clone)]
 83enum ProviderIcon {
 84    Name(IconName),
 85    Path(SharedString),
 86}
 87
 88impl ProviderIcon {
 89    fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
 90        if let Some(path) = provider.icon_path() {
 91            Self::Path(path)
 92        } else {
 93            Self::Name(provider.icon())
 94        }
 95    }
 96}
 97
 98#[derive(Clone)]
 99struct ModelInfo {
100    model: Arc<dyn LanguageModel>,
101    icon: ProviderIcon,
102}
103
104pub struct LanguageModelPickerDelegate {
105    on_model_changed: OnModelChanged,
106    get_active_model: GetActiveModel,
107    all_models: Arc<GroupedModels>,
108    filtered_entries: Vec<LanguageModelPickerEntry>,
109    selected_index: usize,
110    _authenticate_all_providers_task: Task<()>,
111    _refresh_models_task: Task<()>,
112    popover_styles: bool,
113    focus_handle: FocusHandle,
114}
115
116impl LanguageModelPickerDelegate {
117    fn new(
118        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
119        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
120        popover_styles: bool,
121        focus_handle: FocusHandle,
122        window: &mut Window,
123        cx: &mut Context<Picker<Self>>,
124    ) -> Self {
125        let on_model_changed = Arc::new(on_model_changed);
126        let models = all_models(cx);
127        let entries = models.entries();
128
129        Self {
130            on_model_changed,
131            all_models: Arc::new(models),
132            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
133            filtered_entries: entries,
134            get_active_model: Arc::new(get_active_model),
135            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
136            _refresh_models_task: {
137                // Create a channel to signal when models need refreshing
138                let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
139
140                // Subscribe to registry events and send refresh signals through the channel
141                let registry = LanguageModelRegistry::global(cx);
142                cx.subscribe(&registry, move |_picker, _, event, _cx| match event {
143                    language_model::Event::ProviderStateChanged(_) => {
144                        refresh_tx.unbounded_send(()).ok();
145                    }
146                    language_model::Event::AddedProvider(_) => {
147                        refresh_tx.unbounded_send(()).ok();
148                    }
149                    language_model::Event::RemovedProvider(_) => {
150                        refresh_tx.unbounded_send(()).ok();
151                    }
152                    _ => {}
153                })
154                .detach();
155
156                // Spawn a task that listens for refresh signals and updates the picker
157                cx.spawn_in(window, async move |this, cx| {
158                    while let Some(()) = refresh_rx.next().await {
159                        let result = this.update_in(cx, |picker, window, cx| {
160                            picker.delegate.all_models = Arc::new(all_models(cx));
161                            picker.refresh(window, cx);
162                        });
163                        if result.is_err() {
164                            // Picker was dropped, exit the loop
165                            break;
166                        }
167                    }
168                })
169            },
170            popover_styles,
171            focus_handle,
172        }
173    }
174
175    fn get_active_model_index(
176        entries: &[LanguageModelPickerEntry],
177        active_model: Option<ConfiguredModel>,
178    ) -> usize {
179        entries
180            .iter()
181            .position(|entry| {
182                if let LanguageModelPickerEntry::Model(model) = entry {
183                    active_model
184                        .as_ref()
185                        .map(|active_model| {
186                            active_model.model.id() == model.model.id()
187                                && active_model.provider.id() == model.model.provider_id()
188                        })
189                        .unwrap_or_default()
190                } else {
191                    false
192                }
193            })
194            .unwrap_or(0)
195    }
196
197    /// Authenticates all providers in the [`LanguageModelRegistry`].
198    ///
199    /// We do this so that we can populate the language selector with all of the
200    /// models from the configured providers.
201    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
202        let authenticate_all_providers = LanguageModelRegistry::global(cx)
203            .read(cx)
204            .providers()
205            .iter()
206            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
207            .collect::<Vec<_>>();
208
209        cx.spawn(async move |_cx| {
210            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
211                if let Err(err) = authenticate_task.await {
212                    if matches!(err, AuthenticateError::CredentialsNotFound) {
213                        // Since we're authenticating these providers in the
214                        // background for the purposes of populating the
215                        // language selector, we don't care about providers
216                        // where the credentials are not found.
217                    } else {
218                        // Some providers have noisy failure states that we
219                        // don't want to spam the logs with every time the
220                        // language model selector is initialized.
221                        //
222                        // Ideally these should have more clear failure modes
223                        // that we know are safe to ignore here, like what we do
224                        // with `CredentialsNotFound` above.
225                        match provider_id.0.as_ref() {
226                            "lmstudio" | "ollama" => {
227                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
228                                //
229                                // These fail noisily, so we don't log them.
230                            }
231                            "copilot_chat" => {
232                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
233                            }
234                            _ => {
235                                log::error!(
236                                    "Failed to authenticate provider: {}: {err:#}",
237                                    provider_name.0
238                                );
239                            }
240                        }
241                    }
242                }
243            }
244        })
245    }
246
247    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
248        (self.get_active_model)(cx)
249    }
250}
251
252struct GroupedModels {
253    recommended: Vec<ModelInfo>,
254    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
255}
256
257impl GroupedModels {
258    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
259        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
260        for model in all {
261            let provider = model.model.provider_id();
262            if let Some(models) = all_by_provider.get_mut(&provider) {
263                models.push(model);
264            } else {
265                all_by_provider.insert(provider, vec![model]);
266            }
267        }
268
269        Self {
270            recommended,
271            all: all_by_provider,
272        }
273    }
274
275    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
276        let mut entries = Vec::new();
277
278        if !self.recommended.is_empty() {
279            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
280            entries.extend(
281                self.recommended
282                    .iter()
283                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
284            );
285        }
286
287        for models in self.all.values() {
288            if models.is_empty() {
289                continue;
290            }
291            entries.push(LanguageModelPickerEntry::Separator(
292                models[0].model.provider_name().0,
293            ));
294            entries.extend(
295                models
296                    .iter()
297                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
298            );
299        }
300        entries
301    }
302}
303
304enum LanguageModelPickerEntry {
305    Model(ModelInfo),
306    Separator(SharedString),
307}
308
309struct ModelMatcher {
310    models: Vec<ModelInfo>,
311    bg_executor: BackgroundExecutor,
312    candidates: Vec<StringMatchCandidate>,
313}
314
315impl ModelMatcher {
316    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
317        let candidates = Self::make_match_candidates(&models);
318        Self {
319            models,
320            bg_executor,
321            candidates,
322        }
323    }
324
325    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
326        let mut matches = self.bg_executor.block(match_strings(
327            &self.candidates,
328            query,
329            false,
330            true,
331            100,
332            &Default::default(),
333            self.bg_executor.clone(),
334        ));
335
336        let sorting_key = |mat: &StringMatch| {
337            let candidate = &self.candidates[mat.candidate_id];
338            (Reverse(OrderedFloat(mat.score)), candidate.id)
339        };
340        matches.sort_unstable_by_key(sorting_key);
341
342        let matched_models: Vec<_> = matches
343            .into_iter()
344            .map(|mat| self.models[mat.candidate_id].clone())
345            .collect();
346
347        matched_models
348    }
349
350    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
351        self.models
352            .iter()
353            .filter(|m| {
354                m.model
355                    .name()
356                    .0
357                    .to_lowercase()
358                    .contains(&query.to_lowercase())
359            })
360            .cloned()
361            .collect::<Vec<_>>()
362    }
363
364    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
365        model_infos
366            .iter()
367            .enumerate()
368            .map(|(index, model)| {
369                StringMatchCandidate::new(
370                    index,
371                    &format!(
372                        "{}/{}",
373                        &model.model.provider_name().0,
374                        &model.model.name().0
375                    ),
376                )
377            })
378            .collect::<Vec<_>>()
379    }
380}
381
382impl PickerDelegate for LanguageModelPickerDelegate {
383    type ListItem = AnyElement;
384
385    fn match_count(&self) -> usize {
386        self.filtered_entries.len()
387    }
388
389    fn selected_index(&self) -> usize {
390        self.selected_index
391    }
392
393    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
394        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
395        cx.notify();
396    }
397
398    fn can_select(
399        &mut self,
400        ix: usize,
401        _window: &mut Window,
402        _cx: &mut Context<Picker<Self>>,
403    ) -> bool {
404        match self.filtered_entries.get(ix) {
405            Some(LanguageModelPickerEntry::Model(_)) => true,
406            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
407        }
408    }
409
410    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
411        "Select a model…".into()
412    }
413
414    fn update_matches(
415        &mut self,
416        query: String,
417        window: &mut Window,
418        cx: &mut Context<Picker<Self>>,
419    ) -> Task<()> {
420        let all_models = self.all_models.clone();
421        let active_model = (self.get_active_model)(cx);
422        let bg_executor = cx.background_executor();
423
424        let language_model_registry = LanguageModelRegistry::global(cx);
425
426        let configured_providers = language_model_registry
427            .read(cx)
428            .visible_providers()
429            .into_iter()
430            .filter(|provider| provider.is_authenticated(cx))
431            .collect::<Vec<_>>();
432
433        let configured_provider_ids = configured_providers
434            .iter()
435            .map(|provider| provider.id())
436            .collect::<Vec<_>>();
437
438        let recommended_models = all_models
439            .recommended
440            .iter()
441            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
442            .cloned()
443            .collect::<Vec<_>>();
444
445        let available_models = all_models
446            .all
447            .values()
448            .flat_map(|models| models.iter())
449            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
450            .cloned()
451            .collect::<Vec<_>>();
452
453        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
454        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
455
456        let recommended = matcher_rec.exact_search(&query);
457        let all = matcher_all.fuzzy_search(&query);
458
459        let filtered_models = GroupedModels::new(all, recommended);
460
461        cx.spawn_in(window, async move |this, cx| {
462            this.update_in(cx, |this, window, cx| {
463                this.delegate.filtered_entries = filtered_models.entries();
464                // Finds the currently selected model in the list
465                let new_index =
466                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
467                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
468                cx.notify();
469            })
470            .ok();
471        })
472    }
473
474    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
475        if let Some(LanguageModelPickerEntry::Model(model_info)) =
476            self.filtered_entries.get(self.selected_index)
477        {
478            let model = model_info.model.clone();
479            (self.on_model_changed)(model.clone(), cx);
480
481            let current_index = self.selected_index;
482            self.set_selected_index(current_index, window, cx);
483
484            cx.emit(DismissEvent);
485        }
486    }
487
488    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
489        cx.emit(DismissEvent);
490    }
491
492    fn render_match(
493        &self,
494        ix: usize,
495        selected: bool,
496        _: &mut Window,
497        cx: &mut Context<Picker<Self>>,
498    ) -> Option<Self::ListItem> {
499        match self.filtered_entries.get(ix)? {
500            LanguageModelPickerEntry::Separator(title) => Some(
501                div()
502                    .px_2()
503                    .pb_1()
504                    .when(ix > 1, |this| {
505                        this.mt_1()
506                            .pt_2()
507                            .border_t_1()
508                            .border_color(cx.theme().colors().border_variant)
509                    })
510                    .child(
511                        Label::new(title)
512                            .size(LabelSize::XSmall)
513                            .color(Color::Muted),
514                    )
515                    .into_any_element(),
516            ),
517            LanguageModelPickerEntry::Model(model_info) => {
518                let active_model = (self.get_active_model)(cx);
519                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
520                let active_model_id = active_model.map(|m| m.model.id());
521
522                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
523                    && Some(model_info.model.id()) == active_model_id;
524
525                let model_icon_color = if is_selected {
526                    Color::Accent
527                } else {
528                    Color::Muted
529                };
530
531                Some(
532                    ListItem::new(ix)
533                        .inset(true)
534                        .spacing(ListItemSpacing::Sparse)
535                        .toggle_state(selected)
536                        .child(
537                            h_flex()
538                                .w_full()
539                                .gap_1p5()
540                                .child(match &model_info.icon {
541                                    ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
542                                        .color(model_icon_color)
543                                        .size(IconSize::Small),
544                                    ProviderIcon::Path(icon_path) => {
545                                        Icon::from_external_svg(icon_path.clone())
546                                            .color(model_icon_color)
547                                            .size(IconSize::Small)
548                                    }
549                                })
550                                .child(Label::new(model_info.model.name().0).truncate()),
551                        )
552                        .end_slot(div().pr_3().when(is_selected, |this| {
553                            this.child(
554                                Icon::new(IconName::Check)
555                                    .color(Color::Accent)
556                                    .size(IconSize::Small),
557                            )
558                        }))
559                        .into_any_element(),
560                )
561            }
562        }
563    }
564
565    fn render_footer(
566        &self,
567        _window: &mut Window,
568        cx: &mut Context<Picker<Self>>,
569    ) -> Option<gpui::AnyElement> {
570        let focus_handle = self.focus_handle.clone();
571
572        if !self.popover_styles {
573            return None;
574        }
575
576        Some(
577            h_flex()
578                .w_full()
579                .p_1p5()
580                .border_t_1()
581                .border_color(cx.theme().colors().border_variant)
582                .child(
583                    Button::new("configure", "Configure")
584                        .full_width()
585                        .style(ButtonStyle::Outlined)
586                        .key_binding(
587                            KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
588                                .map(|kb| kb.size(rems_from_px(12.))),
589                        )
590                        .on_click(|_, window, cx| {
591                            window.dispatch_action(OpenSettings.boxed_clone(), cx);
592                        }),
593                )
594                .into_any(),
595        )
596    }
597}
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602    use futures::{future::BoxFuture, stream::BoxStream};
603    use gpui::{AsyncApp, TestAppContext, http_client};
604    use language_model::{
605        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
606        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
607        LanguageModelRequest, LanguageModelToolChoice,
608    };
609    use ui::IconName;
610
611    #[derive(Clone)]
612    struct TestLanguageModel {
613        name: LanguageModelName,
614        id: LanguageModelId,
615        provider_id: LanguageModelProviderId,
616        provider_name: LanguageModelProviderName,
617    }
618
619    impl TestLanguageModel {
620        fn new(name: &str, provider: &str) -> Self {
621            Self {
622                name: LanguageModelName::from(name.to_string()),
623                id: LanguageModelId::from(name.to_string()),
624                provider_id: LanguageModelProviderId::from(provider.to_string()),
625                provider_name: LanguageModelProviderName::from(provider.to_string()),
626            }
627        }
628    }
629
630    impl LanguageModel for TestLanguageModel {
631        fn id(&self) -> LanguageModelId {
632            self.id.clone()
633        }
634
635        fn name(&self) -> LanguageModelName {
636            self.name.clone()
637        }
638
639        fn provider_id(&self) -> LanguageModelProviderId {
640            self.provider_id.clone()
641        }
642
643        fn provider_name(&self) -> LanguageModelProviderName {
644            self.provider_name.clone()
645        }
646
647        fn supports_tools(&self) -> bool {
648            false
649        }
650
651        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
652            false
653        }
654
655        fn supports_images(&self) -> bool {
656            false
657        }
658
659        fn telemetry_id(&self) -> String {
660            format!("{}/{}", self.provider_id.0, self.name.0)
661        }
662
663        fn max_token_count(&self) -> u64 {
664            1000
665        }
666
667        fn count_tokens(
668            &self,
669            _: LanguageModelRequest,
670            _: &App,
671        ) -> BoxFuture<'static, http_client::Result<u64>> {
672            unimplemented!()
673        }
674
675        fn stream_completion(
676            &self,
677            _: LanguageModelRequest,
678            _: &AsyncApp,
679        ) -> BoxFuture<
680            'static,
681            Result<
682                BoxStream<
683                    'static,
684                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
685                >,
686                LanguageModelCompletionError,
687            >,
688        > {
689            unimplemented!()
690        }
691    }
692
693    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
694        model_specs
695            .into_iter()
696            .map(|(provider, name)| ModelInfo {
697                model: Arc::new(TestLanguageModel::new(name, provider)),
698                icon: ProviderIcon::Name(IconName::Ai),
699            })
700            .collect()
701    }
702
703    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
704        assert_eq!(
705            result.len(),
706            expected.len(),
707            "Number of models doesn't match"
708        );
709
710        for (i, expected_name) in expected.iter().enumerate() {
711            assert_eq!(
712                result[i].model.telemetry_id(),
713                *expected_name,
714                "Model at position {} doesn't match expected model",
715                i
716            );
717        }
718    }
719
720    #[gpui::test]
721    fn test_exact_match(cx: &mut TestAppContext) {
722        let models = create_models(vec![
723            ("zed", "Claude 3.7 Sonnet"),
724            ("zed", "Claude 3.7 Sonnet Thinking"),
725            ("zed", "gpt-4.1"),
726            ("zed", "gpt-4.1-nano"),
727            ("openai", "gpt-3.5-turbo"),
728            ("openai", "gpt-4.1"),
729            ("openai", "gpt-4.1-nano"),
730            ("ollama", "mistral"),
731            ("ollama", "deepseek"),
732        ]);
733        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
734
735        // The order of models should be maintained, case doesn't matter
736        let results = matcher.exact_search("GPT-4.1");
737        assert_models_eq(
738            results,
739            vec![
740                "zed/gpt-4.1",
741                "zed/gpt-4.1-nano",
742                "openai/gpt-4.1",
743                "openai/gpt-4.1-nano",
744            ],
745        );
746    }
747
748    #[gpui::test]
749    fn test_fuzzy_match(cx: &mut TestAppContext) {
750        let models = create_models(vec![
751            ("zed", "Claude 3.7 Sonnet"),
752            ("zed", "Claude 3.7 Sonnet Thinking"),
753            ("zed", "gpt-4.1"),
754            ("zed", "gpt-4.1-nano"),
755            ("openai", "gpt-3.5-turbo"),
756            ("openai", "gpt-4.1"),
757            ("openai", "gpt-4.1-nano"),
758            ("ollama", "mistral"),
759            ("ollama", "deepseek"),
760        ]);
761        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
762
763        // Results should preserve models order whenever possible.
764        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
765        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
766        // so it should appear first in the results.
767        let results = matcher.fuzzy_search("41");
768        assert_models_eq(
769            results,
770            vec![
771                "zed/gpt-4.1",
772                "openai/gpt-4.1",
773                "zed/gpt-4.1-nano",
774                "openai/gpt-4.1-nano",
775            ],
776        );
777
778        // Model provider should be searchable as well
779        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
780        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
781
782        // Fuzzy search
783        let results = matcher.fuzzy_search("z4n");
784        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
785    }
786
787    #[gpui::test]
788    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
789        let recommended_models = create_models(vec![("zed", "claude")]);
790        let all_models = create_models(vec![
791            ("zed", "claude"), // Should also appear in "other"
792            ("zed", "gemini"),
793            ("copilot", "o3"),
794        ]);
795
796        let grouped_models = GroupedModels::new(all_models, recommended_models);
797
798        let actual_all_models = grouped_models
799            .all
800            .values()
801            .flatten()
802            .cloned()
803            .collect::<Vec<_>>();
804
805        // Recommended models should also appear in "all"
806        assert_models_eq(
807            actual_all_models,
808            vec!["zed/claude", "zed/gemini", "copilot/o3"],
809        );
810    }
811
812    #[gpui::test]
813    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
814        let recommended_models = create_models(vec![("zed", "claude")]);
815        let all_models = create_models(vec![
816            ("zed", "claude"), // Should also appear in "other"
817            ("zed", "gemini"),
818            ("copilot", "claude"), // Different provider, should appear in "other"
819        ]);
820
821        let grouped_models = GroupedModels::new(all_models, recommended_models);
822
823        let actual_all_models = grouped_models
824            .all
825            .values()
826            .flatten()
827            .cloned()
828            .collect::<Vec<_>>();
829
830        // All models should appear in "all" regardless of recommended status
831        assert_models_eq(
832            actual_all_models,
833            vec!["zed/claude", "zed/gemini", "copilot/claude"],
834        );
835    }
836}