language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::{HashSet, IndexMap};
  4use feature_flags::ZedProFeatureFlag;
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{
  7    Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
  8    EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
  9    action_with_deprecated_aliases,
 10};
 11use language_model::{
 12    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 13    LanguageModelRegistry,
 14};
 15use ordered_float::OrderedFloat;
 16use picker::{Picker, PickerDelegate};
 17use proto::Plan;
 18use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
 19
 20action_with_deprecated_aliases!(
 21    agent,
 22    ToggleModelSelector,
 23    [
 24        "assistant::ToggleModelSelector",
 25        "assistant2::ToggleModelSelector"
 26    ]
 27);
 28
 29const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 30
 31type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 32type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 33
 34pub struct LanguageModelSelector {
 35    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 36    _authenticate_all_providers_task: Task<()>,
 37    _subscriptions: Vec<Subscription>,
 38}
 39
 40impl LanguageModelSelector {
 41    pub fn new(
 42        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 43        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 44        window: &mut Window,
 45        cx: &mut Context<Self>,
 46    ) -> Self {
 47        let on_model_changed = Arc::new(on_model_changed);
 48
 49        let all_models = Self::all_models(cx);
 50        let entries = all_models.entries();
 51
 52        let delegate = LanguageModelPickerDelegate {
 53            language_model_selector: cx.entity().downgrade(),
 54            on_model_changed: on_model_changed.clone(),
 55            all_models: Arc::new(all_models),
 56            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
 57            filtered_entries: entries,
 58            get_active_model: Arc::new(get_active_model),
 59        };
 60
 61        let picker = cx.new(|cx| {
 62            Picker::list(delegate, window, cx)
 63                .show_scrollbar(true)
 64                .width(rems(20.))
 65                .max_height(Some(rems(20.).into()))
 66        });
 67
 68        let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
 69
 70        LanguageModelSelector {
 71            picker,
 72            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
 73            _subscriptions: vec![
 74                cx.subscribe_in(
 75                    &LanguageModelRegistry::global(cx),
 76                    window,
 77                    Self::handle_language_model_registry_event,
 78                ),
 79                subscription,
 80            ],
 81        }
 82    }
 83
 84    fn handle_language_model_registry_event(
 85        &mut self,
 86        _registry: &Entity<LanguageModelRegistry>,
 87        event: &language_model::Event,
 88        window: &mut Window,
 89        cx: &mut Context<Self>,
 90    ) {
 91        match event {
 92            language_model::Event::ProviderStateChanged
 93            | language_model::Event::AddedProvider(_)
 94            | language_model::Event::RemovedProvider(_) => {
 95                self.picker.update(cx, |this, cx| {
 96                    let query = this.query(cx);
 97                    this.delegate.all_models = Arc::new(Self::all_models(cx));
 98                    // Update matches will automatically drop the previous task
 99                    // if we get a provider event again
100                    this.update_matches(query, window, cx)
101                });
102            }
103            _ => {}
104        }
105    }
106
107    /// Authenticates all providers in the [`LanguageModelRegistry`].
108    ///
109    /// We do this so that we can populate the language selector with all of the
110    /// models from the configured providers.
111    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
112        let authenticate_all_providers = LanguageModelRegistry::global(cx)
113            .read(cx)
114            .providers()
115            .iter()
116            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
117            .collect::<Vec<_>>();
118
119        cx.spawn(async move |_cx| {
120            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
121                if let Err(err) = authenticate_task.await {
122                    if matches!(err, AuthenticateError::CredentialsNotFound) {
123                        // Since we're authenticating these providers in the
124                        // background for the purposes of populating the
125                        // language selector, we don't care about providers
126                        // where the credentials are not found.
127                    } else {
128                        // Some providers have noisy failure states that we
129                        // don't want to spam the logs with every time the
130                        // language model selector is initialized.
131                        //
132                        // Ideally these should have more clear failure modes
133                        // that we know are safe to ignore here, like what we do
134                        // with `CredentialsNotFound` above.
135                        match provider_id.0.as_ref() {
136                            "lmstudio" | "ollama" => {
137                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
138                                //
139                                // These fail noisily, so we don't log them.
140                            }
141                            "copilot_chat" => {
142                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
143                            }
144                            _ => {
145                                log::error!(
146                                    "Failed to authenticate provider: {}: {err}",
147                                    provider_name.0
148                                );
149                            }
150                        }
151                    }
152                }
153            }
154        })
155    }
156
157    fn all_models(cx: &App) -> GroupedModels {
158        let mut recommended = Vec::new();
159        let mut recommended_set = HashSet::default();
160        for provider in LanguageModelRegistry::global(cx)
161            .read(cx)
162            .providers()
163            .iter()
164        {
165            let models = provider.recommended_models(cx);
166            recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
167            recommended.extend(
168                provider
169                    .recommended_models(cx)
170                    .into_iter()
171                    .map(move |model| ModelInfo {
172                        model: model.clone(),
173                        icon: provider.icon(),
174                    }),
175            );
176        }
177
178        let other_models = LanguageModelRegistry::global(cx)
179            .read(cx)
180            .providers()
181            .iter()
182            .map(|provider| {
183                (
184                    provider.id(),
185                    provider
186                        .provided_models(cx)
187                        .into_iter()
188                        .filter_map(|model| {
189                            let not_included =
190                                !recommended_set.contains(&(model.provider_id(), model.id()));
191                            not_included.then(|| ModelInfo {
192                                model: model.clone(),
193                                icon: provider.icon(),
194                            })
195                        })
196                        .collect::<Vec<_>>(),
197                )
198            })
199            .collect::<IndexMap<_, _>>();
200
201        GroupedModels {
202            recommended,
203            other: other_models,
204        }
205    }
206
207    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
208        (self.picker.read(cx).delegate.get_active_model)(cx)
209    }
210
211    fn get_active_model_index(
212        entries: &[LanguageModelPickerEntry],
213        active_model: Option<ConfiguredModel>,
214    ) -> usize {
215        entries
216            .iter()
217            .position(|entry| {
218                if let LanguageModelPickerEntry::Model(model) = entry {
219                    active_model
220                        .as_ref()
221                        .map(|active_model| {
222                            active_model.model.id() == model.model.id()
223                                && active_model.provider.id() == model.model.provider_id()
224                        })
225                        .unwrap_or_default()
226                } else {
227                    false
228                }
229            })
230            .unwrap_or(0)
231    }
232}
233
234impl EventEmitter<DismissEvent> for LanguageModelSelector {}
235
236impl Focusable for LanguageModelSelector {
237    fn focus_handle(&self, cx: &App) -> FocusHandle {
238        self.picker.focus_handle(cx)
239    }
240}
241
242impl Render for LanguageModelSelector {
243    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
244        self.picker.clone()
245    }
246}
247
248#[derive(IntoElement)]
249pub struct LanguageModelSelectorPopoverMenu<T, TT>
250where
251    T: PopoverTrigger + ButtonCommon,
252    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
253{
254    language_model_selector: Entity<LanguageModelSelector>,
255    trigger: T,
256    tooltip: TT,
257    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
258    anchor: Corner,
259}
260
261impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
262where
263    T: PopoverTrigger + ButtonCommon,
264    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
265{
266    pub fn new(
267        language_model_selector: Entity<LanguageModelSelector>,
268        trigger: T,
269        tooltip: TT,
270        anchor: Corner,
271    ) -> Self {
272        Self {
273            language_model_selector,
274            trigger,
275            tooltip,
276            handle: None,
277            anchor,
278        }
279    }
280
281    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
282        self.handle = Some(handle);
283        self
284    }
285}
286
287impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
288where
289    T: PopoverTrigger + ButtonCommon,
290    TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
291{
292    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
293        let language_model_selector = self.language_model_selector.clone();
294
295        PopoverMenu::new("model-switcher")
296            .menu(move |_window, _cx| Some(language_model_selector.clone()))
297            .trigger_with_tooltip(self.trigger, self.tooltip)
298            .anchor(self.anchor)
299            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
300            .offset(gpui::Point {
301                x: px(0.0),
302                y: px(-2.0),
303            })
304    }
305}
306
307#[derive(Clone)]
308struct ModelInfo {
309    model: Arc<dyn LanguageModel>,
310    icon: IconName,
311}
312
313pub struct LanguageModelPickerDelegate {
314    language_model_selector: WeakEntity<LanguageModelSelector>,
315    on_model_changed: OnModelChanged,
316    get_active_model: GetActiveModel,
317    all_models: Arc<GroupedModels>,
318    filtered_entries: Vec<LanguageModelPickerEntry>,
319    selected_index: usize,
320}
321
322struct GroupedModels {
323    recommended: Vec<ModelInfo>,
324    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
325}
326
327impl GroupedModels {
328    pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
329        let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
330        for model in other {
331            let provider = model.model.provider_id();
332            if let Some(models) = other_by_provider.get_mut(&provider) {
333                models.push(model);
334            } else {
335                other_by_provider.insert(provider, vec![model]);
336            }
337        }
338
339        Self {
340            recommended,
341            other: other_by_provider,
342        }
343    }
344
345    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
346        let mut entries = Vec::new();
347
348        if !self.recommended.is_empty() {
349            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
350            entries.extend(
351                self.recommended
352                    .iter()
353                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
354            );
355        }
356
357        for models in self.other.values() {
358            if models.is_empty() {
359                continue;
360            }
361            entries.push(LanguageModelPickerEntry::Separator(
362                models[0].model.provider_name().0,
363            ));
364            entries.extend(
365                models
366                    .iter()
367                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
368            );
369        }
370        entries
371    }
372
373    fn model_infos(&self) -> Vec<ModelInfo> {
374        let other = self
375            .other
376            .values()
377            .flat_map(|model| model.iter())
378            .cloned()
379            .collect::<Vec<_>>();
380        self.recommended
381            .iter()
382            .chain(&other)
383            .cloned()
384            .collect::<Vec<_>>()
385    }
386}
387
388enum LanguageModelPickerEntry {
389    Model(ModelInfo),
390    Separator(SharedString),
391}
392
393struct ModelMatcher {
394    models: Vec<ModelInfo>,
395    bg_executor: BackgroundExecutor,
396    candidates: Vec<StringMatchCandidate>,
397}
398
399impl ModelMatcher {
400    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
401        let candidates = Self::make_match_candidates(&models);
402        Self {
403            models,
404            bg_executor,
405            candidates,
406        }
407    }
408
409    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
410        let mut matches = self.bg_executor.block(match_strings(
411            &self.candidates,
412            &query,
413            false,
414            100,
415            &Default::default(),
416            self.bg_executor.clone(),
417        ));
418
419        let sorting_key = |mat: &StringMatch| {
420            let candidate = &self.candidates[mat.candidate_id];
421            (Reverse(OrderedFloat(mat.score)), candidate.id)
422        };
423        matches.sort_unstable_by_key(sorting_key);
424
425        let matched_models: Vec<_> = matches
426            .into_iter()
427            .map(|mat| self.models[mat.candidate_id].clone())
428            .collect();
429
430        matched_models
431    }
432
433    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
434        self.models
435            .iter()
436            .filter(|m| {
437                m.model
438                    .name()
439                    .0
440                    .to_lowercase()
441                    .contains(&query.to_lowercase())
442            })
443            .cloned()
444            .collect::<Vec<_>>()
445    }
446
447    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
448        model_infos
449            .iter()
450            .enumerate()
451            .map(|(index, model)| {
452                StringMatchCandidate::new(
453                    index,
454                    &format!(
455                        "{}/{}",
456                        &model.model.provider_name().0,
457                        &model.model.name().0
458                    ),
459                )
460            })
461            .collect::<Vec<_>>()
462    }
463}
464
465impl PickerDelegate for LanguageModelPickerDelegate {
466    type ListItem = AnyElement;
467
468    fn match_count(&self) -> usize {
469        self.filtered_entries.len()
470    }
471
472    fn selected_index(&self) -> usize {
473        self.selected_index
474    }
475
476    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
477        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
478        cx.notify();
479    }
480
481    fn can_select(
482        &mut self,
483        ix: usize,
484        _window: &mut Window,
485        _cx: &mut Context<Picker<Self>>,
486    ) -> bool {
487        match self.filtered_entries.get(ix) {
488            Some(LanguageModelPickerEntry::Model(_)) => true,
489            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
490        }
491    }
492
493    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
494        "Select a model…".into()
495    }
496
497    fn update_matches(
498        &mut self,
499        query: String,
500        window: &mut Window,
501        cx: &mut Context<Picker<Self>>,
502    ) -> Task<()> {
503        let all_models = self.all_models.clone();
504        let current_index = self.selected_index;
505        let bg_executor = cx.background_executor();
506
507        let language_model_registry = LanguageModelRegistry::global(cx);
508
509        let configured_providers = language_model_registry
510            .read(cx)
511            .providers()
512            .into_iter()
513            .filter(|provider| provider.is_authenticated(cx))
514            .collect::<Vec<_>>();
515
516        let configured_provider_ids = configured_providers
517            .iter()
518            .map(|provider| provider.id())
519            .collect::<Vec<_>>();
520
521        let recommended_models = all_models
522            .recommended
523            .iter()
524            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
525            .cloned()
526            .collect::<Vec<_>>();
527
528        let available_models = all_models
529            .model_infos()
530            .iter()
531            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
532            .cloned()
533            .collect::<Vec<_>>();
534
535        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
536        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
537
538        let recommended = matcher_rec.exact_search(&query);
539        let all = matcher_all.fuzzy_search(&query);
540
541        let filtered_models = GroupedModels::new(all, recommended);
542
543        cx.spawn_in(window, async move |this, cx| {
544            this.update_in(cx, |this, window, cx| {
545                this.delegate.filtered_entries = filtered_models.entries();
546                // Preserve selection focus
547                let new_index = if current_index >= this.delegate.filtered_entries.len() {
548                    0
549                } else {
550                    current_index
551                };
552                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
553                cx.notify();
554            })
555            .ok();
556        })
557    }
558
559    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
560        if let Some(LanguageModelPickerEntry::Model(model_info)) =
561            self.filtered_entries.get(self.selected_index)
562        {
563            let model = model_info.model.clone();
564            (self.on_model_changed)(model.clone(), cx);
565
566            let current_index = self.selected_index;
567            self.set_selected_index(current_index, window, cx);
568
569            cx.emit(DismissEvent);
570        }
571    }
572
573    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
574        self.language_model_selector
575            .update(cx, |_this, cx| cx.emit(DismissEvent))
576            .ok();
577    }
578
579    fn render_match(
580        &self,
581        ix: usize,
582        selected: bool,
583        _: &mut Window,
584        cx: &mut Context<Picker<Self>>,
585    ) -> Option<Self::ListItem> {
586        match self.filtered_entries.get(ix)? {
587            LanguageModelPickerEntry::Separator(title) => Some(
588                div()
589                    .px_2()
590                    .pb_1()
591                    .when(ix > 1, |this| {
592                        this.mt_1()
593                            .pt_2()
594                            .border_t_1()
595                            .border_color(cx.theme().colors().border_variant)
596                    })
597                    .child(
598                        Label::new(title)
599                            .size(LabelSize::XSmall)
600                            .color(Color::Muted),
601                    )
602                    .into_any_element(),
603            ),
604            LanguageModelPickerEntry::Model(model_info) => {
605                let active_model = (self.get_active_model)(cx);
606                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
607                let active_model_id = active_model.map(|m| m.model.id());
608
609                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
610                    && Some(model_info.model.id()) == active_model_id;
611
612                let model_icon_color = if is_selected {
613                    Color::Accent
614                } else {
615                    Color::Muted
616                };
617
618                Some(
619                    ListItem::new(ix)
620                        .inset(true)
621                        .spacing(ListItemSpacing::Sparse)
622                        .toggle_state(selected)
623                        .start_slot(
624                            Icon::new(model_info.icon)
625                                .color(model_icon_color)
626                                .size(IconSize::Small),
627                        )
628                        .child(
629                            h_flex()
630                                .w_full()
631                                .pl_0p5()
632                                .gap_1p5()
633                                .w(px(240.))
634                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
635                        )
636                        .end_slot(div().pr_3().when(is_selected, |this| {
637                            this.child(
638                                Icon::new(IconName::Check)
639                                    .color(Color::Accent)
640                                    .size(IconSize::Small),
641                            )
642                        }))
643                        .into_any_element(),
644                )
645            }
646        }
647    }
648
649    fn render_footer(
650        &self,
651        _: &mut Window,
652        cx: &mut Context<Picker<Self>>,
653    ) -> Option<gpui::AnyElement> {
654        use feature_flags::FeatureFlagAppExt;
655
656        let plan = proto::Plan::ZedPro;
657
658        Some(
659            h_flex()
660                .w_full()
661                .border_t_1()
662                .border_color(cx.theme().colors().border_variant)
663                .p_1()
664                .gap_4()
665                .justify_between()
666                .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
667                    this.child(match plan {
668                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
669                            .icon(IconName::ZedAssistant)
670                            .icon_size(IconSize::Small)
671                            .icon_color(Color::Muted)
672                            .icon_position(IconPosition::Start)
673                            .on_click(|_, window, cx| {
674                                window
675                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
676                            }),
677                        Plan::Free | Plan::ZedProTrial => Button::new(
678                            "try-pro",
679                            if plan == Plan::ZedProTrial {
680                                "Upgrade to Pro"
681                            } else {
682                                "Try Pro"
683                            },
684                        )
685                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
686                    })
687                })
688                .child(
689                    Button::new("configure", "Configure")
690                        .icon(IconName::Settings)
691                        .icon_size(IconSize::Small)
692                        .icon_color(Color::Muted)
693                        .icon_position(IconPosition::Start)
694                        .on_click(|_, window, cx| {
695                            window.dispatch_action(
696                                zed_actions::agent::OpenConfiguration.boxed_clone(),
697                                cx,
698                            );
699                        }),
700                )
701                .into_any(),
702        )
703    }
704}
705
706#[cfg(test)]
707mod tests {
708    use super::*;
709    use futures::{future::BoxFuture, stream::BoxStream};
710    use gpui::{AsyncApp, TestAppContext, http_client};
711    use language_model::{
712        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
713        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
714        LanguageModelRequest, LanguageModelToolChoice,
715    };
716    use ui::IconName;
717
718    #[derive(Clone)]
719    struct TestLanguageModel {
720        name: LanguageModelName,
721        id: LanguageModelId,
722        provider_id: LanguageModelProviderId,
723        provider_name: LanguageModelProviderName,
724    }
725
726    impl TestLanguageModel {
727        fn new(name: &str, provider: &str) -> Self {
728            Self {
729                name: LanguageModelName::from(name.to_string()),
730                id: LanguageModelId::from(name.to_string()),
731                provider_id: LanguageModelProviderId::from(provider.to_string()),
732                provider_name: LanguageModelProviderName::from(provider.to_string()),
733            }
734        }
735    }
736
737    impl LanguageModel for TestLanguageModel {
738        fn id(&self) -> LanguageModelId {
739            self.id.clone()
740        }
741
742        fn name(&self) -> LanguageModelName {
743            self.name.clone()
744        }
745
746        fn provider_id(&self) -> LanguageModelProviderId {
747            self.provider_id.clone()
748        }
749
750        fn provider_name(&self) -> LanguageModelProviderName {
751            self.provider_name.clone()
752        }
753
754        fn supports_tools(&self) -> bool {
755            false
756        }
757
758        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
759            false
760        }
761
762        fn supports_images(&self) -> bool {
763            false
764        }
765
766        fn telemetry_id(&self) -> String {
767            format!("{}/{}", self.provider_id.0, self.name.0)
768        }
769
770        fn max_token_count(&self) -> usize {
771            1000
772        }
773
774        fn count_tokens(
775            &self,
776            _: LanguageModelRequest,
777            _: &App,
778        ) -> BoxFuture<'static, http_client::Result<usize>> {
779            unimplemented!()
780        }
781
782        fn stream_completion(
783            &self,
784            _: LanguageModelRequest,
785            _: &AsyncApp,
786        ) -> BoxFuture<
787            'static,
788            http_client::Result<
789                BoxStream<
790                    'static,
791                    http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
792                >,
793            >,
794        > {
795            unimplemented!()
796        }
797    }
798
799    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
800        model_specs
801            .into_iter()
802            .map(|(provider, name)| ModelInfo {
803                model: Arc::new(TestLanguageModel::new(name, provider)),
804                icon: IconName::Ai,
805            })
806            .collect()
807    }
808
809    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
810        assert_eq!(
811            result.len(),
812            expected.len(),
813            "Number of models doesn't match"
814        );
815
816        for (i, expected_name) in expected.iter().enumerate() {
817            assert_eq!(
818                result[i].model.telemetry_id(),
819                *expected_name,
820                "Model at position {} doesn't match expected model",
821                i
822            );
823        }
824    }
825
826    #[gpui::test]
827    fn test_exact_match(cx: &mut TestAppContext) {
828        let models = create_models(vec![
829            ("zed", "Claude 3.7 Sonnet"),
830            ("zed", "Claude 3.7 Sonnet Thinking"),
831            ("zed", "gpt-4.1"),
832            ("zed", "gpt-4.1-nano"),
833            ("openai", "gpt-3.5-turbo"),
834            ("openai", "gpt-4.1"),
835            ("openai", "gpt-4.1-nano"),
836            ("ollama", "mistral"),
837            ("ollama", "deepseek"),
838        ]);
839        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
840
841        // The order of models should be maintained, case doesn't matter
842        let results = matcher.exact_search("GPT-4.1");
843        assert_models_eq(
844            results,
845            vec![
846                "zed/gpt-4.1",
847                "zed/gpt-4.1-nano",
848                "openai/gpt-4.1",
849                "openai/gpt-4.1-nano",
850            ],
851        );
852    }
853
854    #[gpui::test]
855    fn test_fuzzy_match(cx: &mut TestAppContext) {
856        let models = create_models(vec![
857            ("zed", "Claude 3.7 Sonnet"),
858            ("zed", "Claude 3.7 Sonnet Thinking"),
859            ("zed", "gpt-4.1"),
860            ("zed", "gpt-4.1-nano"),
861            ("openai", "gpt-3.5-turbo"),
862            ("openai", "gpt-4.1"),
863            ("openai", "gpt-4.1-nano"),
864            ("ollama", "mistral"),
865            ("ollama", "deepseek"),
866        ]);
867        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
868
869        // Results should preserve models order whenever possible.
870        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
871        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
872        // so it should appear first in the results.
873        let results = matcher.fuzzy_search("41");
874        assert_models_eq(
875            results,
876            vec![
877                "zed/gpt-4.1",
878                "openai/gpt-4.1",
879                "zed/gpt-4.1-nano",
880                "openai/gpt-4.1-nano",
881            ],
882        );
883
884        // Model provider should be searchable as well
885        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
886        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
887
888        // Fuzzy search
889        let results = matcher.fuzzy_search("z4n");
890        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
891    }
892}