language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use agent_settings::AgentSettings;
  4use collections::{HashMap, HashSet, IndexMap};
  5use futures::{StreamExt, channel::mpsc};
  6use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  7use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
  8use language_model::{
  9    AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId,
 10    LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
 11};
 12use ordered_float::OrderedFloat;
 13use picker::{Picker, PickerDelegate};
 14use settings::Settings;
 15use ui::prelude::*;
 16use zed_actions::agent::OpenSettings;
 17
 18use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 19
 20type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 21type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 22type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &App) + 'static>;
 23
 24pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 25
 26pub fn language_model_selector(
 27    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 28    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 29    on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
 30    popover_styles: bool,
 31    focus_handle: FocusHandle,
 32    window: &mut Window,
 33    cx: &mut Context<LanguageModelSelector>,
 34) -> LanguageModelSelector {
 35    let delegate = LanguageModelPickerDelegate::new(
 36        get_active_model,
 37        on_model_changed,
 38        on_toggle_favorite,
 39        popover_styles,
 40        focus_handle,
 41        window,
 42        cx,
 43    );
 44
 45    if popover_styles {
 46        Picker::list(delegate, window, cx)
 47            .show_scrollbar(true)
 48            .width(rems(20.))
 49            .max_height(Some(rems(20.).into()))
 50    } else {
 51        Picker::list(delegate, window, cx).show_scrollbar(true)
 52    }
 53}
 54
 55fn all_models(cx: &App) -> GroupedModels {
 56    let lm_registry = LanguageModelRegistry::global(cx).read(cx);
 57    let providers = lm_registry.visible_providers();
 58
 59    let mut favorites_index = FavoritesIndex::default();
 60
 61    for sel in &AgentSettings::get_global(cx).favorite_models {
 62        favorites_index
 63            .entry(sel.provider.0.clone().into())
 64            .or_default()
 65            .insert(sel.model.clone().into());
 66    }
 67
 68    let recommended = providers
 69        .iter()
 70        .flat_map(|provider| {
 71            provider
 72                .recommended_models(cx)
 73                .into_iter()
 74                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 75        })
 76        .collect();
 77
 78    let all: Vec<ModelInfo> = providers
 79        .iter()
 80        .flat_map(|provider| {
 81            provider
 82                .provided_models(cx)
 83                .into_iter()
 84                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 85        })
 86        .collect();
 87
 88    GroupedModels::new(all, recommended)
 89}
 90
 91type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
 92
 93#[derive(Clone)]
 94struct ModelInfo {
 95    model: Arc<dyn LanguageModel>,
 96    icon: IconOrSvg,
 97    is_favorite: bool,
 98}
 99
