language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::IndexMap;
  4use futures::{StreamExt, channel::mpsc};
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
  7use language_model::{
  8    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
  9    LanguageModelProviderId, LanguageModelRegistry,
 10};
 11use ordered_float::OrderedFloat;
 12use picker::{Picker, PickerDelegate};
 13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
 14use zed_actions::agent::OpenSettings;
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 18
 19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 20
 21pub fn language_model_selector(
 22    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 23    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 24    popover_styles: bool,
 25    focus_handle: FocusHandle,
 26    window: &mut Window,
 27    cx: &mut Context<LanguageModelSelector>,
 28) -> LanguageModelSelector {
 29    let delegate = LanguageModelPickerDelegate::new(
 30        get_active_model,
 31        on_model_changed,
 32        popover_styles,
 33        focus_handle,
 34        window,
 35        cx,
 36    );
 37
 38    if popover_styles {
 39        Picker::list(delegate, window, cx)
 40            .show_scrollbar(true)
 41            .width(rems(20.))
 42            .max_height(Some(rems(20.).into()))
 43    } else {
 44        Picker::list(delegate, window, cx).show_scrollbar(true)
 45    }
 46}
 47
 48fn all_models(cx: &App) -> GroupedModels {
 49    let providers = LanguageModelRegistry::global(cx)
 50        .read(cx)
 51        .visible_providers();
 52
 53    let recommended = providers
 54        .iter()
 55        .flat_map(|provider| {
 56            provider
 57                .recommended_models(cx)
 58                .into_iter()
 59                .map(|model| ModelInfo {
 60                    model,
 61                    icon: ProviderIcon::from_provider(provider.as_ref()),
 62                })
 63        })
 64        .collect();
 65
 66    let all: Vec<ModelInfo> = providers
 67        .iter()
 68        .flat_map(|provider| {
 69            provider
 70                .provided_models(cx)
 71                .into_iter()
 72                .map(|model| ModelInfo {
 73                    model,
 74                    icon: ProviderIcon::from_provider(provider.as_ref()),
 75                })
 76        })
 77        .collect();
 78
 79    GroupedModels::new(all, recommended)
 80}
 81
 82#[derive(Clone)]
 83enum ProviderIcon {
 84    Name(IconName),
 85    Path(SharedString),
 86}
 87
 88impl ProviderIcon {
 89    fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
 90        if let Some(path) = provider.icon_path() {
 91            Self::Path(path)
 92        } else {
 93            Self::Name(provider.icon())
 94        }
 95    }
 96}
 97
 98#[derive(Clone)]
 99struct ModelInfo {
100    model: Arc<dyn LanguageModel>,
101    icon: ProviderIcon,
102}
103
104pub struct LanguageModelPickerDelegate {
105    on_model_changed: OnModelChanged,
106    get_active_model: GetActiveModel,
107    all_models: Arc<GroupedModels>,
108    filtered_entries: Vec<LanguageModelPickerEntry>,
109    selected_index: usize,
110    _authenticate_all_providers_task: Task<()>,
111    _refresh_models_task: Task<()>,
112    popover_styles: bool,
113    focus_handle: FocusHandle,
114}
115
116impl LanguageModelPickerDelegate {
117    fn new(
118        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
119        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
120        popover_styles: bool,
121        focus_handle: FocusHandle,
122        window: &mut Window,
123        cx: &mut Context<Picker<Self>>,
124    ) -> Self {
125        let on_model_changed = Arc::new(on_model_changed);
126        let models = all_models(cx);
127        let entries = models.entries();
128
129        Self {
130            on_model_changed,
131            all_models: Arc::new(models),
132            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
133            filtered_entries: entries,
134            get_active_model: Arc::new(get_active_model),
135            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
136            _refresh_models_task: {
137                // Create a channel to signal when models need refreshing
138                let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
139
140                // Subscribe to registry events and send refresh signals through the channel
141                let registry = LanguageModelRegistry::global(cx);
142                cx.subscribe(&registry, move |_picker, _, event, _cx| match event {
143                    language_model::Event::ProviderStateChanged(_)
144                    | language_model::Event::AddedProvider(_)
145                    | language_model::Event::RemovedProvider(_) => {
146                        refresh_tx.unbounded_send(()).ok();
147                    }
148                    _ => {}
149                })
150                .detach();
151
152                // Spawn a task that listens for refresh signals and updates the picker
153                cx.spawn_in(window, async move |this, cx| {
154                    while let Some(()) = refresh_rx.next().await {
155                        let result = this.update_in(cx, |picker, window, cx| {
156                            picker.delegate.all_models = Arc::new(all_models(cx));
157                            picker.refresh(window, cx);
158                        });
159                        if result.is_err() {
160                            // Picker was dropped, exit the loop
161                            break;
162                        }
163                    }
164                })
165            },
166            popover_styles,
167            focus_handle,
168        }
169    }
170
171    fn get_active_model_index(
172        entries: &[LanguageModelPickerEntry],
173        active_model: Option<ConfiguredModel>,
174    ) -> usize {
175        entries
176            .iter()
177            .position(|entry| {
178                if let LanguageModelPickerEntry::Model(model) = entry {
179                    active_model
180                        .as_ref()
181                        .map(|active_model| {
182                            active_model.model.id() == model.model.id()
183                                && active_model.provider.id() == model.model.provider_id()
184                        })
185                        .unwrap_or_default()
186                } else {
187                    false
188                }
189            })
190            .unwrap_or(0)
191    }
192
193    /// Authenticates all providers in the [`LanguageModelRegistry`].
194    ///
195    /// We do this so that we can populate the language selector with all of the
196    /// models from the configured providers.
197    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
198        let authenticate_all_providers = LanguageModelRegistry::global(cx)
199            .read(cx)
200            .providers()
201            .iter()
202            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
203            .collect::<Vec<_>>();
204
205        cx.spawn(async move |_cx| {
206            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
207                if let Err(err) = authenticate_task.await {
208                    if matches!(err, AuthenticateError::CredentialsNotFound) {
209                        // Since we're authenticating these providers in the
210                        // background for the purposes of populating the
211                        // language selector, we don't care about providers
212                        // where the credentials are not found.
213                    } else {
214                        // Some providers have noisy failure states that we
215                        // don't want to spam the logs with every time the
216                        // language model selector is initialized.
217                        //
218                        // Ideally these should have more clear failure modes
219                        // that we know are safe to ignore here, like what we do
220                        // with `CredentialsNotFound` above.
221                        match provider_id.0.as_ref() {
222                            "lmstudio" | "ollama" => {
223                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
224                                //
225                                // These fail noisily, so we don't log them.
226                            }
227                            "copilot_chat" => {
228                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
229                            }
230                            _ => {
231                                log::error!(
232                                    "Failed to authenticate provider: {}: {err:#}",
233                                    provider_name.0
234                                );
235                            }
236                        }
237                    }
238                }
239            }
240        })
241    }
242
243    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
244        (self.get_active_model)(cx)
245    }
246}
247
248struct GroupedModels {
249    recommended: Vec<ModelInfo>,
250    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
251}
252
253impl GroupedModels {
254    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
255        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
256        for model in all {
257            let provider = model.model.provider_id();
258            if let Some(models) = all_by_provider.get_mut(&provider) {
259                models.push(model);
260            } else {
261                all_by_provider.insert(provider, vec![model]);
262            }
263        }
264
265        Self {
266            recommended,
267            all: all_by_provider,
268        }
269    }
270
271    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
272        let mut entries = Vec::new();
273
274        if !self.recommended.is_empty() {
275            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
276            entries.extend(
277                self.recommended
278                    .iter()
279                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
280            );
281        }
282
283        for models in self.all.values() {
284            if models.is_empty() {
285                continue;
286            }
287            entries.push(LanguageModelPickerEntry::Separator(
288                models[0].model.provider_name().0,
289            ));
290            entries.extend(
291                models
292                    .iter()
293                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
294            );
295        }
296        entries
297    }
298}
299
300enum LanguageModelPickerEntry {
301    Model(ModelInfo),
302    Separator(SharedString),
303}
304
305struct ModelMatcher {
306    models: Vec<ModelInfo>,
307    bg_executor: BackgroundExecutor,
308    candidates: Vec<StringMatchCandidate>,
309}
310
311impl ModelMatcher {
312    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
313        let candidates = Self::make_match_candidates(&models);
314        Self {
315            models,
316            bg_executor,
317            candidates,
318        }
319    }
320
321    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
322        let mut matches = self.bg_executor.block(match_strings(
323            &self.candidates,
324            query,
325            false,
326            true,
327            100,
328            &Default::default(),
329            self.bg_executor.clone(),
330        ));
331
332        let sorting_key = |mat: &StringMatch| {
333            let candidate = &self.candidates[mat.candidate_id];
334            (Reverse(OrderedFloat(mat.score)), candidate.id)
335        };
336        matches.sort_unstable_by_key(sorting_key);
337
338        let matched_models: Vec<_> = matches
339            .into_iter()
340            .map(|mat| self.models[mat.candidate_id].clone())
341            .collect();
342
343        matched_models
344    }
345
346    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
347        self.models
348            .iter()
349            .filter(|m| {
350                m.model
351                    .name()
352                    .0
353                    .to_lowercase()
354                    .contains(&query.to_lowercase())
355            })
356            .cloned()
357            .collect::<Vec<_>>()
358    }
359
360    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
361        model_infos
362            .iter()
363            .enumerate()
364            .map(|(index, model)| {
365                StringMatchCandidate::new(
366                    index,
367                    &format!(
368                        "{}/{}",
369                        &model.model.provider_name().0,
370                        &model.model.name().0
371                    ),
372                )
373            })
374            .collect::<Vec<_>>()
375    }
376}
377
378impl PickerDelegate for LanguageModelPickerDelegate {
379    type ListItem = AnyElement;
380
381    fn match_count(&self) -> usize {
382        self.filtered_entries.len()
383    }
384
385    fn selected_index(&self) -> usize {
386        self.selected_index
387    }
388
389    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
390        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
391        cx.notify();
392    }
393
394    fn can_select(
395        &mut self,
396        ix: usize,
397        _window: &mut Window,
398        _cx: &mut Context<Picker<Self>>,
399    ) -> bool {
400        match self.filtered_entries.get(ix) {
401            Some(LanguageModelPickerEntry::Model(_)) => true,
402            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
403        }
404    }
405
406    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
407        "Select a model…".into()
408    }
409
410    fn update_matches(
411        &mut self,
412        query: String,
413        window: &mut Window,
414        cx: &mut Context<Picker<Self>>,
415    ) -> Task<()> {
416        let all_models = self.all_models.clone();
417        let active_model = (self.get_active_model)(cx);
418        let bg_executor = cx.background_executor();
419
420        let language_model_registry = LanguageModelRegistry::global(cx);
421
422        let configured_providers = language_model_registry
423            .read(cx)
424            .visible_providers()
425            .into_iter()
426            .filter(|provider| provider.is_authenticated(cx))
427            .collect::<Vec<_>>();
428
429        let configured_provider_ids = configured_providers
430            .iter()
431            .map(|provider| provider.id())
432            .collect::<Vec<_>>();
433
434        let recommended_models = all_models
435            .recommended
436            .iter()
437            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
438            .cloned()
439            .collect::<Vec<_>>();
440
441        let available_models = all_models
442            .all
443            .values()
444            .flat_map(|models| models.iter())
445            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
446            .cloned()
447            .collect::<Vec<_>>();
448
449        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
450        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
451
452        let recommended = matcher_rec.exact_search(&query);
453        let all = matcher_all.fuzzy_search(&query);
454
455        let filtered_models = GroupedModels::new(all, recommended);
456
457        cx.spawn_in(window, async move |this, cx| {
458            this.update_in(cx, |this, window, cx| {
459                this.delegate.filtered_entries = filtered_models.entries();
460                // Finds the currently selected model in the list
461                let new_index =
462                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
463                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
464                cx.notify();
465            })
466            .ok();
467        })
468    }
469
470    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
471        if let Some(LanguageModelPickerEntry::Model(model_info)) =
472            self.filtered_entries.get(self.selected_index)
473        {
474            let model = model_info.model.clone();
475            (self.on_model_changed)(model.clone(), cx);
476
477            let current_index = self.selected_index;
478            self.set_selected_index(current_index, window, cx);
479
480            cx.emit(DismissEvent);
481        }
482    }
483
484    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
485        cx.emit(DismissEvent);
486    }
487
488    fn render_match(
489        &self,
490        ix: usize,
491        selected: bool,
492        _: &mut Window,
493        cx: &mut Context<Picker<Self>>,
494    ) -> Option<Self::ListItem> {
495        match self.filtered_entries.get(ix)? {
496            LanguageModelPickerEntry::Separator(title) => Some(
497                div()
498                    .px_2()
499                    .pb_1()
500                    .when(ix > 1, |this| {
501                        this.mt_1()
502                            .pt_2()
503                            .border_t_1()
504                            .border_color(cx.theme().colors().border_variant)
505                    })
506                    .child(
507                        Label::new(title)
508                            .size(LabelSize::XSmall)
509                            .color(Color::Muted),
510                    )
511                    .into_any_element(),
512            ),
513            LanguageModelPickerEntry::Model(model_info) => {
514                let active_model = (self.get_active_model)(cx);
515                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
516                let active_model_id = active_model.map(|m| m.model.id());
517
518                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
519                    && Some(model_info.model.id()) == active_model_id;
520
521                let model_icon_color = if is_selected {
522                    Color::Accent
523                } else {
524                    Color::Muted
525                };
526
527                Some(
528                    ListItem::new(ix)
529                        .inset(true)
530                        .spacing(ListItemSpacing::Sparse)
531                        .toggle_state(selected)
532                        .child(
533                            h_flex()
534                                .w_full()
535                                .gap_1p5()
536                                .child(match &model_info.icon {
537                                    ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
538                                        .color(model_icon_color)
539                                        .size(IconSize::Small),
540                                    ProviderIcon::Path(icon_path) => {
541                                        Icon::from_external_svg(icon_path.clone())
542                                            .color(model_icon_color)
543                                            .size(IconSize::Small)
544                                    }
545                                })
546                                .child(Label::new(model_info.model.name().0).truncate()),
547                        )
548                        .end_slot(div().pr_3().when(is_selected, |this| {
549                            this.child(
550                                Icon::new(IconName::Check)
551                                    .color(Color::Accent)
552                                    .size(IconSize::Small),
553                            )
554                        }))
555                        .into_any_element(),
556                )
557            }
558        }
559    }
560
561    fn render_footer(
562        &self,
563        _window: &mut Window,
564        cx: &mut Context<Picker<Self>>,
565    ) -> Option<gpui::AnyElement> {
566        let focus_handle = self.focus_handle.clone();
567
568        if !self.popover_styles {
569            return None;
570        }
571
572        Some(
573            h_flex()
574                .w_full()
575                .p_1p5()
576                .border_t_1()
577                .border_color(cx.theme().colors().border_variant)
578                .child(
579                    Button::new("configure", "Configure")
580                        .full_width()
581                        .style(ButtonStyle::Outlined)
582                        .key_binding(
583                            KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
584                                .map(|kb| kb.size(rems_from_px(12.))),
585                        )
586                        .on_click(|_, window, cx| {
587                            window.dispatch_action(OpenSettings.boxed_clone(), cx);
588                        }),
589                )
590                .into_any(),
591        )
592    }
593}
594
595#[cfg(test)]
596mod tests {
597    use super::*;
598    use futures::{future::BoxFuture, stream::BoxStream};
599    use gpui::{AsyncApp, TestAppContext, http_client};
600    use language_model::{
601        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
602        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
603        LanguageModelRequest, LanguageModelToolChoice,
604    };
605    use ui::IconName;
606
607    #[derive(Clone)]
608    struct TestLanguageModel {
609        name: LanguageModelName,
610        id: LanguageModelId,
611        provider_id: LanguageModelProviderId,
612        provider_name: LanguageModelProviderName,
613    }
614
615    impl TestLanguageModel {
616        fn new(name: &str, provider: &str) -> Self {
617            Self {
618                name: LanguageModelName::from(name.to_string()),
619                id: LanguageModelId::from(name.to_string()),
620                provider_id: LanguageModelProviderId::from(provider.to_string()),
621                provider_name: LanguageModelProviderName::from(provider.to_string()),
622            }
623        }
624    }
625
626    impl LanguageModel for TestLanguageModel {
627        fn id(&self) -> LanguageModelId {
628            self.id.clone()
629        }
630
631        fn name(&self) -> LanguageModelName {
632            self.name.clone()
633        }
634
635        fn provider_id(&self) -> LanguageModelProviderId {
636            self.provider_id.clone()
637        }
638
639        fn provider_name(&self) -> LanguageModelProviderName {
640            self.provider_name.clone()
641        }
642
643        fn supports_tools(&self) -> bool {
644            false
645        }
646
647        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
648            false
649        }
650
651        fn supports_images(&self) -> bool {
652            false
653        }
654
655        fn telemetry_id(&self) -> String {
656            format!("{}/{}", self.provider_id.0, self.name.0)
657        }
658
659        fn max_token_count(&self) -> u64 {
660            1000
661        }
662
663        fn count_tokens(
664            &self,
665            _: LanguageModelRequest,
666            _: &App,
667        ) -> BoxFuture<'static, http_client::Result<u64>> {
668            unimplemented!()
669        }
670
671        fn stream_completion(
672            &self,
673            _: LanguageModelRequest,
674            _: &AsyncApp,
675        ) -> BoxFuture<
676            'static,
677            Result<
678                BoxStream<
679                    'static,
680                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
681                >,
682                LanguageModelCompletionError,
683            >,
684        > {
685            unimplemented!()
686        }
687    }
688
689    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
690        model_specs
691            .into_iter()
692            .map(|(provider, name)| ModelInfo {
693                model: Arc::new(TestLanguageModel::new(name, provider)),
694                icon: ProviderIcon::Name(IconName::Ai),
695            })
696            .collect()
697    }
698
699    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
700        assert_eq!(
701            result.len(),
702            expected.len(),
703            "Number of models doesn't match"
704        );
705
706        for (i, expected_name) in expected.iter().enumerate() {
707            assert_eq!(
708                result[i].model.telemetry_id(),
709                *expected_name,
710                "Model at position {} doesn't match expected model",
711                i
712            );
713        }
714    }
715
716    #[gpui::test]
717    fn test_exact_match(cx: &mut TestAppContext) {
718        let models = create_models(vec![
719            ("zed", "Claude 3.7 Sonnet"),
720            ("zed", "Claude 3.7 Sonnet Thinking"),
721            ("zed", "gpt-4.1"),
722            ("zed", "gpt-4.1-nano"),
723            ("openai", "gpt-3.5-turbo"),
724            ("openai", "gpt-4.1"),
725            ("openai", "gpt-4.1-nano"),
726            ("ollama", "mistral"),
727            ("ollama", "deepseek"),
728        ]);
729        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
730
731        // The order of models should be maintained, case doesn't matter
732        let results = matcher.exact_search("GPT-4.1");
733        assert_models_eq(
734            results,
735            vec![
736                "zed/gpt-4.1",
737                "zed/gpt-4.1-nano",
738                "openai/gpt-4.1",
739                "openai/gpt-4.1-nano",
740            ],
741        );
742    }
743
744    #[gpui::test]
745    fn test_fuzzy_match(cx: &mut TestAppContext) {
746        let models = create_models(vec![
747            ("zed", "Claude 3.7 Sonnet"),
748            ("zed", "Claude 3.7 Sonnet Thinking"),
749            ("zed", "gpt-4.1"),
750            ("zed", "gpt-4.1-nano"),
751            ("openai", "gpt-3.5-turbo"),
752            ("openai", "gpt-4.1"),
753            ("openai", "gpt-4.1-nano"),
754            ("ollama", "mistral"),
755            ("ollama", "deepseek"),
756        ]);
757        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
758
759        // Results should preserve models order whenever possible.
760        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
761        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
762        // so it should appear first in the results.
763        let results = matcher.fuzzy_search("41");
764        assert_models_eq(
765            results,
766            vec![
767                "zed/gpt-4.1",
768                "openai/gpt-4.1",
769                "zed/gpt-4.1-nano",
770                "openai/gpt-4.1-nano",
771            ],
772        );
773
774        // Model provider should be searchable as well
775        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
776        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
777
778        // Fuzzy search
779        let results = matcher.fuzzy_search("z4n");
780        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
781    }
782
783    #[gpui::test]
784    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
785        let recommended_models = create_models(vec![("zed", "claude")]);
786        let all_models = create_models(vec![
787            ("zed", "claude"), // Should also appear in "other"
788            ("zed", "gemini"),
789            ("copilot", "o3"),
790        ]);
791
792        let grouped_models = GroupedModels::new(all_models, recommended_models);
793
794        let actual_all_models = grouped_models
795            .all
796            .values()
797            .flatten()
798            .cloned()
799            .collect::<Vec<_>>();
800
801        // Recommended models should also appear in "all"
802        assert_models_eq(
803            actual_all_models,
804            vec!["zed/claude", "zed/gemini", "copilot/o3"],
805        );
806    }
807
808    #[gpui::test]
809    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
810        let recommended_models = create_models(vec![("zed", "claude")]);
811        let all_models = create_models(vec![
812            ("zed", "claude"), // Should also appear in "other"
813            ("zed", "gemini"),
814            ("copilot", "claude"), // Different provider, should appear in "other"
815        ]);
816
817        let grouped_models = GroupedModels::new(all_models, recommended_models);
818
819        let actual_all_models = grouped_models
820            .all
821            .values()
822            .flatten()
823            .cloned()
824            .collect::<Vec<_>>();
825
826        // All models should appear in "all" regardless of recommended status
827        assert_models_eq(
828            actual_all_models,
829            vec!["zed/claude", "zed/gemini", "copilot/claude"],
830        );
831    }
832}