language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use agent_settings::AgentSettings;
  4use collections::{HashMap, HashSet, IndexMap};
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{
  7    Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
  8};
  9use language_model::{
 10    AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId,
 11    LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
 12};
 13use ordered_float::OrderedFloat;
 14use picker::{Picker, PickerDelegate};
 15use settings::Settings;
 16use ui::prelude::*;
 17use zed_actions::agent::OpenSettings;
 18
 19use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 20
 21type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 22type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 23type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &App) + 'static>;
 24
 25pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 26
 27pub fn language_model_selector(
 28    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 29    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 30    on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
 31    popover_styles: bool,
 32    focus_handle: FocusHandle,
 33    window: &mut Window,
 34    cx: &mut Context<LanguageModelSelector>,
 35) -> LanguageModelSelector {
 36    let delegate = LanguageModelPickerDelegate::new(
 37        get_active_model,
 38        on_model_changed,
 39        on_toggle_favorite,
 40        popover_styles,
 41        focus_handle,
 42        window,
 43        cx,
 44    );
 45
 46    if popover_styles {
 47        Picker::list(delegate, window, cx)
 48            .show_scrollbar(true)
 49            .width(rems(20.))
 50            .max_height(Some(rems(20.).into()))
 51    } else {
 52        Picker::list(delegate, window, cx).show_scrollbar(true)
 53    }
 54}
 55
 56fn all_models(cx: &App) -> GroupedModels {
 57    let lm_registry = LanguageModelRegistry::global(cx).read(cx);
 58    let providers = lm_registry.visible_providers();
 59
 60    let mut favorites_index = FavoritesIndex::default();
 61
 62    for sel in &AgentSettings::get_global(cx).favorite_models {
 63        favorites_index
 64            .entry(sel.provider.0.clone().into())
 65            .or_default()
 66            .insert(sel.model.clone().into());
 67    }
 68
 69    let recommended = providers
 70        .iter()
 71        .flat_map(|provider| {
 72            provider
 73                .recommended_models(cx)
 74                .into_iter()
 75                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 76        })
 77        .collect();
 78
 79    let all = providers
 80        .iter()
 81        .flat_map(|provider| {
 82            provider
 83                .provided_models(cx)
 84                .into_iter()
 85                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 86        })
 87        .collect();
 88
 89    GroupedModels::new(all, recommended)
 90}
 91
 92type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
 93
 94#[derive(Clone)]
 95struct ModelInfo {
 96    model: Arc<dyn LanguageModel>,
 97    icon: IconOrSvg,
 98    is_favorite: bool,
 99}
