1use std::{cmp::Reverse, sync::Arc};
2
3use agent_settings::AgentSettings;
4use collections::{HashMap, HashSet, IndexMap};
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{
7 Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, ForegroundExecutor,
8 Subscription, Task,
9};
10use language_model::{
11 AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId,
12 LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
13};
14use ordered_float::OrderedFloat;
15use picker::{Picker, PickerDelegate};
16use settings::Settings;
17use ui::prelude::*;
18use zed_actions::agent::OpenSettings;
19
20use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
21
22type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
23type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
24type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static>;
25
26pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
27
28pub fn language_model_selector(
29 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
30 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
31 on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
32 popover_styles: bool,
33 focus_handle: FocusHandle,
34 window: &mut Window,
35 cx: &mut Context<LanguageModelSelector>,
36) -> LanguageModelSelector {
37 let delegate = LanguageModelPickerDelegate::new(
38 get_active_model,
39 on_model_changed,
40 on_toggle_favorite,
41 popover_styles,
42 focus_handle,
43 window,
44 cx,
45 );
46
47 if popover_styles {
48 Picker::list(delegate, window, cx)
49 .show_scrollbar(true)
50 .width(rems(20.))
51 .max_height(Some(rems(20.).into()))
52 } else {
53 Picker::list(delegate, window, cx).show_scrollbar(true)
54 }
55}
56
57fn all_models(cx: &App) -> GroupedModels {
58 let lm_registry = LanguageModelRegistry::global(cx).read(cx);
59 let providers = lm_registry.visible_providers();
60
61 let mut favorites_index = FavoritesIndex::default();
62
63 for sel in &AgentSettings::get_global(cx).favorite_models {
64 favorites_index
65 .entry(sel.provider.0.clone().into())
66 .or_default()
67 .insert(sel.model.clone().into());
68 }
69
70 let recommended = providers
71 .iter()
72 .flat_map(|provider| {
73 provider
74 .recommended_models(cx)
75 .into_iter()
76 .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
77 })
78 .collect();
79
80 let all = providers
81 .iter()
82 .flat_map(|provider| {
83 provider
84 .provided_models(cx)
85 .into_iter()
86 .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
87 })
88 .collect();
89
90 GroupedModels::new(all, recommended)
91}
92
93type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
94
95#[derive(Clone)]
96struct ModelInfo {
97 model: Arc<dyn LanguageModel>,
98 icon: IconOrSvg,
99 is_favorite: bool,
100}
101
102impl ModelInfo {
103 fn new(
104 provider: &dyn LanguageModelProvider,
105 model: Arc<dyn LanguageModel>,
106 favorites_index: &FavoritesIndex,
107 ) -> Self {
108 let is_favorite = favorites_index
109 .get(&provider.id())
110 .map_or(false, |set| set.contains(&model.id()));
111
112 Self {
113 model,
114 icon: provider.icon(),
115 is_favorite,
116 }
117 }
118}
119
120pub struct LanguageModelPickerDelegate {
121 on_model_changed: OnModelChanged,
122 get_active_model: GetActiveModel,
123 on_toggle_favorite: OnToggleFavorite,
124 all_models: Arc<GroupedModels>,
125 filtered_entries: Vec<LanguageModelPickerEntry>,
126 selected_index: usize,
127 _authenticate_all_providers_task: Task<()>,
128 _subscriptions: Vec<Subscription>,
129 popover_styles: bool,
130 focus_handle: FocusHandle,
131}
132
133impl LanguageModelPickerDelegate {
134 fn new(
135 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
136 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
137 on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
138 popover_styles: bool,
139 focus_handle: FocusHandle,
140 window: &mut Window,
141 cx: &mut Context<Picker<Self>>,
142 ) -> Self {
143 let on_model_changed = Arc::new(on_model_changed);
144 let models = all_models(cx);
145 let entries = models.entries();
146
147 Self {
148 on_model_changed,
149 all_models: Arc::new(models),
150 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
151 filtered_entries: entries,
152 get_active_model: Arc::new(get_active_model),
153 on_toggle_favorite: Arc::new(on_toggle_favorite),
154 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
155 _subscriptions: vec![cx.subscribe_in(
156 &LanguageModelRegistry::global(cx),
157 window,
158 |picker, _, event, window, cx| {
159 match event {
160 language_model::Event::ProviderStateChanged(_)
161 | language_model::Event::AddedProvider(_)
162 | language_model::Event::RemovedProvider(_) => {
163 let query = picker.query(cx);
164 picker.delegate.all_models = Arc::new(all_models(cx));
165 // Update matches will automatically drop the previous task
166 // if we get a provider event again
167 picker.update_matches(query, window, cx)
168 }
169 _ => {}
170 }
171 },
172 )],
173 popover_styles,
174 focus_handle,
175 }
176 }
177
178 fn get_active_model_index(
179 entries: &[LanguageModelPickerEntry],
180 active_model: Option<ConfiguredModel>,
181 ) -> usize {
182 entries
183 .iter()
184 .position(|entry| {
185 if let LanguageModelPickerEntry::Model(model) = entry {
186 active_model
187 .as_ref()
188 .map(|active_model| {
189 active_model.model.id() == model.model.id()
190 && active_model.provider.id() == model.model.provider_id()
191 })
192 .unwrap_or_default()
193 } else {
194 false
195 }
196 })
197 .unwrap_or(0)
198 }
199
200 /// Authenticates all providers in the [`LanguageModelRegistry`].
201 ///
202 /// We do this so that we can populate the language selector with all of the
203 /// models from the configured providers.
204 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
205 let authenticate_all_providers = LanguageModelRegistry::global(cx)
206 .read(cx)
207 .visible_providers()
208 .iter()
209 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
210 .collect::<Vec<_>>();
211
212 cx.spawn(async move |_cx| {
213 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
214 if let Err(err) = authenticate_task.await {
215 if matches!(err, AuthenticateError::CredentialsNotFound) {
216 // Since we're authenticating these providers in the
217 // background for the purposes of populating the
218 // language selector, we don't care about providers
219 // where the credentials are not found.
220 } else {
221 // Some providers have noisy failure states that we
222 // don't want to spam the logs with every time the
223 // language model selector is initialized.
224 //
225 // Ideally these should have more clear failure modes
226 // that we know are safe to ignore here, like what we do
227 // with `CredentialsNotFound` above.
228 match provider_id.0.as_ref() {
229 "lmstudio" | "ollama" => {
230 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
231 //
232 // These fail noisily, so we don't log them.
233 }
234 "copilot_chat" => {
235 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
236 }
237 _ => {
238 log::error!(
239 "Failed to authenticate provider: {}: {err:#}",
240 provider_name.0
241 );
242 }
243 }
244 }
245 }
246 }
247 })
248 }
249
250 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
251 (self.get_active_model)(cx)
252 }
253
254 pub fn favorites_count(&self) -> usize {
255 self.all_models.favorites.len()
256 }
257
258 pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
259 if self.all_models.favorites.is_empty() {
260 return;
261 }
262
263 let active_model = (self.get_active_model)(cx);
264 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
265 let active_model_id = active_model.as_ref().map(|m| m.model.id());
266
267 let current_index = self
268 .all_models
269 .favorites
270 .iter()
271 .position(|info| {
272 Some(info.model.provider_id()) == active_provider_id
273 && Some(info.model.id()) == active_model_id
274 })
275 .unwrap_or(usize::MAX);
276
277 let next_index = if current_index == usize::MAX {
278 0
279 } else {
280 (current_index + 1) % self.all_models.favorites.len()
281 };
282
283 let next_model = self.all_models.favorites[next_index].model.clone();
284
285 (self.on_model_changed)(next_model, cx);
286
287 // Align the picker selection with the newly-active model
288 let new_index =
289 Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
290 self.set_selected_index(new_index, window, cx);
291 }
292}
293
294struct GroupedModels {
295 favorites: Vec<ModelInfo>,
296 recommended: Vec<ModelInfo>,
297 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
298}
299
300impl GroupedModels {
301 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
302 let favorites = all
303 .iter()
304 .filter(|info| info.is_favorite)
305 .cloned()
306 .collect();
307
308 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
309 for model in all {
310 let provider = model.model.provider_id();
311 if let Some(models) = all_by_provider.get_mut(&provider) {
312 models.push(model);
313 } else {
314 all_by_provider.insert(provider, vec![model]);
315 }
316 }
317
318 Self {
319 favorites,
320 recommended,
321 all: all_by_provider,
322 }
323 }
324
325 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
326 let mut entries = Vec::new();
327
328 if !self.favorites.is_empty() {
329 entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
330 for info in &self.favorites {
331 entries.push(LanguageModelPickerEntry::Model(info.clone()));
332 }
333 }
334
335 if !self.recommended.is_empty() {
336 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
337 for info in &self.recommended {
338 entries.push(LanguageModelPickerEntry::Model(info.clone()));
339 }
340 }
341
342 for models in self.all.values() {
343 if models.is_empty() {
344 continue;
345 }
346 entries.push(LanguageModelPickerEntry::Separator(
347 models[0].model.provider_name().0,
348 ));
349 for info in models {
350 entries.push(LanguageModelPickerEntry::Model(info.clone()));
351 }
352 }
353
354 entries
355 }
356}
357
358enum LanguageModelPickerEntry {
359 Model(ModelInfo),
360 Separator(SharedString),
361}
362
363struct ModelMatcher {
364 models: Vec<ModelInfo>,
365 fg_executor: ForegroundExecutor,
366 bg_executor: BackgroundExecutor,
367 candidates: Vec<StringMatchCandidate>,
368}
369
370impl ModelMatcher {
371 fn new(
372 models: Vec<ModelInfo>,
373 fg_executor: ForegroundExecutor,
374 bg_executor: BackgroundExecutor,
375 ) -> ModelMatcher {
376 let candidates = Self::make_match_candidates(&models);
377 Self {
378 models,
379 fg_executor,
380 bg_executor,
381 candidates,
382 }
383 }
384
385 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
386 let mut matches = self.fg_executor.block_on(match_strings(
387 &self.candidates,
388 query,
389 false,
390 true,
391 100,
392 &Default::default(),
393 self.bg_executor.clone(),
394 ));
395
396 let sorting_key = |mat: &StringMatch| {
397 let candidate = &self.candidates[mat.candidate_id];
398 (Reverse(OrderedFloat(mat.score)), candidate.id)
399 };
400 matches.sort_unstable_by_key(sorting_key);
401
402 let matched_models: Vec<_> = matches
403 .into_iter()
404 .map(|mat| self.models[mat.candidate_id].clone())
405 .collect();
406
407 matched_models
408 }
409
410 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
411 self.models
412 .iter()
413 .filter(|m| {
414 m.model
415 .name()
416 .0
417 .to_lowercase()
418 .contains(&query.to_lowercase())
419 })
420 .cloned()
421 .collect::<Vec<_>>()
422 }
423
424 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
425 model_infos
426 .iter()
427 .enumerate()
428 .map(|(index, model)| {
429 StringMatchCandidate::new(
430 index,
431 &format!(
432 "{}/{}",
433 &model.model.provider_name().0,
434 &model.model.name().0
435 ),
436 )
437 })
438 .collect::<Vec<_>>()
439 }
440}
441
442impl PickerDelegate for LanguageModelPickerDelegate {
443 type ListItem = AnyElement;
444
445 fn match_count(&self) -> usize {
446 self.filtered_entries.len()
447 }
448
449 fn selected_index(&self) -> usize {
450 self.selected_index
451 }
452
453 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
454 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
455 cx.notify();
456 }
457
458 fn can_select(
459 &mut self,
460 ix: usize,
461 _window: &mut Window,
462 _cx: &mut Context<Picker<Self>>,
463 ) -> bool {
464 match self.filtered_entries.get(ix) {
465 Some(LanguageModelPickerEntry::Model(_)) => true,
466 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
467 }
468 }
469
470 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
471 "Select a model…".into()
472 }
473
474 fn update_matches(
475 &mut self,
476 query: String,
477 window: &mut Window,
478 cx: &mut Context<Picker<Self>>,
479 ) -> Task<()> {
480 let all_models = self.all_models.clone();
481 let active_model = (self.get_active_model)(cx);
482 let fg_executor = cx.foreground_executor();
483 let bg_executor = cx.background_executor();
484
485 let language_model_registry = LanguageModelRegistry::global(cx);
486
487 let configured_providers = language_model_registry
488 .read(cx)
489 .visible_providers()
490 .into_iter()
491 .filter(|provider| provider.is_authenticated(cx))
492 .collect::<Vec<_>>();
493
494 let configured_provider_ids = configured_providers
495 .iter()
496 .map(|provider| provider.id())
497 .collect::<Vec<_>>();
498
499 let recommended_models = all_models
500 .recommended
501 .iter()
502 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
503 .cloned()
504 .collect::<Vec<_>>();
505
506 let available_models = all_models
507 .all
508 .values()
509 .flat_map(|models| models.iter())
510 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
511 .cloned()
512 .collect::<Vec<_>>();
513
514 let matcher_rec =
515 ModelMatcher::new(recommended_models, fg_executor.clone(), bg_executor.clone());
516 let matcher_all =
517 ModelMatcher::new(available_models, fg_executor.clone(), bg_executor.clone());
518
519 let recommended = matcher_rec.exact_search(&query);
520 let all = matcher_all.fuzzy_search(&query);
521
522 let filtered_models = GroupedModels::new(all, recommended);
523
524 cx.spawn_in(window, async move |this, cx| {
525 this.update_in(cx, |this, window, cx| {
526 this.delegate.filtered_entries = filtered_models.entries();
527 // Finds the currently selected model in the list
528 let new_index =
529 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
530 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
531 cx.notify();
532 })
533 .ok();
534 })
535 }
536
537 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
538 if let Some(LanguageModelPickerEntry::Model(model_info)) =
539 self.filtered_entries.get(self.selected_index)
540 {
541 let model = model_info.model.clone();
542 (self.on_model_changed)(model.clone(), cx);
543
544 let current_index = self.selected_index;
545 self.set_selected_index(current_index, window, cx);
546
547 cx.emit(DismissEvent);
548 }
549 }
550
551 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
552 cx.emit(DismissEvent);
553 }
554
555 fn render_match(
556 &self,
557 ix: usize,
558 selected: bool,
559 _: &mut Window,
560 cx: &mut Context<Picker<Self>>,
561 ) -> Option<Self::ListItem> {
562 match self.filtered_entries.get(ix)? {
563 LanguageModelPickerEntry::Separator(title) => {
564 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
565 }
566 LanguageModelPickerEntry::Model(model_info) => {
567 let active_model = (self.get_active_model)(cx);
568 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
569 let active_model_id = active_model.map(|m| m.model.id());
570
571 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
572 && Some(model_info.model.id()) == active_model_id;
573
574 let model_cost = model_info
575 .model
576 .model_cost_info()
577 .map(|cost| cost.to_shared_string());
578
579 let is_favorite = model_info.is_favorite;
580 let handle_action_click = {
581 let model = model_info.model.clone();
582 let on_toggle_favorite = self.on_toggle_favorite.clone();
583 cx.listener(move |picker, _, window, cx| {
584 on_toggle_favorite(model.clone(), !is_favorite, cx);
585 picker.refresh(window, cx);
586 })
587 };
588
589 Some(
590 ModelSelectorListItem::new(ix, model_info.model.name().0)
591 .map(|this| match &model_info.icon {
592 IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
593 IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
594 })
595 .is_selected(is_selected)
596 .is_focused(selected)
597 .is_latest(model_info.model.is_latest())
598 .is_favorite(is_favorite)
599 .cost_info(model_cost)
600 .on_toggle_favorite(handle_action_click)
601 .into_any_element(),
602 )
603 }
604 }
605 }
606
607 fn render_footer(
608 &self,
609 _window: &mut Window,
610 _cx: &mut Context<Picker<Self>>,
611 ) -> Option<gpui::AnyElement> {
612 let focus_handle = self.focus_handle.clone();
613
614 if !self.popover_styles {
615 return None;
616 }
617
618 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625 use futures::{future::BoxFuture, stream::BoxStream};
626 use gpui::{AsyncApp, TestAppContext, http_client};
627 use language_model::{
628 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
629 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
630 LanguageModelRequest, LanguageModelToolChoice,
631 };
632 use ui::IconName;
633
634 #[derive(Clone)]
635 struct TestLanguageModel {
636 name: LanguageModelName,
637 id: LanguageModelId,
638 provider_id: LanguageModelProviderId,
639 provider_name: LanguageModelProviderName,
640 }
641
642 impl TestLanguageModel {
643 fn new(name: &str, provider: &str) -> Self {
644 Self {
645 name: LanguageModelName::from(name.to_string()),
646 id: LanguageModelId::from(name.to_string()),
647 provider_id: LanguageModelProviderId::from(provider.to_string()),
648 provider_name: LanguageModelProviderName::from(provider.to_string()),
649 }
650 }
651 }
652
653 impl LanguageModel for TestLanguageModel {
654 fn id(&self) -> LanguageModelId {
655 self.id.clone()
656 }
657
658 fn name(&self) -> LanguageModelName {
659 self.name.clone()
660 }
661
662 fn provider_id(&self) -> LanguageModelProviderId {
663 self.provider_id.clone()
664 }
665
666 fn provider_name(&self) -> LanguageModelProviderName {
667 self.provider_name.clone()
668 }
669
670 fn supports_tools(&self) -> bool {
671 false
672 }
673
674 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
675 false
676 }
677
678 fn supports_images(&self) -> bool {
679 false
680 }
681
682 fn telemetry_id(&self) -> String {
683 format!("{}/{}", self.provider_id.0, self.name.0)
684 }
685
686 fn max_token_count(&self) -> u64 {
687 1000
688 }
689
690 fn count_tokens(
691 &self,
692 _: LanguageModelRequest,
693 _: &App,
694 ) -> BoxFuture<'static, http_client::Result<u64>> {
695 unimplemented!()
696 }
697
698 fn stream_completion(
699 &self,
700 _: LanguageModelRequest,
701 _: &AsyncApp,
702 ) -> BoxFuture<
703 'static,
704 Result<
705 BoxStream<
706 'static,
707 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
708 >,
709 LanguageModelCompletionError,
710 >,
711 > {
712 unimplemented!()
713 }
714 }
715
716 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
717 create_models_with_favorites(model_specs, vec![])
718 }
719
720 fn create_models_with_favorites(
721 model_specs: Vec<(&str, &str)>,
722 favorites: Vec<(&str, &str)>,
723 ) -> Vec<ModelInfo> {
724 model_specs
725 .into_iter()
726 .map(|(provider, name)| {
727 let is_favorite = favorites
728 .iter()
729 .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
730 ModelInfo {
731 model: Arc::new(TestLanguageModel::new(name, provider)),
732 icon: IconOrSvg::Icon(IconName::Ai),
733 is_favorite,
734 }
735 })
736 .collect()
737 }
738
739 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
740 assert_eq!(
741 result.len(),
742 expected.len(),
743 "Number of models doesn't match"
744 );
745
746 for (i, expected_name) in expected.iter().enumerate() {
747 assert_eq!(
748 result[i].model.telemetry_id(),
749 *expected_name,
750 "Model at position {} doesn't match expected model",
751 i
752 );
753 }
754 }
755
756 #[gpui::test]
757 fn test_exact_match(cx: &mut TestAppContext) {
758 let models = create_models(vec![
759 ("zed", "Claude 3.7 Sonnet"),
760 ("zed", "Claude 3.7 Sonnet Thinking"),
761 ("zed", "gpt-5"),
762 ("zed", "gpt-5-mini"),
763 ("openai", "gpt-3.5-turbo"),
764 ("openai", "gpt-5"),
765 ("openai", "gpt-5-mini"),
766 ("ollama", "mistral"),
767 ("ollama", "deepseek"),
768 ]);
769 let matcher = ModelMatcher::new(
770 models,
771 cx.foreground_executor().clone(),
772 cx.background_executor.clone(),
773 );
774
775 // The order of models should be maintained, case doesn't matter
776 let results = matcher.exact_search("GPT-5");
777 assert_models_eq(
778 results,
779 vec![
780 "zed/gpt-5",
781 "zed/gpt-5-mini",
782 "openai/gpt-5",
783 "openai/gpt-5-mini",
784 ],
785 );
786 }
787
788 #[gpui::test]
789 fn test_fuzzy_match(cx: &mut TestAppContext) {
790 let models = create_models(vec![
791 ("zed", "Claude 3.7 Sonnet"),
792 ("zed", "Claude 3.7 Sonnet Thinking"),
793 ("zed", "gpt-5"),
794 ("zed", "gpt-5-mini"),
795 ("openai", "gpt-3.5-turbo"),
796 ("openai", "gpt-5"),
797 ("openai", "gpt-5-mini"),
798 ("ollama", "mistral"),
799 ("ollama", "deepseek"),
800 ]);
801 let matcher = ModelMatcher::new(
802 models,
803 cx.foreground_executor().clone(),
804 cx.background_executor.clone(),
805 );
806
807 // Results should preserve models order whenever possible.
808 // In the case below, `zed/gpt-5-mini` and `openai/gpt-5-mini` have identical
809 // similarity scores, but `zed/gpt-5-mini` was higher in the models list,
810 // so it should appear first in the results.
811 let results = matcher.fuzzy_search("mini");
812 assert_models_eq(results, vec!["zed/gpt-5-mini", "openai/gpt-5-mini"]);
813
814 // Model provider should be searchable as well
815 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
816 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
817
818 // Fuzzy search - search for Claude to get the Thinking variant
819 let results = matcher.fuzzy_search("thinking");
820 assert_models_eq(results, vec!["zed/Claude 3.7 Sonnet Thinking"]);
821 }
822
823 #[gpui::test]
824 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
825 let recommended_models = create_models(vec![("zed", "claude")]);
826 let all_models = create_models(vec![
827 ("zed", "claude"), // Should also appear in "other"
828 ("zed", "gemini"),
829 ("copilot", "o3"),
830 ]);
831
832 let grouped_models = GroupedModels::new(all_models, recommended_models);
833
834 let actual_all_models = grouped_models
835 .all
836 .values()
837 .flatten()
838 .cloned()
839 .collect::<Vec<_>>();
840
841 // Recommended models should also appear in "all"
842 assert_models_eq(
843 actual_all_models,
844 vec!["zed/claude", "zed/gemini", "copilot/o3"],
845 );
846 }
847
848 #[gpui::test]
849 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
850 let recommended_models = create_models(vec![("zed", "claude")]);
851 let all_models = create_models(vec![
852 ("zed", "claude"), // Should also appear in "other"
853 ("zed", "gemini"),
854 ("copilot", "claude"), // Different provider, should appear in "other"
855 ]);
856
857 let grouped_models = GroupedModels::new(all_models, recommended_models);
858
859 let actual_all_models = grouped_models
860 .all
861 .values()
862 .flatten()
863 .cloned()
864 .collect::<Vec<_>>();
865
866 // All models should appear in "all" regardless of recommended status
867 assert_models_eq(
868 actual_all_models,
869 vec!["zed/claude", "zed/gemini", "copilot/claude"],
870 );
871 }
872
873 #[gpui::test]
874 fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
875 let recommended_models = create_models(vec![("zed", "claude")]);
876 let all_models = create_models_with_favorites(
877 vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
878 vec![("zed", "gemini")],
879 );
880
881 let grouped_models = GroupedModels::new(all_models, recommended_models);
882 let entries = grouped_models.entries();
883
884 assert!(matches!(
885 entries.first(),
886 Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
887 ));
888
889 assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
890 }
891
892 #[gpui::test]
893 fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
894 let recommended_models = create_models(vec![("zed", "claude")]);
895 let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
896
897 let grouped_models = GroupedModels::new(all_models, recommended_models);
898 let entries = grouped_models.entries();
899
900 assert!(matches!(
901 entries.first(),
902 Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
903 ));
904
905 assert!(grouped_models.favorites.is_empty());
906 }
907
908 #[gpui::test]
909 fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
910 let recommended_models =
911 create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
912 let all_models = create_models_with_favorites(
913 vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
914 vec![("zed", "claude")],
915 );
916
917 let grouped_models = GroupedModels::new(all_models, recommended_models);
918 let entries = grouped_models.entries();
919
920 for entry in &entries {
921 if let LanguageModelPickerEntry::Model(info) = entry {
922 if info.model.telemetry_id() == "zed/claude" {
923 assert!(info.is_favorite, "zed/claude should be a favorite");
924 } else {
925 assert!(
926 !info.is_favorite,
927 "{} should not be a favorite",
928 info.model.telemetry_id()
929 );
930 }
931 }
932 }
933 }
934
935 #[gpui::test]
936 fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
937 let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
938
939 let recommended_models =
940 create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
941
942 let all_models = create_models_with_favorites(
943 vec![
944 ("zed", "claude"),
945 ("zed", "gemini"),
946 ("openai", "gpt-4"),
947 ("openai", "gpt-3.5"),
948 ],
949 favorites,
950 );
951
952 let grouped_models = GroupedModels::new(all_models, recommended_models);
953
954 assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
955 assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
956 assert_models_eq(
957 grouped_models.all.values().flatten().cloned().collect(),
958 vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
959 );
960 }
961}