language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use agent_settings::AgentSettings;
  4use collections::{HashMap, HashSet, IndexMap};
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{
  7    Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
  8};
  9use language_model::{
 10    AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId,
 11    LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
 12};
 13use ordered_float::OrderedFloat;
 14use picker::{Picker, PickerDelegate};
 15use settings::Settings;
 16use ui::prelude::*;
 17use zed_actions::agent::OpenSettings;
 18
 19use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 20
 21type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 22type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 23type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static>;
 24
 25pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 26
 27pub fn language_model_selector(
 28    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 29    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 30    on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
 31    popover_styles: bool,
 32    focus_handle: FocusHandle,
 33    window: &mut Window,
 34    cx: &mut Context<LanguageModelSelector>,
 35) -> LanguageModelSelector {
 36    let delegate = LanguageModelPickerDelegate::new(
 37        get_active_model,
 38        on_model_changed,
 39        on_toggle_favorite,
 40        popover_styles,
 41        focus_handle,
 42        window,
 43        cx,
 44    );
 45
 46    if popover_styles {
 47        Picker::list(delegate, window, cx)
 48            .show_scrollbar(true)
 49            .width(rems(20.))
 50            .max_height(Some(rems(20.).into()))
 51    } else {
 52        Picker::list(delegate, window, cx).show_scrollbar(true)
 53    }
 54}
 55
 56fn all_models(cx: &App) -> GroupedModels {
 57    let lm_registry = LanguageModelRegistry::global(cx).read(cx);
 58    let providers = lm_registry.visible_providers();
 59
 60    let mut favorites_index = FavoritesIndex::default();
 61
 62    for sel in &AgentSettings::get_global(cx).favorite_models {
 63        favorites_index
 64            .entry(sel.provider.0.clone().into())
 65            .or_default()
 66            .insert(sel.model.clone().into());
 67    }
 68
 69    let recommended = providers
 70        .iter()
 71        .flat_map(|provider| {
 72            provider
 73                .recommended_models(cx)
 74                .into_iter()
 75                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 76        })
 77        .collect();
 78
 79    let all = providers
 80        .iter()
 81        .flat_map(|provider| {
 82            provider
 83                .provided_models(cx)
 84                .into_iter()
 85                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 86        })
 87        .collect();
 88
 89    GroupedModels::new(all, recommended)
 90}
 91
 92type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
 93
 94#[derive(Clone)]
 95struct ModelInfo {
 96    model: Arc<dyn LanguageModel>,
 97    icon: IconOrSvg,
 98    is_favorite: bool,
 99}
