language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use agent_settings::AgentSettings;
  4use collections::{HashMap, HashSet, IndexMap};
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{
  7    Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, ForegroundExecutor,
  8    Subscription, Task,
  9};
 10use language_model::{
 11    AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId,
 12    LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
 13};
 14use ordered_float::OrderedFloat;
 15use picker::{Picker, PickerDelegate};
 16use settings::Settings;
 17use ui::prelude::*;
 18use zed_actions::agent::OpenSettings;
 19
 20use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 21
 22type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 23type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 24type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static>;
 25
 26pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 27
 28pub fn language_model_selector(
 29    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 30    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 31    on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
 32    popover_styles: bool,
 33    focus_handle: FocusHandle,
 34    window: &mut Window,
 35    cx: &mut Context<LanguageModelSelector>,
 36) -> LanguageModelSelector {
 37    let delegate = LanguageModelPickerDelegate::new(
 38        get_active_model,
 39        on_model_changed,
 40        on_toggle_favorite,
 41        popover_styles,
 42        focus_handle,
 43        window,
 44        cx,
 45    );
 46
 47    if popover_styles {
 48        Picker::list(delegate, window, cx)
 49            .show_scrollbar(true)
 50            .width(rems(20.))
 51            .max_height(Some(rems(20.).into()))
 52    } else {
 53        Picker::list(delegate, window, cx).show_scrollbar(true)
 54    }
 55}
 56
 57fn all_models(cx: &App) -> GroupedModels {
 58    let lm_registry = LanguageModelRegistry::global(cx).read(cx);
 59    let providers = lm_registry.visible_providers();
 60
 61    let mut favorites_index = FavoritesIndex::default();
 62
 63    for sel in &AgentSettings::get_global(cx).favorite_models {
 64        favorites_index
 65            .entry(sel.provider.0.clone().into())
 66            .or_default()
 67            .insert(sel.model.clone().into());
 68    }
 69
 70    let recommended = providers
 71        .iter()
 72        .flat_map(|provider| {
 73            provider
 74                .recommended_models(cx)
 75                .into_iter()
 76                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 77        })
 78        .collect();
 79
 80    let all = providers
 81        .iter()
 82        .flat_map(|provider| {
 83            provider
 84                .provided_models(cx)
 85                .into_iter()
 86                .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
 87        })
 88        .collect();
 89
 90    GroupedModels::new(all, recommended)
 91}
 92
 93type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
 94
 95#[derive(Clone)]
 96struct ModelInfo {
 97    model: Arc<dyn LanguageModel>,
 98    icon: IconOrSvg,
 99    is_favorite: bool,
100}
101
102impl ModelInfo {
103    fn new(
104        provider: &dyn LanguageModelProvider,
105        model: Arc<dyn LanguageModel>,
106        favorites_index: &FavoritesIndex,
107    ) -> Self {
108        let is_favorite = favorites_index
109            .get(&provider.id())
110            .map_or(false, |set| set.contains(&model.id()));
111
112        Self {
113            model,
114            icon: provider.icon(),
115            is_favorite,
116        }
117    }
118}
119
120pub struct LanguageModelPickerDelegate {
121    on_model_changed: OnModelChanged,
122    get_active_model: GetActiveModel,
123    on_toggle_favorite: OnToggleFavorite,
124    all_models: Arc<GroupedModels>,
125    filtered_entries: Vec<LanguageModelPickerEntry>,
126    selected_index: usize,
127    _authenticate_all_providers_task: Task<()>,
128    _subscriptions: Vec<Subscription>,
129    popover_styles: bool,
130    focus_handle: FocusHandle,
131}
132
133impl LanguageModelPickerDelegate {
134    fn new(
135        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
136        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
137        on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
138        popover_styles: bool,
139        focus_handle: FocusHandle,
140        window: &mut Window,
141        cx: &mut Context<Picker<Self>>,
142    ) -> Self {
143        let on_model_changed = Arc::new(on_model_changed);
144        let models = all_models(cx);
145        let entries = models.entries();
146
147        Self {
148            on_model_changed,
149            all_models: Arc::new(models),
150            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
151            filtered_entries: entries,
152            get_active_model: Arc::new(get_active_model),
153            on_toggle_favorite: Arc::new(on_toggle_favorite),
154            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
155            _subscriptions: vec![cx.subscribe_in(
156                &LanguageModelRegistry::global(cx),
157                window,
158                |picker, _, event, window, cx| {
159                    match event {
160                        language_model::Event::ProviderStateChanged(_)
161                        | language_model::Event::AddedProvider(_)
162                        | language_model::Event::RemovedProvider(_) => {
163                            let query = picker.query(cx);
164                            picker.delegate.all_models = Arc::new(all_models(cx));
165                            // Update matches will automatically drop the previous task
166                            // if we get a provider event again
167                            picker.update_matches(query, window, cx)
168                        }
169                        _ => {}
170                    }
171                },
172            )],
173            popover_styles,
174            focus_handle,
175        }
176    }
177
178    fn get_active_model_index(
179        entries: &[LanguageModelPickerEntry],
180        active_model: Option<ConfiguredModel>,
181    ) -> usize {
182        entries
183            .iter()
184            .position(|entry| {
185                if let LanguageModelPickerEntry::Model(model) = entry {
186                    active_model
187                        .as_ref()
188                        .map(|active_model| {
189                            active_model.model.id() == model.model.id()
190                                && active_model.provider.id() == model.model.provider_id()
191                        })
192                        .unwrap_or_default()
193                } else {
194                    false
195                }
196            })
197            .unwrap_or(0)
198    }
199
200    /// Authenticates all providers in the [`LanguageModelRegistry`].
201    ///
202    /// We do this so that we can populate the language selector with all of the
203    /// models from the configured providers.
204    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
205        let authenticate_all_providers = LanguageModelRegistry::global(cx)
206            .read(cx)
207            .visible_providers()
208            .iter()
209            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
210            .collect::<Vec<_>>();
211
212        cx.spawn(async move |_cx| {
213            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
214                if let Err(err) = authenticate_task.await {
215                    if matches!(err, AuthenticateError::CredentialsNotFound) {
216                        // Since we're authenticating these providers in the
217                        // background for the purposes of populating the
218                        // language selector, we don't care about providers
219                        // where the credentials are not found.
220                    } else {
221                        // Some providers have noisy failure states that we
222                        // don't want to spam the logs with every time the
223                        // language model selector is initialized.
224                        //
225                        // Ideally these should have more clear failure modes
226                        // that we know are safe to ignore here, like what we do
227                        // with `CredentialsNotFound` above.
228                        match provider_id.0.as_ref() {
229                            "lmstudio" | "ollama" => {
230                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
231                                //
232                                // These fail noisily, so we don't log them.
233                            }
234                            "copilot_chat" => {
235                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
236                            }
237                            _ => {
238                                log::error!(
239                                    "Failed to authenticate provider: {}: {err:#}",
240                                    provider_name.0
241                                );
242                            }
243                        }
244                    }
245                }
246            }
247        })
248    }
249
250    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
251        (self.get_active_model)(cx)
252    }
253
254    pub fn favorites_count(&self) -> usize {
255        self.all_models.favorites.len()
256    }
257
258    pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
259        if self.all_models.favorites.is_empty() {
260            return;
261        }
262
263        let active_model = (self.get_active_model)(cx);
264        let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
265        let active_model_id = active_model.as_ref().map(|m| m.model.id());
266
267        let current_index = self
268            .all_models
269            .favorites
270            .iter()
271            .position(|info| {
272                Some(info.model.provider_id()) == active_provider_id
273                    && Some(info.model.id()) == active_model_id
274            })
275            .unwrap_or(usize::MAX);
276
277        let next_index = if current_index == usize::MAX {
278            0
279        } else {
280            (current_index + 1) % self.all_models.favorites.len()
281        };
282
283        let next_model = self.all_models.favorites[next_index].model.clone();
284
285        (self.on_model_changed)(next_model, cx);
286
287        // Align the picker selection with the newly-active model
288        let new_index =
289            Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
290        self.set_selected_index(new_index, window, cx);
291    }
292}
293
294struct GroupedModels {
295    favorites: Vec<ModelInfo>,
296    recommended: Vec<ModelInfo>,
297    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
298}
299
300impl GroupedModels {
301    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
302        let favorites = all
303            .iter()
304            .filter(|info| info.is_favorite)
305            .cloned()
306            .collect();
307
308        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
309        for model in all {
310            let provider = model.model.provider_id();
311            if let Some(models) = all_by_provider.get_mut(&provider) {
312                models.push(model);
313            } else {
314                all_by_provider.insert(provider, vec![model]);
315            }
316        }
317
318        Self {
319            favorites,
320            recommended,
321            all: all_by_provider,
322        }
323    }
324
325    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
326        let mut entries = Vec::new();
327
328        if !self.favorites.is_empty() {
329            entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
330            for info in &self.favorites {
331                entries.push(LanguageModelPickerEntry::Model(info.clone()));
332            }
333        }
334
335        if !self.recommended.is_empty() {
336            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
337            for info in &self.recommended {
338                entries.push(LanguageModelPickerEntry::Model(info.clone()));
339            }
340        }
341
342        for models in self.all.values() {
343            if models.is_empty() {
344                continue;
345            }
346            entries.push(LanguageModelPickerEntry::Separator(
347                models[0].model.provider_name().0,
348            ));
349            for info in models {
350                entries.push(LanguageModelPickerEntry::Model(info.clone()));
351            }
352        }
353
354        entries
355    }
356}
357
358enum LanguageModelPickerEntry {
359    Model(ModelInfo),
360    Separator(SharedString),
361}
362
363struct ModelMatcher {
364    models: Vec<ModelInfo>,
365    fg_executor: ForegroundExecutor,
366    bg_executor: BackgroundExecutor,
367    candidates: Vec<StringMatchCandidate>,
368}
369
370impl ModelMatcher {
371    fn new(
372        models: Vec<ModelInfo>,
373        fg_executor: ForegroundExecutor,
374        bg_executor: BackgroundExecutor,
375    ) -> ModelMatcher {
376        let candidates = Self::make_match_candidates(&models);
377        Self {
378            models,
379            fg_executor,
380            bg_executor,
381            candidates,
382        }
383    }
384
385    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
386        let mut matches = self.fg_executor.block_on(match_strings(
387            &self.candidates,
388            query,
389            false,
390            true,
391            100,
392            &Default::default(),
393            self.bg_executor.clone(),
394        ));
395
396        let sorting_key = |mat: &StringMatch| {
397            let candidate = &self.candidates[mat.candidate_id];
398            (Reverse(OrderedFloat(mat.score)), candidate.id)
399        };
400        matches.sort_unstable_by_key(sorting_key);
401
402        let matched_models: Vec<_> = matches
403            .into_iter()
404            .map(|mat| self.models[mat.candidate_id].clone())
405            .collect();
406
407        matched_models
408    }
409
410    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
411        self.models
412            .iter()
413            .filter(|m| {
414                m.model
415                    .name()
416                    .0
417                    .to_lowercase()
418                    .contains(&query.to_lowercase())
419            })
420            .cloned()
421            .collect::<Vec<_>>()
422    }
423
424    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
425        model_infos
426            .iter()
427            .enumerate()
428            .map(|(index, model)| {
429                StringMatchCandidate::new(
430                    index,
431                    &format!(
432                        "{}/{}",
433                        &model.model.provider_name().0,
434                        &model.model.name().0
435                    ),
436                )
437            })
438            .collect::<Vec<_>>()
439    }
440}
441
442impl PickerDelegate for LanguageModelPickerDelegate {
443    type ListItem = AnyElement;
444
445    fn match_count(&self) -> usize {
446        self.filtered_entries.len()
447    }
448
449    fn selected_index(&self) -> usize {
450        self.selected_index
451    }
452
453    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
454        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
455        cx.notify();
456    }
457
458    fn can_select(
459        &mut self,
460        ix: usize,
461        _window: &mut Window,
462        _cx: &mut Context<Picker<Self>>,
463    ) -> bool {
464        match self.filtered_entries.get(ix) {
465            Some(LanguageModelPickerEntry::Model(_)) => true,
466            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
467        }
468    }
469
470    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
471        "Select a model…".into()
472    }
473
474    fn update_matches(
475        &mut self,
476        query: String,
477        window: &mut Window,
478        cx: &mut Context<Picker<Self>>,
479    ) -> Task<()> {
480        let all_models = self.all_models.clone();
481        let active_model = (self.get_active_model)(cx);
482        let fg_executor = cx.foreground_executor();
483        let bg_executor = cx.background_executor();
484
485        let language_model_registry = LanguageModelRegistry::global(cx);
486
487        let configured_providers = language_model_registry
488            .read(cx)
489            .visible_providers()
490            .into_iter()
491            .filter(|provider| provider.is_authenticated(cx))
492            .collect::<Vec<_>>();
493
494        let configured_provider_ids = configured_providers
495            .iter()
496            .map(|provider| provider.id())
497            .collect::<Vec<_>>();
498
499        let recommended_models = all_models
500            .recommended
501            .iter()
502            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
503            .cloned()
504            .collect::<Vec<_>>();
505
506        let available_models = all_models
507            .all
508            .values()
509            .flat_map(|models| models.iter())
510            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
511            .cloned()
512            .collect::<Vec<_>>();
513
514        let matcher_rec =
515            ModelMatcher::new(recommended_models, fg_executor.clone(), bg_executor.clone());
516        let matcher_all =
517            ModelMatcher::new(available_models, fg_executor.clone(), bg_executor.clone());
518
519        let recommended = matcher_rec.exact_search(&query);
520        let all = matcher_all.fuzzy_search(&query);
521
522        let filtered_models = GroupedModels::new(all, recommended);
523
524        cx.spawn_in(window, async move |this, cx| {
525            this.update_in(cx, |this, window, cx| {
526                this.delegate.filtered_entries = filtered_models.entries();
527                // Finds the currently selected model in the list
528                let new_index =
529                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
530                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
531                cx.notify();
532            })
533            .ok();
534        })
535    }
536
537    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
538        if let Some(LanguageModelPickerEntry::Model(model_info)) =
539            self.filtered_entries.get(self.selected_index)
540        {
541            let model = model_info.model.clone();
542            (self.on_model_changed)(model.clone(), cx);
543
544            let current_index = self.selected_index;
545            self.set_selected_index(current_index, window, cx);
546
547            cx.emit(DismissEvent);
548        }
549    }
550
551    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
552        cx.emit(DismissEvent);
553    }
554
555    fn render_match(
556        &self,
557        ix: usize,
558        selected: bool,
559        _: &mut Window,
560        cx: &mut Context<Picker<Self>>,
561    ) -> Option<Self::ListItem> {
562        match self.filtered_entries.get(ix)? {
563            LanguageModelPickerEntry::Separator(title) => {
564                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
565            }
566            LanguageModelPickerEntry::Model(model_info) => {
567                let active_model = (self.get_active_model)(cx);
568                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
569                let active_model_id = active_model.map(|m| m.model.id());
570
571                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
572                    && Some(model_info.model.id()) == active_model_id;
573
574                let model_cost = model_info
575                    .model
576                    .model_cost_info()
577                    .map(|cost| cost.to_shared_string());
578
579                let is_favorite = model_info.is_favorite;
580                let handle_action_click = {
581                    let model = model_info.model.clone();
582                    let on_toggle_favorite = self.on_toggle_favorite.clone();
583                    cx.listener(move |picker, _, window, cx| {
584                        on_toggle_favorite(model.clone(), !is_favorite, cx);
585                        picker.refresh(window, cx);
586                    })
587                };
588
589                Some(
590                    ModelSelectorListItem::new(ix, model_info.model.name().0)
591                        .map(|this| match &model_info.icon {
592                            IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
593                            IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
594                        })
595                        .is_selected(is_selected)
596                        .is_focused(selected)
597                        .is_latest(model_info.model.is_latest())
598                        .is_favorite(is_favorite)
599                        .cost_info(model_cost)
600                        .on_toggle_favorite(handle_action_click)
601                        .into_any_element(),
602                )
603            }
604        }
605    }
606
607    fn render_footer(
608        &self,
609        _window: &mut Window,
610        _cx: &mut Context<Picker<Self>>,
611    ) -> Option<gpui::AnyElement> {
612        let focus_handle = self.focus_handle.clone();
613
614        if !self.popover_styles {
615            return None;
616        }
617
618        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625    use futures::{future::BoxFuture, stream::BoxStream};
626    use gpui::{AsyncApp, TestAppContext, http_client};
627    use language_model::{
628        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
629        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
630        LanguageModelRequest, LanguageModelToolChoice,
631    };
632    use ui::IconName;
633
634    #[derive(Clone)]
635    struct TestLanguageModel {
636        name: LanguageModelName,
637        id: LanguageModelId,
638        provider_id: LanguageModelProviderId,
639        provider_name: LanguageModelProviderName,
640    }
641
642    impl TestLanguageModel {
643        fn new(name: &str, provider: &str) -> Self {
644            Self {
645                name: LanguageModelName::from(name.to_string()),
646                id: LanguageModelId::from(name.to_string()),
647                provider_id: LanguageModelProviderId::from(provider.to_string()),
648                provider_name: LanguageModelProviderName::from(provider.to_string()),
649            }
650        }
651    }
652
653    impl LanguageModel for TestLanguageModel {
654        fn id(&self) -> LanguageModelId {
655            self.id.clone()
656        }
657
658        fn name(&self) -> LanguageModelName {
659            self.name.clone()
660        }
661
662        fn provider_id(&self) -> LanguageModelProviderId {
663            self.provider_id.clone()
664        }
665
666        fn provider_name(&self) -> LanguageModelProviderName {
667            self.provider_name.clone()
668        }
669
670        fn supports_tools(&self) -> bool {
671            false
672        }
673
674        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
675            false
676        }
677
678        fn supports_images(&self) -> bool {
679            false
680        }
681
682        fn telemetry_id(&self) -> String {
683            format!("{}/{}", self.provider_id.0, self.name.0)
684        }
685
686        fn max_token_count(&self) -> u64 {
687            1000
688        }
689
690        fn count_tokens(
691            &self,
692            _: LanguageModelRequest,
693            _: &App,
694        ) -> BoxFuture<'static, http_client::Result<u64>> {
695            unimplemented!()
696        }
697
698        fn stream_completion(
699            &self,
700            _: LanguageModelRequest,
701            _: &AsyncApp,
702        ) -> BoxFuture<
703            'static,
704            Result<
705                BoxStream<
706                    'static,
707                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
708                >,
709                LanguageModelCompletionError,
710            >,
711        > {
712            unimplemented!()
713        }
714    }
715
716    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
717        create_models_with_favorites(model_specs, vec![])
718    }
719
720    fn create_models_with_favorites(
721        model_specs: Vec<(&str, &str)>,
722        favorites: Vec<(&str, &str)>,
723    ) -> Vec<ModelInfo> {
724        model_specs
725            .into_iter()
726            .map(|(provider, name)| {
727                let is_favorite = favorites
728                    .iter()
729                    .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
730                ModelInfo {
731                    model: Arc::new(TestLanguageModel::new(name, provider)),
732                    icon: IconOrSvg::Icon(IconName::Ai),
733                    is_favorite,
734                }
735            })
736            .collect()
737    }
738
739    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
740        assert_eq!(
741            result.len(),
742            expected.len(),
743            "Number of models doesn't match"
744        );
745
746        for (i, expected_name) in expected.iter().enumerate() {
747            assert_eq!(
748                result[i].model.telemetry_id(),
749                *expected_name,
750                "Model at position {} doesn't match expected model",
751                i
752            );
753        }
754    }
755
756    #[gpui::test]
757    fn test_exact_match(cx: &mut TestAppContext) {
758        let models = create_models(vec![
759            ("zed", "Claude 3.7 Sonnet"),
760            ("zed", "Claude 3.7 Sonnet Thinking"),
761            ("zed", "gpt-5"),
762            ("zed", "gpt-5-mini"),
763            ("openai", "gpt-3.5-turbo"),
764            ("openai", "gpt-5"),
765            ("openai", "gpt-5-mini"),
766            ("ollama", "mistral"),
767            ("ollama", "deepseek"),
768        ]);
769        let matcher = ModelMatcher::new(
770            models,
771            cx.foreground_executor().clone(),
772            cx.background_executor.clone(),
773        );
774
775        // The order of models should be maintained, case doesn't matter
776        let results = matcher.exact_search("GPT-5");
777        assert_models_eq(
778            results,
779            vec![
780                "zed/gpt-5",
781                "zed/gpt-5-mini",
782                "openai/gpt-5",
783                "openai/gpt-5-mini",
784            ],
785        );
786    }
787
788    #[gpui::test]
789    fn test_fuzzy_match(cx: &mut TestAppContext) {
790        let models = create_models(vec![
791            ("zed", "Claude 3.7 Sonnet"),
792            ("zed", "Claude 3.7 Sonnet Thinking"),
793            ("zed", "gpt-5"),
794            ("zed", "gpt-5-mini"),
795            ("openai", "gpt-3.5-turbo"),
796            ("openai", "gpt-5"),
797            ("openai", "gpt-5-mini"),
798            ("ollama", "mistral"),
799            ("ollama", "deepseek"),
800        ]);
801        let matcher = ModelMatcher::new(
802            models,
803            cx.foreground_executor().clone(),
804            cx.background_executor.clone(),
805        );
806
807        // Results should preserve models order whenever possible.
808        // In the case below, `zed/gpt-5-mini` and `openai/gpt-5-mini` have identical
809        // similarity scores, but `zed/gpt-5-mini` was higher in the models list,
810        // so it should appear first in the results.
811        let results = matcher.fuzzy_search("mini");
812        assert_models_eq(results, vec!["zed/gpt-5-mini", "openai/gpt-5-mini"]);
813
814        // Model provider should be searchable as well
815        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
816        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
817
818        // Fuzzy search - search for Claude to get the Thinking variant
819        let results = matcher.fuzzy_search("thinking");
820        assert_models_eq(results, vec!["zed/Claude 3.7 Sonnet Thinking"]);
821    }
822
823    #[gpui::test]
824    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
825        let recommended_models = create_models(vec![("zed", "claude")]);
826        let all_models = create_models(vec![
827            ("zed", "claude"), // Should also appear in "other"
828            ("zed", "gemini"),
829            ("copilot", "o3"),
830        ]);
831
832        let grouped_models = GroupedModels::new(all_models, recommended_models);
833
834        let actual_all_models = grouped_models
835            .all
836            .values()
837            .flatten()
838            .cloned()
839            .collect::<Vec<_>>();
840
841        // Recommended models should also appear in "all"
842        assert_models_eq(
843            actual_all_models,
844            vec!["zed/claude", "zed/gemini", "copilot/o3"],
845        );
846    }
847
848    #[gpui::test]
849    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
850        let recommended_models = create_models(vec![("zed", "claude")]);
851        let all_models = create_models(vec![
852            ("zed", "claude"), // Should also appear in "other"
853            ("zed", "gemini"),
854            ("copilot", "claude"), // Different provider, should appear in "other"
855        ]);
856
857        let grouped_models = GroupedModels::new(all_models, recommended_models);
858
859        let actual_all_models = grouped_models
860            .all
861            .values()
862            .flatten()
863            .cloned()
864            .collect::<Vec<_>>();
865
866        // All models should appear in "all" regardless of recommended status
867        assert_models_eq(
868            actual_all_models,
869            vec!["zed/claude", "zed/gemini", "copilot/claude"],
870        );
871    }
872
873    #[gpui::test]
874    fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
875        let recommended_models = create_models(vec![("zed", "claude")]);
876        let all_models = create_models_with_favorites(
877            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
878            vec![("zed", "gemini")],
879        );
880
881        let grouped_models = GroupedModels::new(all_models, recommended_models);
882        let entries = grouped_models.entries();
883
884        assert!(matches!(
885            entries.first(),
886            Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
887        ));
888
889        assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
890    }
891
892    #[gpui::test]
893    fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
894        let recommended_models = create_models(vec![("zed", "claude")]);
895        let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
896
897        let grouped_models = GroupedModels::new(all_models, recommended_models);
898        let entries = grouped_models.entries();
899
900        assert!(matches!(
901            entries.first(),
902            Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
903        ));
904
905        assert!(grouped_models.favorites.is_empty());
906    }
907
908    #[gpui::test]
909    fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
910        let recommended_models =
911            create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
912        let all_models = create_models_with_favorites(
913            vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
914            vec![("zed", "claude")],
915        );
916
917        let grouped_models = GroupedModels::new(all_models, recommended_models);
918        let entries = grouped_models.entries();
919
920        for entry in &entries {
921            if let LanguageModelPickerEntry::Model(info) = entry {
922                if info.model.telemetry_id() == "zed/claude" {
923                    assert!(info.is_favorite, "zed/claude should be a favorite");
924                } else {
925                    assert!(
926                        !info.is_favorite,
927                        "{} should not be a favorite",
928                        info.model.telemetry_id()
929                    );
930                }
931            }
932        }
933    }
934
935    #[gpui::test]
936    fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
937        let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
938
939        let recommended_models =
940            create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
941
942        let all_models = create_models_with_favorites(
943            vec![
944                ("zed", "claude"),
945                ("zed", "gemini"),
946                ("openai", "gpt-4"),
947                ("openai", "gpt-3.5"),
948            ],
949            favorites,
950        );
951
952        let grouped_models = GroupedModels::new(all_models, recommended_models);
953
954        assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
955        assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
956        assert_models_eq(
957            grouped_models.all.values().flatten().cloned().collect(),
958            vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
959        );
960    }
961}