100impl ModelInfo {
101    fn new(
102        provider: &dyn LanguageModelProvider,
103        model: Arc<dyn LanguageModel>,
104        favorites_index: &FavoritesIndex,
105    ) -> Self {
106        let is_favorite = favorites_index
107            .get(&provider.id())
108            .map_or(false, |set| set.contains(&model.id()));
109
110        Self {
111            model,
112            icon: provider.icon(),
113            is_favorite,
114        }
115    }
116}
117
118pub struct LanguageModelPickerDelegate {
119    on_model_changed: OnModelChanged,
120    get_active_model: GetActiveModel,
121    on_toggle_favorite: OnToggleFavorite,
122    all_models: Arc<GroupedModels>,
123    filtered_entries: Vec<LanguageModelPickerEntry>,
124    selected_index: usize,
125    _authenticate_all_providers_task: Task<()>,
126    _refresh_models_task: Task<()>,
127    popover_styles: bool,
128    focus_handle: FocusHandle,
129}
130
131impl LanguageModelPickerDelegate {
132    fn new(
133        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
134        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
135        on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
136        popover_styles: bool,
137        focus_handle: FocusHandle,
138        window: &mut Window,
139        cx: &mut Context<Picker<Self>>,
140    ) -> Self {
141        let on_model_changed = Arc::new(on_model_changed);
142        let models = all_models(cx);
143        let entries = models.entries();
144
145        Self {
146            on_model_changed,
147            all_models: Arc::new(models),
148            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
149            filtered_entries: entries,
150            get_active_model: Arc::new(get_active_model),
151            on_toggle_favorite: Arc::new(on_toggle_favorite),
152            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
153            _refresh_models_task: {
154                // Create a channel to signal when models need refreshing
155                let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
156
157                // Subscribe to registry events and send refresh signals through the channel
158                let registry = LanguageModelRegistry::global(cx);
159                cx.subscribe(&registry, move |_picker, _, event, _cx| match event {
160                    language_model::Event::ProviderStateChanged(_)
161                    | language_model::Event::AddedProvider(_)
162                    | language_model::Event::RemovedProvider(_)
163                    | language_model::Event::ProvidersChanged => {
164                        refresh_tx.unbounded_send(()).ok();
165                    }
166                    language_model::Event::DefaultModelChanged
167                    | language_model::Event::InlineAssistantModelChanged
168                    | language_model::Event::CommitMessageModelChanged
169                    | language_model::Event::ThreadSummaryModelChanged => {}
170                })
171                .detach();
172
173                // Spawn a task that listens for refresh signals and updates the picker
174                cx.spawn_in(window, async move |this, cx| {
175                    while let Some(()) = refresh_rx.next().await {
176                        if this
177                            .update_in(cx, |picker, window, cx| {
178                                picker.delegate.all_models = Arc::new(all_models(cx));
179                                picker.refresh(window, cx);
180                            })
181                            .is_err()
182                        {
183                            // Picker was dropped, exit the loop
184                            break;
185                        }
186                    }
187                })
188            },
189            popover_styles,
190            focus_handle,
191        }
192    }
193
194    fn get_active_model_index(
195        entries: &[LanguageModelPickerEntry],
196        active_model: Option<ConfiguredModel>,
197    ) -> usize {
198        entries
199            .iter()
200            .position(|entry| {
201                if let LanguageModelPickerEntry::Model(model) = entry {
202                    active_model
203                        .as_ref()
204                        .map(|active_model| {
205                            active_model.model.id() == model.model.id()
206                                && active_model.provider.id() == model.model.provider_id()
207                        })
208                        .unwrap_or_default()
209                } else {
210                    false
211                }
212            })
213            .unwrap_or(0)
214    }
215
216    /// Authenticates all providers in the [`LanguageModelRegistry`].
217    ///
218    /// We do this so that we can populate the language selector with all of the
219    /// models from the configured providers.
220    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
221        let authenticate_all_providers = LanguageModelRegistry::global(cx)
222            .read(cx)
223            .visible_providers()
224            .iter()
225            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
226            .collect::<Vec<_>>();
227
228        cx.spawn(async move |_cx| {
229            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
230                if let Err(err) = authenticate_task.await {
231                    if matches!(err, AuthenticateError::CredentialsNotFound) {
232                        // Since we're authenticating these providers in the
233                        // background for the purposes of populating the
234                        // language selector, we don't care about providers
235                        // where the credentials are not found.
236                    } else {
237                        // Some providers have noisy failure states that we
238                        // don't want to spam the logs with every time the
239                        // language model selector is initialized.
240                        //
241                        // Ideally these should have more clear failure modes
242                        // that we know are safe to ignore here, like what we do
243                        // with `CredentialsNotFound` above.
244                        match provider_id.0.as_ref() {
245                            "lmstudio" | "ollama" => {
246                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
247                                //
248                                // These fail noisily, so we don't log them.
249                            }
250                            "copilot_chat" => {
251                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
252                            }
253                            _ => {
254                                log::error!(
255                                    "Failed to authenticate provider: {}: {err:#}",
256                                    provider_name.0
257                                );
258                            }
259                        }
260                    }
261                }
262            }
263        })
264    }
265
266    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
267        (self.get_active_model)(cx)
268    }
269
270    pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
271        if self.all_models.favorites.is_empty() {
272            return;
273        }
274
275        let active_model = (self.get_active_model)(cx);
276        let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
277        let active_model_id = active_model.as_ref().map(|m| m.model.id());
278
279        let current_index = self
280            .all_models
281            .favorites
282            .iter()
283            .position(|info| {
284                Some(info.model.provider_id()) == active_provider_id
285                    && Some(info.model.id()) == active_model_id
286            })
287            .unwrap_or(usize::MAX);
288
289        let next_index = if current_index == usize::MAX {
290            0
291        } else {
292            (current_index + 1) % self.all_models.favorites.len()
293        };
294
295        let next_model = self.all_models.favorites[next_index].model.clone();
296
297        (self.on_model_changed)(next_model, cx);
298
299        // Align the picker selection with the newly-active model
300        let new_index =
301            Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
302        self.set_selected_index(new_index, window, cx);
303    }
304}
305
306struct GroupedModels {
307    favorites: Vec<ModelInfo>,
308    recommended: Vec<ModelInfo>,
309    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
310}
311
312impl GroupedModels {
313    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
314        let favorites = all
315            .iter()
316            .filter(|info| info.is_favorite)
317            .cloned()
318            .collect();
319
320        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
321        for model in all {
322            let provider = model.model.provider_id();
323            if let Some(models) = all_by_provider.get_mut(&provider) {
324                models.push(model);
325            } else {
326                all_by_provider.insert(provider, vec![model]);
327            }
328        }
329
330        Self {
331            favorites,
332            recommended,
333            all: all_by_provider,
334        }
335    }
336
337    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
338        let mut entries = Vec::new();
339
340        if !self.favorites.is_empty() {
341            entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
342            for info in &self.favorites {
343                entries.push(LanguageModelPickerEntry::Model(info.clone()));
344            }
345        }
346
347        if !self.recommended.is_empty() {
348            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
349            for info in &self.recommended {
350                entries.push(LanguageModelPickerEntry::Model(info.clone()));
351            }
352        }
353
354        for models in self.all.values() {
355            if models.is_empty() {
356                continue;
357            }
358            entries.push(LanguageModelPickerEntry::Separator(
359                models[0].model.provider_name().0,
360            ));
361            for info in models {
362                entries.push(LanguageModelPickerEntry::Model(info.clone()));
363            }
364        }
365
366        entries
367    }
368}
369
370enum LanguageModelPickerEntry {
371    Model(ModelInfo),
372    Separator(SharedString),
373}
374
375struct ModelMatcher {
376    models: Vec<ModelInfo>,
377    bg_executor: BackgroundExecutor,
378    candidates: Vec<StringMatchCandidate>,
379}
380
381impl ModelMatcher {
382    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
383        let candidates = Self::make_match_candidates(&models);
384        Self {
385            models,
386            bg_executor,
387            candidates,
388        }
389    }
390
391    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
392        let mut matches = self.bg_executor.block(match_strings(
393            &self.candidates,
394            query,
395            false,
396            true,
397            100,
398            &Default::default(),
399            self.bg_executor.clone(),
400        ));
401
402        let sorting_key = |mat: &StringMatch| {
403            let candidate = &self.candidates[mat.candidate_id];
404            (Reverse(OrderedFloat(mat.score)), candidate.id)
405        };
406        matches.sort_unstable_by_key(sorting_key);
407
408        let matched_models: Vec<_> = matches
409            .into_iter()
410            .map(|mat| self.models[mat.candidate_id].clone())
411            .collect();
412
413        matched_models
414    }
415
416    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
417        self.models
418            .iter()
419            .filter(|m| {
420                m.model
421                    .name()
422                    .0
423                    .to_lowercase()
424                    .contains(&query.to_lowercase())
425            })
426            .cloned()
427            .collect::<Vec<_>>()
428    }
429
430    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
431        model_infos
432            .iter()
433            .enumerate()
434            .map(|(index, model)| {
435                StringMatchCandidate::new(
436                    index,
437                    &format!(
438                        "{}/{}",
439                        &model.model.provider_name().0,
440                        &model.model.name().0
441                    ),
442                )
443            })
444            .collect::<Vec<_>>()
445    }
446}
447
448impl PickerDelegate for LanguageModelPickerDelegate {
449    type ListItem = AnyElement;
450
451    fn match_count(&self) -> usize {
452        self.filtered_entries.len()
453    }
454
455    fn selected_index(&self) -> usize {
456        self.selected_index
457    }
458
459    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
460        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
461        cx.notify();
462    }
463
464    fn can_select(
465        &mut self,
466        ix: usize,
467        _window: &mut Window,
468        _cx: &mut Context<Picker<Self>>,
469    ) -> bool {
470        match self.filtered_entries.get(ix) {
471            Some(LanguageModelPickerEntry::Model(_)) => true,
472            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
473        }
474    }
475
476    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
477        "Select a model…".into()
478    }
479
480    fn update_matches(
481        &mut self,
482        query: String,
483        window: &mut Window,
484        cx: &mut Context<Picker<Self>>,
485    ) -> Task<()> {
486        let all_models = self.all_models.clone();
487        let active_model = (self.get_active_model)(cx);
488        let bg_executor = cx.background_executor();
489
490        let language_model_registry = LanguageModelRegistry::global(cx);
491
492        let configured_providers = language_model_registry
493            .read(cx)
494            .visible_providers()
495            .into_iter()
496            .filter(|provider| provider.is_authenticated(cx))
497            .collect::<Vec<_>>();
498
499        let configured_provider_ids = configured_providers
500            .iter()
501            .map(|provider| provider.id())
502            .collect::<Vec<_>>();
503
504        let recommended_models = all_models
505            .recommended
506            .iter()
507            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
508            .cloned()
509            .collect::<Vec<_>>();
510
511        let available_models = all_models
512            .all
513            .values()
514            .flat_map(|models| models.iter())
515            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
516            .cloned()
517            .collect::<Vec<_>>();
518
519        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
520        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
521
522        let recommended = matcher_rec.exact_search(&query);
523        let all = matcher_all.fuzzy_search(&query);
524
525        let filtered_models = GroupedModels::new(all, recommended);
526
527        cx.spawn_in(window, async move |this, cx| {
528            this.update_in(cx, |this, window, cx| {
529                this.delegate.filtered_entries = filtered_models.entries();
530                // Finds the currently selected model in the list
531                let new_index =
532                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
533                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
534                cx.notify();
535            })
536            .ok();
537        })
538    }
539
540    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
541        if let Some(LanguageModelPickerEntry::Model(model_info)) =
542            self.filtered_entries.get(self.selected_index)
543        {
544            let model = model_info.model.clone();
545            (self.on_model_changed)(model.clone(), cx);
546
547            let current_index = self.selected_index;
548            self.set_selected_index(current_index, window, cx);
549
550            cx.emit(DismissEvent);
551        }
552    }
553
554    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
555        cx.emit(DismissEvent);
556    }
557
558    fn render_match(
559        &self,
560        ix: usize,
561        selected: bool,
562        _: &mut Window,
563        cx: &mut Context<Picker<Self>>,
564    ) -> Option<Self::ListItem> {
565        match self.filtered_entries.get(ix)? {
566            LanguageModelPickerEntry::Separator(title) => {
567                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
568            }
569            LanguageModelPickerEntry::Model(model_info) => {
570                let active_model = (self.get_active_model)(cx);
571                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
572                let active_model_id = active_model.map(|m| m.model.id());
573
574                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
575                    && Some(model_info.model.id()) == active_model_id;
576
577                let is_favorite = model_info.is_favorite;
578                let handle_action_click = {
579                    let model = model_info.model.clone();
580                    let on_toggle_favorite = self.on_toggle_favorite.clone();
581                    move |cx: &App| on_toggle_favorite(model.clone(), !is_favorite, cx)
582                };
583
584                Some(
585                    ModelSelectorListItem::new(ix, model_info.model.name().0)
586                        .map(|this| match &model_info.icon {
587                            IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
588                            IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
589                        })
590                        .is_selected(is_selected)
591                        .is_focused(selected)
592                        .is_favorite(is_favorite)
593                        .on_toggle_favorite(handle_action_click)
594                        .into_any_element(),
595                )
596            }
597        }
598    }
599
600    fn render_footer(
601        &self,
602        _window: &mut Window,
603        _cx: &mut Context<Picker<Self>>,
604    ) -> Option<gpui::AnyElement> {
605        let focus_handle = self.focus_handle.clone();
606
607        if !self.popover_styles {
608            return None;
609        }
610
611        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
612    }
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618    use futures::{future::BoxFuture, stream::BoxStream};
619    use gpui::{AsyncApp, TestAppContext, http_client};
620    use language_model::{
621        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
622        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
623        LanguageModelRequest, LanguageModelToolChoice,
624    };
625    use ui::IconName;
626
627    #[derive(Clone)]
628    struct TestLanguageModel {
629        name: LanguageModelName,
630        id: LanguageModelId,
631        provider_id: LanguageModelProviderId,
632        provider_name: LanguageModelProviderName,
633    }
634
635    impl TestLanguageModel {
636        fn new(name: &str, provider: &str) -> Self {
637            Self {
638                name: LanguageModelName::from(name.to_string()),
639                id: LanguageModelId::from(name.to_string()),
640                provider_id: LanguageModelProviderId::from(provider.to_string()),
641                provider_name: LanguageModelProviderName::from(provider.to_string()),
642            }
643        }
644    }
645
646    impl LanguageModel for TestLanguageModel {
647        fn id(&self) -> LanguageModelId {
648            self.id.clone()
649        }
650
651        fn name(&self) -> LanguageModelName {
652            self.name.clone()
653        }
654
655        fn provider_id(&self) -> LanguageModelProviderId {
656            self.provider_id.clone()
657        }
658
659        fn provider_name(&self) -> LanguageModelProviderName {
660            self.provider_name.clone()
661        }
662
663        fn supports_tools(&self) -> bool {
664            false
665        }
666
667        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
668            false
669        }
670
671        fn supports_images(&self) -> bool {
672            false
673        }
674
675        fn telemetry_id(&self) -> String {
676            format!("{}/{}", self.provider_id.0, self.name.0)
677        }
678
679        fn max_token_count(&self) -> u64 {
680            1000
681        }
682
683        fn count_tokens(
684            &self,
685            _: LanguageModelRequest,
686            _: &App,
687        ) -> BoxFuture<'static, http_client::Result<u64>> {
688            unimplemented!()
689        }
690
691        fn stream_completion(
692            &self,
693            _: LanguageModelRequest,
694            _: &AsyncApp,
695        ) -> BoxFuture<
696            'static,
697            Result<
698                BoxStream<
699                    'static,
700                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
701                >,
702                LanguageModelCompletionError,
703            >,
704        > {
705            unimplemented!()
706        }
707    }
708
709    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
710        create_models_with_favorites(model_specs, vec![])
711    }
712
713    fn create_models_with_favorites(
714        model_specs: Vec<(&str, &str)>,
715        favorites: Vec<(&str, &str)>,
716    ) -> Vec<ModelInfo> {
717        model_specs
718            .into_iter()
719            .map(|(provider, name)| {
720                let is_favorite = favorites
721                    .iter()
722                    .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
723                ModelInfo {
724                    model: Arc::new(TestLanguageModel::new(name, provider)),
725                    icon: IconOrSvg::Icon(IconName::Ai),
726                    is_favorite,
727                }
728            })
729            .collect()
730    }
731
732    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
733        assert_eq!(
734            result.len(),
735            expected.len(),
736            "Number of models doesn't match"
737        );
738
739        for (i, expected_name) in expected.iter().enumerate() {
740            assert_eq!(
741                result[i].model.telemetry_id(),
742                *expected_name,
743                "Model at position {} doesn't match expected model",
744                i
745            );
746        }
747    }
748
749    #[gpui::test]
750    fn test_exact_match(cx: &mut TestAppContext) {
751        let models = create_models(vec![
752            ("zed", "Claude 3.7 Sonnet"),
753            ("zed", "Claude 3.7 Sonnet Thinking"),
754            ("zed", "gpt-4.1"),
755            ("zed", "gpt-4.1-nano"),
756            ("openai", "gpt-3.5-turbo"),
757            ("openai", "gpt-4.1"),
758            ("openai", "gpt-4.1-nano"),
759            ("ollama", "mistral"),
760            ("ollama", "deepseek"),
761        ]);
762        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
763
764        // The order of models should be maintained, case doesn't matter
765        let results = matcher.exact_search("GPT-4.1");
766        assert_models_eq(
767            results,
768            vec![
769                "zed/gpt-4.1",
770                "zed/gpt-4.1-nano",
771                "openai/gpt-4.1",
772                "openai/gpt-4.1-nano",
773            ],
774        );
775    }
776
777    #[gpui::test]
778    fn test_fuzzy_match(cx: &mut TestAppContext) {
779        let models = create_models(vec![
780            ("zed", "Claude 3.7 Sonnet"),
781            ("zed", "Claude 3.7 Sonnet Thinking"),
782            ("zed", "gpt-4.1"),
783            ("zed", "gpt-4.1-nano"),
784            ("openai", "gpt-3.5-turbo"),
785            ("openai", "gpt-4.1"),
786            ("openai", "gpt-4.1-nano"),
787            ("ollama", "mistral"),
788            ("ollama", "deepseek"),
789        ]);
790        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
791
792        // Results should preserve models order whenever possible.
793        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
794        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
795        // so it should appear first in the results.
796        let results = matcher.fuzzy_search("41");
797        assert_models_eq(
798            results,
799            vec![
800                "zed/gpt-4.1",
801                "openai/gpt-4.1",
802                "zed/gpt-4.1-nano",
803                "openai/gpt-4.1-nano",
804            ],
805        );
806
807        // Model provider should be searchable as well
808        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
809        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
810
811        // Fuzzy search
812        let results = matcher.fuzzy_search("z4n");
813        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
814    }
815
816    #[gpui::test]
817    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
818        let recommended_models = create_models(vec![("zed", "claude")]);
819        let all_models = create_models(vec![
820            ("zed", "claude"), // Should also appear in "other"
821            ("zed", "gemini"),
822            ("copilot", "o3"),
823        ]);
824
825        let grouped_models = GroupedModels::new(all_models, recommended_models);
826
827        let actual_all_models = grouped_models
828            .all
829            .values()
830            .flatten()
831            .cloned()
832            .collect::<Vec<_>>();
833
834        // Recommended models should also appear in "all"
835        assert_models_eq(
836            actual_all_models,
837            vec!["zed/claude", "zed/gemini", "copilot/o3"],
838        );
839    }
840
841    #[gpui::test]
842    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
843        let recommended_models = create_models(vec![("zed", "claude")]);
844        let all_models = create_models(vec![
845            ("zed", "claude"), // Should also appear in "other"
846            ("zed", "gemini"),
847            ("copilot", "claude"), // Different provider, should appear in "other"
848        ]);
849
850        let grouped_models = GroupedModels::new(all_models, recommended_models);
851
852        let actual_all_models = grouped_models
853            .all
854            .values()
855            .flatten()
856            .cloned()
857            .collect::<Vec<_>>();
858
859        // All models should appear in "all" regardless of recommended status
860        assert_models_eq(
861            actual_all_models,
862            vec!["zed/claude", "zed/gemini", "copilot/claude"],
863        );
864    }
865
866    #[gpui::test]
867    fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
868        let recommended_models = create_models(vec![("zed", "claude")]);
869        let all_models = create_models_with_favorites(
870            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
871            vec![("zed", "gemini")],
872        );
873
874        let grouped_models = GroupedModels::new(all_models, recommended_models);
875        let entries = grouped_models.entries();
876
877        assert!(matches!(
878            entries.first(),
879            Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
880        ));
881
882        assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
883    }
884
885    #[gpui::test]
886    fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
887        let recommended_models = create_models(vec![("zed", "claude")]);
888        let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
889
890        let grouped_models = GroupedModels::new(all_models, recommended_models);
891        let entries = grouped_models.entries();
892
893        assert!(matches!(
894            entries.first(),
895            Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
896        ));
897
898        assert!(grouped_models.favorites.is_empty());
899    }
900
901    #[gpui::test]
902    fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
903        let recommended_models =
904            create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
905        let all_models = create_models_with_favorites(
906            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
907            vec![("zed", "claude")],
908        );
909
910        let grouped_models = GroupedModels::new(all_models, recommended_models);
911        let entries = grouped_models.entries();
912
913        for entry in &entries {
914            if let LanguageModelPickerEntry::Model(info) = entry {
915                if info.model.telemetry_id() == "zed/claude" {
916                    assert!(info.is_favorite, "zed/claude should be a favorite");
917                } else {
918                    assert!(
919                        !info.is_favorite,
920                        "{} should not be a favorite",
921                        info.model.telemetry_id()
922                    );
923                }
924            }
925        }
926    }
927
928    #[gpui::test]
929    fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
930        let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
931
932        let recommended_models =
933            create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
934
935        let all_models = create_models_with_favorites(
936            vec![
937                ("zed", "claude"),
938                ("zed", "gemini"),
939                ("openai", "gpt-4"),
940                ("openai", "gpt-3.5"),
941            ],
942            favorites,
943        );
944
945        let grouped_models = GroupedModels::new(all_models, recommended_models);
946
947        assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
948        assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
949        assert_models_eq(
950            grouped_models.all.values().flatten().cloned().collect(),
951            vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
952        );
953    }
954}