1use std::{cmp::Reverse, sync::Arc};
2
3use collections::{HashSet, IndexMap};
4use feature_flags::ZedProFeatureFlag;
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{
7 Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
8 EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
9 action_with_deprecated_aliases,
10};
11use language_model::{
12 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
13 LanguageModelRegistry,
14};
15use ordered_float::OrderedFloat;
16use picker::{Picker, PickerDelegate};
17use proto::Plan;
18use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
19
20action_with_deprecated_aliases!(
21 agent,
22 ToggleModelSelector,
23 [
24 "assistant::ToggleModelSelector",
25 "assistant2::ToggleModelSelector"
26 ]
27);
28
29const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
30
31type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
32type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
33
34pub struct LanguageModelSelector {
35 picker: Entity<Picker<LanguageModelPickerDelegate>>,
36 _authenticate_all_providers_task: Task<()>,
37 _subscriptions: Vec<Subscription>,
38}
39
40impl LanguageModelSelector {
41 pub fn new(
42 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
43 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
44 window: &mut Window,
45 cx: &mut Context<Self>,
46 ) -> Self {
47 let on_model_changed = Arc::new(on_model_changed);
48
49 let all_models = Self::all_models(cx);
50 let entries = all_models.entries();
51
52 let delegate = LanguageModelPickerDelegate {
53 language_model_selector: cx.entity().downgrade(),
54 on_model_changed: on_model_changed.clone(),
55 all_models: Arc::new(all_models),
56 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
57 filtered_entries: entries,
58 get_active_model: Arc::new(get_active_model),
59 };
60
61 let picker = cx.new(|cx| {
62 Picker::list(delegate, window, cx)
63 .show_scrollbar(true)
64 .width(rems(20.))
65 .max_height(Some(rems(20.).into()))
66 });
67
68 let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
69
70 LanguageModelSelector {
71 picker,
72 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
73 _subscriptions: vec![
74 cx.subscribe_in(
75 &LanguageModelRegistry::global(cx),
76 window,
77 Self::handle_language_model_registry_event,
78 ),
79 subscription,
80 ],
81 }
82 }
83
84 fn handle_language_model_registry_event(
85 &mut self,
86 _registry: &Entity<LanguageModelRegistry>,
87 event: &language_model::Event,
88 window: &mut Window,
89 cx: &mut Context<Self>,
90 ) {
91 match event {
92 language_model::Event::ProviderStateChanged
93 | language_model::Event::AddedProvider(_)
94 | language_model::Event::RemovedProvider(_) => {
95 self.picker.update(cx, |this, cx| {
96 let query = this.query(cx);
97 this.delegate.all_models = Arc::new(Self::all_models(cx));
98 // Update matches will automatically drop the previous task
99 // if we get a provider event again
100 this.update_matches(query, window, cx)
101 });
102 }
103 _ => {}
104 }
105 }
106
107 /// Authenticates all providers in the [`LanguageModelRegistry`].
108 ///
109 /// We do this so that we can populate the language selector with all of the
110 /// models from the configured providers.
111 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
112 let authenticate_all_providers = LanguageModelRegistry::global(cx)
113 .read(cx)
114 .providers()
115 .iter()
116 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
117 .collect::<Vec<_>>();
118
119 cx.spawn(async move |_cx| {
120 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
121 if let Err(err) = authenticate_task.await {
122 if matches!(err, AuthenticateError::CredentialsNotFound) {
123 // Since we're authenticating these providers in the
124 // background for the purposes of populating the
125 // language selector, we don't care about providers
126 // where the credentials are not found.
127 } else {
128 // Some providers have noisy failure states that we
129 // don't want to spam the logs with every time the
130 // language model selector is initialized.
131 //
132 // Ideally these should have more clear failure modes
133 // that we know are safe to ignore here, like what we do
134 // with `CredentialsNotFound` above.
135 match provider_id.0.as_ref() {
136 "lmstudio" | "ollama" => {
137 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
138 //
139 // These fail noisily, so we don't log them.
140 }
141 "copilot_chat" => {
142 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
143 }
144 _ => {
145 log::error!(
146 "Failed to authenticate provider: {}: {err}",
147 provider_name.0
148 );
149 }
150 }
151 }
152 }
153 }
154 })
155 }
156
157 fn all_models(cx: &App) -> GroupedModels {
158 let providers = LanguageModelRegistry::global(cx).read(cx).providers();
159
160 let recommended = providers
161 .iter()
162 .flat_map(|provider| {
163 provider
164 .recommended_models(cx)
165 .into_iter()
166 .map(|model| ModelInfo {
167 model,
168 icon: provider.icon(),
169 })
170 })
171 .collect();
172
173 let other = providers
174 .iter()
175 .flat_map(|provider| {
176 provider
177 .provided_models(cx)
178 .into_iter()
179 .map(|model| ModelInfo {
180 model,
181 icon: provider.icon(),
182 })
183 })
184 .collect();
185
186 GroupedModels::new(other, recommended)
187 }
188
189 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
190 (self.picker.read(cx).delegate.get_active_model)(cx)
191 }
192
193 fn get_active_model_index(
194 entries: &[LanguageModelPickerEntry],
195 active_model: Option<ConfiguredModel>,
196 ) -> usize {
197 entries
198 .iter()
199 .position(|entry| {
200 if let LanguageModelPickerEntry::Model(model) = entry {
201 active_model
202 .as_ref()
203 .map(|active_model| {
204 active_model.model.id() == model.model.id()
205 && active_model.provider.id() == model.model.provider_id()
206 })
207 .unwrap_or_default()
208 } else {
209 false
210 }
211 })
212 .unwrap_or(0)
213 }
214}
215
216impl EventEmitter<DismissEvent> for LanguageModelSelector {}
217
218impl Focusable for LanguageModelSelector {
219 fn focus_handle(&self, cx: &App) -> FocusHandle {
220 self.picker.focus_handle(cx)
221 }
222}
223
224impl Render for LanguageModelSelector {
225 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
226 self.picker.clone()
227 }
228}
229
230#[derive(IntoElement)]
231pub struct LanguageModelSelectorPopoverMenu<T, TT>
232where
233 T: PopoverTrigger + ButtonCommon,
234 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
235{
236 language_model_selector: Entity<LanguageModelSelector>,
237 trigger: T,
238 tooltip: TT,
239 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
240 anchor: Corner,
241}
242
243impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
244where
245 T: PopoverTrigger + ButtonCommon,
246 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
247{
248 pub fn new(
249 language_model_selector: Entity<LanguageModelSelector>,
250 trigger: T,
251 tooltip: TT,
252 anchor: Corner,
253 ) -> Self {
254 Self {
255 language_model_selector,
256 trigger,
257 tooltip,
258 handle: None,
259 anchor,
260 }
261 }
262
263 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
264 self.handle = Some(handle);
265 self
266 }
267}
268
269impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
270where
271 T: PopoverTrigger + ButtonCommon,
272 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
273{
274 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
275 let language_model_selector = self.language_model_selector.clone();
276
277 PopoverMenu::new("model-switcher")
278 .menu(move |_window, _cx| Some(language_model_selector.clone()))
279 .trigger_with_tooltip(self.trigger, self.tooltip)
280 .anchor(self.anchor)
281 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
282 .offset(gpui::Point {
283 x: px(0.0),
284 y: px(-2.0),
285 })
286 }
287}
288
289#[derive(Clone)]
290struct ModelInfo {
291 model: Arc<dyn LanguageModel>,
292 icon: IconName,
293}
294
295pub struct LanguageModelPickerDelegate {
296 language_model_selector: WeakEntity<LanguageModelSelector>,
297 on_model_changed: OnModelChanged,
298 get_active_model: GetActiveModel,
299 all_models: Arc<GroupedModels>,
300 filtered_entries: Vec<LanguageModelPickerEntry>,
301 selected_index: usize,
302}
303
304struct GroupedModels {
305 recommended: Vec<ModelInfo>,
306 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
307}
308
309impl GroupedModels {
310 pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
311 let recommended_ids = recommended
312 .iter()
313 .map(|info| (info.model.provider_id(), info.model.id()))
314 .collect::<HashSet<_>>();
315
316 let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
317 for model in other {
318 if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
319 continue;
320 }
321
322 let provider = model.model.provider_id();
323 if let Some(models) = other_by_provider.get_mut(&provider) {
324 models.push(model);
325 } else {
326 other_by_provider.insert(provider, vec![model]);
327 }
328 }
329
330 Self {
331 recommended,
332 other: other_by_provider,
333 }
334 }
335
336 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
337 let mut entries = Vec::new();
338
339 if !self.recommended.is_empty() {
340 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
341 entries.extend(
342 self.recommended
343 .iter()
344 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
345 );
346 }
347
348 for models in self.other.values() {
349 if models.is_empty() {
350 continue;
351 }
352 entries.push(LanguageModelPickerEntry::Separator(
353 models[0].model.provider_name().0,
354 ));
355 entries.extend(
356 models
357 .iter()
358 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
359 );
360 }
361 entries
362 }
363
364 fn model_infos(&self) -> Vec<ModelInfo> {
365 let other = self
366 .other
367 .values()
368 .flat_map(|model| model.iter())
369 .cloned()
370 .collect::<Vec<_>>();
371 self.recommended
372 .iter()
373 .chain(&other)
374 .cloned()
375 .collect::<Vec<_>>()
376 }
377}
378
379enum LanguageModelPickerEntry {
380 Model(ModelInfo),
381 Separator(SharedString),
382}
383
384struct ModelMatcher {
385 models: Vec<ModelInfo>,
386 bg_executor: BackgroundExecutor,
387 candidates: Vec<StringMatchCandidate>,
388}
389
390impl ModelMatcher {
391 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
392 let candidates = Self::make_match_candidates(&models);
393 Self {
394 models,
395 bg_executor,
396 candidates,
397 }
398 }
399
400 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
401 let mut matches = self.bg_executor.block(match_strings(
402 &self.candidates,
403 &query,
404 false,
405 100,
406 &Default::default(),
407 self.bg_executor.clone(),
408 ));
409
410 let sorting_key = |mat: &StringMatch| {
411 let candidate = &self.candidates[mat.candidate_id];
412 (Reverse(OrderedFloat(mat.score)), candidate.id)
413 };
414 matches.sort_unstable_by_key(sorting_key);
415
416 let matched_models: Vec<_> = matches
417 .into_iter()
418 .map(|mat| self.models[mat.candidate_id].clone())
419 .collect();
420
421 matched_models
422 }
423
424 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
425 self.models
426 .iter()
427 .filter(|m| {
428 m.model
429 .name()
430 .0
431 .to_lowercase()
432 .contains(&query.to_lowercase())
433 })
434 .cloned()
435 .collect::<Vec<_>>()
436 }
437
438 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
439 model_infos
440 .iter()
441 .enumerate()
442 .map(|(index, model)| {
443 StringMatchCandidate::new(
444 index,
445 &format!(
446 "{}/{}",
447 &model.model.provider_name().0,
448 &model.model.name().0
449 ),
450 )
451 })
452 .collect::<Vec<_>>()
453 }
454}
455
456impl PickerDelegate for LanguageModelPickerDelegate {
457 type ListItem = AnyElement;
458
459 fn match_count(&self) -> usize {
460 self.filtered_entries.len()
461 }
462
463 fn selected_index(&self) -> usize {
464 self.selected_index
465 }
466
467 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
468 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
469 cx.notify();
470 }
471
472 fn can_select(
473 &mut self,
474 ix: usize,
475 _window: &mut Window,
476 _cx: &mut Context<Picker<Self>>,
477 ) -> bool {
478 match self.filtered_entries.get(ix) {
479 Some(LanguageModelPickerEntry::Model(_)) => true,
480 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
481 }
482 }
483
484 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
485 "Select a model…".into()
486 }
487
488 fn update_matches(
489 &mut self,
490 query: String,
491 window: &mut Window,
492 cx: &mut Context<Picker<Self>>,
493 ) -> Task<()> {
494 let all_models = self.all_models.clone();
495 let current_index = self.selected_index;
496 let bg_executor = cx.background_executor();
497
498 let language_model_registry = LanguageModelRegistry::global(cx);
499
500 let configured_providers = language_model_registry
501 .read(cx)
502 .providers()
503 .into_iter()
504 .filter(|provider| provider.is_authenticated(cx))
505 .collect::<Vec<_>>();
506
507 let configured_provider_ids = configured_providers
508 .iter()
509 .map(|provider| provider.id())
510 .collect::<Vec<_>>();
511
512 let recommended_models = all_models
513 .recommended
514 .iter()
515 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
516 .cloned()
517 .collect::<Vec<_>>();
518
519 let available_models = all_models
520 .model_infos()
521 .iter()
522 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
523 .cloned()
524 .collect::<Vec<_>>();
525
526 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
527 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
528
529 let recommended = matcher_rec.exact_search(&query);
530 let all = matcher_all.fuzzy_search(&query);
531
532 let filtered_models = GroupedModels::new(all, recommended);
533
534 cx.spawn_in(window, async move |this, cx| {
535 this.update_in(cx, |this, window, cx| {
536 this.delegate.filtered_entries = filtered_models.entries();
537 // Preserve selection focus
538 let new_index = if current_index >= this.delegate.filtered_entries.len() {
539 0
540 } else {
541 current_index
542 };
543 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
544 cx.notify();
545 })
546 .ok();
547 })
548 }
549
550 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
551 if let Some(LanguageModelPickerEntry::Model(model_info)) =
552 self.filtered_entries.get(self.selected_index)
553 {
554 let model = model_info.model.clone();
555 (self.on_model_changed)(model.clone(), cx);
556
557 let current_index = self.selected_index;
558 self.set_selected_index(current_index, window, cx);
559
560 cx.emit(DismissEvent);
561 }
562 }
563
564 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
565 self.language_model_selector
566 .update(cx, |_this, cx| cx.emit(DismissEvent))
567 .ok();
568 }
569
570 fn render_match(
571 &self,
572 ix: usize,
573 selected: bool,
574 _: &mut Window,
575 cx: &mut Context<Picker<Self>>,
576 ) -> Option<Self::ListItem> {
577 match self.filtered_entries.get(ix)? {
578 LanguageModelPickerEntry::Separator(title) => Some(
579 div()
580 .px_2()
581 .pb_1()
582 .when(ix > 1, |this| {
583 this.mt_1()
584 .pt_2()
585 .border_t_1()
586 .border_color(cx.theme().colors().border_variant)
587 })
588 .child(
589 Label::new(title)
590 .size(LabelSize::XSmall)
591 .color(Color::Muted),
592 )
593 .into_any_element(),
594 ),
595 LanguageModelPickerEntry::Model(model_info) => {
596 let active_model = (self.get_active_model)(cx);
597 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
598 let active_model_id = active_model.map(|m| m.model.id());
599
600 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
601 && Some(model_info.model.id()) == active_model_id;
602
603 let model_icon_color = if is_selected {
604 Color::Accent
605 } else {
606 Color::Muted
607 };
608
609 Some(
610 ListItem::new(ix)
611 .inset(true)
612 .spacing(ListItemSpacing::Sparse)
613 .toggle_state(selected)
614 .start_slot(
615 Icon::new(model_info.icon)
616 .color(model_icon_color)
617 .size(IconSize::Small),
618 )
619 .child(
620 h_flex()
621 .w_full()
622 .pl_0p5()
623 .gap_1p5()
624 .w(px(240.))
625 .child(Label::new(model_info.model.name().0.clone()).truncate()),
626 )
627 .end_slot(div().pr_3().when(is_selected, |this| {
628 this.child(
629 Icon::new(IconName::Check)
630 .color(Color::Accent)
631 .size(IconSize::Small),
632 )
633 }))
634 .into_any_element(),
635 )
636 }
637 }
638 }
639
640 fn render_footer(
641 &self,
642 _: &mut Window,
643 cx: &mut Context<Picker<Self>>,
644 ) -> Option<gpui::AnyElement> {
645 use feature_flags::FeatureFlagAppExt;
646
647 let plan = proto::Plan::ZedPro;
648
649 Some(
650 h_flex()
651 .w_full()
652 .border_t_1()
653 .border_color(cx.theme().colors().border_variant)
654 .p_1()
655 .gap_4()
656 .justify_between()
657 .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
658 this.child(match plan {
659 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
660 .icon(IconName::ZedAssistant)
661 .icon_size(IconSize::Small)
662 .icon_color(Color::Muted)
663 .icon_position(IconPosition::Start)
664 .on_click(|_, window, cx| {
665 window
666 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
667 }),
668 Plan::Free | Plan::ZedProTrial => Button::new(
669 "try-pro",
670 if plan == Plan::ZedProTrial {
671 "Upgrade to Pro"
672 } else {
673 "Try Pro"
674 },
675 )
676 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
677 })
678 })
679 .child(
680 Button::new("configure", "Configure")
681 .icon(IconName::Settings)
682 .icon_size(IconSize::Small)
683 .icon_color(Color::Muted)
684 .icon_position(IconPosition::Start)
685 .on_click(|_, window, cx| {
686 window.dispatch_action(
687 zed_actions::agent::OpenConfiguration.boxed_clone(),
688 cx,
689 );
690 }),
691 )
692 .into_any(),
693 )
694 }
695}
696
697#[cfg(test)]
698mod tests {
699 use super::*;
700 use futures::{future::BoxFuture, stream::BoxStream};
701 use gpui::{AsyncApp, TestAppContext, http_client};
702 use language_model::{
703 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
704 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
705 LanguageModelRequest, LanguageModelToolChoice,
706 };
707 use ui::IconName;
708
709 #[derive(Clone)]
710 struct TestLanguageModel {
711 name: LanguageModelName,
712 id: LanguageModelId,
713 provider_id: LanguageModelProviderId,
714 provider_name: LanguageModelProviderName,
715 }
716
717 impl TestLanguageModel {
718 fn new(name: &str, provider: &str) -> Self {
719 Self {
720 name: LanguageModelName::from(name.to_string()),
721 id: LanguageModelId::from(name.to_string()),
722 provider_id: LanguageModelProviderId::from(provider.to_string()),
723 provider_name: LanguageModelProviderName::from(provider.to_string()),
724 }
725 }
726 }
727
728 impl LanguageModel for TestLanguageModel {
729 fn id(&self) -> LanguageModelId {
730 self.id.clone()
731 }
732
733 fn name(&self) -> LanguageModelName {
734 self.name.clone()
735 }
736
737 fn provider_id(&self) -> LanguageModelProviderId {
738 self.provider_id.clone()
739 }
740
741 fn provider_name(&self) -> LanguageModelProviderName {
742 self.provider_name.clone()
743 }
744
745 fn supports_tools(&self) -> bool {
746 false
747 }
748
749 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
750 false
751 }
752
753 fn supports_images(&self) -> bool {
754 false
755 }
756
757 fn telemetry_id(&self) -> String {
758 format!("{}/{}", self.provider_id.0, self.name.0)
759 }
760
761 fn max_token_count(&self) -> usize {
762 1000
763 }
764
765 fn count_tokens(
766 &self,
767 _: LanguageModelRequest,
768 _: &App,
769 ) -> BoxFuture<'static, http_client::Result<usize>> {
770 unimplemented!()
771 }
772
773 fn stream_completion(
774 &self,
775 _: LanguageModelRequest,
776 _: &AsyncApp,
777 ) -> BoxFuture<
778 'static,
779 http_client::Result<
780 BoxStream<
781 'static,
782 http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
783 >,
784 >,
785 > {
786 unimplemented!()
787 }
788 }
789
790 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
791 model_specs
792 .into_iter()
793 .map(|(provider, name)| ModelInfo {
794 model: Arc::new(TestLanguageModel::new(name, provider)),
795 icon: IconName::Ai,
796 })
797 .collect()
798 }
799
800 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
801 assert_eq!(
802 result.len(),
803 expected.len(),
804 "Number of models doesn't match"
805 );
806
807 for (i, expected_name) in expected.iter().enumerate() {
808 assert_eq!(
809 result[i].model.telemetry_id(),
810 *expected_name,
811 "Model at position {} doesn't match expected model",
812 i
813 );
814 }
815 }
816
817 #[gpui::test]
818 fn test_exact_match(cx: &mut TestAppContext) {
819 let models = create_models(vec![
820 ("zed", "Claude 3.7 Sonnet"),
821 ("zed", "Claude 3.7 Sonnet Thinking"),
822 ("zed", "gpt-4.1"),
823 ("zed", "gpt-4.1-nano"),
824 ("openai", "gpt-3.5-turbo"),
825 ("openai", "gpt-4.1"),
826 ("openai", "gpt-4.1-nano"),
827 ("ollama", "mistral"),
828 ("ollama", "deepseek"),
829 ]);
830 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
831
832 // The order of models should be maintained, case doesn't matter
833 let results = matcher.exact_search("GPT-4.1");
834 assert_models_eq(
835 results,
836 vec![
837 "zed/gpt-4.1",
838 "zed/gpt-4.1-nano",
839 "openai/gpt-4.1",
840 "openai/gpt-4.1-nano",
841 ],
842 );
843 }
844
845 #[gpui::test]
846 fn test_fuzzy_match(cx: &mut TestAppContext) {
847 let models = create_models(vec![
848 ("zed", "Claude 3.7 Sonnet"),
849 ("zed", "Claude 3.7 Sonnet Thinking"),
850 ("zed", "gpt-4.1"),
851 ("zed", "gpt-4.1-nano"),
852 ("openai", "gpt-3.5-turbo"),
853 ("openai", "gpt-4.1"),
854 ("openai", "gpt-4.1-nano"),
855 ("ollama", "mistral"),
856 ("ollama", "deepseek"),
857 ]);
858 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
859
860 // Results should preserve models order whenever possible.
861 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
862 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
863 // so it should appear first in the results.
864 let results = matcher.fuzzy_search("41");
865 assert_models_eq(
866 results,
867 vec![
868 "zed/gpt-4.1",
869 "openai/gpt-4.1",
870 "zed/gpt-4.1-nano",
871 "openai/gpt-4.1-nano",
872 ],
873 );
874
875 // Model provider should be searchable as well
876 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
877 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
878
879 // Fuzzy search
880 let results = matcher.fuzzy_search("z4n");
881 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
882 }
883
884 #[gpui::test]
885 fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
886 let recommended_models = create_models(vec![("zed", "claude")]);
887 let all_models = create_models(vec![
888 ("zed", "claude"), // Should be filtered out from "other"
889 ("zed", "gemini"),
890 ("copilot", "o3"),
891 ]);
892
893 let grouped_models = GroupedModels::new(all_models, recommended_models);
894
895 let actual_other_models = grouped_models
896 .other
897 .values()
898 .flatten()
899 .cloned()
900 .collect::<Vec<_>>();
901
902 // Recommended models should not appear in "other"
903 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
904 }
905
906 #[gpui::test]
907 fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
908 let recommended_models = create_models(vec![("zed", "claude")]);
909 let all_models = create_models(vec![
910 ("zed", "claude"), // Should be filtered out from "other"
911 ("zed", "gemini"),
912 ("copilot", "claude"), // Should not be filtered out from "other"
913 ]);
914
915 let grouped_models = GroupedModels::new(all_models, recommended_models);
916
917 let actual_other_models = grouped_models
918 .other
919 .values()
920 .flatten()
921 .cloned()
922 .collect::<Vec<_>>();
923
924 // Recommended models should not appear in "other"
925 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
926 }
927}