100
101impl ModelInfo {
102    fn new(
103        provider: &dyn LanguageModelProvider,
104        model: Arc<dyn LanguageModel>,
105        favorites_index: &FavoritesIndex,
106    ) -> Self {
107        let is_favorite = favorites_index
108            .get(&provider.id())
109            .map_or(false, |set| set.contains(&model.id()));
110
111        Self {
112            model,
113            icon: provider.icon(),
114            is_favorite,
115        }
116    }
117}
118
119pub struct LanguageModelPickerDelegate {
120    on_model_changed: OnModelChanged,
121    get_active_model: GetActiveModel,
122    on_toggle_favorite: OnToggleFavorite,
123    all_models: Arc<GroupedModels>,
124    filtered_entries: Vec<LanguageModelPickerEntry>,
125    selected_index: usize,
126    _authenticate_all_providers_task: Task<()>,
127    _subscriptions: Vec<Subscription>,
128    popover_styles: bool,
129    focus_handle: FocusHandle,
130}
131
132impl LanguageModelPickerDelegate {
133    fn new(
134        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
135        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
136        on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
137        popover_styles: bool,
138        focus_handle: FocusHandle,
139        window: &mut Window,
140        cx: &mut Context<Picker<Self>>,
141    ) -> Self {
142        let on_model_changed = Arc::new(on_model_changed);
143        let models = all_models(cx);
144        let entries = models.entries();
145
146        Self {
147            on_model_changed,
148            all_models: Arc::new(models),
149            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
150            filtered_entries: entries,
151            get_active_model: Arc::new(get_active_model),
152            on_toggle_favorite: Arc::new(on_toggle_favorite),
153            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
154            _subscriptions: vec![cx.subscribe_in(
155                &LanguageModelRegistry::global(cx),
156                window,
157                |picker, _, event, window, cx| {
158                    match event {
159                        language_model::Event::ProviderStateChanged(_)
160                        | language_model::Event::AddedProvider(_)
161                        | language_model::Event::RemovedProvider(_) => {
162                            let query = picker.query(cx);
163                            picker.delegate.all_models = Arc::new(all_models(cx));
164                            // Update matches will automatically drop the previous task
165                            // if we get a provider event again
166                            picker.update_matches(query, window, cx)
167                        }
168                        _ => {}
169                    }
170                },
171            )],
172            popover_styles,
173            focus_handle,
174        }
175    }
176
177    fn get_active_model_index(
178        entries: &[LanguageModelPickerEntry],
179        active_model: Option<ConfiguredModel>,
180    ) -> usize {
181        entries
182            .iter()
183            .position(|entry| {
184                if let LanguageModelPickerEntry::Model(model) = entry {
185                    active_model
186                        .as_ref()
187                        .map(|active_model| {
188                            active_model.model.id() == model.model.id()
189                                && active_model.provider.id() == model.model.provider_id()
190                        })
191                        .unwrap_or_default()
192                } else {
193                    false
194                }
195            })
196            .unwrap_or(0)
197    }
198
199    /// Authenticates all providers in the [`LanguageModelRegistry`].
200    ///
201    /// We do this so that we can populate the language selector with all of the
202    /// models from the configured providers.
203    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
204        let authenticate_all_providers = LanguageModelRegistry::global(cx)
205            .read(cx)
206            .visible_providers()
207            .iter()
208            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
209            .collect::<Vec<_>>();
210
211        cx.spawn(async move |_cx| {
212            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
213                if let Err(err) = authenticate_task.await {
214                    if matches!(err, AuthenticateError::CredentialsNotFound) {
215                        // Since we're authenticating these providers in the
216                        // background for the purposes of populating the
217                        // language selector, we don't care about providers
218                        // where the credentials are not found.
219                    } else {
220                        // Some providers have noisy failure states that we
221                        // don't want to spam the logs with every time the
222                        // language model selector is initialized.
223                        //
224                        // Ideally these should have more clear failure modes
225                        // that we know are safe to ignore here, like what we do
226                        // with `CredentialsNotFound` above.
227                        match provider_id.0.as_ref() {
228                            "lmstudio" | "ollama" => {
229                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
230                                //
231                                // These fail noisily, so we don't log them.
232                            }
233                            "copilot_chat" => {
234                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
235                            }
236                            _ => {
237                                log::error!(
238                                    "Failed to authenticate provider: {}: {err:#}",
239                                    provider_name.0
240                                );
241                            }
242                        }
243                    }
244                }
245            }
246        })
247    }
248
249    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
250        (self.get_active_model)(cx)
251    }
252
253    pub fn favorites_count(&self) -> usize {
254        self.all_models.favorites.len()
255    }
256
257    pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
258        if self.all_models.favorites.is_empty() {
259            return;
260        }
261
262        let active_model = (self.get_active_model)(cx);
263        let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
264        let active_model_id = active_model.as_ref().map(|m| m.model.id());
265
266        let current_index = self
267            .all_models
268            .favorites
269            .iter()
270            .position(|info| {
271                Some(info.model.provider_id()) == active_provider_id
272                    && Some(info.model.id()) == active_model_id
273            })
274            .unwrap_or(usize::MAX);
275
276        let next_index = if current_index == usize::MAX {
277            0
278        } else {
279            (current_index + 1) % self.all_models.favorites.len()
280        };
281
282        let next_model = self.all_models.favorites[next_index].model.clone();
283
284        (self.on_model_changed)(next_model, cx);
285
286        // Align the picker selection with the newly-active model
287        let new_index =
288            Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
289        self.set_selected_index(new_index, window, cx);
290    }
291}
292
293struct GroupedModels {
294    favorites: Vec<ModelInfo>,
295    recommended: Vec<ModelInfo>,
296    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
297}
298
299impl GroupedModels {
300    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
301        let favorites = all
302            .iter()
303            .filter(|info| info.is_favorite)
304            .cloned()
305            .collect();
306
307        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
308        for model in all {
309            let provider = model.model.provider_id();
310            if let Some(models) = all_by_provider.get_mut(&provider) {
311                models.push(model);
312            } else {
313                all_by_provider.insert(provider, vec![model]);
314            }
315        }
316
317        Self {
318            favorites,
319            recommended,
320            all: all_by_provider,
321        }
322    }
323
324    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
325        let mut entries = Vec::new();
326
327        if !self.favorites.is_empty() {
328            entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
329            for info in &self.favorites {
330                entries.push(LanguageModelPickerEntry::Model(info.clone()));
331            }
332        }
333
334        if !self.recommended.is_empty() {
335            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
336            for info in &self.recommended {
337                entries.push(LanguageModelPickerEntry::Model(info.clone()));
338            }
339        }
340
341        for models in self.all.values() {
342            if models.is_empty() {
343                continue;
344            }
345            entries.push(LanguageModelPickerEntry::Separator(
346                models[0].model.provider_name().0,
347            ));
348            for info in models {
349                entries.push(LanguageModelPickerEntry::Model(info.clone()));
350            }
351        }
352
353        entries
354    }
355}
356
357enum LanguageModelPickerEntry {
358    Model(ModelInfo),
359    Separator(SharedString),
360}
361
362struct ModelMatcher {
363    models: Vec<ModelInfo>,
364    bg_executor: BackgroundExecutor,
365    candidates: Vec<StringMatchCandidate>,
366}
367
368impl ModelMatcher {
369    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
370        let candidates = Self::make_match_candidates(&models);
371        Self {
372            models,
373            bg_executor,
374            candidates,
375        }
376    }
377
378    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
379        let mut matches = self.bg_executor.block(match_strings(
380            &self.candidates,
381            query,
382            false,
383            true,
384            100,
385            &Default::default(),
386            self.bg_executor.clone(),
387        ));
388
389        let sorting_key = |mat: &StringMatch| {
390            let candidate = &self.candidates[mat.candidate_id];
391            (Reverse(OrderedFloat(mat.score)), candidate.id)
392        };
393        matches.sort_unstable_by_key(sorting_key);
394
395        let matched_models: Vec<_> = matches
396            .into_iter()
397            .map(|mat| self.models[mat.candidate_id].clone())
398            .collect();
399
400        matched_models
401    }
402
403    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
404        self.models
405            .iter()
406            .filter(|m| {
407                m.model
408                    .name()
409                    .0
410                    .to_lowercase()
411                    .contains(&query.to_lowercase())
412            })
413            .cloned()
414            .collect::<Vec<_>>()
415    }
416
417    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
418        model_infos
419            .iter()
420            .enumerate()
421            .map(|(index, model)| {
422                StringMatchCandidate::new(
423                    index,
424                    &format!(
425                        "{}/{}",
426                        &model.model.provider_name().0,
427                        &model.model.name().0
428                    ),
429                )
430            })
431            .collect::<Vec<_>>()
432    }
433}
434
435impl PickerDelegate for LanguageModelPickerDelegate {
436    type ListItem = AnyElement;
437
438    fn match_count(&self) -> usize {
439        self.filtered_entries.len()
440    }
441
442    fn selected_index(&self) -> usize {
443        self.selected_index
444    }
445
446    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
447        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
448        cx.notify();
449    }
450
451    fn can_select(
452        &mut self,
453        ix: usize,
454        _window: &mut Window,
455        _cx: &mut Context<Picker<Self>>,
456    ) -> bool {
457        match self.filtered_entries.get(ix) {
458            Some(LanguageModelPickerEntry::Model(_)) => true,
459            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
460        }
461    }
462
463    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
464        "Select a model…".into()
465    }
466
467    fn update_matches(
468        &mut self,
469        query: String,
470        window: &mut Window,
471        cx: &mut Context<Picker<Self>>,
472    ) -> Task<()> {
473        let all_models = self.all_models.clone();
474        let active_model = (self.get_active_model)(cx);
475        let bg_executor = cx.background_executor();
476
477        let language_model_registry = LanguageModelRegistry::global(cx);
478
479        let configured_providers = language_model_registry
480            .read(cx)
481            .visible_providers()
482            .into_iter()
483            .filter(|provider| provider.is_authenticated(cx))
484            .collect::<Vec<_>>();
485
486        let configured_provider_ids = configured_providers
487            .iter()
488            .map(|provider| provider.id())
489            .collect::<Vec<_>>();
490
491        let recommended_models = all_models
492            .recommended
493            .iter()
494            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
495            .cloned()
496            .collect::<Vec<_>>();
497
498        let available_models = all_models
499            .all
500            .values()
501            .flat_map(|models| models.iter())
502            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
503            .cloned()
504            .collect::<Vec<_>>();
505
506        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
507        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
508
509        let recommended = matcher_rec.exact_search(&query);
510        let all = matcher_all.fuzzy_search(&query);
511
512        let filtered_models = GroupedModels::new(all, recommended);
513
514        cx.spawn_in(window, async move |this, cx| {
515            this.update_in(cx, |this, window, cx| {
516                this.delegate.filtered_entries = filtered_models.entries();
517                // Finds the currently selected model in the list
518                let new_index =
519                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
520                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
521                cx.notify();
522            })
523            .ok();
524        })
525    }
526
527    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
528        if let Some(LanguageModelPickerEntry::Model(model_info)) =
529            self.filtered_entries.get(self.selected_index)
530        {
531            let model = model_info.model.clone();
532            (self.on_model_changed)(model.clone(), cx);
533
534            let current_index = self.selected_index;
535            self.set_selected_index(current_index, window, cx);
536
537            cx.emit(DismissEvent);
538        }
539    }
540
541    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
542        cx.emit(DismissEvent);
543    }
544
545    fn render_match(
546        &self,
547        ix: usize,
548        selected: bool,
549        _: &mut Window,
550        cx: &mut Context<Picker<Self>>,
551    ) -> Option<Self::ListItem> {
552        match self.filtered_entries.get(ix)? {
553            LanguageModelPickerEntry::Separator(title) => {
554                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
555            }
556            LanguageModelPickerEntry::Model(model_info) => {
557                let active_model = (self.get_active_model)(cx);
558                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
559                let active_model_id = active_model.map(|m| m.model.id());
560
561                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
562                    && Some(model_info.model.id()) == active_model_id;
563
564                let is_favorite = model_info.is_favorite;
565                let handle_action_click = {
566                    let model = model_info.model.clone();
567                    let on_toggle_favorite = self.on_toggle_favorite.clone();
568                    cx.listener(move |picker, _, window, cx| {
569                        on_toggle_favorite(model.clone(), !is_favorite, cx);
570                        picker.refresh(window, cx);
571                    })
572                };
573
574                Some(
575                    ModelSelectorListItem::new(ix, model_info.model.name().0)
576                        .map(|this| match &model_info.icon {
577                            IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
578                            IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
579                        })
580                        .is_selected(is_selected)
581                        .is_focused(selected)
582                        .is_favorite(is_favorite)
583                        .on_toggle_favorite(handle_action_click)
584                        .into_any_element(),
585                )
586            }
587        }
588    }
589
590    fn render_footer(
591        &self,
592        _window: &mut Window,
593        _cx: &mut Context<Picker<Self>>,
594    ) -> Option<gpui::AnyElement> {
595        let focus_handle = self.focus_handle.clone();
596
597        if !self.popover_styles {
598            return None;
599        }
600
601        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
602    }
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608    use futures::{future::BoxFuture, stream::BoxStream};
609    use gpui::{AsyncApp, TestAppContext, http_client};
610    use language_model::{
611        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
612        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
613        LanguageModelRequest, LanguageModelToolChoice,
614    };
615    use ui::IconName;
616
617    #[derive(Clone)]
618    struct TestLanguageModel {
619        name: LanguageModelName,
620        id: LanguageModelId,
621        provider_id: LanguageModelProviderId,
622        provider_name: LanguageModelProviderName,
623    }
624
625    impl TestLanguageModel {
626        fn new(name: &str, provider: &str) -> Self {
627            Self {
628                name: LanguageModelName::from(name.to_string()),
629                id: LanguageModelId::from(name.to_string()),
630                provider_id: LanguageModelProviderId::from(provider.to_string()),
631                provider_name: LanguageModelProviderName::from(provider.to_string()),
632            }
633        }
634    }
635
636    impl LanguageModel for TestLanguageModel {
637        fn id(&self) -> LanguageModelId {
638            self.id.clone()
639        }
640
641        fn name(&self) -> LanguageModelName {
642            self.name.clone()
643        }
644
645        fn provider_id(&self) -> LanguageModelProviderId {
646            self.provider_id.clone()
647        }
648
649        fn provider_name(&self) -> LanguageModelProviderName {
650            self.provider_name.clone()
651        }
652
653        fn supports_tools(&self) -> bool {
654            false
655        }
656
657        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
658            false
659        }
660
661        fn supports_images(&self) -> bool {
662            false
663        }
664
665        fn telemetry_id(&self) -> String {
666            format!("{}/{}", self.provider_id.0, self.name.0)
667        }
668
669        fn max_token_count(&self) -> u64 {
670            1000
671        }
672
673        fn count_tokens(
674            &self,
675            _: LanguageModelRequest,
676            _: &App,
677        ) -> BoxFuture<'static, http_client::Result<u64>> {
678            unimplemented!()
679        }
680
681        fn stream_completion(
682            &self,
683            _: LanguageModelRequest,
684            _: &AsyncApp,
685        ) -> BoxFuture<
686            'static,
687            Result<
688                BoxStream<
689                    'static,
690                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
691                >,
692                LanguageModelCompletionError,
693            >,
694        > {
695            unimplemented!()
696        }
697    }
698
699    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
700        create_models_with_favorites(model_specs, vec![])
701    }
702
703    fn create_models_with_favorites(
704        model_specs: Vec<(&str, &str)>,
705        favorites: Vec<(&str, &str)>,
706    ) -> Vec<ModelInfo> {
707        model_specs
708            .into_iter()
709            .map(|(provider, name)| {
710                let is_favorite = favorites
711                    .iter()
712                    .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
713                ModelInfo {
714                    model: Arc::new(TestLanguageModel::new(name, provider)),
715                    icon: IconOrSvg::Icon(IconName::Ai),
716                    is_favorite,
717                }
718            })
719            .collect()
720    }
721
722    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
723        assert_eq!(
724            result.len(),
725            expected.len(),
726            "Number of models doesn't match"
727        );
728
729        for (i, expected_name) in expected.iter().enumerate() {
730            assert_eq!(
731                result[i].model.telemetry_id(),
732                *expected_name,
733                "Model at position {} doesn't match expected model",
734                i
735            );
736        }
737    }
738
739    #[gpui::test]
740    fn test_exact_match(cx: &mut TestAppContext) {
741        let models = create_models(vec![
742            ("zed", "Claude 3.7 Sonnet"),
743            ("zed", "Claude 3.7 Sonnet Thinking"),
744            ("zed", "gpt-4.1"),
745            ("zed", "gpt-4.1-nano"),
746            ("openai", "gpt-3.5-turbo"),
747            ("openai", "gpt-4.1"),
748            ("openai", "gpt-4.1-nano"),
749            ("ollama", "mistral"),
750            ("ollama", "deepseek"),
751        ]);
752        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
753
754        // The order of models should be maintained, case doesn't matter
755        let results = matcher.exact_search("GPT-4.1");
756        assert_models_eq(
757            results,
758            vec![
759                "zed/gpt-4.1",
760                "zed/gpt-4.1-nano",
761                "openai/gpt-4.1",
762                "openai/gpt-4.1-nano",
763            ],
764        );
765    }
766
767    #[gpui::test]
768    fn test_fuzzy_match(cx: &mut TestAppContext) {
769        let models = create_models(vec![
770            ("zed", "Claude 3.7 Sonnet"),
771            ("zed", "Claude 3.7 Sonnet Thinking"),
772            ("zed", "gpt-4.1"),
773            ("zed", "gpt-4.1-nano"),
774            ("openai", "gpt-3.5-turbo"),
775            ("openai", "gpt-4.1"),
776            ("openai", "gpt-4.1-nano"),
777            ("ollama", "mistral"),
778            ("ollama", "deepseek"),
779        ]);
780        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
781
782        // Results should preserve models order whenever possible.
783        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
784        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
785        // so it should appear first in the results.
786        let results = matcher.fuzzy_search("41");
787        assert_models_eq(
788            results,
789            vec![
790                "zed/gpt-4.1",
791                "openai/gpt-4.1",
792                "zed/gpt-4.1-nano",
793                "openai/gpt-4.1-nano",
794            ],
795        );
796
797        // Model provider should be searchable as well
798        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
799        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
800
801        // Fuzzy search
802        let results = matcher.fuzzy_search("z4n");
803        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
804    }
805
806    #[gpui::test]
807    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
808        let recommended_models = create_models(vec![("zed", "claude")]);
809        let all_models = create_models(vec![
810            ("zed", "claude"), // Should also appear in "other"
811            ("zed", "gemini"),
812            ("copilot", "o3"),
813        ]);
814
815        let grouped_models = GroupedModels::new(all_models, recommended_models);
816
817        let actual_all_models = grouped_models
818            .all
819            .values()
820            .flatten()
821            .cloned()
822            .collect::<Vec<_>>();
823
824        // Recommended models should also appear in "all"
825        assert_models_eq(
826            actual_all_models,
827            vec!["zed/claude", "zed/gemini", "copilot/o3"],
828        );
829    }
830
831    #[gpui::test]
832    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
833        let recommended_models = create_models(vec![("zed", "claude")]);
834        let all_models = create_models(vec![
835            ("zed", "claude"), // Should also appear in "other"
836            ("zed", "gemini"),
837            ("copilot", "claude"), // Different provider, should appear in "other"
838        ]);
839
840        let grouped_models = GroupedModels::new(all_models, recommended_models);
841
842        let actual_all_models = grouped_models
843            .all
844            .values()
845            .flatten()
846            .cloned()
847            .collect::<Vec<_>>();
848
849        // All models should appear in "all" regardless of recommended status
850        assert_models_eq(
851            actual_all_models,
852            vec!["zed/claude", "zed/gemini", "copilot/claude"],
853        );
854    }
855
856    #[gpui::test]
857    fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
858        let recommended_models = create_models(vec![("zed", "claude")]);
859        let all_models = create_models_with_favorites(
860            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
861            vec![("zed", "gemini")],
862        );
863
864        let grouped_models = GroupedModels::new(all_models, recommended_models);
865        let entries = grouped_models.entries();
866
867        assert!(matches!(
868            entries.first(),
869            Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
870        ));
871
872        assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
873    }
874
875    #[gpui::test]
876    fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
877        let recommended_models = create_models(vec![("zed", "claude")]);
878        let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
879
880        let grouped_models = GroupedModels::new(all_models, recommended_models);
881        let entries = grouped_models.entries();
882
883        assert!(matches!(
884            entries.first(),
885            Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
886        ));
887
888        assert!(grouped_models.favorites.is_empty());
889    }
890
891    #[gpui::test]
892    fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
893        let recommended_models =
894            create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
895        let all_models = create_models_with_favorites(
896            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
897            vec![("zed", "claude")],
898        );
899
900        let grouped_models = GroupedModels::new(all_models, recommended_models);
901        let entries = grouped_models.entries();
902
903        for entry in &entries {
904            if let LanguageModelPickerEntry::Model(info) = entry {
905                if info.model.telemetry_id() == "zed/claude" {
906                    assert!(info.is_favorite, "zed/claude should be a favorite");
907                } else {
908                    assert!(
909                        !info.is_favorite,
910                        "{} should not be a favorite",
911                        info.model.telemetry_id()
912                    );
913                }
914            }
915        }
916    }
917
918    #[gpui::test]
919    fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
920        let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
921
922        let recommended_models =
923            create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
924
925        let all_models = create_models_with_favorites(
926            vec![
927                ("zed", "claude"),
928                ("zed", "gemini"),
929                ("openai", "gpt-4"),
930                ("openai", "gpt-3.5"),
931            ],
932            favorites,
933        );
934
935        let grouped_models = GroupedModels::new(all_models, recommended_models);
936
937        assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
938        assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
939        assert_models_eq(
940            grouped_models.all.values().flatten().cloned().collect(),
941            vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
942        );
943    }
944}