language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use agent_settings::AgentSettings;
  4use collections::{HashMap, HashSet, IndexMap};
  5use futures::{StreamExt, channel::mpsc};
  6use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  7use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
  8use language_model::{
  9    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProvider,
 10    LanguageModelProviderId, LanguageModelRegistry,
 11};
 12use ordered_float::OrderedFloat;
 13use picker::{Picker, PickerDelegate};
 14use settings::Settings;
 15use ui::prelude::*;
 16use zed_actions::agent::OpenSettings;
 17
 18use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 19
 20type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 21type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 22type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &App) + 'static>;
 23
 24pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 25
 26pub fn language_model_selector(
 27    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 28    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 29    on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
 30    popover_styles: bool,
 31    focus_handle: FocusHandle,
 32    window: &mut Window,
 33    cx: &mut Context<LanguageModelSelector>,
 34) -> LanguageModelSelector {
 35    let delegate = LanguageModelPickerDelegate::new(
 36        get_active_model,
 37        on_model_changed,
 38        on_toggle_favorite,
 39        popover_styles,
 40        focus_handle,
 41        window,
 42        cx,
 43    );
 44
 45    if popover_styles {
 46        Picker::list(delegate, window, cx)
 47            .show_scrollbar(true)
 48            .width(rems(20.))
 49            .max_height(Some(rems(20.).into()))
 50    } else {
 51        Picker::list(delegate, window, cx).show_scrollbar(true)
 52    }
 53}
 54
 55fn all_models(cx: &App) -> GroupedModels {
 56    let lm_registry = LanguageModelRegistry::global(cx).read(cx);
 57    let providers = lm_registry.visible_providers();
 58
 59    let mut favorites_index = FavoritesIndex::default();
 60
 61    for sel in &AgentSettings::get_global(cx).favorite_models {
 62        favorites_index
 63            .entry(sel.provider.0.clone().into())
 64            .or_default()
 65            .insert(sel.model.clone().into());
 66    }
 67
 68    let recommended = providers
 69        .iter()
 70        .flat_map(|provider| {
 71            provider
 72                .recommended_models(cx)
 73                .into_iter()
 74                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 75        })
 76        .collect();
 77
 78    let all: Vec<ModelInfo> = providers
 79        .iter()
 80        .flat_map(|provider| {
 81            provider
 82                .provided_models(cx)
 83                .into_iter()
 84                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 85        })
 86        .collect();
 87
 88    GroupedModels::new(all, recommended)
 89}
 90
 91type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
 92
 93#[derive(Clone)]
 94enum ProviderIcon {
 95    Name(IconName),
 96    Path(SharedString),
 97}
 98
 99impl ProviderIcon {
100    fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
101        if let Some(path) = provider.icon_path() {
102            Self::Path(path)
103        } else {
104            Self::Name(provider.icon())
105        }
106    }
107}
108
109#[derive(Clone)]
110struct ModelInfo {
111    model: Arc<dyn LanguageModel>,
112    icon: ProviderIcon,
113    is_favorite: bool,
114}
115
116impl ModelInfo {
117    fn new(
118        provider: &dyn LanguageModelProvider,
119        model: Arc<dyn LanguageModel>,
120        favorites_index: &FavoritesIndex,
121    ) -> Self {
122        let is_favorite = favorites_index
123            .get(&provider.id())
124            .map_or(false, |set| set.contains(&model.id()));
125
126        Self {
127            model,
128            icon: ProviderIcon::from_provider(provider),
129            is_favorite,
130        }
131    }
132}
133
134pub struct LanguageModelPickerDelegate {
135    on_model_changed: OnModelChanged,
136    get_active_model: GetActiveModel,
137    on_toggle_favorite: OnToggleFavorite,
138    all_models: Arc<GroupedModels>,
139    filtered_entries: Vec<LanguageModelPickerEntry>,
140    selected_index: usize,
141    _authenticate_all_providers_task: Task<()>,
142    _refresh_models_task: Task<()>,
143    popover_styles: bool,
144    focus_handle: FocusHandle,
145}
146
147impl LanguageModelPickerDelegate {
148    fn new(
149        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
150        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
151        on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
152        popover_styles: bool,
153        focus_handle: FocusHandle,
154        window: &mut Window,
155        cx: &mut Context<Picker<Self>>,
156    ) -> Self {
157        let on_model_changed = Arc::new(on_model_changed);
158        let models = all_models(cx);
159        let entries = models.entries();
160
161        Self {
162            on_model_changed,
163            all_models: Arc::new(models),
164            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
165            filtered_entries: entries,
166            get_active_model: Arc::new(get_active_model),
167            on_toggle_favorite: Arc::new(on_toggle_favorite),
168            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
169            _refresh_models_task: {
170                // Create a channel to signal when models need refreshing
171                let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
172
173                // Subscribe to registry events and send refresh signals through the channel
174                let registry = LanguageModelRegistry::global(cx);
175                cx.subscribe(&registry, move |_picker, _, event, _cx| match event {
176                    language_model::Event::ProviderStateChanged(_)
177                    | language_model::Event::AddedProvider(_)
178                    | language_model::Event::RemovedProvider(_)
179                    | language_model::Event::ProvidersChanged => {
180                        refresh_tx.unbounded_send(()).ok();
181                    }
182                    language_model::Event::DefaultModelChanged
183                    | language_model::Event::InlineAssistantModelChanged
184                    | language_model::Event::CommitMessageModelChanged
185                    | language_model::Event::ThreadSummaryModelChanged => {}
186                })
187                .detach();
188
189                // Spawn a task that listens for refresh signals and updates the picker
190                cx.spawn_in(window, async move |this, cx| {
191                    while let Some(()) = refresh_rx.next().await {
192                        if this
193                            .update_in(cx, |picker, window, cx| {
194                                picker.delegate.all_models = Arc::new(all_models(cx));
195                                picker.refresh(window, cx);
196                            })
197                            .is_err()
198                        {
199                            // Picker was dropped, exit the loop
200                            break;
201                        }
202                    }
203                })
204            },
205            popover_styles,
206            focus_handle,
207        }
208    }
209
210    fn get_active_model_index(
211        entries: &[LanguageModelPickerEntry],
212        active_model: Option<ConfiguredModel>,
213    ) -> usize {
214        entries
215            .iter()
216            .position(|entry| {
217                if let LanguageModelPickerEntry::Model(model) = entry {
218                    active_model
219                        .as_ref()
220                        .map(|active_model| {
221                            active_model.model.id() == model.model.id()
222                                && active_model.provider.id() == model.model.provider_id()
223                        })
224                        .unwrap_or_default()
225                } else {
226                    false
227                }
228            })
229            .unwrap_or(0)
230    }
231
232    /// Authenticates all providers in the [`LanguageModelRegistry`].
233    ///
234    /// We do this so that we can populate the language selector with all of the
235    /// models from the configured providers.
236    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
237        let authenticate_all_providers = LanguageModelRegistry::global(cx)
238            .read(cx)
239            .providers()
240            .iter()
241            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
242            .collect::<Vec<_>>();
243
244        cx.spawn(async move |_cx| {
245            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
246                if let Err(err) = authenticate_task.await {
247                    if matches!(err, AuthenticateError::CredentialsNotFound) {
248                        // Since we're authenticating these providers in the
249                        // background for the purposes of populating the
250                        // language selector, we don't care about providers
251                        // where the credentials are not found.
252                    } else {
253                        // Some providers have noisy failure states that we
254                        // don't want to spam the logs with every time the
255                        // language model selector is initialized.
256                        //
257                        // Ideally these should have more clear failure modes
258                        // that we know are safe to ignore here, like what we do
259                        // with `CredentialsNotFound` above.
260                        match provider_id.0.as_ref() {
261                            "lmstudio" | "ollama" => {
262                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
263                                //
264                                // These fail noisily, so we don't log them.
265                            }
266                            "copilot_chat" => {
267                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
268                            }
269                            _ => {
270                                log::error!(
271                                    "Failed to authenticate provider: {}: {err:#}",
272                                    provider_name.0
273                                );
274                            }
275                        }
276                    }
277                }
278            }
279        })
280    }
281
282    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
283        (self.get_active_model)(cx)
284    }
285
286    pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
287        if self.all_models.favorites.is_empty() {
288            return;
289        }
290
291        let active_model = (self.get_active_model)(cx);
292        let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
293        let active_model_id = active_model.as_ref().map(|m| m.model.id());
294
295        let current_index = self
296            .all_models
297            .favorites
298            .iter()
299            .position(|info| {
300                Some(info.model.provider_id()) == active_provider_id
301                    && Some(info.model.id()) == active_model_id
302            })
303            .unwrap_or(usize::MAX);
304
305        let next_index = if current_index == usize::MAX {
306            0
307        } else {
308            (current_index + 1) % self.all_models.favorites.len()
309        };
310
311        let next_model = self.all_models.favorites[next_index].model.clone();
312
313        (self.on_model_changed)(next_model, cx);
314
315        // Align the picker selection with the newly-active model
316        let new_index =
317            Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
318        self.set_selected_index(new_index, window, cx);
319    }
320}
321
322struct GroupedModels {
323    favorites: Vec<ModelInfo>,
324    recommended: Vec<ModelInfo>,
325    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
326}
327
328impl GroupedModels {
329    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
330        let favorites = all
331            .iter()
332            .filter(|info| info.is_favorite)
333            .cloned()
334            .collect();
335
336        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
337        for model in all {
338            let provider = model.model.provider_id();
339            if let Some(models) = all_by_provider.get_mut(&provider) {
340                models.push(model);
341            } else {
342                all_by_provider.insert(provider, vec![model]);
343            }
344        }
345
346        Self {
347            favorites,
348            recommended,
349            all: all_by_provider,
350        }
351    }
352
353    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
354        let mut entries = Vec::new();
355
356        if !self.favorites.is_empty() {
357            entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
358            for info in &self.favorites {
359                entries.push(LanguageModelPickerEntry::Model(info.clone()));
360            }
361        }
362
363        if !self.recommended.is_empty() {
364            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
365            for info in &self.recommended {
366                entries.push(LanguageModelPickerEntry::Model(info.clone()));
367            }
368        }
369
370        for models in self.all.values() {
371            if models.is_empty() {
372                continue;
373            }
374            entries.push(LanguageModelPickerEntry::Separator(
375                models[0].model.provider_name().0,
376            ));
377            for info in models {
378                entries.push(LanguageModelPickerEntry::Model(info.clone()));
379            }
380        }
381
382        entries
383    }
384}
385
386enum LanguageModelPickerEntry {
387    Model(ModelInfo),
388    Separator(SharedString),
389}
390
391struct ModelMatcher {
392    models: Vec<ModelInfo>,
393    bg_executor: BackgroundExecutor,
394    candidates: Vec<StringMatchCandidate>,
395}
396
397impl ModelMatcher {
398    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
399        let candidates = Self::make_match_candidates(&models);
400        Self {
401            models,
402            bg_executor,
403            candidates,
404        }
405    }
406
407    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
408        let mut matches = self.bg_executor.block(match_strings(
409            &self.candidates,
410            query,
411            false,
412            true,
413            100,
414            &Default::default(),
415            self.bg_executor.clone(),
416        ));
417
418        let sorting_key = |mat: &StringMatch| {
419            let candidate = &self.candidates[mat.candidate_id];
420            (Reverse(OrderedFloat(mat.score)), candidate.id)
421        };
422        matches.sort_unstable_by_key(sorting_key);
423
424        let matched_models: Vec<_> = matches
425            .into_iter()
426            .map(|mat| self.models[mat.candidate_id].clone())
427            .collect();
428
429        matched_models
430    }
431
432    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
433        self.models
434            .iter()
435            .filter(|m| {
436                m.model
437                    .name()
438                    .0
439                    .to_lowercase()
440                    .contains(&query.to_lowercase())
441            })
442            .cloned()
443            .collect::<Vec<_>>()
444    }
445
446    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
447        model_infos
448            .iter()
449            .enumerate()
450            .map(|(index, model)| {
451                StringMatchCandidate::new(
452                    index,
453                    &format!(
454                        "{}/{}",
455                        &model.model.provider_name().0,
456                        &model.model.name().0
457                    ),
458                )
459            })
460            .collect::<Vec<_>>()
461    }
462}
463
464impl PickerDelegate for LanguageModelPickerDelegate {
465    type ListItem = AnyElement;
466
467    fn match_count(&self) -> usize {
468        self.filtered_entries.len()
469    }
470
471    fn selected_index(&self) -> usize {
472        self.selected_index
473    }
474
475    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
476        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
477        cx.notify();
478    }
479
480    fn can_select(
481        &mut self,
482        ix: usize,
483        _window: &mut Window,
484        _cx: &mut Context<Picker<Self>>,
485    ) -> bool {
486        match self.filtered_entries.get(ix) {
487            Some(LanguageModelPickerEntry::Model(_)) => true,
488            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
489        }
490    }
491
492    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
493        "Select a model…".into()
494    }
495
496    fn update_matches(
497        &mut self,
498        query: String,
499        window: &mut Window,
500        cx: &mut Context<Picker<Self>>,
501    ) -> Task<()> {
502        let all_models = self.all_models.clone();
503        let active_model = (self.get_active_model)(cx);
504        let bg_executor = cx.background_executor();
505
506        let language_model_registry = LanguageModelRegistry::global(cx);
507
508        let configured_providers = language_model_registry
509            .read(cx)
510            .visible_providers()
511            .into_iter()
512            .filter(|provider| provider.is_authenticated(cx))
513            .collect::<Vec<_>>();
514
515        let configured_provider_ids = configured_providers
516            .iter()
517            .map(|provider| provider.id())
518            .collect::<Vec<_>>();
519
520        let recommended_models = all_models
521            .recommended
522            .iter()
523            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
524            .cloned()
525            .collect::<Vec<_>>();
526
527        let available_models = all_models
528            .all
529            .values()
530            .flat_map(|models| models.iter())
531            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
532            .cloned()
533            .collect::<Vec<_>>();
534
535        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
536        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
537
538        let recommended = matcher_rec.exact_search(&query);
539        let all = matcher_all.fuzzy_search(&query);
540
541        let filtered_models = GroupedModels::new(all, recommended);
542
543        cx.spawn_in(window, async move |this, cx| {
544            this.update_in(cx, |this, window, cx| {
545                this.delegate.filtered_entries = filtered_models.entries();
546                // Finds the currently selected model in the list
547                let new_index =
548                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
549                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
550                cx.notify();
551            })
552            .ok();
553        })
554    }
555
556    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
557        if let Some(LanguageModelPickerEntry::Model(model_info)) =
558            self.filtered_entries.get(self.selected_index)
559        {
560            let model = model_info.model.clone();
561            (self.on_model_changed)(model.clone(), cx);
562
563            let current_index = self.selected_index;
564            self.set_selected_index(current_index, window, cx);
565
566            cx.emit(DismissEvent);
567        }
568    }
569
570    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
571        cx.emit(DismissEvent);
572    }
573
574    fn render_match(
575        &self,
576        ix: usize,
577        selected: bool,
578        _: &mut Window,
579        cx: &mut Context<Picker<Self>>,
580    ) -> Option<Self::ListItem> {
581        match self.filtered_entries.get(ix)? {
582            LanguageModelPickerEntry::Separator(title) => {
583                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
584            }
585            LanguageModelPickerEntry::Model(model_info) => {
586                let active_model = (self.get_active_model)(cx);
587                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
588                let active_model_id = active_model.map(|m| m.model.id());
589
590                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
591                    && Some(model_info.model.id()) == active_model_id;
592
593                let is_favorite = model_info.is_favorite;
594                let handle_action_click = {
595                    let model = model_info.model.clone();
596                    let on_toggle_favorite = self.on_toggle_favorite.clone();
597                    move |cx: &App| on_toggle_favorite(model.clone(), !is_favorite, cx)
598                };
599
600                Some(
601                    ModelSelectorListItem::new(ix, model_info.model.name().0)
602                        .map(|this| match &model_info.icon {
603                            ProviderIcon::Name(icon_name) => this.icon(*icon_name),
604                            ProviderIcon::Path(icon_path) => this.icon_path(icon_path.clone()),
605                        })
606                        .is_selected(is_selected)
607                        .is_focused(selected)
608                        .is_favorite(is_favorite)
609                        .on_toggle_favorite(handle_action_click)
610                        .into_any_element(),
611                )
612            }
613        }
614    }
615
616    fn render_footer(
617        &self,
618        _window: &mut Window,
619        _cx: &mut Context<Picker<Self>>,
620    ) -> Option<gpui::AnyElement> {
621        let focus_handle = self.focus_handle.clone();
622
623        if !self.popover_styles {
624            return None;
625        }
626
627        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
628    }
629}
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634    use futures::{future::BoxFuture, stream::BoxStream};
635    use gpui::{AsyncApp, TestAppContext, http_client};
636    use language_model::{
637        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
638        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
639        LanguageModelRequest, LanguageModelToolChoice,
640    };
641    use ui::IconName;
642
643    #[derive(Clone)]
644    struct TestLanguageModel {
645        name: LanguageModelName,
646        id: LanguageModelId,
647        provider_id: LanguageModelProviderId,
648        provider_name: LanguageModelProviderName,
649    }
650
651    impl TestLanguageModel {
652        fn new(name: &str, provider: &str) -> Self {
653            Self {
654                name: LanguageModelName::from(name.to_string()),
655                id: LanguageModelId::from(name.to_string()),
656                provider_id: LanguageModelProviderId::from(provider.to_string()),
657                provider_name: LanguageModelProviderName::from(provider.to_string()),
658            }
659        }
660    }
661
662    impl LanguageModel for TestLanguageModel {
663        fn id(&self) -> LanguageModelId {
664            self.id.clone()
665        }
666
667        fn name(&self) -> LanguageModelName {
668            self.name.clone()
669        }
670
671        fn provider_id(&self) -> LanguageModelProviderId {
672            self.provider_id.clone()
673        }
674
675        fn provider_name(&self) -> LanguageModelProviderName {
676            self.provider_name.clone()
677        }
678
679        fn supports_tools(&self) -> bool {
680            false
681        }
682
683        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
684            false
685        }
686
687        fn supports_images(&self) -> bool {
688            false
689        }
690
691        fn telemetry_id(&self) -> String {
692            format!("{}/{}", self.provider_id.0, self.name.0)
693        }
694
695        fn max_token_count(&self) -> u64 {
696            1000
697        }
698
699        fn count_tokens(
700            &self,
701            _: LanguageModelRequest,
702            _: &App,
703        ) -> BoxFuture<'static, http_client::Result<u64>> {
704            unimplemented!()
705        }
706
707        fn stream_completion(
708            &self,
709            _: LanguageModelRequest,
710            _: &AsyncApp,
711        ) -> BoxFuture<
712            'static,
713            Result<
714                BoxStream<
715                    'static,
716                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
717                >,
718                LanguageModelCompletionError,
719            >,
720        > {
721            unimplemented!()
722        }
723    }
724
725    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
726        create_models_with_favorites(model_specs, vec![])
727    }
728
729    fn create_models_with_favorites(
730        model_specs: Vec<(&str, &str)>,
731        favorites: Vec<(&str, &str)>,
732    ) -> Vec<ModelInfo> {
733        model_specs
734            .into_iter()
735            .map(|(provider, name)| {
736                let is_favorite = favorites
737                    .iter()
738                    .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
739                ModelInfo {
740                    model: Arc::new(TestLanguageModel::new(name, provider)),
741                    icon: ProviderIcon::Name(IconName::Ai),
742                    is_favorite,
743                }
744            })
745            .collect()
746    }
747
748    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
749        assert_eq!(
750            result.len(),
751            expected.len(),
752            "Number of models doesn't match"
753        );
754
755        for (i, expected_name) in expected.iter().enumerate() {
756            assert_eq!(
757                result[i].model.telemetry_id(),
758                *expected_name,
759                "Model at position {} doesn't match expected model",
760                i
761            );
762        }
763    }
764
765    #[gpui::test]
766    fn test_exact_match(cx: &mut TestAppContext) {
767        let models = create_models(vec![
768            ("zed", "Claude 3.7 Sonnet"),
769            ("zed", "Claude 3.7 Sonnet Thinking"),
770            ("zed", "gpt-4.1"),
771            ("zed", "gpt-4.1-nano"),
772            ("openai", "gpt-3.5-turbo"),
773            ("openai", "gpt-4.1"),
774            ("openai", "gpt-4.1-nano"),
775            ("ollama", "mistral"),
776            ("ollama", "deepseek"),
777        ]);
778        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
779
780        // The order of models should be maintained, case doesn't matter
781        let results = matcher.exact_search("GPT-4.1");
782        assert_models_eq(
783            results,
784            vec![
785                "zed/gpt-4.1",
786                "zed/gpt-4.1-nano",
787                "openai/gpt-4.1",
788                "openai/gpt-4.1-nano",
789            ],
790        );
791    }
792
793    #[gpui::test]
794    fn test_fuzzy_match(cx: &mut TestAppContext) {
795        let models = create_models(vec![
796            ("zed", "Claude 3.7 Sonnet"),
797            ("zed", "Claude 3.7 Sonnet Thinking"),
798            ("zed", "gpt-4.1"),
799            ("zed", "gpt-4.1-nano"),
800            ("openai", "gpt-3.5-turbo"),
801            ("openai", "gpt-4.1"),
802            ("openai", "gpt-4.1-nano"),
803            ("ollama", "mistral"),
804            ("ollama", "deepseek"),
805        ]);
806        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
807
808        // Results should preserve models order whenever possible.
809        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
810        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
811        // so it should appear first in the results.
812        let results = matcher.fuzzy_search("41");
813        assert_models_eq(
814            results,
815            vec![
816                "zed/gpt-4.1",
817                "openai/gpt-4.1",
818                "zed/gpt-4.1-nano",
819                "openai/gpt-4.1-nano",
820            ],
821        );
822
823        // Model provider should be searchable as well
824        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
825        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
826
827        // Fuzzy search
828        let results = matcher.fuzzy_search("z4n");
829        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
830    }
831
832    #[gpui::test]
833    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
834        let recommended_models = create_models(vec![("zed", "claude")]);
835        let all_models = create_models(vec![
836            ("zed", "claude"), // Should also appear in "other"
837            ("zed", "gemini"),
838            ("copilot", "o3"),
839        ]);
840
841        let grouped_models = GroupedModels::new(all_models, recommended_models);
842
843        let actual_all_models = grouped_models
844            .all
845            .values()
846            .flatten()
847            .cloned()
848            .collect::<Vec<_>>();
849
850        // Recommended models should also appear in "all"
851        assert_models_eq(
852            actual_all_models,
853            vec!["zed/claude", "zed/gemini", "copilot/o3"],
854        );
855    }
856
857    #[gpui::test]
858    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
859        let recommended_models = create_models(vec![("zed", "claude")]);
860        let all_models = create_models(vec![
861            ("zed", "claude"), // Should also appear in "other"
862            ("zed", "gemini"),
863            ("copilot", "claude"), // Different provider, should appear in "other"
864        ]);
865
866        let grouped_models = GroupedModels::new(all_models, recommended_models);
867
868        let actual_all_models = grouped_models
869            .all
870            .values()
871            .flatten()
872            .cloned()
873            .collect::<Vec<_>>();
874
875        // All models should appear in "all" regardless of recommended status
876        assert_models_eq(
877            actual_all_models,
878            vec!["zed/claude", "zed/gemini", "copilot/claude"],
879        );
880    }
881
882    #[gpui::test]
883    fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
884        let recommended_models = create_models(vec![("zed", "claude")]);
885        let all_models = create_models_with_favorites(
886            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
887            vec![("zed", "gemini")],
888        );
889
890        let grouped_models = GroupedModels::new(all_models, recommended_models);
891        let entries = grouped_models.entries();
892
893        assert!(matches!(
894            entries.first(),
895            Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
896        ));
897
898        assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
899    }
900
901    #[gpui::test]
902    fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
903        let recommended_models = create_models(vec![("zed", "claude")]);
904        let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
905
906        let grouped_models = GroupedModels::new(all_models, recommended_models);
907        let entries = grouped_models.entries();
908
909        assert!(matches!(
910            entries.first(),
911            Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
912        ));
913
914        assert!(grouped_models.favorites.is_empty());
915    }
916
917    #[gpui::test]
918    fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
919        let recommended_models =
920            create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
921        let all_models = create_models_with_favorites(
922            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
923            vec![("zed", "claude")],
924        );
925
926        let grouped_models = GroupedModels::new(all_models, recommended_models);
927        let entries = grouped_models.entries();
928
929        for entry in &entries {
930            if let LanguageModelPickerEntry::Model(info) = entry {
931                if info.model.telemetry_id() == "zed/claude" {
932                    assert!(info.is_favorite, "zed/claude should be a favorite");
933                } else {
934                    assert!(
935                        !info.is_favorite,
936                        "{} should not be a favorite",
937                        info.model.telemetry_id()
938                    );
939                }
940            }
941        }
942    }
943
944    #[gpui::test]
945    fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
946        let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
947
948        let recommended_models =
949            create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
950
951        let all_models = create_models_with_favorites(
952            vec![
953                ("zed", "claude"),
954                ("zed", "gemini"),
955                ("openai", "gpt-4"),
956                ("openai", "gpt-3.5"),
957            ],
958            favorites,
959        );
960
961        let grouped_models = GroupedModels::new(all_models, recommended_models);
962
963        assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
964        assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
965        assert_models_eq(
966            grouped_models.all.values().flatten().cloned().collect(),
967            vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
968        );
969    }
970}