language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use agent_settings::AgentSettings;
  4use collections::{HashMap, HashSet, IndexMap};
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{
  7    Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, ForegroundExecutor,
  8    Subscription, Task,
  9};
 10use language_model::{
 11    ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId, LanguageModelProvider,
 12    LanguageModelProviderId, LanguageModelRegistry,
 13};
 14use ordered_float::OrderedFloat;
 15use picker::{Picker, PickerDelegate};
 16use settings::Settings;
 17use ui::prelude::*;
 18use zed_actions::agent::OpenSettings;
 19
 20use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 21
 22type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 23type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 24type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static>;
 25
 26pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 27
 28pub fn language_model_selector(
 29    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 30    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 31    on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
 32    popover_styles: bool,
 33    focus_handle: FocusHandle,
 34    window: &mut Window,
 35    cx: &mut Context<LanguageModelSelector>,
 36) -> LanguageModelSelector {
 37    let delegate = LanguageModelPickerDelegate::new(
 38        get_active_model,
 39        on_model_changed,
 40        on_toggle_favorite,
 41        popover_styles,
 42        focus_handle,
 43        window,
 44        cx,
 45    );
 46
 47    if popover_styles {
 48        Picker::list(delegate, window, cx)
 49            .show_scrollbar(true)
 50            .width(rems(20.))
 51            .max_height(Some(rems(20.).into()))
 52    } else {
 53        Picker::list(delegate, window, cx).show_scrollbar(true)
 54    }
 55}
 56
 57fn all_models(cx: &App) -> GroupedModels {
 58    let lm_registry = LanguageModelRegistry::global(cx).read(cx);
 59    let providers = lm_registry.visible_providers();
 60
 61    let mut favorites_index = FavoritesIndex::default();
 62
 63    for sel in &AgentSettings::get_global(cx).favorite_models {
 64        favorites_index
 65            .entry(sel.provider.0.clone().into())
 66            .or_default()
 67            .insert(sel.model.clone().into());
 68    }
 69
 70    let recommended = providers
 71        .iter()
 72        .flat_map(|provider| {
 73            provider
 74                .recommended_models(cx)
 75                .into_iter()
 76                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 77        })
 78        .collect();
 79
 80    let all = providers
 81        .iter()
 82        .flat_map(|provider| {
 83            provider
 84                .provided_models(cx)
 85                .into_iter()
 86                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 87        })
 88        .collect();
 89
 90    GroupedModels::new(all, recommended)
 91}
 92
 93type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
 94
 95#[derive(Clone)]
 96struct ModelInfo {
 97    model: Arc<dyn LanguageModel>,
 98    icon: IconOrSvg,
 99    is_favorite: bool,
100}
101
102impl ModelInfo {
103    fn new(
104        provider: &dyn LanguageModelProvider,
105        model: Arc<dyn LanguageModel>,
106        favorites_index: &FavoritesIndex,
107    ) -> Self {
108        let is_favorite = favorites_index
109            .get(&provider.id())
110            .map_or(false, |set| set.contains(&model.id()));
111
112        Self {
113            model,
114            icon: provider.icon(),
115            is_favorite,
116        }
117    }
118}
119
120pub struct LanguageModelPickerDelegate {
121    on_model_changed: OnModelChanged,
122    get_active_model: GetActiveModel,
123    on_toggle_favorite: OnToggleFavorite,
124    all_models: Arc<GroupedModels>,
125    filtered_entries: Vec<LanguageModelPickerEntry>,
126    selected_index: usize,
127    _subscriptions: Vec<Subscription>,
128    popover_styles: bool,
129    focus_handle: FocusHandle,
130}
131
132impl LanguageModelPickerDelegate {
133    fn new(
134        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
135        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
136        on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
137        popover_styles: bool,
138        focus_handle: FocusHandle,
139        window: &mut Window,
140        cx: &mut Context<Picker<Self>>,
141    ) -> Self {
142        let on_model_changed = Arc::new(on_model_changed);
143        let models = all_models(cx);
144        let entries = models.entries();
145
146        Self {
147            on_model_changed,
148            all_models: Arc::new(models),
149            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
150            filtered_entries: entries,
151            get_active_model: Arc::new(get_active_model),
152            on_toggle_favorite: Arc::new(on_toggle_favorite),
153            _subscriptions: vec![cx.subscribe_in(
154                &LanguageModelRegistry::global(cx),
155                window,
156                |picker, _, event, window, cx| {
157                    match event {
158                        language_model::Event::ProviderStateChanged(_)
159                        | language_model::Event::AddedProvider(_)
160                        | language_model::Event::RemovedProvider(_) => {
161                            let query = picker.query(cx);
162                            picker.delegate.all_models = Arc::new(all_models(cx));
163                            // Update matches will automatically drop the previous task
164                            // if we get a provider event again
165                            picker.update_matches(query, window, cx)
166                        }
167                        _ => {}
168                    }
169                },
170            )],
171            popover_styles,
172            focus_handle,
173        }
174    }
175
176    fn get_active_model_index(
177        entries: &[LanguageModelPickerEntry],
178        active_model: Option<ConfiguredModel>,
179    ) -> usize {
180        entries
181            .iter()
182            .position(|entry| {
183                if let LanguageModelPickerEntry::Model(model) = entry {
184                    active_model
185                        .as_ref()
186                        .map(|active_model| {
187                            active_model.model.id() == model.model.id()
188                                && active_model.provider.id() == model.model.provider_id()
189                        })
190                        .unwrap_or_default()
191                } else {
192                    false
193                }
194            })
195            .unwrap_or(0)
196    }
197
198    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
199        (self.get_active_model)(cx)
200    }
201
202    pub fn favorites_count(&self) -> usize {
203        self.all_models.favorites.len()
204    }
205
206    pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
207        if self.all_models.favorites.is_empty() {
208            return;
209        }
210
211        let active_model = (self.get_active_model)(cx);
212        let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
213        let active_model_id = active_model.as_ref().map(|m| m.model.id());
214
215        let current_index = self
216            .all_models
217            .favorites
218            .iter()
219            .position(|info| {
220                Some(info.model.provider_id()) == active_provider_id
221                    && Some(info.model.id()) == active_model_id
222            })
223            .unwrap_or(usize::MAX);
224
225        let next_index = if current_index == usize::MAX {
226            0
227        } else {
228            (current_index + 1) % self.all_models.favorites.len()
229        };
230
231        let next_model = self.all_models.favorites[next_index].model.clone();
232
233        (self.on_model_changed)(next_model, cx);
234
235        // Align the picker selection with the newly-active model
236        let new_index =
237            Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
238        self.set_selected_index(new_index, window, cx);
239    }
240}
241
242struct GroupedModels {
243    favorites: Vec<ModelInfo>,
244    recommended: Vec<ModelInfo>,
245    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
246}
247
248impl GroupedModels {
249    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
250        let favorites = all
251            .iter()
252            .filter(|info| info.is_favorite)
253            .cloned()
254            .collect();
255
256        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
257        for model in all {
258            let provider = model.model.provider_id();
259            if let Some(models) = all_by_provider.get_mut(&provider) {
260                models.push(model);
261            } else {
262                all_by_provider.insert(provider, vec![model]);
263            }
264        }
265
266        Self {
267            favorites,
268            recommended,
269            all: all_by_provider,
270        }
271    }
272
273    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
274        let mut entries = Vec::new();
275
276        if !self.favorites.is_empty() {
277            entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
278            for info in &self.favorites {
279                entries.push(LanguageModelPickerEntry::Model(info.clone()));
280            }
281        }
282
283        if !self.recommended.is_empty() {
284            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
285            for info in &self.recommended {
286                entries.push(LanguageModelPickerEntry::Model(info.clone()));
287            }
288        }
289
290        for models in self.all.values() {
291            if models.is_empty() {
292                continue;
293            }
294            entries.push(LanguageModelPickerEntry::Separator(
295                models[0].model.provider_name().0,
296            ));
297            for info in models {
298                entries.push(LanguageModelPickerEntry::Model(info.clone()));
299            }
300        }
301
302        entries
303    }
304}
305
306enum LanguageModelPickerEntry {
307    Model(ModelInfo),
308    Separator(SharedString),
309}
310
311struct ModelMatcher {
312    models: Vec<ModelInfo>,
313    fg_executor: ForegroundExecutor,
314    bg_executor: BackgroundExecutor,
315    candidates: Vec<StringMatchCandidate>,
316}
317
318impl ModelMatcher {
319    fn new(
320        models: Vec<ModelInfo>,
321        fg_executor: ForegroundExecutor,
322        bg_executor: BackgroundExecutor,
323    ) -> ModelMatcher {
324        let candidates = Self::make_match_candidates(&models);
325        Self {
326            models,
327            fg_executor,
328            bg_executor,
329            candidates,
330        }
331    }
332
333    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
334        let mut matches = self.fg_executor.block_on(match_strings(
335            &self.candidates,
336            query,
337            false,
338            true,
339            100,
340            &Default::default(),
341            self.bg_executor.clone(),
342        ));
343
344        let sorting_key = |mat: &StringMatch| {
345            let candidate = &self.candidates[mat.candidate_id];
346            (Reverse(OrderedFloat(mat.score)), candidate.id)
347        };
348        matches.sort_unstable_by_key(sorting_key);
349
350        let matched_models: Vec<_> = matches
351            .into_iter()
352            .map(|mat| self.models[mat.candidate_id].clone())
353            .collect();
354
355        matched_models
356    }
357
358    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
359        self.models
360            .iter()
361            .filter(|m| {
362                m.model
363                    .name()
364                    .0
365                    .to_lowercase()
366                    .contains(&query.to_lowercase())
367            })
368            .cloned()
369            .collect::<Vec<_>>()
370    }
371
372    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
373        model_infos
374            .iter()
375            .enumerate()
376            .map(|(index, model)| {
377                StringMatchCandidate::new(
378                    index,
379                    &format!(
380                        "{}/{}",
381                        &model.model.provider_name().0,
382                        &model.model.name().0
383                    ),
384                )
385            })
386            .collect::<Vec<_>>()
387    }
388}
389
390impl PickerDelegate for LanguageModelPickerDelegate {
391    type ListItem = AnyElement;
392
393    fn match_count(&self) -> usize {
394        self.filtered_entries.len()
395    }
396
397    fn selected_index(&self) -> usize {
398        self.selected_index
399    }
400
401    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
402        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
403        cx.notify();
404    }
405
406    fn can_select(&self, ix: usize, _window: &mut Window, _cx: &mut Context<Picker<Self>>) -> bool {
407        match self.filtered_entries.get(ix) {
408            Some(LanguageModelPickerEntry::Model(_)) => true,
409            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
410        }
411    }
412
413    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
414        "Select a model…".into()
415    }
416
417    fn update_matches(
418        &mut self,
419        query: String,
420        window: &mut Window,
421        cx: &mut Context<Picker<Self>>,
422    ) -> Task<()> {
423        let all_models = self.all_models.clone();
424        let active_model = (self.get_active_model)(cx);
425        let fg_executor = cx.foreground_executor();
426        let bg_executor = cx.background_executor();
427
428        let language_model_registry = LanguageModelRegistry::global(cx);
429
430        let configured_providers = language_model_registry
431            .read(cx)
432            .visible_providers()
433            .into_iter()
434            .filter(|provider| provider.is_authenticated(cx))
435            .collect::<Vec<_>>();
436
437        let configured_provider_ids = configured_providers
438            .iter()
439            .map(|provider| provider.id())
440            .collect::<Vec<_>>();
441
442        let recommended_models = all_models
443            .recommended
444            .iter()
445            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
446            .cloned()
447            .collect::<Vec<_>>();
448
449        let available_models = all_models
450            .all
451            .values()
452            .flat_map(|models| models.iter())
453            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
454            .cloned()
455            .collect::<Vec<_>>();
456
457        let matcher_rec =
458            ModelMatcher::new(recommended_models, fg_executor.clone(), bg_executor.clone());
459        let matcher_all =
460            ModelMatcher::new(available_models, fg_executor.clone(), bg_executor.clone());
461
462        let recommended = matcher_rec.exact_search(&query);
463        let all = matcher_all.fuzzy_search(&query);
464
465        let filtered_models = GroupedModels::new(all, recommended);
466
467        cx.spawn_in(window, async move |this, cx| {
468            this.update_in(cx, |this, window, cx| {
469                this.delegate.filtered_entries = filtered_models.entries();
470                // Finds the currently selected model in the list
471                let new_index =
472                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
473                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
474                cx.notify();
475            })
476            .ok();
477        })
478    }
479
480    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
481        if let Some(LanguageModelPickerEntry::Model(model_info)) =
482            self.filtered_entries.get(self.selected_index)
483        {
484            let model = model_info.model.clone();
485            (self.on_model_changed)(model.clone(), cx);
486
487            let current_index = self.selected_index;
488            self.set_selected_index(current_index, window, cx);
489
490            cx.emit(DismissEvent);
491        }
492    }
493
494    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
495        cx.emit(DismissEvent);
496    }
497
498    fn render_match(
499        &self,
500        ix: usize,
501        selected: bool,
502        _: &mut Window,
503        cx: &mut Context<Picker<Self>>,
504    ) -> Option<Self::ListItem> {
505        match self.filtered_entries.get(ix)? {
506            LanguageModelPickerEntry::Separator(title) => {
507                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
508            }
509            LanguageModelPickerEntry::Model(model_info) => {
510                let active_model = (self.get_active_model)(cx);
511                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
512                let active_model_id = active_model.map(|m| m.model.id());
513
514                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
515                    && Some(model_info.model.id()) == active_model_id;
516
517                let model_cost = model_info
518                    .model
519                    .model_cost_info()
520                    .map(|cost| cost.to_shared_string());
521
522                let is_favorite = model_info.is_favorite;
523                let handle_action_click = {
524                    let model = model_info.model.clone();
525                    let on_toggle_favorite = self.on_toggle_favorite.clone();
526                    cx.listener(move |picker, _, window, cx| {
527                        on_toggle_favorite(model.clone(), !is_favorite, cx);
528                        picker.refresh(window, cx);
529                    })
530                };
531
532                Some(
533                    ModelSelectorListItem::new(ix, model_info.model.name().0)
534                        .map(|this| match &model_info.icon {
535                            IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
536                            IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
537                        })
538                        .is_selected(is_selected)
539                        .is_focused(selected)
540                        .is_latest(model_info.model.is_latest())
541                        .is_favorite(is_favorite)
542                        .cost_info(model_cost)
543                        .on_toggle_favorite(handle_action_click)
544                        .into_any_element(),
545                )
546            }
547        }
548    }
549
550    fn render_footer(
551        &self,
552        _window: &mut Window,
553        _cx: &mut Context<Picker<Self>>,
554    ) -> Option<gpui::AnyElement> {
555        let focus_handle = self.focus_handle.clone();
556
557        if !self.popover_styles {
558            return None;
559        }
560
561        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568    use futures::{future::BoxFuture, stream::BoxStream};
569    use gpui::{AsyncApp, TestAppContext, http_client};
570    use language_model::{
571        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
572        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
573        LanguageModelRequest, LanguageModelToolChoice,
574    };
575    use ui::IconName;
576
577    #[derive(Clone)]
578    struct TestLanguageModel {
579        name: LanguageModelName,
580        id: LanguageModelId,
581        provider_id: LanguageModelProviderId,
582        provider_name: LanguageModelProviderName,
583    }
584
585    impl TestLanguageModel {
586        fn new(name: &str, provider: &str) -> Self {
587            Self {
588                name: LanguageModelName::from(name.to_string()),
589                id: LanguageModelId::from(name.to_string()),
590                provider_id: LanguageModelProviderId::from(provider.to_string()),
591                provider_name: LanguageModelProviderName::from(provider.to_string()),
592            }
593        }
594    }
595
596    impl LanguageModel for TestLanguageModel {
597        fn id(&self) -> LanguageModelId {
598            self.id.clone()
599        }
600
601        fn name(&self) -> LanguageModelName {
602            self.name.clone()
603        }
604
605        fn provider_id(&self) -> LanguageModelProviderId {
606            self.provider_id.clone()
607        }
608
609        fn provider_name(&self) -> LanguageModelProviderName {
610            self.provider_name.clone()
611        }
612
613        fn supports_tools(&self) -> bool {
614            false
615        }
616
617        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
618            false
619        }
620
621        fn supports_images(&self) -> bool {
622            false
623        }
624
625        fn telemetry_id(&self) -> String {
626            format!("{}/{}", self.provider_id.0, self.name.0)
627        }
628
629        fn max_token_count(&self) -> u64 {
630            1000
631        }
632
633        fn count_tokens(
634            &self,
635            _: LanguageModelRequest,
636            _: &App,
637        ) -> BoxFuture<'static, http_client::Result<u64>> {
638            unimplemented!()
639        }
640
641        fn stream_completion(
642            &self,
643            _: LanguageModelRequest,
644            _: &AsyncApp,
645        ) -> BoxFuture<
646            'static,
647            Result<
648                BoxStream<
649                    'static,
650                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
651                >,
652                LanguageModelCompletionError,
653            >,
654        > {
655            unimplemented!()
656        }
657    }
658
659    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
660        create_models_with_favorites(model_specs, vec![])
661    }
662
663    fn create_models_with_favorites(
664        model_specs: Vec<(&str, &str)>,
665        favorites: Vec<(&str, &str)>,
666    ) -> Vec<ModelInfo> {
667        model_specs
668            .into_iter()
669            .map(|(provider, name)| {
670                let is_favorite = favorites
671                    .iter()
672                    .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
673                ModelInfo {
674                    model: Arc::new(TestLanguageModel::new(name, provider)),
675                    icon: IconOrSvg::Icon(IconName::ZedAgent),
676                    is_favorite,
677                }
678            })
679            .collect()
680    }
681
682    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
683        assert_eq!(
684            result.len(),
685            expected.len(),
686            "Number of models doesn't match"
687        );
688
689        for (i, expected_name) in expected.iter().enumerate() {
690            assert_eq!(
691                result[i].model.telemetry_id(),
692                *expected_name,
693                "Model at position {} doesn't match expected model",
694                i
695            );
696        }
697    }
698
699    #[gpui::test]
700    fn test_exact_match(cx: &mut TestAppContext) {
701        let models = create_models(vec![
702            ("zed", "Claude 3.7 Sonnet"),
703            ("zed", "Claude 3.7 Sonnet Thinking"),
704            ("zed", "gpt-5"),
705            ("zed", "gpt-5-mini"),
706            ("openai", "gpt-3.5-turbo"),
707            ("openai", "gpt-5"),
708            ("openai", "gpt-5-mini"),
709            ("ollama", "mistral"),
710            ("ollama", "deepseek"),
711        ]);
712        let matcher = ModelMatcher::new(
713            models,
714            cx.foreground_executor().clone(),
715            cx.background_executor.clone(),
716        );
717
718        // The order of models should be maintained, case doesn't matter
719        let results = matcher.exact_search("GPT-5");
720        assert_models_eq(
721            results,
722            vec![
723                "zed/gpt-5",
724                "zed/gpt-5-mini",
725                "openai/gpt-5",
726                "openai/gpt-5-mini",
727            ],
728        );
729    }
730
731    #[gpui::test]
732    fn test_fuzzy_match(cx: &mut TestAppContext) {
733        let models = create_models(vec![
734            ("zed", "Claude 3.7 Sonnet"),
735            ("zed", "Claude 3.7 Sonnet Thinking"),
736            ("zed", "gpt-5"),
737            ("zed", "gpt-5-mini"),
738            ("openai", "gpt-3.5-turbo"),
739            ("openai", "gpt-5"),
740            ("openai", "gpt-5-mini"),
741            ("ollama", "mistral"),
742            ("ollama", "deepseek"),
743        ]);
744        let matcher = ModelMatcher::new(
745            models,
746            cx.foreground_executor().clone(),
747            cx.background_executor.clone(),
748        );
749
750        // Results should preserve models order whenever possible.
751        // In the case below, `zed/gpt-5-mini` and `openai/gpt-5-mini` have identical
752        // similarity scores, but `zed/gpt-5-mini` was higher in the models list,
753        // so it should appear first in the results.
754        let results = matcher.fuzzy_search("mini");
755        assert_models_eq(results, vec!["zed/gpt-5-mini", "openai/gpt-5-mini"]);
756
757        // Model provider should be searchable as well
758        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
759        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
760
761        // Fuzzy search - search for Claude to get the Thinking variant
762        let results = matcher.fuzzy_search("thinking");
763        assert_models_eq(results, vec!["zed/Claude 3.7 Sonnet Thinking"]);
764    }
765
766    #[gpui::test]
767    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
768        let recommended_models = create_models(vec![("zed", "claude")]);
769        let all_models = create_models(vec![
770            ("zed", "claude"), // Should also appear in "other"
771            ("zed", "gemini"),
772            ("copilot", "o3"),
773        ]);
774
775        let grouped_models = GroupedModels::new(all_models, recommended_models);
776
777        let actual_all_models = grouped_models
778            .all
779            .values()
780            .flatten()
781            .cloned()
782            .collect::<Vec<_>>();
783
784        // Recommended models should also appear in "all"
785        assert_models_eq(
786            actual_all_models,
787            vec!["zed/claude", "zed/gemini", "copilot/o3"],
788        );
789    }
790
791    #[gpui::test]
792    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
793        let recommended_models = create_models(vec![("zed", "claude")]);
794        let all_models = create_models(vec![
795            ("zed", "claude"), // Should also appear in "other"
796            ("zed", "gemini"),
797            ("copilot", "claude"), // Different provider, should appear in "other"
798        ]);
799
800        let grouped_models = GroupedModels::new(all_models, recommended_models);
801
802        let actual_all_models = grouped_models
803            .all
804            .values()
805            .flatten()
806            .cloned()
807            .collect::<Vec<_>>();
808
809        // All models should appear in "all" regardless of recommended status
810        assert_models_eq(
811            actual_all_models,
812            vec!["zed/claude", "zed/gemini", "copilot/claude"],
813        );
814    }
815
816    #[gpui::test]
817    fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
818        let recommended_models = create_models(vec![("zed", "claude")]);
819        let all_models = create_models_with_favorites(
820            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
821            vec![("zed", "gemini")],
822        );
823
824        let grouped_models = GroupedModels::new(all_models, recommended_models);
825        let entries = grouped_models.entries();
826
827        assert!(matches!(
828            entries.first(),
829            Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
830        ));
831
832        assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
833    }
834
835    #[gpui::test]
836    fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
837        let recommended_models = create_models(vec![("zed", "claude")]);
838        let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
839
840        let grouped_models = GroupedModels::new(all_models, recommended_models);
841        let entries = grouped_models.entries();
842
843        assert!(matches!(
844            entries.first(),
845            Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
846        ));
847
848        assert!(grouped_models.favorites.is_empty());
849    }
850
851    #[gpui::test]
852    fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
853        let recommended_models =
854            create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
855        let all_models = create_models_with_favorites(
856            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
857            vec![("zed", "claude")],
858        );
859
860        let grouped_models = GroupedModels::new(all_models, recommended_models);
861        let entries = grouped_models.entries();
862
863        for entry in &entries {
864            if let LanguageModelPickerEntry::Model(info) = entry {
865                if info.model.telemetry_id() == "zed/claude" {
866                    assert!(info.is_favorite, "zed/claude should be a favorite");
867                } else {
868                    assert!(
869                        !info.is_favorite,
870                        "{} should not be a favorite",
871                        info.model.telemetry_id()
872                    );
873                }
874            }
875        }
876    }
877
878    #[gpui::test]
879    fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
880        let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
881
882        let recommended_models =
883            create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
884
885        let all_models = create_models_with_favorites(
886            vec![
887                ("zed", "claude"),
888                ("zed", "gemini"),
889                ("openai", "gpt-4"),
890                ("openai", "gpt-3.5"),
891            ],
892            favorites,
893        );
894
895        let grouped_models = GroupedModels::new(all_models, recommended_models);
896
897        assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
898        assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
899        assert_models_eq(
900            grouped_models.all.values().flatten().cloned().collect(),
901            vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
902        );
903    }
904}