1use std::{cmp::Reverse, sync::Arc};
2
3use collections::{HashSet, IndexMap};
4use feature_flags::ZedProFeatureFlag;
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{
7 Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
8 EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
9 action_with_deprecated_aliases,
10};
11use language_model::{
12 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
13 LanguageModelRegistry,
14};
15use ordered_float::OrderedFloat;
16use picker::{Picker, PickerDelegate};
17use proto::Plan;
18use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
19
20action_with_deprecated_aliases!(
21 agent,
22 ToggleModelSelector,
23 [
24 "assistant::ToggleModelSelector",
25 "assistant2::ToggleModelSelector"
26 ]
27);
28
29const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
30
31type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
32type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
33
34pub struct LanguageModelSelector {
35 picker: Entity<Picker<LanguageModelPickerDelegate>>,
36 _authenticate_all_providers_task: Task<()>,
37 _subscriptions: Vec<Subscription>,
38}
39
40impl LanguageModelSelector {
41 pub fn new(
42 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
43 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
44 window: &mut Window,
45 cx: &mut Context<Self>,
46 ) -> Self {
47 let on_model_changed = Arc::new(on_model_changed);
48
49 let all_models = Self::all_models(cx);
50 let entries = all_models.entries();
51
52 let delegate = LanguageModelPickerDelegate {
53 language_model_selector: cx.entity().downgrade(),
54 on_model_changed: on_model_changed.clone(),
55 all_models: Arc::new(all_models),
56 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
57 filtered_entries: entries,
58 get_active_model: Arc::new(get_active_model),
59 };
60
61 let picker = cx.new(|cx| {
62 Picker::list(delegate, window, cx)
63 .show_scrollbar(true)
64 .width(rems(20.))
65 .max_height(Some(rems(20.).into()))
66 });
67
68 let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
69
70 LanguageModelSelector {
71 picker,
72 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
73 _subscriptions: vec![
74 cx.subscribe_in(
75 &LanguageModelRegistry::global(cx),
76 window,
77 Self::handle_language_model_registry_event,
78 ),
79 subscription,
80 ],
81 }
82 }
83
84 fn handle_language_model_registry_event(
85 &mut self,
86 _registry: &Entity<LanguageModelRegistry>,
87 event: &language_model::Event,
88 window: &mut Window,
89 cx: &mut Context<Self>,
90 ) {
91 match event {
92 language_model::Event::ProviderStateChanged
93 | language_model::Event::AddedProvider(_)
94 | language_model::Event::RemovedProvider(_) => {
95 self.picker.update(cx, |this, cx| {
96 let query = this.query(cx);
97 this.delegate.all_models = Arc::new(Self::all_models(cx));
98 // Update matches will automatically drop the previous task
99 // if we get a provider event again
100 this.update_matches(query, window, cx)
101 });
102 }
103 _ => {}
104 }
105 }
106
107 /// Authenticates all providers in the [`LanguageModelRegistry`].
108 ///
109 /// We do this so that we can populate the language selector with all of the
110 /// models from the configured providers.
111 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
112 let authenticate_all_providers = LanguageModelRegistry::global(cx)
113 .read(cx)
114 .providers()
115 .iter()
116 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
117 .collect::<Vec<_>>();
118
119 cx.spawn(async move |_cx| {
120 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
121 if let Err(err) = authenticate_task.await {
122 if matches!(err, AuthenticateError::CredentialsNotFound) {
123 // Since we're authenticating these providers in the
124 // background for the purposes of populating the
125 // language selector, we don't care about providers
126 // where the credentials are not found.
127 } else {
128 // Some providers have noisy failure states that we
129 // don't want to spam the logs with every time the
130 // language model selector is initialized.
131 //
132 // Ideally these should have more clear failure modes
133 // that we know are safe to ignore here, like what we do
134 // with `CredentialsNotFound` above.
135 match provider_id.0.as_ref() {
136 "lmstudio" | "ollama" => {
137 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
138 //
139 // These fail noisily, so we don't log them.
140 }
141 "copilot_chat" => {
142 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
143 }
144 _ => {
145 log::error!(
146 "Failed to authenticate provider: {}: {err}",
147 provider_name.0
148 );
149 }
150 }
151 }
152 }
153 }
154 })
155 }
156
157 fn all_models(cx: &App) -> GroupedModels {
158 let mut recommended = Vec::new();
159 let mut recommended_set = HashSet::default();
160 for provider in LanguageModelRegistry::global(cx)
161 .read(cx)
162 .providers()
163 .iter()
164 {
165 let models = provider.recommended_models(cx);
166 recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
167 recommended.extend(
168 provider
169 .recommended_models(cx)
170 .into_iter()
171 .map(move |model| ModelInfo {
172 model: model.clone(),
173 icon: provider.icon(),
174 }),
175 );
176 }
177
178 let other_models = LanguageModelRegistry::global(cx)
179 .read(cx)
180 .providers()
181 .iter()
182 .map(|provider| {
183 (
184 provider.id(),
185 provider
186 .provided_models(cx)
187 .into_iter()
188 .filter_map(|model| {
189 let not_included =
190 !recommended_set.contains(&(model.provider_id(), model.id()));
191 not_included.then(|| ModelInfo {
192 model: model.clone(),
193 icon: provider.icon(),
194 })
195 })
196 .collect::<Vec<_>>(),
197 )
198 })
199 .collect::<IndexMap<_, _>>();
200
201 GroupedModels {
202 recommended,
203 other: other_models,
204 }
205 }
206
207 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
208 (self.picker.read(cx).delegate.get_active_model)(cx)
209 }
210
211 fn get_active_model_index(
212 entries: &[LanguageModelPickerEntry],
213 active_model: Option<ConfiguredModel>,
214 ) -> usize {
215 entries
216 .iter()
217 .position(|entry| {
218 if let LanguageModelPickerEntry::Model(model) = entry {
219 active_model
220 .as_ref()
221 .map(|active_model| {
222 active_model.model.id() == model.model.id()
223 && active_model.provider.id() == model.model.provider_id()
224 })
225 .unwrap_or_default()
226 } else {
227 false
228 }
229 })
230 .unwrap_or(0)
231 }
232}
233
234impl EventEmitter<DismissEvent> for LanguageModelSelector {}
235
236impl Focusable for LanguageModelSelector {
237 fn focus_handle(&self, cx: &App) -> FocusHandle {
238 self.picker.focus_handle(cx)
239 }
240}
241
242impl Render for LanguageModelSelector {
243 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
244 self.picker.clone()
245 }
246}
247
248#[derive(IntoElement)]
249pub struct LanguageModelSelectorPopoverMenu<T, TT>
250where
251 T: PopoverTrigger + ButtonCommon,
252 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
253{
254 language_model_selector: Entity<LanguageModelSelector>,
255 trigger: T,
256 tooltip: TT,
257 handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
258 anchor: Corner,
259}
260
261impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
262where
263 T: PopoverTrigger + ButtonCommon,
264 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
265{
266 pub fn new(
267 language_model_selector: Entity<LanguageModelSelector>,
268 trigger: T,
269 tooltip: TT,
270 anchor: Corner,
271 ) -> Self {
272 Self {
273 language_model_selector,
274 trigger,
275 tooltip,
276 handle: None,
277 anchor,
278 }
279 }
280
281 pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
282 self.handle = Some(handle);
283 self
284 }
285}
286
287impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
288where
289 T: PopoverTrigger + ButtonCommon,
290 TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
291{
292 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
293 let language_model_selector = self.language_model_selector.clone();
294
295 PopoverMenu::new("model-switcher")
296 .menu(move |_window, _cx| Some(language_model_selector.clone()))
297 .trigger_with_tooltip(self.trigger, self.tooltip)
298 .anchor(self.anchor)
299 .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
300 .offset(gpui::Point {
301 x: px(0.0),
302 y: px(-2.0),
303 })
304 }
305}
306
307#[derive(Clone)]
308struct ModelInfo {
309 model: Arc<dyn LanguageModel>,
310 icon: IconName,
311}
312
313pub struct LanguageModelPickerDelegate {
314 language_model_selector: WeakEntity<LanguageModelSelector>,
315 on_model_changed: OnModelChanged,
316 get_active_model: GetActiveModel,
317 all_models: Arc<GroupedModels>,
318 filtered_entries: Vec<LanguageModelPickerEntry>,
319 selected_index: usize,
320}
321
322struct GroupedModels {
323 recommended: Vec<ModelInfo>,
324 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
325}
326
327impl GroupedModels {
328 pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
329 let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
330 for model in other {
331 let provider = model.model.provider_id();
332 if let Some(models) = other_by_provider.get_mut(&provider) {
333 models.push(model);
334 } else {
335 other_by_provider.insert(provider, vec![model]);
336 }
337 }
338
339 Self {
340 recommended,
341 other: other_by_provider,
342 }
343 }
344
345 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
346 let mut entries = Vec::new();
347
348 if !self.recommended.is_empty() {
349 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
350 entries.extend(
351 self.recommended
352 .iter()
353 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
354 );
355 }
356
357 for models in self.other.values() {
358 if models.is_empty() {
359 continue;
360 }
361 entries.push(LanguageModelPickerEntry::Separator(
362 models[0].model.provider_name().0,
363 ));
364 entries.extend(
365 models
366 .iter()
367 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
368 );
369 }
370 entries
371 }
372
373 fn model_infos(&self) -> Vec<ModelInfo> {
374 let other = self
375 .other
376 .values()
377 .flat_map(|model| model.iter())
378 .cloned()
379 .collect::<Vec<_>>();
380 self.recommended
381 .iter()
382 .chain(&other)
383 .cloned()
384 .collect::<Vec<_>>()
385 }
386}
387
388enum LanguageModelPickerEntry {
389 Model(ModelInfo),
390 Separator(SharedString),
391}
392
393struct ModelMatcher {
394 models: Vec<ModelInfo>,
395 bg_executor: BackgroundExecutor,
396 candidates: Vec<StringMatchCandidate>,
397}
398
399impl ModelMatcher {
400 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
401 let candidates = Self::make_match_candidates(&models);
402 Self {
403 models,
404 bg_executor,
405 candidates,
406 }
407 }
408
409 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
410 let mut matches = self.bg_executor.block(match_strings(
411 &self.candidates,
412 &query,
413 false,
414 100,
415 &Default::default(),
416 self.bg_executor.clone(),
417 ));
418
419 let sorting_key = |mat: &StringMatch| {
420 let candidate = &self.candidates[mat.candidate_id];
421 (Reverse(OrderedFloat(mat.score)), candidate.id)
422 };
423 matches.sort_unstable_by_key(sorting_key);
424
425 let matched_models: Vec<_> = matches
426 .into_iter()
427 .map(|mat| self.models[mat.candidate_id].clone())
428 .collect();
429
430 matched_models
431 }
432
433 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
434 self.models
435 .iter()
436 .filter(|m| {
437 m.model
438 .name()
439 .0
440 .to_lowercase()
441 .contains(&query.to_lowercase())
442 })
443 .cloned()
444 .collect::<Vec<_>>()
445 }
446
447 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
448 model_infos
449 .iter()
450 .enumerate()
451 .map(|(index, model)| {
452 StringMatchCandidate::new(
453 index,
454 &format!(
455 "{}/{}",
456 &model.model.provider_name().0,
457 &model.model.name().0
458 ),
459 )
460 })
461 .collect::<Vec<_>>()
462 }
463}
464
465impl PickerDelegate for LanguageModelPickerDelegate {
466 type ListItem = AnyElement;
467
468 fn match_count(&self) -> usize {
469 self.filtered_entries.len()
470 }
471
472 fn selected_index(&self) -> usize {
473 self.selected_index
474 }
475
476 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
477 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
478 cx.notify();
479 }
480
481 fn can_select(
482 &mut self,
483 ix: usize,
484 _window: &mut Window,
485 _cx: &mut Context<Picker<Self>>,
486 ) -> bool {
487 match self.filtered_entries.get(ix) {
488 Some(LanguageModelPickerEntry::Model(_)) => true,
489 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
490 }
491 }
492
493 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
494 "Select a model…".into()
495 }
496
497 fn update_matches(
498 &mut self,
499 query: String,
500 window: &mut Window,
501 cx: &mut Context<Picker<Self>>,
502 ) -> Task<()> {
503 let all_models = self.all_models.clone();
504 let current_index = self.selected_index;
505 let bg_executor = cx.background_executor();
506
507 let language_model_registry = LanguageModelRegistry::global(cx);
508
509 let configured_providers = language_model_registry
510 .read(cx)
511 .providers()
512 .into_iter()
513 .filter(|provider| provider.is_authenticated(cx))
514 .collect::<Vec<_>>();
515
516 let configured_provider_ids = configured_providers
517 .iter()
518 .map(|provider| provider.id())
519 .collect::<Vec<_>>();
520
521 let recommended_models = all_models
522 .recommended
523 .iter()
524 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
525 .cloned()
526 .collect::<Vec<_>>();
527
528 let available_models = all_models
529 .model_infos()
530 .iter()
531 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
532 .cloned()
533 .collect::<Vec<_>>();
534
535 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
536 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
537
538 let recommended = matcher_rec.exact_search(&query);
539 let all = matcher_all.fuzzy_search(&query);
540
541 let filtered_models = GroupedModels::new(all, recommended);
542
543 cx.spawn_in(window, async move |this, cx| {
544 this.update_in(cx, |this, window, cx| {
545 this.delegate.filtered_entries = filtered_models.entries();
546 // Preserve selection focus
547 let new_index = if current_index >= this.delegate.filtered_entries.len() {
548 0
549 } else {
550 current_index
551 };
552 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
553 cx.notify();
554 })
555 .ok();
556 })
557 }
558
559 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
560 if let Some(LanguageModelPickerEntry::Model(model_info)) =
561 self.filtered_entries.get(self.selected_index)
562 {
563 let model = model_info.model.clone();
564 (self.on_model_changed)(model.clone(), cx);
565
566 let current_index = self.selected_index;
567 self.set_selected_index(current_index, window, cx);
568
569 cx.emit(DismissEvent);
570 }
571 }
572
573 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
574 self.language_model_selector
575 .update(cx, |_this, cx| cx.emit(DismissEvent))
576 .ok();
577 }
578
579 fn render_match(
580 &self,
581 ix: usize,
582 selected: bool,
583 _: &mut Window,
584 cx: &mut Context<Picker<Self>>,
585 ) -> Option<Self::ListItem> {
586 match self.filtered_entries.get(ix)? {
587 LanguageModelPickerEntry::Separator(title) => Some(
588 div()
589 .px_2()
590 .pb_1()
591 .when(ix > 1, |this| {
592 this.mt_1()
593 .pt_2()
594 .border_t_1()
595 .border_color(cx.theme().colors().border_variant)
596 })
597 .child(
598 Label::new(title)
599 .size(LabelSize::XSmall)
600 .color(Color::Muted),
601 )
602 .into_any_element(),
603 ),
604 LanguageModelPickerEntry::Model(model_info) => {
605 let active_model = (self.get_active_model)(cx);
606 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
607 let active_model_id = active_model.map(|m| m.model.id());
608
609 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
610 && Some(model_info.model.id()) == active_model_id;
611
612 let model_icon_color = if is_selected {
613 Color::Accent
614 } else {
615 Color::Muted
616 };
617
618 Some(
619 ListItem::new(ix)
620 .inset(true)
621 .spacing(ListItemSpacing::Sparse)
622 .toggle_state(selected)
623 .start_slot(
624 Icon::new(model_info.icon)
625 .color(model_icon_color)
626 .size(IconSize::Small),
627 )
628 .child(
629 h_flex()
630 .w_full()
631 .pl_0p5()
632 .gap_1p5()
633 .w(px(240.))
634 .child(Label::new(model_info.model.name().0.clone()).truncate()),
635 )
636 .end_slot(div().pr_3().when(is_selected, |this| {
637 this.child(
638 Icon::new(IconName::Check)
639 .color(Color::Accent)
640 .size(IconSize::Small),
641 )
642 }))
643 .into_any_element(),
644 )
645 }
646 }
647 }
648
649 fn render_footer(
650 &self,
651 _: &mut Window,
652 cx: &mut Context<Picker<Self>>,
653 ) -> Option<gpui::AnyElement> {
654 use feature_flags::FeatureFlagAppExt;
655
656 let plan = proto::Plan::ZedPro;
657
658 Some(
659 h_flex()
660 .w_full()
661 .border_t_1()
662 .border_color(cx.theme().colors().border_variant)
663 .p_1()
664 .gap_4()
665 .justify_between()
666 .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
667 this.child(match plan {
668 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
669 .icon(IconName::ZedAssistant)
670 .icon_size(IconSize::Small)
671 .icon_color(Color::Muted)
672 .icon_position(IconPosition::Start)
673 .on_click(|_, window, cx| {
674 window
675 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
676 }),
677 Plan::Free | Plan::ZedProTrial => Button::new(
678 "try-pro",
679 if plan == Plan::ZedProTrial {
680 "Upgrade to Pro"
681 } else {
682 "Try Pro"
683 },
684 )
685 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
686 })
687 })
688 .child(
689 Button::new("configure", "Configure")
690 .icon(IconName::Settings)
691 .icon_size(IconSize::Small)
692 .icon_color(Color::Muted)
693 .icon_position(IconPosition::Start)
694 .on_click(|_, window, cx| {
695 window.dispatch_action(
696 zed_actions::agent::OpenConfiguration.boxed_clone(),
697 cx,
698 );
699 }),
700 )
701 .into_any(),
702 )
703 }
704}
705
706#[cfg(test)]
707mod tests {
708 use super::*;
709 use futures::{future::BoxFuture, stream::BoxStream};
710 use gpui::{AsyncApp, TestAppContext, http_client};
711 use language_model::{
712 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
713 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
714 LanguageModelRequest, LanguageModelToolChoice,
715 };
716 use ui::IconName;
717
718 #[derive(Clone)]
719 struct TestLanguageModel {
720 name: LanguageModelName,
721 id: LanguageModelId,
722 provider_id: LanguageModelProviderId,
723 provider_name: LanguageModelProviderName,
724 }
725
726 impl TestLanguageModel {
727 fn new(name: &str, provider: &str) -> Self {
728 Self {
729 name: LanguageModelName::from(name.to_string()),
730 id: LanguageModelId::from(name.to_string()),
731 provider_id: LanguageModelProviderId::from(provider.to_string()),
732 provider_name: LanguageModelProviderName::from(provider.to_string()),
733 }
734 }
735 }
736
737 impl LanguageModel for TestLanguageModel {
738 fn id(&self) -> LanguageModelId {
739 self.id.clone()
740 }
741
742 fn name(&self) -> LanguageModelName {
743 self.name.clone()
744 }
745
746 fn provider_id(&self) -> LanguageModelProviderId {
747 self.provider_id.clone()
748 }
749
750 fn provider_name(&self) -> LanguageModelProviderName {
751 self.provider_name.clone()
752 }
753
754 fn supports_tools(&self) -> bool {
755 false
756 }
757
758 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
759 false
760 }
761
762 fn supports_images(&self) -> bool {
763 false
764 }
765
766 fn telemetry_id(&self) -> String {
767 format!("{}/{}", self.provider_id.0, self.name.0)
768 }
769
770 fn max_token_count(&self) -> usize {
771 1000
772 }
773
774 fn count_tokens(
775 &self,
776 _: LanguageModelRequest,
777 _: &App,
778 ) -> BoxFuture<'static, http_client::Result<usize>> {
779 unimplemented!()
780 }
781
782 fn stream_completion(
783 &self,
784 _: LanguageModelRequest,
785 _: &AsyncApp,
786 ) -> BoxFuture<
787 'static,
788 http_client::Result<
789 BoxStream<
790 'static,
791 http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
792 >,
793 >,
794 > {
795 unimplemented!()
796 }
797 }
798
799 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
800 model_specs
801 .into_iter()
802 .map(|(provider, name)| ModelInfo {
803 model: Arc::new(TestLanguageModel::new(name, provider)),
804 icon: IconName::Ai,
805 })
806 .collect()
807 }
808
809 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
810 assert_eq!(
811 result.len(),
812 expected.len(),
813 "Number of models doesn't match"
814 );
815
816 for (i, expected_name) in expected.iter().enumerate() {
817 assert_eq!(
818 result[i].model.telemetry_id(),
819 *expected_name,
820 "Model at position {} doesn't match expected model",
821 i
822 );
823 }
824 }
825
826 #[gpui::test]
827 fn test_exact_match(cx: &mut TestAppContext) {
828 let models = create_models(vec![
829 ("zed", "Claude 3.7 Sonnet"),
830 ("zed", "Claude 3.7 Sonnet Thinking"),
831 ("zed", "gpt-4.1"),
832 ("zed", "gpt-4.1-nano"),
833 ("openai", "gpt-3.5-turbo"),
834 ("openai", "gpt-4.1"),
835 ("openai", "gpt-4.1-nano"),
836 ("ollama", "mistral"),
837 ("ollama", "deepseek"),
838 ]);
839 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
840
841 // The order of models should be maintained, case doesn't matter
842 let results = matcher.exact_search("GPT-4.1");
843 assert_models_eq(
844 results,
845 vec![
846 "zed/gpt-4.1",
847 "zed/gpt-4.1-nano",
848 "openai/gpt-4.1",
849 "openai/gpt-4.1-nano",
850 ],
851 );
852 }
853
854 #[gpui::test]
855 fn test_fuzzy_match(cx: &mut TestAppContext) {
856 let models = create_models(vec![
857 ("zed", "Claude 3.7 Sonnet"),
858 ("zed", "Claude 3.7 Sonnet Thinking"),
859 ("zed", "gpt-4.1"),
860 ("zed", "gpt-4.1-nano"),
861 ("openai", "gpt-3.5-turbo"),
862 ("openai", "gpt-4.1"),
863 ("openai", "gpt-4.1-nano"),
864 ("ollama", "mistral"),
865 ("ollama", "deepseek"),
866 ]);
867 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
868
869 // Results should preserve models order whenever possible.
870 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
871 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
872 // so it should appear first in the results.
873 let results = matcher.fuzzy_search("41");
874 assert_models_eq(
875 results,
876 vec![
877 "zed/gpt-4.1",
878 "openai/gpt-4.1",
879 "zed/gpt-4.1-nano",
880 "openai/gpt-4.1-nano",
881 ],
882 );
883
884 // Model provider should be searchable as well
885 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
886 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
887
888 // Fuzzy search
889 let results = matcher.fuzzy_search("z4n");
890 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
891 }
892}