language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::{HashSet, IndexMap};
  4use feature_flags::ZedProFeatureFlag;
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{
  7    Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
  8    EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
  9    action_with_deprecated_aliases,
 10};
 11use language_model::{
 12    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 13    LanguageModelRegistry,
 14};
 15use ordered_float::OrderedFloat;
 16use picker::{Picker, PickerDelegate};
 17use proto::Plan;
 18use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
 19
 20action_with_deprecated_aliases!(
 21    agent,
 22    ToggleModelSelector,
 23    [
 24        "assistant::ToggleModelSelector",
 25        "assistant2::ToggleModelSelector"
 26    ]
 27);
 28
 29const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 30
 31type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 32type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 33
 34pub struct LanguageModelSelector {
 35    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 36    _authenticate_all_providers_task: Task<()>,
 37    _subscriptions: Vec<Subscription>,
 38}
 39
 40impl LanguageModelSelector {
 41    pub fn new(
 42        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 43        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 44        window: &mut Window,
 45        cx: &mut Context<Self>,
 46    ) -> Self {
 47        let on_model_changed = Arc::new(on_model_changed);
 48
 49        let all_models = Self::all_models(cx);
 50        let entries = all_models.entries();
 51
 52        let delegate = LanguageModelPickerDelegate {
 53            language_model_selector: cx.entity().downgrade(),
 54            on_model_changed: on_model_changed.clone(),
 55            all_models: Arc::new(all_models),
 56            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
 57            filtered_entries: entries,
 58            get_active_model: Arc::new(get_active_model),
 59        };
 60
 61        let picker = cx.new(|cx| {
 62            Picker::list(delegate, window, cx)
 63                .show_scrollbar(true)
 64                .width(rems(20.))
 65                .max_height(Some(rems(20.).into()))
 66        });
 67
 68        let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
 69
 70        LanguageModelSelector {
 71            picker,
 72            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 73            _subscriptions: vec![
 74                cx.subscribe_in(
 75                    &LanguageModelRegistry::global(cx),
 76                    window,
 77                    Self::handle_language_model_registry_event,
 78                ),
 79                subscription,
 80            ],
 81        }
 82    }
 83
 84    fn handle_language_model_registry_event(
 85        &mut self,
 86        _registry: &Entity<LanguageModelRegistry>,
 87        event: &language_model::Event,
 88        window: &mut Window,
 89        cx: &mut Context<Self>,
 90    ) {
 91        match event {
 92            language_model::Event::ProviderStateChanged
 93            | language_model::Event::AddedProvider(_)
 94            | language_model::Event::RemovedProvider(_) => {
 95                self.picker.update(cx, |this, cx| {
 96                    let query = this.query(cx);
 97                    this.delegate.all_models = Arc::new(Self::all_models(cx));
 98                    // Update matches will automatically drop the previous task
 99                    // if we get a provider event again
100                    this.update_matches(query, window, cx)
101                });
102            }
103            _ => {}
104        }
105    }
106
107    /// Authenticates all providers in the [`LanguageModelRegistry`].
108    ///
109    /// We do this so that we can populate the language selector with all of the
110    /// models from the configured providers.
111    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
112        let authenticate_all_providers = LanguageModelRegistry::global(cx)
113            .read(cx)
114            .providers()
115            .iter()
116            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
117            .collect::<Vec<_>>();
118
119        cx.spawn(async move |_cx| {
120            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
121                if let Err(err) = authenticate_task.await {
122                    if matches!(err, AuthenticateError::CredentialsNotFound) {
123                        // Since we're authenticating these providers in the
124                        // background for the purposes of populating the
125                        // language selector, we don't care about providers
126                        // where the credentials are not found.
127                    } else {
128                        // Some providers have noisy failure states that we
129                        // don't want to spam the logs with every time the
130                        // language model selector is initialized.
131                        //
132                        // Ideally these should have more clear failure modes
133                        // that we know are safe to ignore here, like what we do
134                        // with `CredentialsNotFound` above.
135                        match provider_id.0.as_ref() {
136                            "lmstudio" | "ollama" => {
137                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
138                                //
139                                // These fail noisily, so we don't log them.
140                            }
141                            "copilot_chat" => {
142                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
143                            }
144                            _ => {
145                                log::error!(
146                                    "Failed to authenticate provider: {}: {err}",
147                                    provider_name.0
148                                );
149                            }
150                        }
151                    }
152                }
153            }
154        })
155    }
156
157    fn all_models(cx: &App) -> GroupedModels {
158        let providers = LanguageModelRegistry::global(cx).read(cx).providers();
159
160        let recommended = providers
161            .iter()
162            .flat_map(|provider| {
163                provider
164                    .recommended_models(cx)
165                    .into_iter()
166                    .map(|model| ModelInfo {
167                        model,
168                        icon: provider.icon(),
169                    })
170            })
171            .collect();
172
173        let other = providers
174            .iter()
175            .flat_map(|provider| {
176                provider
177                    .provided_models(cx)
178                    .into_iter()
179                    .map(|model| ModelInfo {
180                        model,
181                        icon: provider.icon(),
182                    })
183            })
184            .collect();
185
186        GroupedModels::new(other, recommended)
187    }
188
189    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
190        (self.picker.read(cx).delegate.get_active_model)(cx)
191    }
192
193    fn get_active_model_index(
194        entries: &[LanguageModelPickerEntry],
195        active_model: Option<ConfiguredModel>,
196    ) -> usize {
197        entries
198            .iter()
199            .position(|entry| {
200                if let LanguageModelPickerEntry::Model(model) = entry {
201                    active_model
202                        .as_ref()
203                        .map(|active_model| {
204                            active_model.model.id() == model.model.id()
205                                && active_model.provider.id() == model.model.provider_id()
206                        })
207                        .unwrap_or_default()
208                } else {
209                    false
210                }
211            })
212            .unwrap_or(0)
213    }
214}
215
216impl EventEmitter<DismissEvent> for LanguageModelSelector {}
217
218impl Focusable for LanguageModelSelector {
219    fn focus_handle(&self, cx: &App) -> FocusHandle {
220        self.picker.focus_handle(cx)
221    }
222}
223
224impl Render for LanguageModelSelector {
225    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
226        self.picker.clone()
227    }
228}
229
230#[derive(IntoElement)]
231pub struct LanguageModelSelectorPopoverMenu<T, TT>
232where
233    T: PopoverTrigger + ButtonCommon,
234    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
235{
236    language_model_selector: Entity<LanguageModelSelector>,
237    trigger: T,
238    tooltip: TT,
239    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
240    anchor: Corner,
241}
242
243impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
244where
245    T: PopoverTrigger + ButtonCommon,
246    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
247{
248    pub fn new(
249        language_model_selector: Entity<LanguageModelSelector>,
250        trigger: T,
251        tooltip: TT,
252        anchor: Corner,
253    ) -> Self {
254        Self {
255            language_model_selector,
256            trigger,
257            tooltip,
258            handle: None,
259            anchor,
260        }
261    }
262
263    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
264        self.handle = Some(handle);
265        self
266    }
267}
268
269impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
270where
271    T: PopoverTrigger + ButtonCommon,
272    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
273{
274    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
275        let language_model_selector = self.language_model_selector.clone();
276
277        PopoverMenu::new("model-switcher")
278            .menu(move |_window, _cx| Some(language_model_selector.clone()))
279            .trigger_with_tooltip(self.trigger, self.tooltip)
280            .anchor(self.anchor)
281            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
282            .offset(gpui::Point {
283                x: px(0.0),
284                y: px(-2.0),
285            })
286    }
287}
288
289#[derive(Clone)]
290struct ModelInfo {
291    model: Arc<dyn LanguageModel>,
292    icon: IconName,
293}
294
295pub struct LanguageModelPickerDelegate {
296    language_model_selector: WeakEntity<LanguageModelSelector>,
297    on_model_changed: OnModelChanged,
298    get_active_model: GetActiveModel,
299    all_models: Arc<GroupedModels>,
300    filtered_entries: Vec<LanguageModelPickerEntry>,
301    selected_index: usize,
302}
303
304struct GroupedModels {
305    recommended: Vec<ModelInfo>,
306    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
307}
308
309impl GroupedModels {
310    pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
311        let recommended_ids = recommended
312            .iter()
313            .map(|info| (info.model.provider_id(), info.model.id()))
314            .collect::<HashSet<_>>();
315
316        let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
317        for model in other {
318            if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
319                continue;
320            }
321
322            let provider = model.model.provider_id();
323            if let Some(models) = other_by_provider.get_mut(&provider) {
324                models.push(model);
325            } else {
326                other_by_provider.insert(provider, vec![model]);
327            }
328        }
329
330        Self {
331            recommended,
332            other: other_by_provider,
333        }
334    }
335
336    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
337        let mut entries = Vec::new();
338
339        if !self.recommended.is_empty() {
340            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
341            entries.extend(
342                self.recommended
343                    .iter()
344                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
345            );
346        }
347
348        for models in self.other.values() {
349            if models.is_empty() {
350                continue;
351            }
352            entries.push(LanguageModelPickerEntry::Separator(
353                models[0].model.provider_name().0,
354            ));
355            entries.extend(
356                models
357                    .iter()
358                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
359            );
360        }
361        entries
362    }
363
364    fn model_infos(&self) -> Vec<ModelInfo> {
365        let other = self
366            .other
367            .values()
368            .flat_map(|model| model.iter())
369            .cloned()
370            .collect::<Vec<_>>();
371        self.recommended
372            .iter()
373            .chain(&other)
374            .cloned()
375            .collect::<Vec<_>>()
376    }
377}
378
379enum LanguageModelPickerEntry {
380    Model(ModelInfo),
381    Separator(SharedString),
382}
383
384struct ModelMatcher {
385    models: Vec<ModelInfo>,
386    bg_executor: BackgroundExecutor,
387    candidates: Vec<StringMatchCandidate>,
388}
389
390impl ModelMatcher {
391    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
392        let candidates = Self::make_match_candidates(&models);
393        Self {
394            models,
395            bg_executor,
396            candidates,
397        }
398    }
399
400    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
401        let mut matches = self.bg_executor.block(match_strings(
402            &self.candidates,
403            &query,
404            false,
405            100,
406            &Default::default(),
407            self.bg_executor.clone(),
408        ));
409
410        let sorting_key = |mat: &StringMatch| {
411            let candidate = &self.candidates[mat.candidate_id];
412            (Reverse(OrderedFloat(mat.score)), candidate.id)
413        };
414        matches.sort_unstable_by_key(sorting_key);
415
416        let matched_models: Vec<_> = matches
417            .into_iter()
418            .map(|mat| self.models[mat.candidate_id].clone())
419            .collect();
420
421        matched_models
422    }
423
424    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
425        self.models
426            .iter()
427            .filter(|m| {
428                m.model
429                    .name()
430                    .0
431                    .to_lowercase()
432                    .contains(&query.to_lowercase())
433            })
434            .cloned()
435            .collect::<Vec<_>>()
436    }
437
438    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
439        model_infos
440            .iter()
441            .enumerate()
442            .map(|(index, model)| {
443                StringMatchCandidate::new(
444                    index,
445                    &format!(
446                        "{}/{}",
447                        &model.model.provider_name().0,
448                        &model.model.name().0
449                    ),
450                )
451            })
452            .collect::<Vec<_>>()
453    }
454}
455
456impl PickerDelegate for LanguageModelPickerDelegate {
457    type ListItem = AnyElement;
458
459    fn match_count(&self) -> usize {
460        self.filtered_entries.len()
461    }
462
463    fn selected_index(&self) -> usize {
464        self.selected_index
465    }
466
467    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
468        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
469        cx.notify();
470    }
471
472    fn can_select(
473        &mut self,
474        ix: usize,
475        _window: &mut Window,
476        _cx: &mut Context<Picker<Self>>,
477    ) -> bool {
478        match self.filtered_entries.get(ix) {
479            Some(LanguageModelPickerEntry::Model(_)) => true,
480            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
481        }
482    }
483
484    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
485        "Select a model…".into()
486    }
487
488    fn update_matches(
489        &mut self,
490        query: String,
491        window: &mut Window,
492        cx: &mut Context<Picker<Self>>,
493    ) -> Task<()> {
494        let all_models = self.all_models.clone();
495        let current_index = self.selected_index;
496        let bg_executor = cx.background_executor();
497
498        let language_model_registry = LanguageModelRegistry::global(cx);
499
500        let configured_providers = language_model_registry
501            .read(cx)
502            .providers()
503            .into_iter()
504            .filter(|provider| provider.is_authenticated(cx))
505            .collect::<Vec<_>>();
506
507        let configured_provider_ids = configured_providers
508            .iter()
509            .map(|provider| provider.id())
510            .collect::<Vec<_>>();
511
512        let recommended_models = all_models
513            .recommended
514            .iter()
515            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
516            .cloned()
517            .collect::<Vec<_>>();
518
519        let available_models = all_models
520            .model_infos()
521            .iter()
522            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
523            .cloned()
524            .collect::<Vec<_>>();
525
526        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
527        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
528
529        let recommended = matcher_rec.exact_search(&query);
530        let all = matcher_all.fuzzy_search(&query);
531
532        let filtered_models = GroupedModels::new(all, recommended);
533
534        cx.spawn_in(window, async move |this, cx| {
535            this.update_in(cx, |this, window, cx| {
536                this.delegate.filtered_entries = filtered_models.entries();
537                // Preserve selection focus
538                let new_index = if current_index >= this.delegate.filtered_entries.len() {
539                    0
540                } else {
541                    current_index
542                };
543                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
544                cx.notify();
545            })
546            .ok();
547        })
548    }
549
550    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
551        if let Some(LanguageModelPickerEntry::Model(model_info)) =
552            self.filtered_entries.get(self.selected_index)
553        {
554            let model = model_info.model.clone();
555            (self.on_model_changed)(model.clone(), cx);
556
557            let current_index = self.selected_index;
558            self.set_selected_index(current_index, window, cx);
559
560            cx.emit(DismissEvent);
561        }
562    }
563
564    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
565        self.language_model_selector
566            .update(cx, |_this, cx| cx.emit(DismissEvent))
567            .ok();
568    }
569
570    fn render_match(
571        &self,
572        ix: usize,
573        selected: bool,
574        _: &mut Window,
575        cx: &mut Context<Picker<Self>>,
576    ) -> Option<Self::ListItem> {
577        match self.filtered_entries.get(ix)? {
578            LanguageModelPickerEntry::Separator(title) => Some(
579                div()
580                    .px_2()
581                    .pb_1()
582                    .when(ix > 1, |this| {
583                        this.mt_1()
584                            .pt_2()
585                            .border_t_1()
586                            .border_color(cx.theme().colors().border_variant)
587                    })
588                    .child(
589                        Label::new(title)
590                            .size(LabelSize::XSmall)
591                            .color(Color::Muted),
592                    )
593                    .into_any_element(),
594            ),
595            LanguageModelPickerEntry::Model(model_info) => {
596                let active_model = (self.get_active_model)(cx);
597                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
598                let active_model_id = active_model.map(|m| m.model.id());
599
600                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
601                    && Some(model_info.model.id()) == active_model_id;
602
603                let model_icon_color = if is_selected {
604                    Color::Accent
605                } else {
606                    Color::Muted
607                };
608
609                Some(
610                    ListItem::new(ix)
611                        .inset(true)
612                        .spacing(ListItemSpacing::Sparse)
613                        .toggle_state(selected)
614                        .start_slot(
615                            Icon::new(model_info.icon)
616                                .color(model_icon_color)
617                                .size(IconSize::Small),
618                        )
619                        .child(
620                            h_flex()
621                                .w_full()
622                                .pl_0p5()
623                                .gap_1p5()
624                                .w(px(240.))
625                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
626                        )
627                        .end_slot(div().pr_3().when(is_selected, |this| {
628                            this.child(
629                                Icon::new(IconName::Check)
630                                    .color(Color::Accent)
631                                    .size(IconSize::Small),
632                            )
633                        }))
634                        .into_any_element(),
635                )
636            }
637        }
638    }
639
640    fn render_footer(
641        &self,
642        _: &mut Window,
643        cx: &mut Context<Picker<Self>>,
644    ) -> Option<gpui::AnyElement> {
645        use feature_flags::FeatureFlagAppExt;
646
647        let plan = proto::Plan::ZedPro;
648
649        Some(
650            h_flex()
651                .w_full()
652                .border_t_1()
653                .border_color(cx.theme().colors().border_variant)
654                .p_1()
655                .gap_4()
656                .justify_between()
657                .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
658                    this.child(match plan {
659                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
660                            .icon(IconName::ZedAssistant)
661                            .icon_size(IconSize::Small)
662                            .icon_color(Color::Muted)
663                            .icon_position(IconPosition::Start)
664                            .on_click(|_, window, cx| {
665                                window
666                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
667                            }),
668                        Plan::Free | Plan::ZedProTrial => Button::new(
669                            "try-pro",
670                            if plan == Plan::ZedProTrial {
671                                "Upgrade to Pro"
672                            } else {
673                                "Try Pro"
674                            },
675                        )
676                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
677                    })
678                })
679                .child(
680                    Button::new("configure", "Configure")
681                        .icon(IconName::Settings)
682                        .icon_size(IconSize::Small)
683                        .icon_color(Color::Muted)
684                        .icon_position(IconPosition::Start)
685                        .on_click(|_, window, cx| {
686                            window.dispatch_action(
687                                zed_actions::agent::OpenConfiguration.boxed_clone(),
688                                cx,
689                            );
690                        }),
691                )
692                .into_any(),
693        )
694    }
695}
696
697#[cfg(test)]
698mod tests {
699    use super::*;
700    use futures::{future::BoxFuture, stream::BoxStream};
701    use gpui::{AsyncApp, TestAppContext, http_client};
702    use language_model::{
703        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
704        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
705        LanguageModelRequest, LanguageModelToolChoice,
706    };
707    use ui::IconName;
708
709    #[derive(Clone)]
710    struct TestLanguageModel {
711        name: LanguageModelName,
712        id: LanguageModelId,
713        provider_id: LanguageModelProviderId,
714        provider_name: LanguageModelProviderName,
715    }
716
717    impl TestLanguageModel {
718        fn new(name: &str, provider: &str) -> Self {
719            Self {
720                name: LanguageModelName::from(name.to_string()),
721                id: LanguageModelId::from(name.to_string()),
722                provider_id: LanguageModelProviderId::from(provider.to_string()),
723                provider_name: LanguageModelProviderName::from(provider.to_string()),
724            }
725        }
726    }
727
728    impl LanguageModel for TestLanguageModel {
729        fn id(&self) -> LanguageModelId {
730            self.id.clone()
731        }
732
733        fn name(&self) -> LanguageModelName {
734            self.name.clone()
735        }
736
737        fn provider_id(&self) -> LanguageModelProviderId {
738            self.provider_id.clone()
739        }
740
741        fn provider_name(&self) -> LanguageModelProviderName {
742            self.provider_name.clone()
743        }
744
745        fn supports_tools(&self) -> bool {
746            false
747        }
748
749        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
750            false
751        }
752
753        fn supports_images(&self) -> bool {
754            false
755        }
756
757        fn telemetry_id(&self) -> String {
758            format!("{}/{}", self.provider_id.0, self.name.0)
759        }
760
761        fn max_token_count(&self) -> usize {
762            1000
763        }
764
765        fn count_tokens(
766            &self,
767            _: LanguageModelRequest,
768            _: &App,
769        ) -> BoxFuture<'static, http_client::Result<usize>> {
770            unimplemented!()
771        }
772
773        fn stream_completion(
774            &self,
775            _: LanguageModelRequest,
776            _: &AsyncApp,
777        ) -> BoxFuture<
778            'static,
779            http_client::Result<
780                BoxStream<
781                    'static,
782                    http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
783                >,
784            >,
785        > {
786            unimplemented!()
787        }
788    }
789
790    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
791        model_specs
792            .into_iter()
793            .map(|(provider, name)| ModelInfo {
794                model: Arc::new(TestLanguageModel::new(name, provider)),
795                icon: IconName::Ai,
796            })
797            .collect()
798    }
799
800    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
801        assert_eq!(
802            result.len(),
803            expected.len(),
804            "Number of models doesn't match"
805        );
806
807        for (i, expected_name) in expected.iter().enumerate() {
808            assert_eq!(
809                result[i].model.telemetry_id(),
810                *expected_name,
811                "Model at position {} doesn't match expected model",
812                i
813            );
814        }
815    }
816
817    #[gpui::test]
818    fn test_exact_match(cx: &mut TestAppContext) {
819        let models = create_models(vec![
820            ("zed", "Claude 3.7 Sonnet"),
821            ("zed", "Claude 3.7 Sonnet Thinking"),
822            ("zed", "gpt-4.1"),
823            ("zed", "gpt-4.1-nano"),
824            ("openai", "gpt-3.5-turbo"),
825            ("openai", "gpt-4.1"),
826            ("openai", "gpt-4.1-nano"),
827            ("ollama", "mistral"),
828            ("ollama", "deepseek"),
829        ]);
830        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
831
832        // The order of models should be maintained, case doesn't matter
833        let results = matcher.exact_search("GPT-4.1");
834        assert_models_eq(
835            results,
836            vec![
837                "zed/gpt-4.1",
838                "zed/gpt-4.1-nano",
839                "openai/gpt-4.1",
840                "openai/gpt-4.1-nano",
841            ],
842        );
843    }
844
845    #[gpui::test]
846    fn test_fuzzy_match(cx: &mut TestAppContext) {
847        let models = create_models(vec![
848            ("zed", "Claude 3.7 Sonnet"),
849            ("zed", "Claude 3.7 Sonnet Thinking"),
850            ("zed", "gpt-4.1"),
851            ("zed", "gpt-4.1-nano"),
852            ("openai", "gpt-3.5-turbo"),
853            ("openai", "gpt-4.1"),
854            ("openai", "gpt-4.1-nano"),
855            ("ollama", "mistral"),
856            ("ollama", "deepseek"),
857        ]);
858        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
859
860        // Results should preserve models order whenever possible.
861        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
862        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
863        // so it should appear first in the results.
864        let results = matcher.fuzzy_search("41");
865        assert_models_eq(
866            results,
867            vec![
868                "zed/gpt-4.1",
869                "openai/gpt-4.1",
870                "zed/gpt-4.1-nano",
871                "openai/gpt-4.1-nano",
872            ],
873        );
874
875        // Model provider should be searchable as well
876        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
877        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
878
879        // Fuzzy search
880        let results = matcher.fuzzy_search("z4n");
881        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
882    }
883
884    #[gpui::test]
885    fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
886        let recommended_models = create_models(vec![("zed", "claude")]);
887        let all_models = create_models(vec![
888            ("zed", "claude"), // Should be filtered out from "other"
889            ("zed", "gemini"),
890            ("copilot", "o3"),
891        ]);
892
893        let grouped_models = GroupedModels::new(all_models, recommended_models);
894
895        let actual_other_models = grouped_models
896            .other
897            .values()
898            .flatten()
899            .cloned()
900            .collect::<Vec<_>>();
901
902        // Recommended models should not appear in "other"
903        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
904    }
905
906    #[gpui::test]
907    fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
908        let recommended_models = create_models(vec![("zed", "claude")]);
909        let all_models = create_models(vec![
910            ("zed", "claude"), // Should be filtered out from "other"
911            ("zed", "gemini"),
912            ("copilot", "claude"), // Should not be filtered out from "other"
913        ]);
914
915        let grouped_models = GroupedModels::new(all_models, recommended_models);
916
917        let actual_other_models = grouped_models
918            .other
919            .values()
920            .flatten()
921            .cloned()
922            .collect::<Vec<_>>();
923
924        // Recommended models should not appear in "other"
925        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
926    }
927}