1use std::{cmp::Reverse, sync::Arc};
2
3use collections::IndexMap;
4use futures::{StreamExt, channel::mpsc};
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
7use language_model::{
8 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
9 LanguageModelRegistry,
10};
11use ordered_float::OrderedFloat;
12use picker::{Picker, PickerDelegate};
13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
14use zed_actions::agent::OpenSettings;
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
18
19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
20
21pub fn language_model_selector(
22 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
23 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
24 popover_styles: bool,
25 focus_handle: FocusHandle,
26 window: &mut Window,
27 cx: &mut Context<LanguageModelSelector>,
28) -> LanguageModelSelector {
29 let delegate = LanguageModelPickerDelegate::new(
30 get_active_model,
31 on_model_changed,
32 popover_styles,
33 focus_handle,
34 window,
35 cx,
36 );
37
38 if popover_styles {
39 Picker::list(delegate, window, cx)
40 .show_scrollbar(true)
41 .width(rems(20.))
42 .max_height(Some(rems(20.).into()))
43 } else {
44 Picker::list(delegate, window, cx).show_scrollbar(true)
45 }
46}
47
48fn all_models(cx: &App) -> GroupedModels {
49 let now = std::time::SystemTime::now()
50 .duration_since(std::time::UNIX_EPOCH)
51 .unwrap_or_default()
52 .as_millis();
53 eprintln!("[{}ms] all_models() called", now);
54 let providers = LanguageModelRegistry::global(cx).read(cx).providers();
55 eprintln!("[{}ms] all_models: got {} providers", now, providers.len());
56
57 let recommended = providers
58 .iter()
59 .flat_map(|provider| {
60 provider
61 .recommended_models(cx)
62 .into_iter()
63 .map(|model| ModelInfo {
64 model,
65 icon: provider.icon(),
66 })
67 })
68 .collect();
69
70 let all: Vec<ModelInfo> = providers
71 .iter()
72 .flat_map(|provider| {
73 let now = std::time::SystemTime::now()
74 .duration_since(std::time::UNIX_EPOCH)
75 .unwrap_or_default()
76 .as_millis();
77 eprintln!(
78 "[{}ms] all_models: calling provided_models for {:?}",
79 now,
80 provider.id()
81 );
82 let models = provider.provided_models(cx);
83 eprintln!(
84 "[{}ms] all_models: provider {:?} returned {} models",
85 now,
86 provider.id(),
87 models.len()
88 );
89 models.into_iter().map(|model| ModelInfo {
90 model,
91 icon: provider.icon(),
92 })
93 })
94 .collect();
95
96 let now = std::time::SystemTime::now()
97 .duration_since(std::time::UNIX_EPOCH)
98 .unwrap_or_default()
99 .as_millis();
100 eprintln!(
101 "[{}ms] all_models: returning {} total models",
102 now,
103 all.len()
104 );
105 GroupedModels::new(all, recommended)
106}
107
108#[derive(Clone)]
109struct ModelInfo {
110 model: Arc<dyn LanguageModel>,
111 icon: IconName,
112}
113
114pub struct LanguageModelPickerDelegate {
115 on_model_changed: OnModelChanged,
116 get_active_model: GetActiveModel,
117 all_models: Arc<GroupedModels>,
118 filtered_entries: Vec<LanguageModelPickerEntry>,
119 selected_index: usize,
120 _authenticate_all_providers_task: Task<()>,
121 _refresh_models_task: Task<()>,
122 popover_styles: bool,
123 focus_handle: FocusHandle,
124}
125
126impl LanguageModelPickerDelegate {
127 fn new(
128 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
129 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
130 popover_styles: bool,
131 focus_handle: FocusHandle,
132 window: &mut Window,
133 cx: &mut Context<Picker<Self>>,
134 ) -> Self {
135 let now = std::time::SystemTime::now()
136 .duration_since(std::time::UNIX_EPOCH)
137 .unwrap_or_default()
138 .as_millis();
139 eprintln!("[{}ms] LanguageModelPickerDelegate::new() called", now);
140 let on_model_changed = Arc::new(on_model_changed);
141 let models = all_models(cx);
142 eprintln!(
143 "[{}ms] LanguageModelPickerDelegate::new() got {} models from all_models()",
144 now,
145 models.all.len()
146 );
147 let entries = models.entries();
148
149 Self {
150 on_model_changed,
151 all_models: Arc::new(models),
152 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
153 filtered_entries: entries,
154 get_active_model: Arc::new(get_active_model),
155 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
156 _refresh_models_task: {
157 let now = std::time::SystemTime::now()
158 .duration_since(std::time::UNIX_EPOCH)
159 .unwrap_or_default()
160 .as_millis();
161 eprintln!(
162 "[{}ms] LanguageModelPickerDelegate::new() setting up refresh task for LanguageModelRegistry",
163 now
164 );
165
166 // Create a channel to signal when models need refreshing
167 let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
168
169 // Subscribe to registry events and send refresh signals through the channel
170 let registry = LanguageModelRegistry::global(cx);
171 eprintln!(
172 "[{}ms] LanguageModelPickerDelegate::new() subscribing to registry entity_id: {:?}",
173 now,
174 registry.entity_id()
175 );
176 cx.subscribe(®istry, move |_picker, _, event, _cx| {
177 match event {
178 language_model::Event::ProviderStateChanged(id) => {
179 let now = std::time::SystemTime::now()
180 .duration_since(std::time::UNIX_EPOCH)
181 .unwrap_or_default()
182 .as_millis();
183 eprintln!(
184 "[{}ms] LanguageModelSelector: ProviderStateChanged event for {:?}, sending refresh signal",
185 now, id
186 );
187 refresh_tx.unbounded_send(()).ok();
188 }
189 language_model::Event::AddedProvider(id) => {
190 let now = std::time::SystemTime::now()
191 .duration_since(std::time::UNIX_EPOCH)
192 .unwrap_or_default()
193 .as_millis();
194 eprintln!(
195 "[{}ms] LanguageModelSelector: AddedProvider event for {:?}, sending refresh signal",
196 now, id
197 );
198 refresh_tx.unbounded_send(()).ok();
199 }
200 language_model::Event::RemovedProvider(id) => {
201 let now = std::time::SystemTime::now()
202 .duration_since(std::time::UNIX_EPOCH)
203 .unwrap_or_default()
204 .as_millis();
205 eprintln!(
206 "[{}ms] LanguageModelSelector: RemovedProvider event for {:?}, sending refresh signal",
207 now, id
208 );
209 refresh_tx.unbounded_send(()).ok();
210 }
211 _ => {}
212 }
213 })
214 .detach();
215
216 // Spawn a task that listens for refresh signals and updates the picker
217 cx.spawn_in(window, async move |this, cx| {
218 while let Some(()) = refresh_rx.next().await {
219 let now = std::time::SystemTime::now()
220 .duration_since(std::time::UNIX_EPOCH)
221 .unwrap_or_default()
222 .as_millis();
223 eprintln!(
224 "[{}ms] LanguageModelSelector: refresh signal received, updating models",
225 now
226 );
227 let result = this.update_in(cx, |picker, window, cx| {
228 picker.delegate.all_models = Arc::new(all_models(cx));
229 picker.refresh(window, cx);
230 });
231 if result.is_err() {
232 // Picker was dropped, exit the loop
233 break;
234 }
235 }
236 })
237 },
238 popover_styles,
239 focus_handle,
240 }
241 }
242
243 fn get_active_model_index(
244 entries: &[LanguageModelPickerEntry],
245 active_model: Option<ConfiguredModel>,
246 ) -> usize {
247 entries
248 .iter()
249 .position(|entry| {
250 if let LanguageModelPickerEntry::Model(model) = entry {
251 active_model
252 .as_ref()
253 .map(|active_model| {
254 active_model.model.id() == model.model.id()
255 && active_model.provider.id() == model.model.provider_id()
256 })
257 .unwrap_or_default()
258 } else {
259 false
260 }
261 })
262 .unwrap_or(0)
263 }
264
265 /// Authenticates all providers in the [`LanguageModelRegistry`].
266 ///
267 /// We do this so that we can populate the language selector with all of the
268 /// models from the configured providers.
269 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
270 let authenticate_all_providers = LanguageModelRegistry::global(cx)
271 .read(cx)
272 .providers()
273 .iter()
274 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
275 .collect::<Vec<_>>();
276
277 cx.spawn(async move |_cx| {
278 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
279 if let Err(err) = authenticate_task.await {
280 if matches!(err, AuthenticateError::CredentialsNotFound) {
281 // Since we're authenticating these providers in the
282 // background for the purposes of populating the
283 // language selector, we don't care about providers
284 // where the credentials are not found.
285 } else {
286 // Some providers have noisy failure states that we
287 // don't want to spam the logs with every time the
288 // language model selector is initialized.
289 //
290 // Ideally these should have more clear failure modes
291 // that we know are safe to ignore here, like what we do
292 // with `CredentialsNotFound` above.
293 match provider_id.0.as_ref() {
294 "lmstudio" | "ollama" => {
295 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
296 //
297 // These fail noisily, so we don't log them.
298 }
299 "copilot_chat" => {
300 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
301 }
302 _ => {
303 log::error!(
304 "Failed to authenticate provider: {}: {err:#}",
305 provider_name.0
306 );
307 }
308 }
309 }
310 }
311 }
312 })
313 }
314
315 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
316 (self.get_active_model)(cx)
317 }
318}
319
320struct GroupedModels {
321 recommended: Vec<ModelInfo>,
322 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
323}
324
325impl GroupedModels {
326 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
327 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
328 for model in all {
329 let provider = model.model.provider_id();
330 if let Some(models) = all_by_provider.get_mut(&provider) {
331 models.push(model);
332 } else {
333 all_by_provider.insert(provider, vec![model]);
334 }
335 }
336
337 Self {
338 recommended,
339 all: all_by_provider,
340 }
341 }
342
343 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
344 let mut entries = Vec::new();
345
346 if !self.recommended.is_empty() {
347 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
348 entries.extend(
349 self.recommended
350 .iter()
351 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
352 );
353 }
354
355 for models in self.all.values() {
356 if models.is_empty() {
357 continue;
358 }
359 entries.push(LanguageModelPickerEntry::Separator(
360 models[0].model.provider_name().0,
361 ));
362 entries.extend(
363 models
364 .iter()
365 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
366 );
367 }
368 entries
369 }
370}
371
372enum LanguageModelPickerEntry {
373 Model(ModelInfo),
374 Separator(SharedString),
375}
376
377struct ModelMatcher {
378 models: Vec<ModelInfo>,
379 bg_executor: BackgroundExecutor,
380 candidates: Vec<StringMatchCandidate>,
381}
382
383impl ModelMatcher {
384 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
385 let candidates = Self::make_match_candidates(&models);
386 Self {
387 models,
388 bg_executor,
389 candidates,
390 }
391 }
392
393 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
394 let mut matches = self.bg_executor.block(match_strings(
395 &self.candidates,
396 query,
397 false,
398 true,
399 100,
400 &Default::default(),
401 self.bg_executor.clone(),
402 ));
403
404 let sorting_key = |mat: &StringMatch| {
405 let candidate = &self.candidates[mat.candidate_id];
406 (Reverse(OrderedFloat(mat.score)), candidate.id)
407 };
408 matches.sort_unstable_by_key(sorting_key);
409
410 let matched_models: Vec<_> = matches
411 .into_iter()
412 .map(|mat| self.models[mat.candidate_id].clone())
413 .collect();
414
415 matched_models
416 }
417
418 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
419 self.models
420 .iter()
421 .filter(|m| {
422 m.model
423 .name()
424 .0
425 .to_lowercase()
426 .contains(&query.to_lowercase())
427 })
428 .cloned()
429 .collect::<Vec<_>>()
430 }
431
432 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
433 model_infos
434 .iter()
435 .enumerate()
436 .map(|(index, model)| {
437 StringMatchCandidate::new(
438 index,
439 &format!(
440 "{}/{}",
441 &model.model.provider_name().0,
442 &model.model.name().0
443 ),
444 )
445 })
446 .collect::<Vec<_>>()
447 }
448}
449
450impl PickerDelegate for LanguageModelPickerDelegate {
451 type ListItem = AnyElement;
452
453 fn match_count(&self) -> usize {
454 self.filtered_entries.len()
455 }
456
457 fn selected_index(&self) -> usize {
458 self.selected_index
459 }
460
461 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
462 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
463 cx.notify();
464 }
465
466 fn can_select(
467 &mut self,
468 ix: usize,
469 _window: &mut Window,
470 _cx: &mut Context<Picker<Self>>,
471 ) -> bool {
472 match self.filtered_entries.get(ix) {
473 Some(LanguageModelPickerEntry::Model(_)) => true,
474 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
475 }
476 }
477
478 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
479 "Select a model…".into()
480 }
481
482 fn update_matches(
483 &mut self,
484 query: String,
485 window: &mut Window,
486 cx: &mut Context<Picker<Self>>,
487 ) -> Task<()> {
488 let all_models = self.all_models.clone();
489 let active_model = (self.get_active_model)(cx);
490 let bg_executor = cx.background_executor();
491
492 let language_model_registry = LanguageModelRegistry::global(cx);
493
494 let configured_providers = language_model_registry
495 .read(cx)
496 .providers()
497 .into_iter()
498 .filter(|provider| provider.is_authenticated(cx))
499 .collect::<Vec<_>>();
500
501 let configured_provider_ids = configured_providers
502 .iter()
503 .map(|provider| provider.id())
504 .collect::<Vec<_>>();
505
506 let recommended_models = all_models
507 .recommended
508 .iter()
509 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
510 .cloned()
511 .collect::<Vec<_>>();
512
513 let available_models = all_models
514 .all
515 .values()
516 .flat_map(|models| models.iter())
517 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
518 .cloned()
519 .collect::<Vec<_>>();
520
521 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
522 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
523
524 let recommended = matcher_rec.exact_search(&query);
525 let all = matcher_all.fuzzy_search(&query);
526
527 let filtered_models = GroupedModels::new(all, recommended);
528
529 cx.spawn_in(window, async move |this, cx| {
530 this.update_in(cx, |this, window, cx| {
531 this.delegate.filtered_entries = filtered_models.entries();
532 // Finds the currently selected model in the list
533 let new_index =
534 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
535 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
536 cx.notify();
537 })
538 .ok();
539 })
540 }
541
542 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
543 if let Some(LanguageModelPickerEntry::Model(model_info)) =
544 self.filtered_entries.get(self.selected_index)
545 {
546 let model = model_info.model.clone();
547 (self.on_model_changed)(model.clone(), cx);
548
549 let current_index = self.selected_index;
550 self.set_selected_index(current_index, window, cx);
551
552 cx.emit(DismissEvent);
553 }
554 }
555
556 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
557 cx.emit(DismissEvent);
558 }
559
560 fn render_match(
561 &self,
562 ix: usize,
563 selected: bool,
564 _: &mut Window,
565 cx: &mut Context<Picker<Self>>,
566 ) -> Option<Self::ListItem> {
567 match self.filtered_entries.get(ix)? {
568 LanguageModelPickerEntry::Separator(title) => Some(
569 div()
570 .px_2()
571 .pb_1()
572 .when(ix > 1, |this| {
573 this.mt_1()
574 .pt_2()
575 .border_t_1()
576 .border_color(cx.theme().colors().border_variant)
577 })
578 .child(
579 Label::new(title)
580 .size(LabelSize::XSmall)
581 .color(Color::Muted),
582 )
583 .into_any_element(),
584 ),
585 LanguageModelPickerEntry::Model(model_info) => {
586 let active_model = (self.get_active_model)(cx);
587 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
588 let active_model_id = active_model.map(|m| m.model.id());
589
590 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
591 && Some(model_info.model.id()) == active_model_id;
592
593 let model_icon_color = if is_selected {
594 Color::Accent
595 } else {
596 Color::Muted
597 };
598
599 Some(
600 ListItem::new(ix)
601 .inset(true)
602 .spacing(ListItemSpacing::Sparse)
603 .toggle_state(selected)
604 .child(
605 h_flex()
606 .w_full()
607 .gap_1p5()
608 .child(
609 Icon::new(model_info.icon)
610 .color(model_icon_color)
611 .size(IconSize::Small),
612 )
613 .child(Label::new(model_info.model.name().0).truncate()),
614 )
615 .end_slot(div().pr_3().when(is_selected, |this| {
616 this.child(
617 Icon::new(IconName::Check)
618 .color(Color::Accent)
619 .size(IconSize::Small),
620 )
621 }))
622 .into_any_element(),
623 )
624 }
625 }
626 }
627
628 fn render_footer(
629 &self,
630 _window: &mut Window,
631 cx: &mut Context<Picker<Self>>,
632 ) -> Option<gpui::AnyElement> {
633 let focus_handle = self.focus_handle.clone();
634
635 if !self.popover_styles {
636 return None;
637 }
638
639 Some(
640 h_flex()
641 .w_full()
642 .p_1p5()
643 .border_t_1()
644 .border_color(cx.theme().colors().border_variant)
645 .child(
646 Button::new("configure", "Configure")
647 .full_width()
648 .style(ButtonStyle::Outlined)
649 .key_binding(
650 KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
651 .map(|kb| kb.size(rems_from_px(12.))),
652 )
653 .on_click(|_, window, cx| {
654 window.dispatch_action(OpenSettings.boxed_clone(), cx);
655 }),
656 )
657 .into_any(),
658 )
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665 use futures::{future::BoxFuture, stream::BoxStream};
666 use gpui::{AsyncApp, TestAppContext, http_client};
667 use language_model::{
668 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
669 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
670 LanguageModelRequest, LanguageModelToolChoice,
671 };
672 use ui::IconName;
673
674 #[derive(Clone)]
675 struct TestLanguageModel {
676 name: LanguageModelName,
677 id: LanguageModelId,
678 provider_id: LanguageModelProviderId,
679 provider_name: LanguageModelProviderName,
680 }
681
682 impl TestLanguageModel {
683 fn new(name: &str, provider: &str) -> Self {
684 Self {
685 name: LanguageModelName::from(name.to_string()),
686 id: LanguageModelId::from(name.to_string()),
687 provider_id: LanguageModelProviderId::from(provider.to_string()),
688 provider_name: LanguageModelProviderName::from(provider.to_string()),
689 }
690 }
691 }
692
693 impl LanguageModel for TestLanguageModel {
694 fn id(&self) -> LanguageModelId {
695 self.id.clone()
696 }
697
698 fn name(&self) -> LanguageModelName {
699 self.name.clone()
700 }
701
702 fn provider_id(&self) -> LanguageModelProviderId {
703 self.provider_id.clone()
704 }
705
706 fn provider_name(&self) -> LanguageModelProviderName {
707 self.provider_name.clone()
708 }
709
710 fn supports_tools(&self) -> bool {
711 false
712 }
713
714 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
715 false
716 }
717
718 fn supports_images(&self) -> bool {
719 false
720 }
721
722 fn telemetry_id(&self) -> String {
723 format!("{}/{}", self.provider_id.0, self.name.0)
724 }
725
726 fn max_token_count(&self) -> u64 {
727 1000
728 }
729
730 fn count_tokens(
731 &self,
732 _: LanguageModelRequest,
733 _: &App,
734 ) -> BoxFuture<'static, http_client::Result<u64>> {
735 unimplemented!()
736 }
737
738 fn stream_completion(
739 &self,
740 _: LanguageModelRequest,
741 _: &AsyncApp,
742 ) -> BoxFuture<
743 'static,
744 Result<
745 BoxStream<
746 'static,
747 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
748 >,
749 LanguageModelCompletionError,
750 >,
751 > {
752 unimplemented!()
753 }
754 }
755
756 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
757 model_specs
758 .into_iter()
759 .map(|(provider, name)| ModelInfo {
760 model: Arc::new(TestLanguageModel::new(name, provider)),
761 icon: IconName::Ai,
762 })
763 .collect()
764 }
765
766 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
767 assert_eq!(
768 result.len(),
769 expected.len(),
770 "Number of models doesn't match"
771 );
772
773 for (i, expected_name) in expected.iter().enumerate() {
774 assert_eq!(
775 result[i].model.telemetry_id(),
776 *expected_name,
777 "Model at position {} doesn't match expected model",
778 i
779 );
780 }
781 }
782
783 #[gpui::test]
784 fn test_exact_match(cx: &mut TestAppContext) {
785 let models = create_models(vec![
786 ("zed", "Claude 3.7 Sonnet"),
787 ("zed", "Claude 3.7 Sonnet Thinking"),
788 ("zed", "gpt-4.1"),
789 ("zed", "gpt-4.1-nano"),
790 ("openai", "gpt-3.5-turbo"),
791 ("openai", "gpt-4.1"),
792 ("openai", "gpt-4.1-nano"),
793 ("ollama", "mistral"),
794 ("ollama", "deepseek"),
795 ]);
796 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
797
798 // The order of models should be maintained, case doesn't matter
799 let results = matcher.exact_search("GPT-4.1");
800 assert_models_eq(
801 results,
802 vec![
803 "zed/gpt-4.1",
804 "zed/gpt-4.1-nano",
805 "openai/gpt-4.1",
806 "openai/gpt-4.1-nano",
807 ],
808 );
809 }
810
811 #[gpui::test]
812 fn test_fuzzy_match(cx: &mut TestAppContext) {
813 let models = create_models(vec![
814 ("zed", "Claude 3.7 Sonnet"),
815 ("zed", "Claude 3.7 Sonnet Thinking"),
816 ("zed", "gpt-4.1"),
817 ("zed", "gpt-4.1-nano"),
818 ("openai", "gpt-3.5-turbo"),
819 ("openai", "gpt-4.1"),
820 ("openai", "gpt-4.1-nano"),
821 ("ollama", "mistral"),
822 ("ollama", "deepseek"),
823 ]);
824 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
825
826 // Results should preserve models order whenever possible.
827 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
828 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
829 // so it should appear first in the results.
830 let results = matcher.fuzzy_search("41");
831 assert_models_eq(
832 results,
833 vec![
834 "zed/gpt-4.1",
835 "openai/gpt-4.1",
836 "zed/gpt-4.1-nano",
837 "openai/gpt-4.1-nano",
838 ],
839 );
840
841 // Model provider should be searchable as well
842 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
843 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
844
845 // Fuzzy search
846 let results = matcher.fuzzy_search("z4n");
847 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
848 }
849
850 #[gpui::test]
851 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
852 let recommended_models = create_models(vec![("zed", "claude")]);
853 let all_models = create_models(vec![
854 ("zed", "claude"), // Should also appear in "other"
855 ("zed", "gemini"),
856 ("copilot", "o3"),
857 ]);
858
859 let grouped_models = GroupedModels::new(all_models, recommended_models);
860
861 let actual_all_models = grouped_models
862 .all
863 .values()
864 .flatten()
865 .cloned()
866 .collect::<Vec<_>>();
867
868 // Recommended models should also appear in "all"
869 assert_models_eq(
870 actual_all_models,
871 vec!["zed/claude", "zed/gemini", "copilot/o3"],
872 );
873 }
874
875 #[gpui::test]
876 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
877 let recommended_models = create_models(vec![("zed", "claude")]);
878 let all_models = create_models(vec![
879 ("zed", "claude"), // Should also appear in "other"
880 ("zed", "gemini"),
881 ("copilot", "claude"), // Different provider, should appear in "other"
882 ]);
883
884 let grouped_models = GroupedModels::new(all_models, recommended_models);
885
886 let actual_all_models = grouped_models
887 .all
888 .values()
889 .flatten()
890 .cloned()
891 .collect::<Vec<_>>();
892
893 // All models should appear in "all" regardless of recommended status
894 assert_models_eq(
895 actual_all_models,
896 vec!["zed/claude", "zed/gemini", "copilot/claude"],
897 );
898 }
899}