language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use agent_settings::AgentSettings;
  4use collections::{HashMap, HashSet, IndexMap};
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{
  7    Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, ForegroundExecutor,
  8    Subscription, Task,
  9};
 10use language_model::{
 11    ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId, LanguageModelProvider,
 12    LanguageModelProviderId, LanguageModelRegistry,
 13};
 14use ordered_float::OrderedFloat;
 15use picker::{Picker, PickerDelegate};
 16use settings::Settings;
 17use ui::prelude::*;
 18use zed_actions::agent::OpenSettings;
 19
 20use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 21
 22type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 23type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 24type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static>;
 25
 26pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 27
 28pub fn language_model_selector(
 29    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 30    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 31    on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
 32    popover_styles: bool,
 33    focus_handle: FocusHandle,
 34    window: &mut Window,
 35    cx: &mut Context<LanguageModelSelector>,
 36) -> LanguageModelSelector {
 37    let delegate = LanguageModelPickerDelegate::new(
 38        get_active_model,
 39        on_model_changed,
 40        on_toggle_favorite,
 41        popover_styles,
 42        focus_handle,
 43        window,
 44        cx,
 45    );
 46
 47    if popover_styles {
 48        Picker::list(delegate, window, cx)
 49            .show_scrollbar(true)
 50            .width(rems(20.))
 51            .max_height(Some(rems(20.).into()))
 52    } else {
 53        Picker::list(delegate, window, cx).show_scrollbar(true)
 54    }
 55}
 56
 57fn all_models(cx: &App) -> GroupedModels {
 58    let lm_registry = LanguageModelRegistry::global(cx).read(cx);
 59    let providers = lm_registry.visible_providers();
 60
 61    let mut favorites_index = FavoritesIndex::default();
 62
 63    for sel in &AgentSettings::get_global(cx).favorite_models {
 64        favorites_index
 65            .entry(sel.provider.0.clone().into())
 66            .or_default()
 67            .insert(sel.model.clone().into());
 68    }
 69
 70    let recommended = providers
 71        .iter()
 72        .flat_map(|provider| {
 73            provider
 74                .recommended_models(cx)
 75                .into_iter()
 76                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 77        })
 78        .collect();
 79
 80    let all = providers
 81        .iter()
 82        .flat_map(|provider| {
 83            provider
 84                .provided_models(cx)
 85                .into_iter()
 86                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 87        })
 88        .collect();
 89
 90    GroupedModels::new(all, recommended)
 91}
 92
 93type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
 94
 95#[derive(Clone)]
 96struct ModelInfo {
 97    model: Arc<dyn LanguageModel>,
 98    icon: IconOrSvg,
 99    is_favorite: bool,
100}
101
102impl ModelInfo {
103    fn new(
104        provider: &dyn LanguageModelProvider,
105        model: Arc<dyn LanguageModel>,
106        favorites_index: &FavoritesIndex,
107    ) -> Self {
108        let is_favorite = favorites_index
109            .get(&provider.id())
110            .map_or(false, |set| set.contains(&model.id()));
111
112        Self {
113            model,
114            icon: provider.icon(),
115            is_favorite,
116        }
117    }
118}
119
120pub struct LanguageModelPickerDelegate {
121    on_model_changed: OnModelChanged,
122    get_active_model: GetActiveModel,
123    on_toggle_favorite: OnToggleFavorite,
124    all_models: Arc<GroupedModels>,
125    filtered_entries: Vec<LanguageModelPickerEntry>,
126    selected_index: usize,
127    _subscriptions: Vec<Subscription>,
128    popover_styles: bool,
129    focus_handle: FocusHandle,
130}
131
132impl LanguageModelPickerDelegate {
133    fn new(
134        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
135        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
136        on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
137        popover_styles: bool,
138        focus_handle: FocusHandle,
139        window: &mut Window,
140        cx: &mut Context<Picker<Self>>,
141    ) -> Self {
142        let on_model_changed = Arc::new(on_model_changed);
143        let models = all_models(cx);
144        let entries = models.entries();
145
146        Self {
147            on_model_changed,
148            all_models: Arc::new(models),
149            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
150            filtered_entries: entries,
151            get_active_model: Arc::new(get_active_model),
152            on_toggle_favorite: Arc::new(on_toggle_favorite),
153            _subscriptions: vec![cx.subscribe_in(
154                &LanguageModelRegistry::global(cx),
155                window,
156                |picker, _, event, window, cx| {
157                    match event {
158                        language_model::Event::ProviderStateChanged(_)
159                        | language_model::Event::AddedProvider(_)
160                        | language_model::Event::RemovedProvider(_) => {
161                            let query = picker.query(cx);
162                            picker.delegate.all_models = Arc::new(all_models(cx));
163                            // Update matches will automatically drop the previous task
164                            // if we get a provider event again
165                            picker.update_matches(query, window, cx)
166                        }
167                        _ => {}
168                    }
169                },
170            )],
171            popover_styles,
172            focus_handle,
173        }
174    }
175
176    fn get_active_model_index(
177        entries: &[LanguageModelPickerEntry],
178        active_model: Option<ConfiguredModel>,
179    ) -> usize {
180        entries
181            .iter()
182            .position(|entry| {
183                if let LanguageModelPickerEntry::Model(model) = entry {
184                    active_model
185                        .as_ref()
186                        .map(|active_model| {
187                            active_model.model.id() == model.model.id()
188                                && active_model.provider.id() == model.model.provider_id()
189                        })
190                        .unwrap_or_default()
191                } else {
192                    false
193                }
194            })
195            .unwrap_or(0)
196    }
197
198    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
199        (self.get_active_model)(cx)
200    }
201
202    pub fn favorites_count(&self) -> usize {
203        self.all_models.favorites.len()
204    }
205
206    pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
207        if self.all_models.favorites.is_empty() {
208            return;
209        }
210
211        let active_model = (self.get_active_model)(cx);
212        let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
213        let active_model_id = active_model.as_ref().map(|m| m.model.id());
214
215        let current_index = self
216            .all_models
217            .favorites
218            .iter()
219            .position(|info| {
220                Some(info.model.provider_id()) == active_provider_id
221                    && Some(info.model.id()) == active_model_id
222            })
223            .unwrap_or(usize::MAX);
224
225        let next_index = if current_index == usize::MAX {
226            0
227        } else {
228            (current_index + 1) % self.all_models.favorites.len()
229        };
230
231        let next_model = self.all_models.favorites[next_index].model.clone();
232
233        (self.on_model_changed)(next_model, cx);
234
235        // Align the picker selection with the newly-active model
236        let new_index =
237            Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
238        self.set_selected_index(new_index, window, cx);
239    }
240}
241
242struct GroupedModels {
243    favorites: Vec<ModelInfo>,
244    recommended: Vec<ModelInfo>,
245    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
246}
247
248impl GroupedModels {
249    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
250        let favorites = all
251            .iter()
252            .filter(|info| info.is_favorite)
253            .cloned()
254            .collect();
255
256        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
257        for model in all {
258            let provider = model.model.provider_id();
259            if let Some(models) = all_by_provider.get_mut(&provider) {
260                models.push(model);
261            } else {
262                all_by_provider.insert(provider, vec![model]);
263            }
264        }
265
266        Self {
267            favorites,
268            recommended,
269            all: all_by_provider,
270        }
271    }
272
273    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
274        let mut entries = Vec::new();
275
276        if !self.favorites.is_empty() {
277            entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
278            for info in &self.favorites {
279                entries.push(LanguageModelPickerEntry::Model(info.clone()));
280            }
281        }
282
283        if !self.recommended.is_empty() {
284            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
285            for info in &self.recommended {
286                entries.push(LanguageModelPickerEntry::Model(info.clone()));
287            }
288        }
289
290        for models in self.all.values() {
291            if models.is_empty() {
292                continue;
293            }
294            entries.push(LanguageModelPickerEntry::Separator(
295                models[0].model.provider_name().0,
296            ));
297            for info in models {
298                entries.push(LanguageModelPickerEntry::Model(info.clone()));
299            }
300        }
301
302        entries
303    }
304}
305
306enum LanguageModelPickerEntry {
307    Model(ModelInfo),
308    Separator(SharedString),
309}
310
311struct ModelMatcher {
312    models: Vec<ModelInfo>,
313    fg_executor: ForegroundExecutor,
314    bg_executor: BackgroundExecutor,
315    candidates: Vec<StringMatchCandidate>,
316}
317
318impl ModelMatcher {
319    fn new(
320        models: Vec<ModelInfo>,
321        fg_executor: ForegroundExecutor,
322        bg_executor: BackgroundExecutor,
323    ) -> ModelMatcher {
324        let candidates = Self::make_match_candidates(&models);
325        Self {
326            models,
327            fg_executor,
328            bg_executor,
329            candidates,
330        }
331    }
332
333    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
334        let mut matches = self.fg_executor.block_on(match_strings(
335            &self.candidates,
336            query,
337            false,
338            true,
339            100,
340            &Default::default(),
341            self.bg_executor.clone(),
342        ));
343
344        let sorting_key = |mat: &StringMatch| {
345            let candidate = &self.candidates[mat.candidate_id];
346            (Reverse(OrderedFloat(mat.score)), candidate.id)
347        };
348        matches.sort_unstable_by_key(sorting_key);
349
350        let matched_models: Vec<_> = matches
351            .into_iter()
352            .map(|mat| self.models[mat.candidate_id].clone())
353            .collect();
354
355        matched_models
356    }
357
358    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
359        self.models
360            .iter()
361            .filter(|m| {
362                m.model
363                    .name()
364                    .0
365                    .to_lowercase()
366                    .contains(&query.to_lowercase())
367            })
368            .cloned()
369            .collect::<Vec<_>>()
370    }
371
372    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
373        model_infos
374            .iter()
375            .enumerate()
376            .map(|(index, model)| {
377                StringMatchCandidate::new(
378                    index,
379                    &format!(
380                        "{}/{}",
381                        &model.model.provider_name().0,
382                        &model.model.name().0
383                    ),
384                )
385            })
386            .collect::<Vec<_>>()
387    }
388}
389
390impl PickerDelegate for LanguageModelPickerDelegate {
391    type ListItem = AnyElement;
392
393    fn match_count(&self) -> usize {
394        self.filtered_entries.len()
395    }
396
397    fn selected_index(&self) -> usize {
398        self.selected_index
399    }
400
401    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
402        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
403        cx.notify();
404    }
405
406    fn can_select(&self, ix: usize, _window: &mut Window, _cx: &mut Context<Picker<Self>>) -> bool {
407        match self.filtered_entries.get(ix) {
408            Some(LanguageModelPickerEntry::Model(_)) => true,
409            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
410        }
411    }
412
413    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
414        "Select a model…".into()
415    }
416
417    fn update_matches(
418        &mut self,
419        query: String,
420        window: &mut Window,
421        cx: &mut Context<Picker<Self>>,
422    ) -> Task<()> {
423        let all_models = self.all_models.clone();
424        let active_model = (self.get_active_model)(cx);
425        let fg_executor = cx.foreground_executor();
426        let bg_executor = cx.background_executor();
427
428        let language_model_registry = LanguageModelRegistry::global(cx);
429
430        let configured_providers = language_model_registry
431            .read(cx)
432            .visible_providers()
433            .into_iter()
434            .filter(|provider| provider.is_authenticated(cx))
435            .collect::<Vec<_>>();
436
437        let configured_provider_ids = configured_providers
438            .iter()
439            .map(|provider| provider.id())
440            .collect::<Vec<_>>();
441
442        let recommended_models = all_models
443            .recommended
444            .iter()
445            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
446            .cloned()
447            .collect::<Vec<_>>();
448
449        let available_models = all_models
450            .all
451            .values()
452            .flat_map(|models| models.iter())
453            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
454            .cloned()
455            .collect::<Vec<_>>();
456
457        let matcher_rec =
458            ModelMatcher::new(recommended_models, fg_executor.clone(), bg_executor.clone());
459        let matcher_all =
460            ModelMatcher::new(available_models, fg_executor.clone(), bg_executor.clone());
461
462        let recommended = matcher_rec.exact_search(&query);
463        let all = matcher_all.fuzzy_search(&query);
464
465        let filtered_models = GroupedModels::new(all, recommended);
466
467        cx.spawn_in(window, async move |this, cx| {
468            this.update_in(cx, |this, window, cx| {
469                this.delegate.filtered_entries = filtered_models.entries();
470                // Finds the currently selected model in the list
471                let new_index =
472                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
473                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
474                cx.notify();
475            })
476            .ok();
477        })
478    }
479
480    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
481        if let Some(LanguageModelPickerEntry::Model(model_info)) =
482            self.filtered_entries.get(self.selected_index)
483        {
484            let model = model_info.model.clone();
485            (self.on_model_changed)(model.clone(), cx);
486
487            let current_index = self.selected_index;
488            self.set_selected_index(current_index, window, cx);
489
490            cx.emit(DismissEvent);
491        }
492    }
493
494    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
495        cx.emit(DismissEvent);
496    }
497
498    fn render_match(
499        &self,
500        ix: usize,
501        selected: bool,
502        _: &mut Window,
503        cx: &mut Context<Picker<Self>>,
504    ) -> Option<Self::ListItem> {
505        match self.filtered_entries.get(ix)? {
506            LanguageModelPickerEntry::Separator(title) => {
507                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
508            }
509            LanguageModelPickerEntry::Model(model_info) => {
510                let active_model = (self.get_active_model)(cx);
511                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
512                let active_model_id = active_model.map(|m| m.model.id());
513
514                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
515                    && Some(model_info.model.id()) == active_model_id;
516
517                let model_cost = model_info
518                    .model
519                    .model_cost_info()
520                    .map(|cost| cost.to_shared_string());
521
522                let is_favorite = model_info.is_favorite;
523                let handle_action_click = {
524                    let model = model_info.model.clone();
525                    let on_toggle_favorite = self.on_toggle_favorite.clone();
526                    cx.listener(move |picker, _, window, cx| {
527                        on_toggle_favorite(model.clone(), !is_favorite, cx);
528                        picker.refresh(window, cx);
529                    })
530                };
531
532                Some(
533                    ModelSelectorListItem::new(ix, model_info.model.name().0)
534                        .map(|this| match &model_info.icon {
535                            IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
536                            IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
537                        })
538                        .is_selected(is_selected)
539                        .is_focused(selected)
540                        .is_latest(model_info.model.is_latest())
541                        .is_favorite(is_favorite)
542                        .cost_info(model_cost)
543                        .on_toggle_favorite(handle_action_click)
544                        .into_any_element(),
545                )
546            }
547        }
548    }
549
550    fn render_footer(
551        &self,
552        _window: &mut Window,
553        _cx: &mut Context<Picker<Self>>,
554    ) -> Option<gpui::AnyElement> {
555        let focus_handle = self.focus_handle.clone();
556
557        if !self.popover_styles {
558            return None;
559        }
560
561        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568    use futures::{future::BoxFuture, stream::BoxStream};
569    use gpui::{AsyncApp, TestAppContext};
570    use language_model::{
571        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
572        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
573        LanguageModelRequest, LanguageModelToolChoice,
574    };
575    use ui::IconName;
576
577    #[derive(Clone)]
578    struct TestLanguageModel {
579        name: LanguageModelName,
580        id: LanguageModelId,
581        provider_id: LanguageModelProviderId,
582        provider_name: LanguageModelProviderName,
583    }
584
585    impl TestLanguageModel {
586        fn new(name: &str, provider: &str) -> Self {
587            Self {
588                name: LanguageModelName::from(name.to_string()),
589                id: LanguageModelId::from(name.to_string()),
590                provider_id: LanguageModelProviderId::from(provider.to_string()),
591                provider_name: LanguageModelProviderName::from(provider.to_string()),
592            }
593        }
594    }
595
596    impl LanguageModel for TestLanguageModel {
597        fn id(&self) -> LanguageModelId {
598            self.id.clone()
599        }
600
601        fn name(&self) -> LanguageModelName {
602            self.name.clone()
603        }
604
605        fn provider_id(&self) -> LanguageModelProviderId {
606            self.provider_id.clone()
607        }
608
609        fn provider_name(&self) -> LanguageModelProviderName {
610            self.provider_name.clone()
611        }
612
613        fn supports_tools(&self) -> bool {
614            false
615        }
616
617        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
618            false
619        }
620
621        fn supports_images(&self) -> bool {
622            false
623        }
624
625        fn telemetry_id(&self) -> String {
626            format!("{}/{}", self.provider_id.0, self.name.0)
627        }
628
629        fn max_token_count(&self) -> u64 {
630            1000
631        }
632
633        fn stream_completion(
634            &self,
635            _: LanguageModelRequest,
636            _: &AsyncApp,
637        ) -> BoxFuture<
638            'static,
639            Result<
640                BoxStream<
641                    'static,
642                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
643                >,
644                LanguageModelCompletionError,
645            >,
646        > {
647            unimplemented!()
648        }
649    }
650
651    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
652        create_models_with_favorites(model_specs, vec![])
653    }
654
655    fn create_models_with_favorites(
656        model_specs: Vec<(&str, &str)>,
657        favorites: Vec<(&str, &str)>,
658    ) -> Vec<ModelInfo> {
659        model_specs
660            .into_iter()
661            .map(|(provider, name)| {
662                let is_favorite = favorites
663                    .iter()
664                    .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
665                ModelInfo {
666                    model: Arc::new(TestLanguageModel::new(name, provider)),
667                    icon: IconOrSvg::Icon(IconName::ZedAgent),
668                    is_favorite,
669                }
670            })
671            .collect()
672    }
673
674    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
675        assert_eq!(
676            result.len(),
677            expected.len(),
678            "Number of models doesn't match"
679        );
680
681        for (i, expected_name) in expected.iter().enumerate() {
682            assert_eq!(
683                result[i].model.telemetry_id(),
684                *expected_name,
685                "Model at position {} doesn't match expected model",
686                i
687            );
688        }
689    }
690
691    #[gpui::test]
692    fn test_exact_match(cx: &mut TestAppContext) {
693        let models = create_models(vec![
694            ("zed", "Claude 3.7 Sonnet"),
695            ("zed", "Claude 3.7 Sonnet Thinking"),
696            ("zed", "gpt-5"),
697            ("zed", "gpt-5-mini"),
698            ("openai", "gpt-3.5-turbo"),
699            ("openai", "gpt-5"),
700            ("openai", "gpt-5-mini"),
701            ("ollama", "mistral"),
702            ("ollama", "deepseek"),
703        ]);
704        let matcher = ModelMatcher::new(
705            models,
706            cx.foreground_executor().clone(),
707            cx.background_executor.clone(),
708        );
709
710        // The order of models should be maintained, case doesn't matter
711        let results = matcher.exact_search("GPT-5");
712        assert_models_eq(
713            results,
714            vec![
715                "zed/gpt-5",
716                "zed/gpt-5-mini",
717                "openai/gpt-5",
718                "openai/gpt-5-mini",
719            ],
720        );
721    }
722
723    #[gpui::test]
724    fn test_fuzzy_match(cx: &mut TestAppContext) {
725        let models = create_models(vec![
726            ("zed", "Claude 3.7 Sonnet"),
727            ("zed", "Claude 3.7 Sonnet Thinking"),
728            ("zed", "gpt-5"),
729            ("zed", "gpt-5-mini"),
730            ("openai", "gpt-3.5-turbo"),
731            ("openai", "gpt-5"),
732            ("openai", "gpt-5-mini"),
733            ("ollama", "mistral"),
734            ("ollama", "deepseek"),
735        ]);
736        let matcher = ModelMatcher::new(
737            models,
738            cx.foreground_executor().clone(),
739            cx.background_executor.clone(),
740        );
741
742        // Results should preserve models order whenever possible.
743        // In the case below, `zed/gpt-5-mini` and `openai/gpt-5-mini` have identical
744        // similarity scores, but `zed/gpt-5-mini` was higher in the models list,
745        // so it should appear first in the results.
746        let results = matcher.fuzzy_search("mini");
747        assert_models_eq(results, vec!["zed/gpt-5-mini", "openai/gpt-5-mini"]);
748
749        // Model provider should be searchable as well
750        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
751        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
752
753        // Fuzzy search - search for Claude to get the Thinking variant
754        let results = matcher.fuzzy_search("thinking");
755        assert_models_eq(results, vec!["zed/Claude 3.7 Sonnet Thinking"]);
756    }
757
758    #[gpui::test]
759    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
760        let recommended_models = create_models(vec![("zed", "claude")]);
761        let all_models = create_models(vec![
762            ("zed", "claude"), // Should also appear in "other"
763            ("zed", "gemini"),
764            ("copilot", "o3"),
765        ]);
766
767        let grouped_models = GroupedModels::new(all_models, recommended_models);
768
769        let actual_all_models = grouped_models
770            .all
771            .values()
772            .flatten()
773            .cloned()
774            .collect::<Vec<_>>();
775
776        // Recommended models should also appear in "all"
777        assert_models_eq(
778            actual_all_models,
779            vec!["zed/claude", "zed/gemini", "copilot/o3"],
780        );
781    }
782
783    #[gpui::test]
784    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
785        let recommended_models = create_models(vec![("zed", "claude")]);
786        let all_models = create_models(vec![
787            ("zed", "claude"), // Should also appear in "other"
788            ("zed", "gemini"),
789            ("copilot", "claude"), // Different provider, should appear in "other"
790        ]);
791
792        let grouped_models = GroupedModels::new(all_models, recommended_models);
793
794        let actual_all_models = grouped_models
795            .all
796            .values()
797            .flatten()
798            .cloned()
799            .collect::<Vec<_>>();
800
801        // All models should appear in "all" regardless of recommended status
802        assert_models_eq(
803            actual_all_models,
804            vec!["zed/claude", "zed/gemini", "copilot/claude"],
805        );
806    }
807
808    #[gpui::test]
809    fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
810        let recommended_models = create_models(vec![("zed", "claude")]);
811        let all_models = create_models_with_favorites(
812            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
813            vec![("zed", "gemini")],
814        );
815
816        let grouped_models = GroupedModels::new(all_models, recommended_models);
817        let entries = grouped_models.entries();
818
819        assert!(matches!(
820            entries.first(),
821            Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
822        ));
823
824        assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
825    }
826
827    #[gpui::test]
828    fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
829        let recommended_models = create_models(vec![("zed", "claude")]);
830        let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
831
832        let grouped_models = GroupedModels::new(all_models, recommended_models);
833        let entries = grouped_models.entries();
834
835        assert!(matches!(
836            entries.first(),
837            Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
838        ));
839
840        assert!(grouped_models.favorites.is_empty());
841    }
842
843    #[gpui::test]
844    fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
845        let recommended_models =
846            create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
847        let all_models = create_models_with_favorites(
848            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
849            vec![("zed", "claude")],
850        );
851
852        let grouped_models = GroupedModels::new(all_models, recommended_models);
853        let entries = grouped_models.entries();
854
855        for entry in &entries {
856            if let LanguageModelPickerEntry::Model(info) = entry {
857                if info.model.telemetry_id() == "zed/claude" {
858                    assert!(info.is_favorite, "zed/claude should be a favorite");
859                } else {
860                    assert!(
861                        !info.is_favorite,
862                        "{} should not be a favorite",
863                        info.model.telemetry_id()
864                    );
865                }
866            }
867        }
868    }
869
870    #[gpui::test]
871    fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
872        let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
873
874        let recommended_models =
875            create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
876
877        let all_models = create_models_with_favorites(
878            vec![
879                ("zed", "claude"),
880                ("zed", "gemini"),
881                ("openai", "gpt-4"),
882                ("openai", "gpt-3.5"),
883            ],
884            favorites,
885        );
886
887        let grouped_models = GroupedModels::new(all_models, recommended_models);
888
889        assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
890        assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
891        assert_models_eq(
892            grouped_models.all.values().flatten().cloned().collect(),
893            vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
894        );
895    }
896}