language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::IndexMap;
  4use futures::{StreamExt, channel::mpsc};
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
  7use language_model::{
  8    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
  9    LanguageModelProviderId, LanguageModelRegistry,
 10};
 11use ordered_float::OrderedFloat;
 12use picker::{Picker, PickerDelegate};
 13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
 14use zed_actions::agent::OpenSettings;
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 18
 19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 20
 21pub fn language_model_selector(
 22    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 23    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 24    popover_styles: bool,
 25    focus_handle: FocusHandle,
 26    window: &mut Window,
 27    cx: &mut Context<LanguageModelSelector>,
 28) -> LanguageModelSelector {
 29    let delegate = LanguageModelPickerDelegate::new(
 30        get_active_model,
 31        on_model_changed,
 32        popover_styles,
 33        focus_handle,
 34        window,
 35        cx,
 36    );
 37
 38    if popover_styles {
 39        Picker::list(delegate, window, cx)
 40            .show_scrollbar(true)
 41            .width(rems(20.))
 42            .max_height(Some(rems(20.).into()))
 43    } else {
 44        Picker::list(delegate, window, cx).show_scrollbar(true)
 45    }
 46}
 47
 48fn all_models(cx: &App) -> GroupedModels {
 49    let providers = LanguageModelRegistry::global(cx)
 50        .read(cx)
 51        .visible_providers();
 52
 53    let recommended = providers
 54        .iter()
 55        .flat_map(|provider| {
 56            provider
 57                .recommended_models(cx)
 58                .into_iter()
 59                .map(|model| ModelInfo {
 60                    model,
 61                    icon: ProviderIcon::from_provider(provider.as_ref()),
 62                })
 63        })
 64        .collect();
 65
 66    let all: Vec<ModelInfo> = providers
 67        .iter()
 68        .flat_map(|provider| {
 69            provider
 70                .provided_models(cx)
 71                .into_iter()
 72                .map(|model| ModelInfo {
 73                    model,
 74                    icon: ProviderIcon::from_provider(provider.as_ref()),
 75                })
 76        })
 77        .collect();
 78
 79    GroupedModels::new(all, recommended)
 80}
 81
 82#[derive(Clone)]
 83enum ProviderIcon {
 84    Name(IconName),
 85    Path(SharedString),
 86}
 87
 88impl ProviderIcon {
 89    fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
 90        if let Some(path) = provider.icon_path() {
 91            Self::Path(path)
 92        } else {
 93            Self::Name(provider.icon())
 94        }
 95    }
 96}
 97
 98#[derive(Clone)]
 99struct ModelInfo {
100    model: Arc<dyn LanguageModel>,
101    icon: ProviderIcon,
102}
103
104pub struct LanguageModelPickerDelegate {
105    on_model_changed: OnModelChanged,
106    get_active_model: GetActiveModel,
107    all_models: Arc<GroupedModels>,
108    filtered_entries: Vec<LanguageModelPickerEntry>,
109    selected_index: usize,
110    _authenticate_all_providers_task: Task<()>,
111    _refresh_models_task: Task<()>,
112    popover_styles: bool,
113    focus_handle: FocusHandle,
114}
115
116impl LanguageModelPickerDelegate {
117    fn new(
118        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
119        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
120        popover_styles: bool,
121        focus_handle: FocusHandle,
122        window: &mut Window,
123        cx: &mut Context<Picker<Self>>,
124    ) -> Self {
125        let on_model_changed = Arc::new(on_model_changed);
126        let models = all_models(cx);
127        let entries = models.entries();
128
129        Self {
130            on_model_changed,
131            all_models: Arc::new(models),
132            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
133            filtered_entries: entries,
134            get_active_model: Arc::new(get_active_model),
135            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
136            _refresh_models_task: {
137                // Create a channel to signal when models need refreshing
138                let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
139
140                // Subscribe to registry events and send refresh signals through the channel
141                let registry = LanguageModelRegistry::global(cx);
142                cx.subscribe(&registry, move |_picker, _, event, _cx| match event {
143                    language_model::Event::ProviderStateChanged(_) => {
144                        refresh_tx.unbounded_send(()).ok();
145                    }
146                    language_model::Event::AddedProvider(_) => {
147                        refresh_tx.unbounded_send(()).ok();
148                    }
149                    language_model::Event::RemovedProvider(_) => {
150                        refresh_tx.unbounded_send(()).ok();
151                    }
152                    language_model::Event::ProvidersChanged => {
153                        refresh_tx.unbounded_send(()).ok();
154                    }
155                    _ => {}
156                })
157                .detach();
158
159                // Spawn a task that listens for refresh signals and updates the picker
160                cx.spawn_in(window, async move |this, cx| {
161                    while let Some(()) = refresh_rx.next().await {
162                        let result = this.update_in(cx, |picker, window, cx| {
163                            picker.delegate.all_models = Arc::new(all_models(cx));
164                            picker.refresh(window, cx);
165                        });
166                        if result.is_err() {
167                            // Picker was dropped, exit the loop
168                            break;
169                        }
170                    }
171                })
172            },
173            popover_styles,
174            focus_handle,
175        }
176    }
177
178    fn get_active_model_index(
179        entries: &[LanguageModelPickerEntry],
180        active_model: Option<ConfiguredModel>,
181    ) -> usize {
182        entries
183            .iter()
184            .position(|entry| {
185                if let LanguageModelPickerEntry::Model(model) = entry {
186                    active_model
187                        .as_ref()
188                        .map(|active_model| {
189                            active_model.model.id() == model.model.id()
190                                && active_model.provider.id() == model.model.provider_id()
191                        })
192                        .unwrap_or_default()
193                } else {
194                    false
195                }
196            })
197            .unwrap_or(0)
198    }
199
200    /// Authenticates all providers in the [`LanguageModelRegistry`].
201    ///
202    /// We do this so that we can populate the language selector with all of the
203    /// models from the configured providers.
204    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
205        let authenticate_all_providers = LanguageModelRegistry::global(cx)
206            .read(cx)
207            .providers()
208            .iter()
209            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
210            .collect::<Vec<_>>();
211
212        cx.spawn(async move |_cx| {
213            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
214                if let Err(err) = authenticate_task.await {
215                    if matches!(err, AuthenticateError::CredentialsNotFound) {
216                        // Since we're authenticating these providers in the
217                        // background for the purposes of populating the
218                        // language selector, we don't care about providers
219                        // where the credentials are not found.
220                    } else {
221                        // Some providers have noisy failure states that we
222                        // don't want to spam the logs with every time the
223                        // language model selector is initialized.
224                        //
225                        // Ideally these should have more clear failure modes
226                        // that we know are safe to ignore here, like what we do
227                        // with `CredentialsNotFound` above.
228                        match provider_id.0.as_ref() {
229                            "lmstudio" | "ollama" => {
230                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
231                                //
232                                // These fail noisily, so we don't log them.
233                            }
234                            "copilot_chat" => {
235                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
236                            }
237                            _ => {
238                                log::error!(
239                                    "Failed to authenticate provider: {}: {err:#}",
240                                    provider_name.0
241                                );
242                            }
243                        }
244                    }
245                }
246            }
247        })
248    }
249
250    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
251        (self.get_active_model)(cx)
252    }
253}
254
255struct GroupedModels {
256    recommended: Vec<ModelInfo>,
257    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
258}
259
260impl GroupedModels {
261    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
262        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
263        for model in all {
264            let provider = model.model.provider_id();
265            if let Some(models) = all_by_provider.get_mut(&provider) {
266                models.push(model);
267            } else {
268                all_by_provider.insert(provider, vec![model]);
269            }
270        }
271
272        Self {
273            recommended,
274            all: all_by_provider,
275        }
276    }
277
278    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
279        let mut entries = Vec::new();
280
281        if !self.recommended.is_empty() {
282            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
283            entries.extend(
284                self.recommended
285                    .iter()
286                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
287            );
288        }
289
290        for models in self.all.values() {
291            if models.is_empty() {
292                continue;
293            }
294            entries.push(LanguageModelPickerEntry::Separator(
295                models[0].model.provider_name().0,
296            ));
297            entries.extend(
298                models
299                    .iter()
300                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
301            );
302        }
303        entries
304    }
305}
306
307enum LanguageModelPickerEntry {
308    Model(ModelInfo),
309    Separator(SharedString),
310}
311
312struct ModelMatcher {
313    models: Vec<ModelInfo>,
314    bg_executor: BackgroundExecutor,
315    candidates: Vec<StringMatchCandidate>,
316}
317
318impl ModelMatcher {
319    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
320        let candidates = Self::make_match_candidates(&models);
321        Self {
322            models,
323            bg_executor,
324            candidates,
325        }
326    }
327
328    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
329        let mut matches = self.bg_executor.block(match_strings(
330            &self.candidates,
331            query,
332            false,
333            true,
334            100,
335            &Default::default(),
336            self.bg_executor.clone(),
337        ));
338
339        let sorting_key = |mat: &StringMatch| {
340            let candidate = &self.candidates[mat.candidate_id];
341            (Reverse(OrderedFloat(mat.score)), candidate.id)
342        };
343        matches.sort_unstable_by_key(sorting_key);
344
345        let matched_models: Vec<_> = matches
346            .into_iter()
347            .map(|mat| self.models[mat.candidate_id].clone())
348            .collect();
349
350        matched_models
351    }
352
353    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
354        self.models
355            .iter()
356            .filter(|m| {
357                m.model
358                    .name()
359                    .0
360                    .to_lowercase()
361                    .contains(&query.to_lowercase())
362            })
363            .cloned()
364            .collect::<Vec<_>>()
365    }
366
367    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
368        model_infos
369            .iter()
370            .enumerate()
371            .map(|(index, model)| {
372                StringMatchCandidate::new(
373                    index,
374                    &format!(
375                        "{}/{}",
376                        &model.model.provider_name().0,
377                        &model.model.name().0
378                    ),
379                )
380            })
381            .collect::<Vec<_>>()
382    }
383}
384
385impl PickerDelegate for LanguageModelPickerDelegate {
386    type ListItem = AnyElement;
387
388    fn match_count(&self) -> usize {
389        self.filtered_entries.len()
390    }
391
392    fn selected_index(&self) -> usize {
393        self.selected_index
394    }
395
396    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
397        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
398        cx.notify();
399    }
400
401    fn can_select(
402        &mut self,
403        ix: usize,
404        _window: &mut Window,
405        _cx: &mut Context<Picker<Self>>,
406    ) -> bool {
407        match self.filtered_entries.get(ix) {
408            Some(LanguageModelPickerEntry::Model(_)) => true,
409            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
410        }
411    }
412
413    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
414        "Select a model…".into()
415    }
416
417    fn update_matches(
418        &mut self,
419        query: String,
420        window: &mut Window,
421        cx: &mut Context<Picker<Self>>,
422    ) -> Task<()> {
423        let all_models = self.all_models.clone();
424        let active_model = (self.get_active_model)(cx);
425        let bg_executor = cx.background_executor();
426
427        let language_model_registry = LanguageModelRegistry::global(cx);
428
429        let configured_providers = language_model_registry
430            .read(cx)
431            .visible_providers()
432            .into_iter()
433            .filter(|provider| provider.is_authenticated(cx))
434            .collect::<Vec<_>>();
435
436        let configured_provider_ids = configured_providers
437            .iter()
438            .map(|provider| provider.id())
439            .collect::<Vec<_>>();
440
441        let recommended_models = all_models
442            .recommended
443            .iter()
444            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
445            .cloned()
446            .collect::<Vec<_>>();
447
448        let available_models = all_models
449            .all
450            .values()
451            .flat_map(|models| models.iter())
452            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
453            .cloned()
454            .collect::<Vec<_>>();
455
456        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
457        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
458
459        let recommended = matcher_rec.exact_search(&query);
460        let all = matcher_all.fuzzy_search(&query);
461
462        let filtered_models = GroupedModels::new(all, recommended);
463
464        cx.spawn_in(window, async move |this, cx| {
465            this.update_in(cx, |this, window, cx| {
466                this.delegate.filtered_entries = filtered_models.entries();
467                // Finds the currently selected model in the list
468                let new_index =
469                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
470                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
471                cx.notify();
472            })
473            .ok();
474        })
475    }
476
477    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
478        if let Some(LanguageModelPickerEntry::Model(model_info)) =
479            self.filtered_entries.get(self.selected_index)
480        {
481            let model = model_info.model.clone();
482            (self.on_model_changed)(model.clone(), cx);
483
484            let current_index = self.selected_index;
485            self.set_selected_index(current_index, window, cx);
486
487            cx.emit(DismissEvent);
488        }
489    }
490
491    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
492        cx.emit(DismissEvent);
493    }
494
495    fn render_match(
496        &self,
497        ix: usize,
498        selected: bool,
499        _: &mut Window,
500        cx: &mut Context<Picker<Self>>,
501    ) -> Option<Self::ListItem> {
502        match self.filtered_entries.get(ix)? {
503            LanguageModelPickerEntry::Separator(title) => Some(
504                div()
505                    .px_2()
506                    .pb_1()
507                    .when(ix > 1, |this| {
508                        this.mt_1()
509                            .pt_2()
510                            .border_t_1()
511                            .border_color(cx.theme().colors().border_variant)
512                    })
513                    .child(
514                        Label::new(title)
515                            .size(LabelSize::XSmall)
516                            .color(Color::Muted),
517                    )
518                    .into_any_element(),
519            ),
520            LanguageModelPickerEntry::Model(model_info) => {
521                let active_model = (self.get_active_model)(cx);
522                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
523                let active_model_id = active_model.map(|m| m.model.id());
524
525                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
526                    && Some(model_info.model.id()) == active_model_id;
527
528                let model_icon_color = if is_selected {
529                    Color::Accent
530                } else {
531                    Color::Muted
532                };
533
534                Some(
535                    ListItem::new(ix)
536                        .inset(true)
537                        .spacing(ListItemSpacing::Sparse)
538                        .toggle_state(selected)
539                        .child(
540                            h_flex()
541                                .w_full()
542                                .gap_1p5()
543                                .child(match &model_info.icon {
544                                    ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
545                                        .color(model_icon_color)
546                                        .size(IconSize::Small),
547                                    ProviderIcon::Path(icon_path) => {
548                                        Icon::from_external_svg(icon_path.clone())
549                                            .color(model_icon_color)
550                                            .size(IconSize::Small)
551                                    }
552                                })
553                                .child(Label::new(model_info.model.name().0).truncate()),
554                        )
555                        .end_slot(div().pr_3().when(is_selected, |this| {
556                            this.child(
557                                Icon::new(IconName::Check)
558                                    .color(Color::Accent)
559                                    .size(IconSize::Small),
560                            )
561                        }))
562                        .into_any_element(),
563                )
564            }
565        }
566    }
567
568    fn render_footer(
569        &self,
570        _window: &mut Window,
571        cx: &mut Context<Picker<Self>>,
572    ) -> Option<gpui::AnyElement> {
573        let focus_handle = self.focus_handle.clone();
574
575        if !self.popover_styles {
576            return None;
577        }
578
579        Some(
580            h_flex()
581                .w_full()
582                .p_1p5()
583                .border_t_1()
584                .border_color(cx.theme().colors().border_variant)
585                .child(
586                    Button::new("configure", "Configure")
587                        .full_width()
588                        .style(ButtonStyle::Outlined)
589                        .key_binding(
590                            KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
591                                .map(|kb| kb.size(rems_from_px(12.))),
592                        )
593                        .on_click(|_, window, cx| {
594                            window.dispatch_action(OpenSettings.boxed_clone(), cx);
595                        }),
596                )
597                .into_any(),
598        )
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605    use futures::{future::BoxFuture, stream::BoxStream};
606    use gpui::{AsyncApp, TestAppContext, http_client};
607    use language_model::{
608        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
609        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
610        LanguageModelRequest, LanguageModelToolChoice,
611    };
612    use ui::IconName;
613
614    #[derive(Clone)]
615    struct TestLanguageModel {
616        name: LanguageModelName,
617        id: LanguageModelId,
618        provider_id: LanguageModelProviderId,
619        provider_name: LanguageModelProviderName,
620    }
621
622    impl TestLanguageModel {
623        fn new(name: &str, provider: &str) -> Self {
624            Self {
625                name: LanguageModelName::from(name.to_string()),
626                id: LanguageModelId::from(name.to_string()),
627                provider_id: LanguageModelProviderId::from(provider.to_string()),
628                provider_name: LanguageModelProviderName::from(provider.to_string()),
629            }
630        }
631    }
632
633    impl LanguageModel for TestLanguageModel {
634        fn id(&self) -> LanguageModelId {
635            self.id.clone()
636        }
637
638        fn name(&self) -> LanguageModelName {
639            self.name.clone()
640        }
641
642        fn provider_id(&self) -> LanguageModelProviderId {
643            self.provider_id.clone()
644        }
645
646        fn provider_name(&self) -> LanguageModelProviderName {
647            self.provider_name.clone()
648        }
649
650        fn supports_tools(&self) -> bool {
651            false
652        }
653
654        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
655            false
656        }
657
658        fn supports_images(&self) -> bool {
659            false
660        }
661
662        fn telemetry_id(&self) -> String {
663            format!("{}/{}", self.provider_id.0, self.name.0)
664        }
665
666        fn max_token_count(&self) -> u64 {
667            1000
668        }
669
670        fn count_tokens(
671            &self,
672            _: LanguageModelRequest,
673            _: &App,
674        ) -> BoxFuture<'static, http_client::Result<u64>> {
675            unimplemented!()
676        }
677
678        fn stream_completion(
679            &self,
680            _: LanguageModelRequest,
681            _: &AsyncApp,
682        ) -> BoxFuture<
683            'static,
684            Result<
685                BoxStream<
686                    'static,
687                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
688                >,
689                LanguageModelCompletionError,
690            >,
691        > {
692            unimplemented!()
693        }
694    }
695
696    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
697        model_specs
698            .into_iter()
699            .map(|(provider, name)| ModelInfo {
700                model: Arc::new(TestLanguageModel::new(name, provider)),
701                icon: ProviderIcon::Name(IconName::Ai),
702            })
703            .collect()
704    }
705
706    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
707        assert_eq!(
708            result.len(),
709            expected.len(),
710            "Number of models doesn't match"
711        );
712
713        for (i, expected_name) in expected.iter().enumerate() {
714            assert_eq!(
715                result[i].model.telemetry_id(),
716                *expected_name,
717                "Model at position {} doesn't match expected model",
718                i
719            );
720        }
721    }
722
723    #[gpui::test]
724    fn test_exact_match(cx: &mut TestAppContext) {
725        let models = create_models(vec![
726            ("zed", "Claude 3.7 Sonnet"),
727            ("zed", "Claude 3.7 Sonnet Thinking"),
728            ("zed", "gpt-4.1"),
729            ("zed", "gpt-4.1-nano"),
730            ("openai", "gpt-3.5-turbo"),
731            ("openai", "gpt-4.1"),
732            ("openai", "gpt-4.1-nano"),
733            ("ollama", "mistral"),
734            ("ollama", "deepseek"),
735        ]);
736        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
737
738        // The order of models should be maintained, case doesn't matter
739        let results = matcher.exact_search("GPT-4.1");
740        assert_models_eq(
741            results,
742            vec![
743                "zed/gpt-4.1",
744                "zed/gpt-4.1-nano",
745                "openai/gpt-4.1",
746                "openai/gpt-4.1-nano",
747            ],
748        );
749    }
750
751    #[gpui::test]
752    fn test_fuzzy_match(cx: &mut TestAppContext) {
753        let models = create_models(vec![
754            ("zed", "Claude 3.7 Sonnet"),
755            ("zed", "Claude 3.7 Sonnet Thinking"),
756            ("zed", "gpt-4.1"),
757            ("zed", "gpt-4.1-nano"),
758            ("openai", "gpt-3.5-turbo"),
759            ("openai", "gpt-4.1"),
760            ("openai", "gpt-4.1-nano"),
761            ("ollama", "mistral"),
762            ("ollama", "deepseek"),
763        ]);
764        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
765
766        // Results should preserve models order whenever possible.
767        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
768        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
769        // so it should appear first in the results.
770        let results = matcher.fuzzy_search("41");
771        assert_models_eq(
772            results,
773            vec![
774                "zed/gpt-4.1",
775                "openai/gpt-4.1",
776                "zed/gpt-4.1-nano",
777                "openai/gpt-4.1-nano",
778            ],
779        );
780
781        // Model provider should be searchable as well
782        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
783        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
784
785        // Fuzzy search
786        let results = matcher.fuzzy_search("z4n");
787        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
788    }
789
790    #[gpui::test]
791    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
792        let recommended_models = create_models(vec![("zed", "claude")]);
793        let all_models = create_models(vec![
794            ("zed", "claude"), // Should also appear in "other"
795            ("zed", "gemini"),
796            ("copilot", "o3"),
797        ]);
798
799        let grouped_models = GroupedModels::new(all_models, recommended_models);
800
801        let actual_all_models = grouped_models
802            .all
803            .values()
804            .flatten()
805            .cloned()
806            .collect::<Vec<_>>();
807
808        // Recommended models should also appear in "all"
809        assert_models_eq(
810            actual_all_models,
811            vec!["zed/claude", "zed/gemini", "copilot/o3"],
812        );
813    }
814
815    #[gpui::test]
816    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
817        let recommended_models = create_models(vec![("zed", "claude")]);
818        let all_models = create_models(vec![
819            ("zed", "claude"), // Should also appear in "other"
820            ("zed", "gemini"),
821            ("copilot", "claude"), // Different provider, should appear in "other"
822        ]);
823
824        let grouped_models = GroupedModels::new(all_models, recommended_models);
825
826        let actual_all_models = grouped_models
827            .all
828            .values()
829            .flatten()
830            .cloned()
831            .collect::<Vec<_>>();
832
833        // All models should appear in "all" regardless of recommended status
834        assert_models_eq(
835            actual_all_models,
836            vec!["zed/claude", "zed/gemini", "copilot/claude"],
837        );
838    }
839}