100
101impl ModelInfo {
102    fn new(
103        provider: &dyn LanguageModelProvider,
104        model: Arc<dyn LanguageModel>,
105        favorites_index: &FavoritesIndex,
106    ) -> Self {
107        let is_favorite = favorites_index
108            .get(&provider.id())
109            .map_or(false, |set| set.contains(&model.id()));
110
111        Self {
112            model,
113            icon: provider.icon(),
114            is_favorite,
115        }
116    }
117}
118
119pub struct LanguageModelPickerDelegate {
120    on_model_changed: OnModelChanged,
121    get_active_model: GetActiveModel,
122    on_toggle_favorite: OnToggleFavorite,
123    all_models: Arc<GroupedModels>,
124    filtered_entries: Vec<LanguageModelPickerEntry>,
125    selected_index: usize,
126    _authenticate_all_providers_task: Task<()>,
127    _subscriptions: Vec<Subscription>,
128    popover_styles: bool,
129    focus_handle: FocusHandle,
130}
131
132impl LanguageModelPickerDelegate {
133    fn new(
134        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
135        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
136        on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
137        popover_styles: bool,
138        focus_handle: FocusHandle,
139        window: &mut Window,
140        cx: &mut Context<Picker<Self>>,
141    ) -> Self {
142        let on_model_changed = Arc::new(on_model_changed);
143        let models = all_models(cx);
144        let entries = models.entries();
145
146        Self {
147            on_model_changed,
148            all_models: Arc::new(models),
149            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
150            filtered_entries: entries,
151            get_active_model: Arc::new(get_active_model),
152            on_toggle_favorite: Arc::new(on_toggle_favorite),
153            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
154            _subscriptions: vec![cx.subscribe_in(
155                &LanguageModelRegistry::global(cx),
156                window,
157                |picker, _, event, window, cx| {
158                    match event {
159                        language_model::Event::ProviderStateChanged(_)
160                        | language_model::Event::AddedProvider(_)
161                        | language_model::Event::RemovedProvider(_) => {
162                            let query = picker.query(cx);
163                            picker.delegate.all_models = Arc::new(all_models(cx));
164                            // Update matches will automatically drop the previous task
165                            // if we get a provider event again
166                            picker.update_matches(query, window, cx)
167                        }
168                        _ => {}
169                    }
170                },
171            )],
172            popover_styles,
173            focus_handle,
174        }
175    }
176
177    fn get_active_model_index(
178        entries: &[LanguageModelPickerEntry],
179        active_model: Option<ConfiguredModel>,
180    ) -> usize {
181        entries
182            .iter()
183            .position(|entry| {
184                if let LanguageModelPickerEntry::Model(model) = entry {
185                    active_model
186                        .as_ref()
187                        .map(|active_model| {
188                            active_model.model.id() == model.model.id()
189                                && active_model.provider.id() == model.model.provider_id()
190                        })
191                        .unwrap_or_default()
192                } else {
193                    false
194                }
195            })
196            .unwrap_or(0)
197    }
198
199    /// Authenticates all providers in the [`LanguageModelRegistry`].
200    ///
201    /// We do this so that we can populate the language selector with all of the
202    /// models from the configured providers.
203    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
204        let authenticate_all_providers = LanguageModelRegistry::global(cx)
205            .read(cx)
206            .visible_providers()
207            .iter()
208            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
209            .collect::<Vec<_>>();
210
211        cx.spawn(async move |_cx| {
212            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
213                if let Err(err) = authenticate_task.await {
214                    if matches!(err, AuthenticateError::CredentialsNotFound) {
215                        // Since we're authenticating these providers in the
216                        // background for the purposes of populating the
217                        // language selector, we don't care about providers
218                        // where the credentials are not found.
219                    } else {
220                        // Some providers have noisy failure states that we
221                        // don't want to spam the logs with every time the
222                        // language model selector is initialized.
223                        //
224                        // Ideally these should have more clear failure modes
225                        // that we know are safe to ignore here, like what we do
226                        // with `CredentialsNotFound` above.
227                        match provider_id.0.as_ref() {
228                            "lmstudio" | "ollama" => {
229                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
230                                //
231                                // These fail noisily, so we don't log them.
232                            }
233                            "copilot_chat" => {
234                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
235                            }
236                            _ => {
237                                log::error!(
238                                    "Failed to authenticate provider: {}: {err:#}",
239                                    provider_name.0
240                                );
241                            }
242                        }
243                    }
244                }
245            }
246        })
247    }
248
249    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
250        (self.get_active_model)(cx)
251    }
252
253    pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
254        if self.all_models.favorites.is_empty() {
255            return;
256        }
257
258        let active_model = (self.get_active_model)(cx);
259        let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
260        let active_model_id = active_model.as_ref().map(|m| m.model.id());
261
262        let current_index = self
263            .all_models
264            .favorites
265            .iter()
266            .position(|info| {
267                Some(info.model.provider_id()) == active_provider_id
268                    && Some(info.model.id()) == active_model_id
269            })
270            .unwrap_or(usize::MAX);
271
272        let next_index = if current_index == usize::MAX {
273            0
274        } else {
275            (current_index + 1) % self.all_models.favorites.len()
276        };
277
278        let next_model = self.all_models.favorites[next_index].model.clone();
279
280        (self.on_model_changed)(next_model, cx);
281
282        // Align the picker selection with the newly-active model
283        let new_index =
284            Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
285        self.set_selected_index(new_index, window, cx);
286    }
287}
288
289struct GroupedModels {
290    favorites: Vec<ModelInfo>,
291    recommended: Vec<ModelInfo>,
292    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
293}
294
295impl GroupedModels {
296    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
297        let favorites = all
298            .iter()
299            .filter(|info| info.is_favorite)
300            .cloned()
301            .collect();
302
303        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
304        for model in all {
305            let provider = model.model.provider_id();
306            if let Some(models) = all_by_provider.get_mut(&provider) {
307                models.push(model);
308            } else {
309                all_by_provider.insert(provider, vec![model]);
310            }
311        }
312
313        Self {
314            favorites,
315            recommended,
316            all: all_by_provider,
317        }
318    }
319
320    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
321        let mut entries = Vec::new();
322
323        if !self.favorites.is_empty() {
324            entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
325            for info in &self.favorites {
326                entries.push(LanguageModelPickerEntry::Model(info.clone()));
327            }
328        }
329
330        if !self.recommended.is_empty() {
331            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
332            for info in &self.recommended {
333                entries.push(LanguageModelPickerEntry::Model(info.clone()));
334            }
335        }
336
337        for models in self.all.values() {
338            if models.is_empty() {
339                continue;
340            }
341            entries.push(LanguageModelPickerEntry::Separator(
342                models[0].model.provider_name().0,
343            ));
344            for info in models {
345                entries.push(LanguageModelPickerEntry::Model(info.clone()));
346            }
347        }
348
349        entries
350    }
351}
352
353enum LanguageModelPickerEntry {
354    Model(ModelInfo),
355    Separator(SharedString),
356}
357
358struct ModelMatcher {
359    models: Vec<ModelInfo>,
360    bg_executor: BackgroundExecutor,
361    candidates: Vec<StringMatchCandidate>,
362}
363
364impl ModelMatcher {
365    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
366        let candidates = Self::make_match_candidates(&models);
367        Self {
368            models,
369            bg_executor,
370            candidates,
371        }
372    }
373
374    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
375        let mut matches = self.bg_executor.block(match_strings(
376            &self.candidates,
377            query,
378            false,
379            true,
380            100,
381            &Default::default(),
382            self.bg_executor.clone(),
383        ));
384
385        let sorting_key = |mat: &StringMatch| {
386            let candidate = &self.candidates[mat.candidate_id];
387            (Reverse(OrderedFloat(mat.score)), candidate.id)
388        };
389        matches.sort_unstable_by_key(sorting_key);
390
391        let matched_models: Vec<_> = matches
392            .into_iter()
393            .map(|mat| self.models[mat.candidate_id].clone())
394            .collect();
395
396        matched_models
397    }
398
399    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
400        self.models
401            .iter()
402            .filter(|m| {
403                m.model
404                    .name()
405                    .0
406                    .to_lowercase()
407                    .contains(&query.to_lowercase())
408            })
409            .cloned()
410            .collect::<Vec<_>>()
411    }
412
413    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
414        model_infos
415            .iter()
416            .enumerate()
417            .map(|(index, model)| {
418                StringMatchCandidate::new(
419                    index,
420                    &format!(
421                        "{}/{}",
422                        &model.model.provider_name().0,
423                        &model.model.name().0
424                    ),
425                )
426            })
427            .collect::<Vec<_>>()
428    }
429}
430
431impl PickerDelegate for LanguageModelPickerDelegate {
432    type ListItem = AnyElement;
433
434    fn match_count(&self) -> usize {
435        self.filtered_entries.len()
436    }
437
438    fn selected_index(&self) -> usize {
439        self.selected_index
440    }
441
442    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
443        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
444        cx.notify();
445    }
446
447    fn can_select(
448        &mut self,
449        ix: usize,
450        _window: &mut Window,
451        _cx: &mut Context<Picker<Self>>,
452    ) -> bool {
453        match self.filtered_entries.get(ix) {
454            Some(LanguageModelPickerEntry::Model(_)) => true,
455            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
456        }
457    }
458
459    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
460        "Select a model…".into()
461    }
462
463    fn update_matches(
464        &mut self,
465        query: String,
466        window: &mut Window,
467        cx: &mut Context<Picker<Self>>,
468    ) -> Task<()> {
469        let all_models = self.all_models.clone();
470        let active_model = (self.get_active_model)(cx);
471        let bg_executor = cx.background_executor();
472
473        let language_model_registry = LanguageModelRegistry::global(cx);
474
475        let configured_providers = language_model_registry
476            .read(cx)
477            .visible_providers()
478            .into_iter()
479            .filter(|provider| provider.is_authenticated(cx))
480            .collect::<Vec<_>>();
481
482        let configured_provider_ids = configured_providers
483            .iter()
484            .map(|provider| provider.id())
485            .collect::<Vec<_>>();
486
487        let recommended_models = all_models
488            .recommended
489            .iter()
490            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
491            .cloned()
492            .collect::<Vec<_>>();
493
494        let available_models = all_models
495            .all
496            .values()
497            .flat_map(|models| models.iter())
498            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
499            .cloned()
500            .collect::<Vec<_>>();
501
502        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
503        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
504
505        let recommended = matcher_rec.exact_search(&query);
506        let all = matcher_all.fuzzy_search(&query);
507
508        let filtered_models = GroupedModels::new(all, recommended);
509
510        cx.spawn_in(window, async move |this, cx| {
511            this.update_in(cx, |this, window, cx| {
512                this.delegate.filtered_entries = filtered_models.entries();
513                // Finds the currently selected model in the list
514                let new_index =
515                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
516                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
517                cx.notify();
518            })
519            .ok();
520        })
521    }
522
523    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
524        if let Some(LanguageModelPickerEntry::Model(model_info)) =
525            self.filtered_entries.get(self.selected_index)
526        {
527            let model = model_info.model.clone();
528            (self.on_model_changed)(model.clone(), cx);
529
530            let current_index = self.selected_index;
531            self.set_selected_index(current_index, window, cx);
532
533            cx.emit(DismissEvent);
534        }
535    }
536
537    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
538        cx.emit(DismissEvent);
539    }
540
541    fn render_match(
542        &self,
543        ix: usize,
544        selected: bool,
545        _: &mut Window,
546        cx: &mut Context<Picker<Self>>,
547    ) -> Option<Self::ListItem> {
548        match self.filtered_entries.get(ix)? {
549            LanguageModelPickerEntry::Separator(title) => {
550                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
551            }
552            LanguageModelPickerEntry::Model(model_info) => {
553                let active_model = (self.get_active_model)(cx);
554                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
555                let active_model_id = active_model.map(|m| m.model.id());
556
557                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
558                    && Some(model_info.model.id()) == active_model_id;
559
560                let is_favorite = model_info.is_favorite;
561                let handle_action_click = {
562                    let model = model_info.model.clone();
563                    let on_toggle_favorite = self.on_toggle_favorite.clone();
564                    move |cx: &App| on_toggle_favorite(model.clone(), !is_favorite, cx)
565                };
566
567                Some(
568                    ModelSelectorListItem::new(ix, model_info.model.name().0)
569                        .map(|this| match &model_info.icon {
570                            IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
571                            IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
572                        })
573                        .is_selected(is_selected)
574                        .is_focused(selected)
575                        .is_favorite(is_favorite)
576                        .on_toggle_favorite(handle_action_click)
577                        .into_any_element(),
578                )
579            }
580        }
581    }
582
583    fn render_footer(
584        &self,
585        _window: &mut Window,
586        _cx: &mut Context<Picker<Self>>,
587    ) -> Option<gpui::AnyElement> {
588        let focus_handle = self.focus_handle.clone();
589
590        if !self.popover_styles {
591            return None;
592        }
593
594        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use futures::{future::BoxFuture, stream::BoxStream};
602    use gpui::{AsyncApp, TestAppContext, http_client};
603    use language_model::{
604        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
605        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
606        LanguageModelRequest, LanguageModelToolChoice,
607    };
608    use ui::IconName;
609
610    #[derive(Clone)]
611    struct TestLanguageModel {
612        name: LanguageModelName,
613        id: LanguageModelId,
614        provider_id: LanguageModelProviderId,
615        provider_name: LanguageModelProviderName,
616    }
617
618    impl TestLanguageModel {
619        fn new(name: &str, provider: &str) -> Self {
620            Self {
621                name: LanguageModelName::from(name.to_string()),
622                id: LanguageModelId::from(name.to_string()),
623                provider_id: LanguageModelProviderId::from(provider.to_string()),
624                provider_name: LanguageModelProviderName::from(provider.to_string()),
625            }
626        }
627    }
628
629    impl LanguageModel for TestLanguageModel {
630        fn id(&self) -> LanguageModelId {
631            self.id.clone()
632        }
633
634        fn name(&self) -> LanguageModelName {
635            self.name.clone()
636        }
637
638        fn provider_id(&self) -> LanguageModelProviderId {
639            self.provider_id.clone()
640        }
641
642        fn provider_name(&self) -> LanguageModelProviderName {
643            self.provider_name.clone()
644        }
645
646        fn supports_tools(&self) -> bool {
647            false
648        }
649
650        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
651            false
652        }
653
654        fn supports_images(&self) -> bool {
655            false
656        }
657
658        fn telemetry_id(&self) -> String {
659            format!("{}/{}", self.provider_id.0, self.name.0)
660        }
661
662        fn max_token_count(&self) -> u64 {
663            1000
664        }
665
666        fn count_tokens(
667            &self,
668            _: LanguageModelRequest,
669            _: &App,
670        ) -> BoxFuture<'static, http_client::Result<u64>> {
671            unimplemented!()
672        }
673
674        fn stream_completion(
675            &self,
676            _: LanguageModelRequest,
677            _: &AsyncApp,
678        ) -> BoxFuture<
679            'static,
680            Result<
681                BoxStream<
682                    'static,
683                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
684                >,
685                LanguageModelCompletionError,
686            >,
687        > {
688            unimplemented!()
689        }
690    }
691
692    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
693        create_models_with_favorites(model_specs, vec![])
694    }
695
696    fn create_models_with_favorites(
697        model_specs: Vec<(&str, &str)>,
698        favorites: Vec<(&str, &str)>,
699    ) -> Vec<ModelInfo> {
700        model_specs
701            .into_iter()
702            .map(|(provider, name)| {
703                let is_favorite = favorites
704                    .iter()
705                    .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
706                ModelInfo {
707                    model: Arc::new(TestLanguageModel::new(name, provider)),
708                    icon: IconOrSvg::Icon(IconName::Ai),
709                    is_favorite,
710                }
711            })
712            .collect()
713    }
714
715    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
716        assert_eq!(
717            result.len(),
718            expected.len(),
719            "Number of models doesn't match"
720        );
721
722        for (i, expected_name) in expected.iter().enumerate() {
723            assert_eq!(
724                result[i].model.telemetry_id(),
725                *expected_name,
726                "Model at position {} doesn't match expected model",
727                i
728            );
729        }
730    }
731
732    #[gpui::test]
733    fn test_exact_match(cx: &mut TestAppContext) {
734        let models = create_models(vec![
735            ("zed", "Claude 3.7 Sonnet"),
736            ("zed", "Claude 3.7 Sonnet Thinking"),
737            ("zed", "gpt-4.1"),
738            ("zed", "gpt-4.1-nano"),
739            ("openai", "gpt-3.5-turbo"),
740            ("openai", "gpt-4.1"),
741            ("openai", "gpt-4.1-nano"),
742            ("ollama", "mistral"),
743            ("ollama", "deepseek"),
744        ]);
745        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
746
747        // The order of models should be maintained, case doesn't matter
748        let results = matcher.exact_search("GPT-4.1");
749        assert_models_eq(
750            results,
751            vec![
752                "zed/gpt-4.1",
753                "zed/gpt-4.1-nano",
754                "openai/gpt-4.1",
755                "openai/gpt-4.1-nano",
756            ],
757        );
758    }
759
760    #[gpui::test]
761    fn test_fuzzy_match(cx: &mut TestAppContext) {
762        let models = create_models(vec![
763            ("zed", "Claude 3.7 Sonnet"),
764            ("zed", "Claude 3.7 Sonnet Thinking"),
765            ("zed", "gpt-4.1"),
766            ("zed", "gpt-4.1-nano"),
767            ("openai", "gpt-3.5-turbo"),
768            ("openai", "gpt-4.1"),
769            ("openai", "gpt-4.1-nano"),
770            ("ollama", "mistral"),
771            ("ollama", "deepseek"),
772        ]);
773        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
774
775        // Results should preserve models order whenever possible.
776        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
777        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
778        // so it should appear first in the results.
779        let results = matcher.fuzzy_search("41");
780        assert_models_eq(
781            results,
782            vec![
783                "zed/gpt-4.1",
784                "openai/gpt-4.1",
785                "zed/gpt-4.1-nano",
786                "openai/gpt-4.1-nano",
787            ],
788        );
789
790        // Model provider should be searchable as well
791        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
792        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
793
794        // Fuzzy search
795        let results = matcher.fuzzy_search("z4n");
796        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
797    }
798
799    #[gpui::test]
800    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
801        let recommended_models = create_models(vec![("zed", "claude")]);
802        let all_models = create_models(vec![
803            ("zed", "claude"), // Should also appear in "other"
804            ("zed", "gemini"),
805            ("copilot", "o3"),
806        ]);
807
808        let grouped_models = GroupedModels::new(all_models, recommended_models);
809
810        let actual_all_models = grouped_models
811            .all
812            .values()
813            .flatten()
814            .cloned()
815            .collect::<Vec<_>>();
816
817        // Recommended models should also appear in "all"
818        assert_models_eq(
819            actual_all_models,
820            vec!["zed/claude", "zed/gemini", "copilot/o3"],
821        );
822    }
823
824    #[gpui::test]
825    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
826        let recommended_models = create_models(vec![("zed", "claude")]);
827        let all_models = create_models(vec![
828            ("zed", "claude"), // Should also appear in "other"
829            ("zed", "gemini"),
830            ("copilot", "claude"), // Different provider, should appear in "other"
831        ]);
832
833        let grouped_models = GroupedModels::new(all_models, recommended_models);
834
835        let actual_all_models = grouped_models
836            .all
837            .values()
838            .flatten()
839            .cloned()
840            .collect::<Vec<_>>();
841
842        // All models should appear in "all" regardless of recommended status
843        assert_models_eq(
844            actual_all_models,
845            vec!["zed/claude", "zed/gemini", "copilot/claude"],
846        );
847    }
848
849    #[gpui::test]
850    fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
851        let recommended_models = create_models(vec![("zed", "claude")]);
852        let all_models = create_models_with_favorites(
853            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
854            vec![("zed", "gemini")],
855        );
856
857        let grouped_models = GroupedModels::new(all_models, recommended_models);
858        let entries = grouped_models.entries();
859
860        assert!(matches!(
861            entries.first(),
862            Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
863        ));
864
865        assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
866    }
867
868    #[gpui::test]
869    fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
870        let recommended_models = create_models(vec![("zed", "claude")]);
871        let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
872
873        let grouped_models = GroupedModels::new(all_models, recommended_models);
874        let entries = grouped_models.entries();
875
876        assert!(matches!(
877            entries.first(),
878            Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
879        ));
880
881        assert!(grouped_models.favorites.is_empty());
882    }
883
884    #[gpui::test]
885    fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
886        let recommended_models =
887            create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
888        let all_models = create_models_with_favorites(
889            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
890            vec![("zed", "claude")],
891        );
892
893        let grouped_models = GroupedModels::new(all_models, recommended_models);
894        let entries = grouped_models.entries();
895
896        for entry in &entries {
897            if let LanguageModelPickerEntry::Model(info) = entry {
898                if info.model.telemetry_id() == "zed/claude" {
899                    assert!(info.is_favorite, "zed/claude should be a favorite");
900                } else {
901                    assert!(
902                        !info.is_favorite,
903                        "{} should not be a favorite",
904                        info.model.telemetry_id()
905                    );
906                }
907            }
908        }
909    }
910
911    #[gpui::test]
912    fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
913        let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
914
915        let recommended_models =
916            create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
917
918        let all_models = create_models_with_favorites(
919            vec![
920                ("zed", "claude"),
921                ("zed", "gemini"),
922                ("openai", "gpt-4"),
923                ("openai", "gpt-3.5"),
924            ],
925            favorites,
926        );
927
928        let grouped_models = GroupedModels::new(all_models, recommended_models);
929
930        assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
931        assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
932        assert_models_eq(
933            grouped_models.all.values().flatten().cloned().collect(),
934            vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
935        );
936    }
937}