1use std::{cmp::Reverse, sync::Arc};
2
3use collections::{HashSet, IndexMap};
4use feature_flags::ZedProFeatureFlag;
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{
7 Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
8 EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
9 action_with_deprecated_aliases,
10};
11use language_model::{
12 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
13 LanguageModelRegistry,
14};
15use ordered_float::OrderedFloat;
16use picker::{Picker, PickerDelegate};
17use proto::Plan;
18use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
19
20action_with_deprecated_aliases!(
21 agent,
22 ToggleModelSelector,
23 [
24 "assistant::ToggleModelSelector",
25 "assistant2::ToggleModelSelector"
26 ]
27);
28
29const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
30
31type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
32type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
33
34pub struct LanguageModelSelector {
35 picker: Entity<Picker<LanguageModelPickerDelegate>>,
36 _authenticate_all_providers_task: Task<()>,
37 _subscriptions: Vec<Subscription>,
38}
39
40impl LanguageModelSelector {
41 pub fn new(
42 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
43 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
44 window: &mut Window,
45 cx: &mut Context<Self>,
46 ) -> Self {
47 let on_model_changed = Arc::new(on_model_changed);
48
49 let all_models = Self::all_models(cx);
50 let entries = all_models.entries();
51
52 let delegate = LanguageModelPickerDelegate {
53 language_model_selector: cx.entity().downgrade(),
54 on_model_changed: on_model_changed.clone(),
55 all_models: Arc::new(all_models),
56 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
57 filtered_entries: entries,
58 get_active_model: Arc::new(get_active_model),
59 };
60
61 let picker = cx.new(|cx| {
62 Picker::list(delegate, window, cx)
63 .show_scrollbar(true)
64 .width(rems(20.))
65 .max_height(Some(rems(20.).into()))
66 });
67
68 let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
69
70 LanguageModelSelector {
71 picker,
72 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
73 _subscriptions: vec![
74 cx.subscribe_in(
75 &LanguageModelRegistry::global(cx),
76 window,
77 Self::handle_language_model_registry_event,
78 ),
79 subscription,
80 ],
81 }
82 }
83
84 fn handle_language_model_registry_event(
85 &mut self,
86 _registry: &Entity<LanguageModelRegistry>,
87 event: &language_model::Event,
88 window: &mut Window,
89 cx: &mut Context<Self>,
90 ) {
91 match event {
92 language_model::Event::ProviderStateChanged
93 | language_model::Event::AddedProvider(_)
94 | language_model::Event::RemovedProvider(_) => {
95 self.picker.update(cx, |this, cx| {
96 let query = this.query(cx);
97 this.delegate.all_models = Arc::new(Self::all_models(cx));
98 // Update matches will automatically drop the previous task
99 // if we get a provider event again
100 this.update_matches(query, window, cx)
101 });
102 }
103 _ => {}
104 }
105 }
106
107 /// Authenticates all providers in the [`LanguageModelRegistry`].
108 ///
109 /// We do this so that we can populate the language selector with all of the
110 /// models from the configured providers.
111 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
112 let authenticate_all_providers = LanguageModelRegistry::global(cx)
113 .read(cx)
114 .providers()
115 .iter()
116 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
117 .collect::<Vec<_>>();
118
119 cx.spawn(async move |_cx| {
120 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
121 if let Err(err) = authenticate_task.await {
122 if matches!(err, AuthenticateError::CredentialsNotFound) {
123 // Since we're authenticating these providers in the
124 // background for the purposes of populating the
125 // language selector, we don't care about providers
126 // where the credentials are not found.
127 } else {
128 // Some providers have noisy failure states that we
129 // don't want to spam the logs with every time the
130 // language model selector is initialized.
131 //
132 // Ideally these should have more clear failure modes
133 // that we know are safe to ignore here, like what we do
134 // with `CredentialsNotFound` above.
135 match provider_id.0.as_ref() {
136 "lmstudio" | "ollama" => {
137 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
138 //
139 // These fail noisily, so we don't log them.
140 }
141 "copilot_chat" => {
142 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
143 }
144 _ => {
145 log::error!(
146 "Failed to authenticate provider: {}: {err}",
147 provider_name.0
148 );
149 }
150 }
151 }
152 }
153 }
154 })
155 }
156
157 fn all_models(cx: &App) -> GroupedModels {
158 let mut recommended = Vec::new();
159 let mut recommended_set = HashSet::default();
160 for provider in LanguageModelRegistry::global(cx)
161 .read(cx)
162 .providers()
163 .iter()
164 {
165 let models = provider.recommended_models(cx);
166 recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
167 recommended.extend(
168 provider
169 .recommended_models(cx)
170 .into_iter()
171 .map(move |model| ModelInfo {
172 model: model.clone(),
173 icon: provider.icon(),
174 }),
175 );
176 }
177
178 let other_models = LanguageModelRegistry::global(cx)
179 .read(cx)
180 .providers()
181 .iter()
182 .map(|provider| {
183 (
184 provider.id(),
185 provider
186 .provided_models(cx)
187 .into_iter()
188 .filter_map(|model| {
189 let not_included =
190 !recommended_set.contains(&(model.provider_id(), model.id()));
191 not_included.then(|| ModelInfo {
192 model: model.clone(),
193 icon: provider.icon(),
194 })
195 })
196 .collect::<Vec<_>>(),
197 )
198 })
199 .collect::<IndexMap<_, _>>();
200
201 GroupedModels {
202 recommended,
203 other: other_models,
204 }
205 }
206
207 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
208 (self.picker.read(cx).delegate.get_active_model)(cx)
209 }
210
211 fn get_active_model_index(
212 entries: &[LanguageModelPickerEntry],
213 active_model: Option<ConfiguredModel>,
214 ) -> usize {
215 entries
216 .iter()
217 .position(|entry| {
218 if let LanguageModelPickerEntry::Model(model) = entry {
219 active_model
220 .as_ref()
221 .map(|active_model| {
222 active_model.model.id() == model.model.id()
223 && active_model.provider.id() == model.model.provider_id()
224 })
225 .unwrap_or_default()
226 } else {
227 false
228 }
229 })
230 .unwrap_or(0)
231 }
232}
233
234impl EventEmitter<DismissEvent> for LanguageModelSelector {}
235
236impl Focusable for LanguageModelSelector {
237 fn focus_handle(&self, cx: &App) -> FocusHandle {
238 self.picker.focus_handle(cx)
239 }
240}
241
242impl Render for LanguageModelSelector {
243 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
244 self.picker.clone()
245 }
246}
247
248#[derive(IntoElement)]
249pub struct LanguageModelSelectorPopoverMenu<T, TT>
250where
251 T: PopoverTrigger + ButtonCommon,
252 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
253{
254 language_model_selector: Entity<LanguageModelSelector>,
255 trigger: T,
256 tooltip: TT,
257 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
258 anchor: Corner,
259}
260
261impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
262where
263 T: PopoverTrigger + ButtonCommon,
264 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
265{
266 pub fn new(
267 language_model_selector: Entity<LanguageModelSelector>,
268 trigger: T,
269 tooltip: TT,
270 anchor: Corner,
271 ) -> Self {
272 Self {
273 language_model_selector,
274 trigger,
275 tooltip,
276 handle: None,
277 anchor,
278 }
279 }
280
281 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
282 self.handle = Some(handle);
283 self
284 }
285}
286
287impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
288where
289 T: PopoverTrigger + ButtonCommon,
290 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
291{
292 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
293 let language_model_selector = self.language_model_selector.clone();
294
295 PopoverMenu::new("model-switcher")
296 .menu(move |_window, _cx| Some(language_model_selector.clone()))
297 .trigger_with_tooltip(self.trigger, self.tooltip)
298 .anchor(self.anchor)
299 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
300 .offset(gpui::Point {
301 x: px(0.0),
302 y: px(-2.0),
303 })
304 }
305}
306
307#[derive(Clone)]
308struct ModelInfo {
309 model: Arc<dyn LanguageModel>,
310 icon: IconName,
311}
312
313pub struct LanguageModelPickerDelegate {
314 language_model_selector: WeakEntity<LanguageModelSelector>,
315 on_model_changed: OnModelChanged,
316 get_active_model: GetActiveModel,
317 all_models: Arc<GroupedModels>,
318 filtered_entries: Vec<LanguageModelPickerEntry>,
319 selected_index: usize,
320}
321
322struct GroupedModels {
323 recommended: Vec<ModelInfo>,
324 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
325}
326
327impl GroupedModels {
328 pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
329 let recommended_ids: HashSet<_> = recommended.iter().map(|info| info.model.id()).collect();
330
331 let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
332 for model in other {
333 if recommended_ids.contains(&model.model.id()) {
334 continue;
335 }
336
337 let provider = model.model.provider_id();
338 if let Some(models) = other_by_provider.get_mut(&provider) {
339 models.push(model);
340 } else {
341 other_by_provider.insert(provider, vec![model]);
342 }
343 }
344
345 Self {
346 recommended,
347 other: other_by_provider,
348 }
349 }
350
351 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
352 let mut entries = Vec::new();
353
354 if !self.recommended.is_empty() {
355 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
356 entries.extend(
357 self.recommended
358 .iter()
359 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
360 );
361 }
362
363 for models in self.other.values() {
364 if models.is_empty() {
365 continue;
366 }
367 entries.push(LanguageModelPickerEntry::Separator(
368 models[0].model.provider_name().0,
369 ));
370 entries.extend(
371 models
372 .iter()
373 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
374 );
375 }
376 entries
377 }
378
379 fn model_infos(&self) -> Vec<ModelInfo> {
380 let other = self
381 .other
382 .values()
383 .flat_map(|model| model.iter())
384 .cloned()
385 .collect::<Vec<_>>();
386 self.recommended
387 .iter()
388 .chain(&other)
389 .cloned()
390 .collect::<Vec<_>>()
391 }
392}
393
394enum LanguageModelPickerEntry {
395 Model(ModelInfo),
396 Separator(SharedString),
397}
398
399struct ModelMatcher {
400 models: Vec<ModelInfo>,
401 bg_executor: BackgroundExecutor,
402 candidates: Vec<StringMatchCandidate>,
403}
404
405impl ModelMatcher {
406 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
407 let candidates = Self::make_match_candidates(&models);
408 Self {
409 models,
410 bg_executor,
411 candidates,
412 }
413 }
414
415 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
416 let mut matches = self.bg_executor.block(match_strings(
417 &self.candidates,
418 &query,
419 false,
420 100,
421 &Default::default(),
422 self.bg_executor.clone(),
423 ));
424
425 let sorting_key = |mat: &StringMatch| {
426 let candidate = &self.candidates[mat.candidate_id];
427 (Reverse(OrderedFloat(mat.score)), candidate.id)
428 };
429 matches.sort_unstable_by_key(sorting_key);
430
431 let matched_models: Vec<_> = matches
432 .into_iter()
433 .map(|mat| self.models[mat.candidate_id].clone())
434 .collect();
435
436 matched_models
437 }
438
439 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
440 self.models
441 .iter()
442 .filter(|m| {
443 m.model
444 .name()
445 .0
446 .to_lowercase()
447 .contains(&query.to_lowercase())
448 })
449 .cloned()
450 .collect::<Vec<_>>()
451 }
452
453 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
454 model_infos
455 .iter()
456 .enumerate()
457 .map(|(index, model)| {
458 StringMatchCandidate::new(
459 index,
460 &format!(
461 "{}/{}",
462 &model.model.provider_name().0,
463 &model.model.name().0
464 ),
465 )
466 })
467 .collect::<Vec<_>>()
468 }
469}
470
471impl PickerDelegate for LanguageModelPickerDelegate {
472 type ListItem = AnyElement;
473
474 fn match_count(&self) -> usize {
475 self.filtered_entries.len()
476 }
477
478 fn selected_index(&self) -> usize {
479 self.selected_index
480 }
481
482 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
483 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
484 cx.notify();
485 }
486
487 fn can_select(
488 &mut self,
489 ix: usize,
490 _window: &mut Window,
491 _cx: &mut Context<Picker<Self>>,
492 ) -> bool {
493 match self.filtered_entries.get(ix) {
494 Some(LanguageModelPickerEntry::Model(_)) => true,
495 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
496 }
497 }
498
499 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
500 "Select a model…".into()
501 }
502
503 fn update_matches(
504 &mut self,
505 query: String,
506 window: &mut Window,
507 cx: &mut Context<Picker<Self>>,
508 ) -> Task<()> {
509 let all_models = self.all_models.clone();
510 let current_index = self.selected_index;
511 let bg_executor = cx.background_executor();
512
513 let language_model_registry = LanguageModelRegistry::global(cx);
514
515 let configured_providers = language_model_registry
516 .read(cx)
517 .providers()
518 .into_iter()
519 .filter(|provider| provider.is_authenticated(cx))
520 .collect::<Vec<_>>();
521
522 let configured_provider_ids = configured_providers
523 .iter()
524 .map(|provider| provider.id())
525 .collect::<Vec<_>>();
526
527 let recommended_models = all_models
528 .recommended
529 .iter()
530 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
531 .cloned()
532 .collect::<Vec<_>>();
533
534 let available_models = all_models
535 .model_infos()
536 .iter()
537 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
538 .cloned()
539 .collect::<Vec<_>>();
540
541 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
542 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
543
544 let recommended = matcher_rec.exact_search(&query);
545 let all = matcher_all.fuzzy_search(&query);
546
547 let filtered_models = GroupedModels::new(all, recommended);
548
549 cx.spawn_in(window, async move |this, cx| {
550 this.update_in(cx, |this, window, cx| {
551 this.delegate.filtered_entries = filtered_models.entries();
552 // Preserve selection focus
553 let new_index = if current_index >= this.delegate.filtered_entries.len() {
554 0
555 } else {
556 current_index
557 };
558 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
559 cx.notify();
560 })
561 .ok();
562 })
563 }
564
565 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
566 if let Some(LanguageModelPickerEntry::Model(model_info)) =
567 self.filtered_entries.get(self.selected_index)
568 {
569 let model = model_info.model.clone();
570 (self.on_model_changed)(model.clone(), cx);
571
572 let current_index = self.selected_index;
573 self.set_selected_index(current_index, window, cx);
574
575 cx.emit(DismissEvent);
576 }
577 }
578
579 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
580 self.language_model_selector
581 .update(cx, |_this, cx| cx.emit(DismissEvent))
582 .ok();
583 }
584
585 fn render_match(
586 &self,
587 ix: usize,
588 selected: bool,
589 _: &mut Window,
590 cx: &mut Context<Picker<Self>>,
591 ) -> Option<Self::ListItem> {
592 match self.filtered_entries.get(ix)? {
593 LanguageModelPickerEntry::Separator(title) => Some(
594 div()
595 .px_2()
596 .pb_1()
597 .when(ix > 1, |this| {
598 this.mt_1()
599 .pt_2()
600 .border_t_1()
601 .border_color(cx.theme().colors().border_variant)
602 })
603 .child(
604 Label::new(title)
605 .size(LabelSize::XSmall)
606 .color(Color::Muted),
607 )
608 .into_any_element(),
609 ),
610 LanguageModelPickerEntry::Model(model_info) => {
611 let active_model = (self.get_active_model)(cx);
612 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
613 let active_model_id = active_model.map(|m| m.model.id());
614
615 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
616 && Some(model_info.model.id()) == active_model_id;
617
618 let model_icon_color = if is_selected {
619 Color::Accent
620 } else {
621 Color::Muted
622 };
623
624 Some(
625 ListItem::new(ix)
626 .inset(true)
627 .spacing(ListItemSpacing::Sparse)
628 .toggle_state(selected)
629 .start_slot(
630 Icon::new(model_info.icon)
631 .color(model_icon_color)
632 .size(IconSize::Small),
633 )
634 .child(
635 h_flex()
636 .w_full()
637 .pl_0p5()
638 .gap_1p5()
639 .w(px(240.))
640 .child(Label::new(model_info.model.name().0.clone()).truncate()),
641 )
642 .end_slot(div().pr_3().when(is_selected, |this| {
643 this.child(
644 Icon::new(IconName::Check)
645 .color(Color::Accent)
646 .size(IconSize::Small),
647 )
648 }))
649 .into_any_element(),
650 )
651 }
652 }
653 }
654
655 fn render_footer(
656 &self,
657 _: &mut Window,
658 cx: &mut Context<Picker<Self>>,
659 ) -> Option<gpui::AnyElement> {
660 use feature_flags::FeatureFlagAppExt;
661
662 let plan = proto::Plan::ZedPro;
663
664 Some(
665 h_flex()
666 .w_full()
667 .border_t_1()
668 .border_color(cx.theme().colors().border_variant)
669 .p_1()
670 .gap_4()
671 .justify_between()
672 .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
673 this.child(match plan {
674 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
675 .icon(IconName::ZedAssistant)
676 .icon_size(IconSize::Small)
677 .icon_color(Color::Muted)
678 .icon_position(IconPosition::Start)
679 .on_click(|_, window, cx| {
680 window
681 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
682 }),
683 Plan::Free | Plan::ZedProTrial => Button::new(
684 "try-pro",
685 if plan == Plan::ZedProTrial {
686 "Upgrade to Pro"
687 } else {
688 "Try Pro"
689 },
690 )
691 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
692 })
693 })
694 .child(
695 Button::new("configure", "Configure")
696 .icon(IconName::Settings)
697 .icon_size(IconSize::Small)
698 .icon_color(Color::Muted)
699 .icon_position(IconPosition::Start)
700 .on_click(|_, window, cx| {
701 window.dispatch_action(
702 zed_actions::agent::OpenConfiguration.boxed_clone(),
703 cx,
704 );
705 }),
706 )
707 .into_any(),
708 )
709 }
710}
711
712#[cfg(test)]
713mod tests {
714 use super::*;
715 use futures::{future::BoxFuture, stream::BoxStream};
716 use gpui::{AsyncApp, TestAppContext, http_client};
717 use language_model::{
718 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
719 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
720 LanguageModelRequest, LanguageModelToolChoice,
721 };
722 use ui::IconName;
723
724 #[derive(Clone)]
725 struct TestLanguageModel {
726 name: LanguageModelName,
727 id: LanguageModelId,
728 provider_id: LanguageModelProviderId,
729 provider_name: LanguageModelProviderName,
730 }
731
732 impl TestLanguageModel {
733 fn new(name: &str, provider: &str) -> Self {
734 Self {
735 name: LanguageModelName::from(name.to_string()),
736 id: LanguageModelId::from(name.to_string()),
737 provider_id: LanguageModelProviderId::from(provider.to_string()),
738 provider_name: LanguageModelProviderName::from(provider.to_string()),
739 }
740 }
741 }
742
743 impl LanguageModel for TestLanguageModel {
744 fn id(&self) -> LanguageModelId {
745 self.id.clone()
746 }
747
748 fn name(&self) -> LanguageModelName {
749 self.name.clone()
750 }
751
752 fn provider_id(&self) -> LanguageModelProviderId {
753 self.provider_id.clone()
754 }
755
756 fn provider_name(&self) -> LanguageModelProviderName {
757 self.provider_name.clone()
758 }
759
760 fn supports_tools(&self) -> bool {
761 false
762 }
763
764 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
765 false
766 }
767
768 fn supports_images(&self) -> bool {
769 false
770 }
771
772 fn telemetry_id(&self) -> String {
773 format!("{}/{}", self.provider_id.0, self.name.0)
774 }
775
776 fn max_token_count(&self) -> usize {
777 1000
778 }
779
780 fn count_tokens(
781 &self,
782 _: LanguageModelRequest,
783 _: &App,
784 ) -> BoxFuture<'static, http_client::Result<usize>> {
785 unimplemented!()
786 }
787
788 fn stream_completion(
789 &self,
790 _: LanguageModelRequest,
791 _: &AsyncApp,
792 ) -> BoxFuture<
793 'static,
794 http_client::Result<
795 BoxStream<
796 'static,
797 http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
798 >,
799 >,
800 > {
801 unimplemented!()
802 }
803 }
804
805 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
806 model_specs
807 .into_iter()
808 .map(|(provider, name)| ModelInfo {
809 model: Arc::new(TestLanguageModel::new(name, provider)),
810 icon: IconName::Ai,
811 })
812 .collect()
813 }
814
815 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
816 assert_eq!(
817 result.len(),
818 expected.len(),
819 "Number of models doesn't match"
820 );
821
822 for (i, expected_name) in expected.iter().enumerate() {
823 assert_eq!(
824 result[i].model.telemetry_id(),
825 *expected_name,
826 "Model at position {} doesn't match expected model",
827 i
828 );
829 }
830 }
831
832 #[gpui::test]
833 fn test_exact_match(cx: &mut TestAppContext) {
834 let models = create_models(vec![
835 ("zed", "Claude 3.7 Sonnet"),
836 ("zed", "Claude 3.7 Sonnet Thinking"),
837 ("zed", "gpt-4.1"),
838 ("zed", "gpt-4.1-nano"),
839 ("openai", "gpt-3.5-turbo"),
840 ("openai", "gpt-4.1"),
841 ("openai", "gpt-4.1-nano"),
842 ("ollama", "mistral"),
843 ("ollama", "deepseek"),
844 ]);
845 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
846
847 // The order of models should be maintained, case doesn't matter
848 let results = matcher.exact_search("GPT-4.1");
849 assert_models_eq(
850 results,
851 vec![
852 "zed/gpt-4.1",
853 "zed/gpt-4.1-nano",
854 "openai/gpt-4.1",
855 "openai/gpt-4.1-nano",
856 ],
857 );
858 }
859
860 #[gpui::test]
861 fn test_fuzzy_match(cx: &mut TestAppContext) {
862 let models = create_models(vec![
863 ("zed", "Claude 3.7 Sonnet"),
864 ("zed", "Claude 3.7 Sonnet Thinking"),
865 ("zed", "gpt-4.1"),
866 ("zed", "gpt-4.1-nano"),
867 ("openai", "gpt-3.5-turbo"),
868 ("openai", "gpt-4.1"),
869 ("openai", "gpt-4.1-nano"),
870 ("ollama", "mistral"),
871 ("ollama", "deepseek"),
872 ]);
873 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
874
875 // Results should preserve models order whenever possible.
876 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
877 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
878 // so it should appear first in the results.
879 let results = matcher.fuzzy_search("41");
880 assert_models_eq(
881 results,
882 vec![
883 "zed/gpt-4.1",
884 "openai/gpt-4.1",
885 "zed/gpt-4.1-nano",
886 "openai/gpt-4.1-nano",
887 ],
888 );
889
890 // Model provider should be searchable as well
891 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
892 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
893
894 // Fuzzy search
895 let results = matcher.fuzzy_search("z4n");
896 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
897 }
898
899 #[gpui::test]
900 fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
901 let recommended_models = create_models(vec![("zed", "claude")]);
902 let all_models = create_models(vec![
903 ("zed", "claude"), // Should be filtered out from "other"
904 ("zed", "gemini"),
905 ("copilot", "o3"),
906 ]);
907
908 let grouped_models = GroupedModels::new(all_models, recommended_models);
909
910 let actual_other_models = grouped_models
911 .other
912 .values()
913 .flatten()
914 .cloned()
915 .collect::<Vec<_>>();
916
917 // Recommended models should not appear in "other"
918 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
919 }
920}