language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::IndexMap;
  4use futures::{StreamExt, channel::mpsc};
  5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
  7use language_model::{
  8    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
  9    LanguageModelRegistry,
 10};
 11use ordered_float::OrderedFloat;
 12use picker::{Picker, PickerDelegate};
 13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
 14use zed_actions::agent::OpenSettings;
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 18
 19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 20
 21pub fn language_model_selector(
 22    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 23    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 24    popover_styles: bool,
 25    focus_handle: FocusHandle,
 26    window: &mut Window,
 27    cx: &mut Context<LanguageModelSelector>,
 28) -> LanguageModelSelector {
 29    let delegate = LanguageModelPickerDelegate::new(
 30        get_active_model,
 31        on_model_changed,
 32        popover_styles,
 33        focus_handle,
 34        window,
 35        cx,
 36    );
 37
 38    if popover_styles {
 39        Picker::list(delegate, window, cx)
 40            .show_scrollbar(true)
 41            .width(rems(20.))
 42            .max_height(Some(rems(20.).into()))
 43    } else {
 44        Picker::list(delegate, window, cx).show_scrollbar(true)
 45    }
 46}
 47
 48fn all_models(cx: &App) -> GroupedModels {
 49    let now = std::time::SystemTime::now()
 50        .duration_since(std::time::UNIX_EPOCH)
 51        .unwrap_or_default()
 52        .as_millis();
 53    eprintln!("[{}ms] all_models() called", now);
 54    let providers = LanguageModelRegistry::global(cx).read(cx).providers();
 55    eprintln!("[{}ms] all_models: got {} providers", now, providers.len());
 56
 57    let recommended = providers
 58        .iter()
 59        .flat_map(|provider| {
 60            provider
 61                .recommended_models(cx)
 62                .into_iter()
 63                .map(|model| ModelInfo {
 64                    model,
 65                    icon: provider.icon(),
 66                })
 67        })
 68        .collect();
 69
 70    let all: Vec<ModelInfo> = providers
 71        .iter()
 72        .flat_map(|provider| {
 73            let now = std::time::SystemTime::now()
 74                .duration_since(std::time::UNIX_EPOCH)
 75                .unwrap_or_default()
 76                .as_millis();
 77            eprintln!(
 78                "[{}ms] all_models: calling provided_models for {:?}",
 79                now,
 80                provider.id()
 81            );
 82            let models = provider.provided_models(cx);
 83            eprintln!(
 84                "[{}ms] all_models: provider {:?} returned {} models",
 85                now,
 86                provider.id(),
 87                models.len()
 88            );
 89            models.into_iter().map(|model| ModelInfo {
 90                model,
 91                icon: provider.icon(),
 92            })
 93        })
 94        .collect();
 95
 96    let now = std::time::SystemTime::now()
 97        .duration_since(std::time::UNIX_EPOCH)
 98        .unwrap_or_default()
 99        .as_millis();
100    eprintln!(
101        "[{}ms] all_models: returning {} total models",
102        now,
103        all.len()
104    );
105    GroupedModels::new(all, recommended)
106}
107
108#[derive(Clone)]
109struct ModelInfo {
110    model: Arc<dyn LanguageModel>,
111    icon: IconName,
112}
113
114pub struct LanguageModelPickerDelegate {
115    on_model_changed: OnModelChanged,
116    get_active_model: GetActiveModel,
117    all_models: Arc<GroupedModels>,
118    filtered_entries: Vec<LanguageModelPickerEntry>,
119    selected_index: usize,
120    _authenticate_all_providers_task: Task<()>,
121    _refresh_models_task: Task<()>,
122    popover_styles: bool,
123    focus_handle: FocusHandle,
124}
125
126impl LanguageModelPickerDelegate {
127    fn new(
128        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
129        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
130        popover_styles: bool,
131        focus_handle: FocusHandle,
132        window: &mut Window,
133        cx: &mut Context<Picker<Self>>,
134    ) -> Self {
135        let now = std::time::SystemTime::now()
136            .duration_since(std::time::UNIX_EPOCH)
137            .unwrap_or_default()
138            .as_millis();
139        eprintln!("[{}ms] LanguageModelPickerDelegate::new() called", now);
140        let on_model_changed = Arc::new(on_model_changed);
141        let models = all_models(cx);
142        eprintln!(
143            "[{}ms] LanguageModelPickerDelegate::new() got {} models from all_models()",
144            now,
145            models.all.len()
146        );
147        let entries = models.entries();
148
149        Self {
150            on_model_changed,
151            all_models: Arc::new(models),
152            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
153            filtered_entries: entries,
154            get_active_model: Arc::new(get_active_model),
155            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
156            _refresh_models_task: {
157                let now = std::time::SystemTime::now()
158                    .duration_since(std::time::UNIX_EPOCH)
159                    .unwrap_or_default()
160                    .as_millis();
161                eprintln!(
162                    "[{}ms] LanguageModelPickerDelegate::new() setting up refresh task for LanguageModelRegistry",
163                    now
164                );
165
166                // Create a channel to signal when models need refreshing
167                let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
168
169                // Subscribe to registry events and send refresh signals through the channel
170                let registry = LanguageModelRegistry::global(cx);
171                eprintln!(
172                    "[{}ms] LanguageModelPickerDelegate::new() subscribing to registry entity_id: {:?}",
173                    now,
174                    registry.entity_id()
175                );
176                cx.subscribe(&registry, move |_picker, _, event, _cx| {
177                    match event {
178                        language_model::Event::ProviderStateChanged(id) => {
179                            let now = std::time::SystemTime::now()
180                                .duration_since(std::time::UNIX_EPOCH)
181                                .unwrap_or_default()
182                                .as_millis();
183                            eprintln!(
184                                "[{}ms] LanguageModelSelector: ProviderStateChanged event for {:?}, sending refresh signal",
185                                now, id
186                            );
187                            refresh_tx.unbounded_send(()).ok();
188                        }
189                        language_model::Event::AddedProvider(id) => {
190                            let now = std::time::SystemTime::now()
191                                .duration_since(std::time::UNIX_EPOCH)
192                                .unwrap_or_default()
193                                .as_millis();
194                            eprintln!(
195                                "[{}ms] LanguageModelSelector: AddedProvider event for {:?}, sending refresh signal",
196                                now, id
197                            );
198                            refresh_tx.unbounded_send(()).ok();
199                        }
200                        language_model::Event::RemovedProvider(id) => {
201                            let now = std::time::SystemTime::now()
202                                .duration_since(std::time::UNIX_EPOCH)
203                                .unwrap_or_default()
204                                .as_millis();
205                            eprintln!(
206                                "[{}ms] LanguageModelSelector: RemovedProvider event for {:?}, sending refresh signal",
207                                now, id
208                            );
209                            refresh_tx.unbounded_send(()).ok();
210                        }
211                        _ => {}
212                    }
213                })
214                .detach();
215
216                // Spawn a task that listens for refresh signals and updates the picker
217                cx.spawn_in(window, async move |this, cx| {
218                    while let Some(()) = refresh_rx.next().await {
219                        let now = std::time::SystemTime::now()
220                            .duration_since(std::time::UNIX_EPOCH)
221                            .unwrap_or_default()
222                            .as_millis();
223                        eprintln!(
224                            "[{}ms] LanguageModelSelector: refresh signal received, updating models",
225                            now
226                        );
227                        let result = this.update_in(cx, |picker, window, cx| {
228                            picker.delegate.all_models = Arc::new(all_models(cx));
229                            picker.refresh(window, cx);
230                        });
231                        if result.is_err() {
232                            // Picker was dropped, exit the loop
233                            break;
234                        }
235                    }
236                })
237            },
238            popover_styles,
239            focus_handle,
240        }
241    }
242
243    fn get_active_model_index(
244        entries: &[LanguageModelPickerEntry],
245        active_model: Option<ConfiguredModel>,
246    ) -> usize {
247        entries
248            .iter()
249            .position(|entry| {
250                if let LanguageModelPickerEntry::Model(model) = entry {
251                    active_model
252                        .as_ref()
253                        .map(|active_model| {
254                            active_model.model.id() == model.model.id()
255                                && active_model.provider.id() == model.model.provider_id()
256                        })
257                        .unwrap_or_default()
258                } else {
259                    false
260                }
261            })
262            .unwrap_or(0)
263    }
264
265    /// Authenticates all providers in the [`LanguageModelRegistry`].
266    ///
267    /// We do this so that we can populate the language selector with all of the
268    /// models from the configured providers.
269    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
270        let authenticate_all_providers = LanguageModelRegistry::global(cx)
271            .read(cx)
272            .providers()
273            .iter()
274            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
275            .collect::<Vec<_>>();
276
277        cx.spawn(async move |_cx| {
278            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
279                if let Err(err) = authenticate_task.await {
280                    if matches!(err, AuthenticateError::CredentialsNotFound) {
281                        // Since we're authenticating these providers in the
282                        // background for the purposes of populating the
283                        // language selector, we don't care about providers
284                        // where the credentials are not found.
285                    } else {
286                        // Some providers have noisy failure states that we
287                        // don't want to spam the logs with every time the
288                        // language model selector is initialized.
289                        //
290                        // Ideally these should have more clear failure modes
291                        // that we know are safe to ignore here, like what we do
292                        // with `CredentialsNotFound` above.
293                        match provider_id.0.as_ref() {
294                            "lmstudio" | "ollama" => {
295                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
296                                //
297                                // These fail noisily, so we don't log them.
298                            }
299                            "copilot_chat" => {
300                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
301                            }
302                            _ => {
303                                log::error!(
304                                    "Failed to authenticate provider: {}: {err:#}",
305                                    provider_name.0
306                                );
307                            }
308                        }
309                    }
310                }
311            }
312        })
313    }
314
315    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
316        (self.get_active_model)(cx)
317    }
318}
319
320struct GroupedModels {
321    recommended: Vec<ModelInfo>,
322    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
323}
324
325impl GroupedModels {
326    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
327        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
328        for model in all {
329            let provider = model.model.provider_id();
330            if let Some(models) = all_by_provider.get_mut(&provider) {
331                models.push(model);
332            } else {
333                all_by_provider.insert(provider, vec![model]);
334            }
335        }
336
337        Self {
338            recommended,
339            all: all_by_provider,
340        }
341    }
342
343    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
344        let mut entries = Vec::new();
345
346        if !self.recommended.is_empty() {
347            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
348            entries.extend(
349                self.recommended
350                    .iter()
351                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
352            );
353        }
354
355        for models in self.all.values() {
356            if models.is_empty() {
357                continue;
358            }
359            entries.push(LanguageModelPickerEntry::Separator(
360                models[0].model.provider_name().0,
361            ));
362            entries.extend(
363                models
364                    .iter()
365                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
366            );
367        }
368        entries
369    }
370}
371
372enum LanguageModelPickerEntry {
373    Model(ModelInfo),
374    Separator(SharedString),
375}
376
377struct ModelMatcher {
378    models: Vec<ModelInfo>,
379    bg_executor: BackgroundExecutor,
380    candidates: Vec<StringMatchCandidate>,
381}
382
383impl ModelMatcher {
384    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
385        let candidates = Self::make_match_candidates(&models);
386        Self {
387            models,
388            bg_executor,
389            candidates,
390        }
391    }
392
393    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
394        let mut matches = self.bg_executor.block(match_strings(
395            &self.candidates,
396            query,
397            false,
398            true,
399            100,
400            &Default::default(),
401            self.bg_executor.clone(),
402        ));
403
404        let sorting_key = |mat: &StringMatch| {
405            let candidate = &self.candidates[mat.candidate_id];
406            (Reverse(OrderedFloat(mat.score)), candidate.id)
407        };
408        matches.sort_unstable_by_key(sorting_key);
409
410        let matched_models: Vec<_> = matches
411            .into_iter()
412            .map(|mat| self.models[mat.candidate_id].clone())
413            .collect();
414
415        matched_models
416    }
417
418    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
419        self.models
420            .iter()
421            .filter(|m| {
422                m.model
423                    .name()
424                    .0
425                    .to_lowercase()
426                    .contains(&query.to_lowercase())
427            })
428            .cloned()
429            .collect::<Vec<_>>()
430    }
431
432    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
433        model_infos
434            .iter()
435            .enumerate()
436            .map(|(index, model)| {
437                StringMatchCandidate::new(
438                    index,
439                    &format!(
440                        "{}/{}",
441                        &model.model.provider_name().0,
442                        &model.model.name().0
443                    ),
444                )
445            })
446            .collect::<Vec<_>>()
447    }
448}
449
450impl PickerDelegate for LanguageModelPickerDelegate {
451    type ListItem = AnyElement;
452
453    fn match_count(&self) -> usize {
454        self.filtered_entries.len()
455    }
456
457    fn selected_index(&self) -> usize {
458        self.selected_index
459    }
460
461    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
462        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
463        cx.notify();
464    }
465
466    fn can_select(
467        &mut self,
468        ix: usize,
469        _window: &mut Window,
470        _cx: &mut Context<Picker<Self>>,
471    ) -> bool {
472        match self.filtered_entries.get(ix) {
473            Some(LanguageModelPickerEntry::Model(_)) => true,
474            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
475        }
476    }
477
478    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
479        "Select a model…".into()
480    }
481
482    fn update_matches(
483        &mut self,
484        query: String,
485        window: &mut Window,
486        cx: &mut Context<Picker<Self>>,
487    ) -> Task<()> {
488        let all_models = self.all_models.clone();
489        let active_model = (self.get_active_model)(cx);
490        let bg_executor = cx.background_executor();
491
492        let language_model_registry = LanguageModelRegistry::global(cx);
493
494        let configured_providers = language_model_registry
495            .read(cx)
496            .providers()
497            .into_iter()
498            .filter(|provider| provider.is_authenticated(cx))
499            .collect::<Vec<_>>();
500
501        let configured_provider_ids = configured_providers
502            .iter()
503            .map(|provider| provider.id())
504            .collect::<Vec<_>>();
505
506        let recommended_models = all_models
507            .recommended
508            .iter()
509            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
510            .cloned()
511            .collect::<Vec<_>>();
512
513        let available_models = all_models
514            .all
515            .values()
516            .flat_map(|models| models.iter())
517            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
518            .cloned()
519            .collect::<Vec<_>>();
520
521        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
522        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
523
524        let recommended = matcher_rec.exact_search(&query);
525        let all = matcher_all.fuzzy_search(&query);
526
527        let filtered_models = GroupedModels::new(all, recommended);
528
529        cx.spawn_in(window, async move |this, cx| {
530            this.update_in(cx, |this, window, cx| {
531                this.delegate.filtered_entries = filtered_models.entries();
532                // Finds the currently selected model in the list
533                let new_index =
534                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
535                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
536                cx.notify();
537            })
538            .ok();
539        })
540    }
541
542    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
543        if let Some(LanguageModelPickerEntry::Model(model_info)) =
544            self.filtered_entries.get(self.selected_index)
545        {
546            let model = model_info.model.clone();
547            (self.on_model_changed)(model.clone(), cx);
548
549            let current_index = self.selected_index;
550            self.set_selected_index(current_index, window, cx);
551
552            cx.emit(DismissEvent);
553        }
554    }
555
556    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
557        cx.emit(DismissEvent);
558    }
559
560    fn render_match(
561        &self,
562        ix: usize,
563        selected: bool,
564        _: &mut Window,
565        cx: &mut Context<Picker<Self>>,
566    ) -> Option<Self::ListItem> {
567        match self.filtered_entries.get(ix)? {
568            LanguageModelPickerEntry::Separator(title) => Some(
569                div()
570                    .px_2()
571                    .pb_1()
572                    .when(ix > 1, |this| {
573                        this.mt_1()
574                            .pt_2()
575                            .border_t_1()
576                            .border_color(cx.theme().colors().border_variant)
577                    })
578                    .child(
579                        Label::new(title)
580                            .size(LabelSize::XSmall)
581                            .color(Color::Muted),
582                    )
583                    .into_any_element(),
584            ),
585            LanguageModelPickerEntry::Model(model_info) => {
586                let active_model = (self.get_active_model)(cx);
587                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
588                let active_model_id = active_model.map(|m| m.model.id());
589
590                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
591                    && Some(model_info.model.id()) == active_model_id;
592
593                let model_icon_color = if is_selected {
594                    Color::Accent
595                } else {
596                    Color::Muted
597                };
598
599                Some(
600                    ListItem::new(ix)
601                        .inset(true)
602                        .spacing(ListItemSpacing::Sparse)
603                        .toggle_state(selected)
604                        .child(
605                            h_flex()
606                                .w_full()
607                                .gap_1p5()
608                                .child(
609                                    Icon::new(model_info.icon)
610                                        .color(model_icon_color)
611                                        .size(IconSize::Small),
612                                )
613                                .child(Label::new(model_info.model.name().0).truncate()),
614                        )
615                        .end_slot(div().pr_3().when(is_selected, |this| {
616                            this.child(
617                                Icon::new(IconName::Check)
618                                    .color(Color::Accent)
619                                    .size(IconSize::Small),
620                            )
621                        }))
622                        .into_any_element(),
623                )
624            }
625        }
626    }
627
628    fn render_footer(
629        &self,
630        _window: &mut Window,
631        cx: &mut Context<Picker<Self>>,
632    ) -> Option<gpui::AnyElement> {
633        let focus_handle = self.focus_handle.clone();
634
635        if !self.popover_styles {
636            return None;
637        }
638
639        Some(
640            h_flex()
641                .w_full()
642                .p_1p5()
643                .border_t_1()
644                .border_color(cx.theme().colors().border_variant)
645                .child(
646                    Button::new("configure", "Configure")
647                        .full_width()
648                        .style(ButtonStyle::Outlined)
649                        .key_binding(
650                            KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
651                                .map(|kb| kb.size(rems_from_px(12.))),
652                        )
653                        .on_click(|_, window, cx| {
654                            window.dispatch_action(OpenSettings.boxed_clone(), cx);
655                        }),
656                )
657                .into_any(),
658        )
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665    use futures::{future::BoxFuture, stream::BoxStream};
666    use gpui::{AsyncApp, TestAppContext, http_client};
667    use language_model::{
668        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
669        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
670        LanguageModelRequest, LanguageModelToolChoice,
671    };
672    use ui::IconName;
673
674    #[derive(Clone)]
675    struct TestLanguageModel {
676        name: LanguageModelName,
677        id: LanguageModelId,
678        provider_id: LanguageModelProviderId,
679        provider_name: LanguageModelProviderName,
680    }
681
682    impl TestLanguageModel {
683        fn new(name: &str, provider: &str) -> Self {
684            Self {
685                name: LanguageModelName::from(name.to_string()),
686                id: LanguageModelId::from(name.to_string()),
687                provider_id: LanguageModelProviderId::from(provider.to_string()),
688                provider_name: LanguageModelProviderName::from(provider.to_string()),
689            }
690        }
691    }
692
693    impl LanguageModel for TestLanguageModel {
694        fn id(&self) -> LanguageModelId {
695            self.id.clone()
696        }
697
698        fn name(&self) -> LanguageModelName {
699            self.name.clone()
700        }
701
702        fn provider_id(&self) -> LanguageModelProviderId {
703            self.provider_id.clone()
704        }
705
706        fn provider_name(&self) -> LanguageModelProviderName {
707            self.provider_name.clone()
708        }
709
710        fn supports_tools(&self) -> bool {
711            false
712        }
713
714        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
715            false
716        }
717
718        fn supports_images(&self) -> bool {
719            false
720        }
721
722        fn telemetry_id(&self) -> String {
723            format!("{}/{}", self.provider_id.0, self.name.0)
724        }
725
726        fn max_token_count(&self) -> u64 {
727            1000
728        }
729
730        fn count_tokens(
731            &self,
732            _: LanguageModelRequest,
733            _: &App,
734        ) -> BoxFuture<'static, http_client::Result<u64>> {
735            unimplemented!()
736        }
737
738        fn stream_completion(
739            &self,
740            _: LanguageModelRequest,
741            _: &AsyncApp,
742        ) -> BoxFuture<
743            'static,
744            Result<
745                BoxStream<
746                    'static,
747                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
748                >,
749                LanguageModelCompletionError,
750            >,
751        > {
752            unimplemented!()
753        }
754    }
755
756    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
757        model_specs
758            .into_iter()
759            .map(|(provider, name)| ModelInfo {
760                model: Arc::new(TestLanguageModel::new(name, provider)),
761                icon: IconName::Ai,
762            })
763            .collect()
764    }
765
766    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
767        assert_eq!(
768            result.len(),
769            expected.len(),
770            "Number of models doesn't match"
771        );
772
773        for (i, expected_name) in expected.iter().enumerate() {
774            assert_eq!(
775                result[i].model.telemetry_id(),
776                *expected_name,
777                "Model at position {} doesn't match expected model",
778                i
779            );
780        }
781    }
782
783    #[gpui::test]
784    fn test_exact_match(cx: &mut TestAppContext) {
785        let models = create_models(vec![
786            ("zed", "Claude 3.7 Sonnet"),
787            ("zed", "Claude 3.7 Sonnet Thinking"),
788            ("zed", "gpt-4.1"),
789            ("zed", "gpt-4.1-nano"),
790            ("openai", "gpt-3.5-turbo"),
791            ("openai", "gpt-4.1"),
792            ("openai", "gpt-4.1-nano"),
793            ("ollama", "mistral"),
794            ("ollama", "deepseek"),
795        ]);
796        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
797
798        // The order of models should be maintained, case doesn't matter
799        let results = matcher.exact_search("GPT-4.1");
800        assert_models_eq(
801            results,
802            vec![
803                "zed/gpt-4.1",
804                "zed/gpt-4.1-nano",
805                "openai/gpt-4.1",
806                "openai/gpt-4.1-nano",
807            ],
808        );
809    }
810
811    #[gpui::test]
812    fn test_fuzzy_match(cx: &mut TestAppContext) {
813        let models = create_models(vec![
814            ("zed", "Claude 3.7 Sonnet"),
815            ("zed", "Claude 3.7 Sonnet Thinking"),
816            ("zed", "gpt-4.1"),
817            ("zed", "gpt-4.1-nano"),
818            ("openai", "gpt-3.5-turbo"),
819            ("openai", "gpt-4.1"),
820            ("openai", "gpt-4.1-nano"),
821            ("ollama", "mistral"),
822            ("ollama", "deepseek"),
823        ]);
824        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
825
826        // Results should preserve models order whenever possible.
827        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
828        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
829        // so it should appear first in the results.
830        let results = matcher.fuzzy_search("41");
831        assert_models_eq(
832            results,
833            vec![
834                "zed/gpt-4.1",
835                "openai/gpt-4.1",
836                "zed/gpt-4.1-nano",
837                "openai/gpt-4.1-nano",
838            ],
839        );
840
841        // Model provider should be searchable as well
842        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
843        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
844
845        // Fuzzy search
846        let results = matcher.fuzzy_search("z4n");
847        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
848    }
849
850    #[gpui::test]
851    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
852        let recommended_models = create_models(vec![("zed", "claude")]);
853        let all_models = create_models(vec![
854            ("zed", "claude"), // Should also appear in "other"
855            ("zed", "gemini"),
856            ("copilot", "o3"),
857        ]);
858
859        let grouped_models = GroupedModels::new(all_models, recommended_models);
860
861        let actual_all_models = grouped_models
862            .all
863            .values()
864            .flatten()
865            .cloned()
866            .collect::<Vec<_>>();
867
868        // Recommended models should also appear in "all"
869        assert_models_eq(
870            actual_all_models,
871            vec!["zed/claude", "zed/gemini", "copilot/o3"],
872        );
873    }
874
875    #[gpui::test]
876    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
877        let recommended_models = create_models(vec![("zed", "claude")]);
878        let all_models = create_models(vec![
879            ("zed", "claude"), // Should also appear in "other"
880            ("zed", "gemini"),
881            ("copilot", "claude"), // Different provider, should appear in "other"
882        ]);
883
884        let grouped_models = GroupedModels::new(all_models, recommended_models);
885
886        let actual_all_models = grouped_models
887            .all
888            .values()
889            .flatten()
890            .cloned()
891            .collect::<Vec<_>>();
892
893        // All models should appear in "all" regardless of recommended status
894        assert_models_eq(
895            actual_all_models,
896            vec!["zed/claude", "zed/gemini", "copilot/claude"],
897        );
898    }
899}