language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::{HashSet, IndexMap};
  4use feature_flags::ZedProFeatureFlag;
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{
  7    Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
  8    EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
  9    action_with_deprecated_aliases,
 10};
 11use language_model::{
 12    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 13    LanguageModelRegistry,
 14};
 15use ordered_float::OrderedFloat;
 16use picker::{Picker, PickerDelegate};
 17use proto::Plan;
 18use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
 19
 20action_with_deprecated_aliases!(
 21    agent,
 22    ToggleModelSelector,
 23    [
 24        "assistant::ToggleModelSelector",
 25        "assistant2::ToggleModelSelector"
 26    ]
 27);
 28
 29const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 30
 31type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 32type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 33
 34pub struct LanguageModelSelector {
 35    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 36    _authenticate_all_providers_task: Task<()>,
 37    _subscriptions: Vec<Subscription>,
 38}
 39
 40impl LanguageModelSelector {
 41    pub fn new(
 42        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 43        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 44        window: &mut Window,
 45        cx: &mut Context<Self>,
 46    ) -> Self {
 47        let on_model_changed = Arc::new(on_model_changed);
 48
 49        let all_models = Self::all_models(cx);
 50        let entries = all_models.entries();
 51
 52        let delegate = LanguageModelPickerDelegate {
 53            language_model_selector: cx.entity().downgrade(),
 54            on_model_changed: on_model_changed.clone(),
 55            all_models: Arc::new(all_models),
 56            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
 57            filtered_entries: entries,
 58            get_active_model: Arc::new(get_active_model),
 59        };
 60
 61        let picker = cx.new(|cx| {
 62            Picker::list(delegate, window, cx)
 63                .show_scrollbar(true)
 64                .width(rems(20.))
 65                .max_height(Some(rems(20.).into()))
 66        });
 67
 68        let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
 69
 70        LanguageModelSelector {
 71            picker,
 72            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 73            _subscriptions: vec![
 74                cx.subscribe_in(
 75                    &LanguageModelRegistry::global(cx),
 76                    window,
 77                    Self::handle_language_model_registry_event,
 78                ),
 79                subscription,
 80            ],
 81        }
 82    }
 83
 84    fn handle_language_model_registry_event(
 85        &mut self,
 86        _registry: &Entity<LanguageModelRegistry>,
 87        event: &language_model::Event,
 88        window: &mut Window,
 89        cx: &mut Context<Self>,
 90    ) {
 91        match event {
 92            language_model::Event::ProviderStateChanged
 93            | language_model::Event::AddedProvider(_)
 94            | language_model::Event::RemovedProvider(_) => {
 95                self.picker.update(cx, |this, cx| {
 96                    let query = this.query(cx);
 97                    this.delegate.all_models = Arc::new(Self::all_models(cx));
 98                    // Update matches will automatically drop the previous task
 99                    // if we get a provider event again
100                    this.update_matches(query, window, cx)
101                });
102            }
103            _ => {}
104        }
105    }
106
107    /// Authenticates all providers in the [`LanguageModelRegistry`].
108    ///
109    /// We do this so that we can populate the language selector with all of the
110    /// models from the configured providers.
111    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
112        let authenticate_all_providers = LanguageModelRegistry::global(cx)
113            .read(cx)
114            .providers()
115            .iter()
116            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
117            .collect::<Vec<_>>();
118
119        cx.spawn(async move |_cx| {
120            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
121                if let Err(err) = authenticate_task.await {
122                    if matches!(err, AuthenticateError::CredentialsNotFound) {
123                        // Since we're authenticating these providers in the
124                        // background for the purposes of populating the
125                        // language selector, we don't care about providers
126                        // where the credentials are not found.
127                    } else {
128                        // Some providers have noisy failure states that we
129                        // don't want to spam the logs with every time the
130                        // language model selector is initialized.
131                        //
132                        // Ideally these should have more clear failure modes
133                        // that we know are safe to ignore here, like what we do
134                        // with `CredentialsNotFound` above.
135                        match provider_id.0.as_ref() {
136                            "lmstudio" | "ollama" => {
137                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
138                                //
139                                // These fail noisily, so we don't log them.
140                            }
141                            "copilot_chat" => {
142                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
143                            }
144                            _ => {
145                                log::error!(
146                                    "Failed to authenticate provider: {}: {err}",
147                                    provider_name.0
148                                );
149                            }
150                        }
151                    }
152                }
153            }
154        })
155    }
156
157    fn all_models(cx: &App) -> GroupedModels {
158        let mut recommended = Vec::new();
159        let mut recommended_set = HashSet::default();
160        for provider in LanguageModelRegistry::global(cx)
161            .read(cx)
162            .providers()
163            .iter()
164        {
165            let models = provider.recommended_models(cx);
166            recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
167            recommended.extend(
168                provider
169                    .recommended_models(cx)
170                    .into_iter()
171                    .map(move |model| ModelInfo {
172                        model: model.clone(),
173                        icon: provider.icon(),
174                    }),
175            );
176        }
177
178        let other_models = LanguageModelRegistry::global(cx)
179            .read(cx)
180            .providers()
181            .iter()
182            .map(|provider| {
183                (
184                    provider.id(),
185                    provider
186                        .provided_models(cx)
187                        .into_iter()
188                        .filter_map(|model| {
189                            let not_included =
190                                !recommended_set.contains(&(model.provider_id(), model.id()));
191                            not_included.then(|| ModelInfo {
192                                model: model.clone(),
193                                icon: provider.icon(),
194                            })
195                        })
196                        .collect::<Vec<_>>(),
197                )
198            })
199            .collect::<IndexMap<_, _>>();
200
201        GroupedModels {
202            recommended,
203            other: other_models,
204        }
205    }
206
207    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
208        (self.picker.read(cx).delegate.get_active_model)(cx)
209    }
210
211    fn get_active_model_index(
212        entries: &[LanguageModelPickerEntry],
213        active_model: Option<ConfiguredModel>,
214    ) -> usize {
215        entries
216            .iter()
217            .position(|entry| {
218                if let LanguageModelPickerEntry::Model(model) = entry {
219                    active_model
220                        .as_ref()
221                        .map(|active_model| {
222                            active_model.model.id() == model.model.id()
223                                && active_model.provider.id() == model.model.provider_id()
224                        })
225                        .unwrap_or_default()
226                } else {
227                    false
228                }
229            })
230            .unwrap_or(0)
231    }
232}
233
234impl EventEmitter<DismissEvent> for LanguageModelSelector {}
235
236impl Focusable for LanguageModelSelector {
237    fn focus_handle(&self, cx: &App) -> FocusHandle {
238        self.picker.focus_handle(cx)
239    }
240}
241
242impl Render for LanguageModelSelector {
243    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
244        self.picker.clone()
245    }
246}
247
248#[derive(IntoElement)]
249pub struct LanguageModelSelectorPopoverMenu<T, TT>
250where
251    T: PopoverTrigger + ButtonCommon,
252    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
253{
254    language_model_selector: Entity<LanguageModelSelector>,
255    trigger: T,
256    tooltip: TT,
257    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
258    anchor: Corner,
259}
260
261impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
262where
263    T: PopoverTrigger + ButtonCommon,
264    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
265{
266    pub fn new(
267        language_model_selector: Entity<LanguageModelSelector>,
268        trigger: T,
269        tooltip: TT,
270        anchor: Corner,
271    ) -> Self {
272        Self {
273            language_model_selector,
274            trigger,
275            tooltip,
276            handle: None,
277            anchor,
278        }
279    }
280
281    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
282        self.handle = Some(handle);
283        self
284    }
285}
286
287impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
288where
289    T: PopoverTrigger + ButtonCommon,
290    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
291{
292    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
293        let language_model_selector = self.language_model_selector.clone();
294
295        PopoverMenu::new("model-switcher")
296            .menu(move |_window, _cx| Some(language_model_selector.clone()))
297            .trigger_with_tooltip(self.trigger, self.tooltip)
298            .anchor(self.anchor)
299            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
300            .offset(gpui::Point {
301                x: px(0.0),
302                y: px(-2.0),
303            })
304    }
305}
306
307#[derive(Clone)]
308struct ModelInfo {
309    model: Arc<dyn LanguageModel>,
310    icon: IconName,
311}
312
313pub struct LanguageModelPickerDelegate {
314    language_model_selector: WeakEntity<LanguageModelSelector>,
315    on_model_changed: OnModelChanged,
316    get_active_model: GetActiveModel,
317    all_models: Arc<GroupedModels>,
318    filtered_entries: Vec<LanguageModelPickerEntry>,
319    selected_index: usize,
320}
321
322struct GroupedModels {
323    recommended: Vec<ModelInfo>,
324    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
325}
326
327impl GroupedModels {
328    pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
329        let recommended_ids: HashSet<_> = recommended.iter().map(|info| info.model.id()).collect();
330
331        let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
332        for model in other {
333            if recommended_ids.contains(&model.model.id()) {
334                continue;
335            }
336
337            let provider = model.model.provider_id();
338            if let Some(models) = other_by_provider.get_mut(&provider) {
339                models.push(model);
340            } else {
341                other_by_provider.insert(provider, vec![model]);
342            }
343        }
344
345        Self {
346            recommended,
347            other: other_by_provider,
348        }
349    }
350
351    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
352        let mut entries = Vec::new();
353
354        if !self.recommended.is_empty() {
355            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
356            entries.extend(
357                self.recommended
358                    .iter()
359                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
360            );
361        }
362
363        for models in self.other.values() {
364            if models.is_empty() {
365                continue;
366            }
367            entries.push(LanguageModelPickerEntry::Separator(
368                models[0].model.provider_name().0,
369            ));
370            entries.extend(
371                models
372                    .iter()
373                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
374            );
375        }
376        entries
377    }
378
379    fn model_infos(&self) -> Vec<ModelInfo> {
380        let other = self
381            .other
382            .values()
383            .flat_map(|model| model.iter())
384            .cloned()
385            .collect::<Vec<_>>();
386        self.recommended
387            .iter()
388            .chain(&other)
389            .cloned()
390            .collect::<Vec<_>>()
391    }
392}
393
394enum LanguageModelPickerEntry {
395    Model(ModelInfo),
396    Separator(SharedString),
397}
398
399struct ModelMatcher {
400    models: Vec<ModelInfo>,
401    bg_executor: BackgroundExecutor,
402    candidates: Vec<StringMatchCandidate>,
403}
404
405impl ModelMatcher {
406    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
407        let candidates = Self::make_match_candidates(&models);
408        Self {
409            models,
410            bg_executor,
411            candidates,
412        }
413    }
414
415    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
416        let mut matches = self.bg_executor.block(match_strings(
417            &self.candidates,
418            &query,
419            false,
420            100,
421            &Default::default(),
422            self.bg_executor.clone(),
423        ));
424
425        let sorting_key = |mat: &StringMatch| {
426            let candidate = &self.candidates[mat.candidate_id];
427            (Reverse(OrderedFloat(mat.score)), candidate.id)
428        };
429        matches.sort_unstable_by_key(sorting_key);
430
431        let matched_models: Vec<_> = matches
432            .into_iter()
433            .map(|mat| self.models[mat.candidate_id].clone())
434            .collect();
435
436        matched_models
437    }
438
439    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
440        self.models
441            .iter()
442            .filter(|m| {
443                m.model
444                    .name()
445                    .0
446                    .to_lowercase()
447                    .contains(&query.to_lowercase())
448            })
449            .cloned()
450            .collect::<Vec<_>>()
451    }
452
453    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
454        model_infos
455            .iter()
456            .enumerate()
457            .map(|(index, model)| {
458                StringMatchCandidate::new(
459                    index,
460                    &format!(
461                        "{}/{}",
462                        &model.model.provider_name().0,
463                        &model.model.name().0
464                    ),
465                )
466            })
467            .collect::<Vec<_>>()
468    }
469}
470
471impl PickerDelegate for LanguageModelPickerDelegate {
472    type ListItem = AnyElement;
473
474    fn match_count(&self) -> usize {
475        self.filtered_entries.len()
476    }
477
478    fn selected_index(&self) -> usize {
479        self.selected_index
480    }
481
482    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
483        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
484        cx.notify();
485    }
486
487    fn can_select(
488        &mut self,
489        ix: usize,
490        _window: &mut Window,
491        _cx: &mut Context<Picker<Self>>,
492    ) -> bool {
493        match self.filtered_entries.get(ix) {
494            Some(LanguageModelPickerEntry::Model(_)) => true,
495            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
496        }
497    }
498
499    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
500        "Select a model…".into()
501    }
502
503    fn update_matches(
504        &mut self,
505        query: String,
506        window: &mut Window,
507        cx: &mut Context<Picker<Self>>,
508    ) -> Task<()> {
509        let all_models = self.all_models.clone();
510        let current_index = self.selected_index;
511        let bg_executor = cx.background_executor();
512
513        let language_model_registry = LanguageModelRegistry::global(cx);
514
515        let configured_providers = language_model_registry
516            .read(cx)
517            .providers()
518            .into_iter()
519            .filter(|provider| provider.is_authenticated(cx))
520            .collect::<Vec<_>>();
521
522        let configured_provider_ids = configured_providers
523            .iter()
524            .map(|provider| provider.id())
525            .collect::<Vec<_>>();
526
527        let recommended_models = all_models
528            .recommended
529            .iter()
530            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
531            .cloned()
532            .collect::<Vec<_>>();
533
534        let available_models = all_models
535            .model_infos()
536            .iter()
537            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
538            .cloned()
539            .collect::<Vec<_>>();
540
541        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
542        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
543
544        let recommended = matcher_rec.exact_search(&query);
545        let all = matcher_all.fuzzy_search(&query);
546
547        let filtered_models = GroupedModels::new(all, recommended);
548
549        cx.spawn_in(window, async move |this, cx| {
550            this.update_in(cx, |this, window, cx| {
551                this.delegate.filtered_entries = filtered_models.entries();
552                // Preserve selection focus
553                let new_index = if current_index >= this.delegate.filtered_entries.len() {
554                    0
555                } else {
556                    current_index
557                };
558                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
559                cx.notify();
560            })
561            .ok();
562        })
563    }
564
565    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
566        if let Some(LanguageModelPickerEntry::Model(model_info)) =
567            self.filtered_entries.get(self.selected_index)
568        {
569            let model = model_info.model.clone();
570            (self.on_model_changed)(model.clone(), cx);
571
572            let current_index = self.selected_index;
573            self.set_selected_index(current_index, window, cx);
574
575            cx.emit(DismissEvent);
576        }
577    }
578
579    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
580        self.language_model_selector
581            .update(cx, |_this, cx| cx.emit(DismissEvent))
582            .ok();
583    }
584
585    fn render_match(
586        &self,
587        ix: usize,
588        selected: bool,
589        _: &mut Window,
590        cx: &mut Context<Picker<Self>>,
591    ) -> Option<Self::ListItem> {
592        match self.filtered_entries.get(ix)? {
593            LanguageModelPickerEntry::Separator(title) => Some(
594                div()
595                    .px_2()
596                    .pb_1()
597                    .when(ix > 1, |this| {
598                        this.mt_1()
599                            .pt_2()
600                            .border_t_1()
601                            .border_color(cx.theme().colors().border_variant)
602                    })
603                    .child(
604                        Label::new(title)
605                            .size(LabelSize::XSmall)
606                            .color(Color::Muted),
607                    )
608                    .into_any_element(),
609            ),
610            LanguageModelPickerEntry::Model(model_info) => {
611                let active_model = (self.get_active_model)(cx);
612                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
613                let active_model_id = active_model.map(|m| m.model.id());
614
615                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
616                    && Some(model_info.model.id()) == active_model_id;
617
618                let model_icon_color = if is_selected {
619                    Color::Accent
620                } else {
621                    Color::Muted
622                };
623
624                Some(
625                    ListItem::new(ix)
626                        .inset(true)
627                        .spacing(ListItemSpacing::Sparse)
628                        .toggle_state(selected)
629                        .start_slot(
630                            Icon::new(model_info.icon)
631                                .color(model_icon_color)
632                                .size(IconSize::Small),
633                        )
634                        .child(
635                            h_flex()
636                                .w_full()
637                                .pl_0p5()
638                                .gap_1p5()
639                                .w(px(240.))
640                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
641                        )
642                        .end_slot(div().pr_3().when(is_selected, |this| {
643                            this.child(
644                                Icon::new(IconName::Check)
645                                    .color(Color::Accent)
646                                    .size(IconSize::Small),
647                            )
648                        }))
649                        .into_any_element(),
650                )
651            }
652        }
653    }
654
655    fn render_footer(
656        &self,
657        _: &mut Window,
658        cx: &mut Context<Picker<Self>>,
659    ) -> Option<gpui::AnyElement> {
660        use feature_flags::FeatureFlagAppExt;
661
662        let plan = proto::Plan::ZedPro;
663
664        Some(
665            h_flex()
666                .w_full()
667                .border_t_1()
668                .border_color(cx.theme().colors().border_variant)
669                .p_1()
670                .gap_4()
671                .justify_between()
672                .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
673                    this.child(match plan {
674                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
675                            .icon(IconName::ZedAssistant)
676                            .icon_size(IconSize::Small)
677                            .icon_color(Color::Muted)
678                            .icon_position(IconPosition::Start)
679                            .on_click(|_, window, cx| {
680                                window
681                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
682                            }),
683                        Plan::Free | Plan::ZedProTrial => Button::new(
684                            "try-pro",
685                            if plan == Plan::ZedProTrial {
686                                "Upgrade to Pro"
687                            } else {
688                                "Try Pro"
689                            },
690                        )
691                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
692                    })
693                })
694                .child(
695                    Button::new("configure", "Configure")
696                        .icon(IconName::Settings)
697                        .icon_size(IconSize::Small)
698                        .icon_color(Color::Muted)
699                        .icon_position(IconPosition::Start)
700                        .on_click(|_, window, cx| {
701                            window.dispatch_action(
702                                zed_actions::agent::OpenConfiguration.boxed_clone(),
703                                cx,
704                            );
705                        }),
706                )
707                .into_any(),
708        )
709    }
710}
711
712#[cfg(test)]
713mod tests {
714    use super::*;
715    use futures::{future::BoxFuture, stream::BoxStream};
716    use gpui::{AsyncApp, TestAppContext, http_client};
717    use language_model::{
718        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
719        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
720        LanguageModelRequest, LanguageModelToolChoice,
721    };
722    use ui::IconName;
723
724    #[derive(Clone)]
725    struct TestLanguageModel {
726        name: LanguageModelName,
727        id: LanguageModelId,
728        provider_id: LanguageModelProviderId,
729        provider_name: LanguageModelProviderName,
730    }
731
732    impl TestLanguageModel {
733        fn new(name: &str, provider: &str) -> Self {
734            Self {
735                name: LanguageModelName::from(name.to_string()),
736                id: LanguageModelId::from(name.to_string()),
737                provider_id: LanguageModelProviderId::from(provider.to_string()),
738                provider_name: LanguageModelProviderName::from(provider.to_string()),
739            }
740        }
741    }
742
743    impl LanguageModel for TestLanguageModel {
744        fn id(&self) -> LanguageModelId {
745            self.id.clone()
746        }
747
748        fn name(&self) -> LanguageModelName {
749            self.name.clone()
750        }
751
752        fn provider_id(&self) -> LanguageModelProviderId {
753            self.provider_id.clone()
754        }
755
756        fn provider_name(&self) -> LanguageModelProviderName {
757            self.provider_name.clone()
758        }
759
760        fn supports_tools(&self) -> bool {
761            false
762        }
763
764        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
765            false
766        }
767
768        fn supports_images(&self) -> bool {
769            false
770        }
771
772        fn telemetry_id(&self) -> String {
773            format!("{}/{}", self.provider_id.0, self.name.0)
774        }
775
776        fn max_token_count(&self) -> usize {
777            1000
778        }
779
780        fn count_tokens(
781            &self,
782            _: LanguageModelRequest,
783            _: &App,
784        ) -> BoxFuture<'static, http_client::Result<usize>> {
785            unimplemented!()
786        }
787
788        fn stream_completion(
789            &self,
790            _: LanguageModelRequest,
791            _: &AsyncApp,
792        ) -> BoxFuture<
793            'static,
794            http_client::Result<
795                BoxStream<
796                    'static,
797                    http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
798                >,
799            >,
800        > {
801            unimplemented!()
802        }
803    }
804
805    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
806        model_specs
807            .into_iter()
808            .map(|(provider, name)| ModelInfo {
809                model: Arc::new(TestLanguageModel::new(name, provider)),
810                icon: IconName::Ai,
811            })
812            .collect()
813    }
814
815    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
816        assert_eq!(
817            result.len(),
818            expected.len(),
819            "Number of models doesn't match"
820        );
821
822        for (i, expected_name) in expected.iter().enumerate() {
823            assert_eq!(
824                result[i].model.telemetry_id(),
825                *expected_name,
826                "Model at position {} doesn't match expected model",
827                i
828            );
829        }
830    }
831
832    #[gpui::test]
833    fn test_exact_match(cx: &mut TestAppContext) {
834        let models = create_models(vec![
835            ("zed", "Claude 3.7 Sonnet"),
836            ("zed", "Claude 3.7 Sonnet Thinking"),
837            ("zed", "gpt-4.1"),
838            ("zed", "gpt-4.1-nano"),
839            ("openai", "gpt-3.5-turbo"),
840            ("openai", "gpt-4.1"),
841            ("openai", "gpt-4.1-nano"),
842            ("ollama", "mistral"),
843            ("ollama", "deepseek"),
844        ]);
845        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
846
847        // The order of models should be maintained, case doesn't matter
848        let results = matcher.exact_search("GPT-4.1");
849        assert_models_eq(
850            results,
851            vec![
852                "zed/gpt-4.1",
853                "zed/gpt-4.1-nano",
854                "openai/gpt-4.1",
855                "openai/gpt-4.1-nano",
856            ],
857        );
858    }
859
860    #[gpui::test]
861    fn test_fuzzy_match(cx: &mut TestAppContext) {
862        let models = create_models(vec![
863            ("zed", "Claude 3.7 Sonnet"),
864            ("zed", "Claude 3.7 Sonnet Thinking"),
865            ("zed", "gpt-4.1"),
866            ("zed", "gpt-4.1-nano"),
867            ("openai", "gpt-3.5-turbo"),
868            ("openai", "gpt-4.1"),
869            ("openai", "gpt-4.1-nano"),
870            ("ollama", "mistral"),
871            ("ollama", "deepseek"),
872        ]);
873        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
874
875        // Results should preserve models order whenever possible.
876        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
877        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
878        // so it should appear first in the results.
879        let results = matcher.fuzzy_search("41");
880        assert_models_eq(
881            results,
882            vec![
883                "zed/gpt-4.1",
884                "openai/gpt-4.1",
885                "zed/gpt-4.1-nano",
886                "openai/gpt-4.1-nano",
887            ],
888        );
889
890        // Model provider should be searchable as well
891        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
892        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
893
894        // Fuzzy search
895        let results = matcher.fuzzy_search("z4n");
896        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
897    }
898
899    #[gpui::test]
900    fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
901        let recommended_models = create_models(vec![("zed", "claude")]);
902        let all_models = create_models(vec![
903            ("zed", "claude"), // Should be filtered out from "other"
904            ("zed", "gemini"),
905            ("copilot", "o3"),
906        ]);
907
908        let grouped_models = GroupedModels::new(all_models, recommended_models);
909
910        let actual_other_models = grouped_models
911            .other
912            .values()
913            .flatten()
914            .cloned()
915            .collect::<Vec<_>>();
916
917        // Recommended models should not appear in "other"
918        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
919    }
920}