1use std::{cmp::Reverse, sync::Arc};
2
3use agent_settings::AgentSettings;
4use collections::{HashMap, HashSet, IndexMap};
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{
7 Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
8};
9use language_model::{
10 AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId,
11 LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
12};
13use ordered_float::OrderedFloat;
14use picker::{Picker, PickerDelegate};
15use settings::Settings;
16use ui::prelude::*;
17use zed_actions::agent::OpenSettings;
18
19use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
20
21type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
22type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
23type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static>;
24
25pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
26
27pub fn language_model_selector(
28 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
29 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
30 on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
31 popover_styles: bool,
32 focus_handle: FocusHandle,
33 window: &mut Window,
34 cx: &mut Context<LanguageModelSelector>,
35) -> LanguageModelSelector {
36 let delegate = LanguageModelPickerDelegate::new(
37 get_active_model,
38 on_model_changed,
39 on_toggle_favorite,
40 popover_styles,
41 focus_handle,
42 window,
43 cx,
44 );
45
46 if popover_styles {
47 Picker::list(delegate, window, cx)
48 .show_scrollbar(true)
49 .width(rems(20.))
50 .max_height(Some(rems(20.).into()))
51 } else {
52 Picker::list(delegate, window, cx).show_scrollbar(true)
53 }
54}
55
56fn all_models(cx: &App) -> GroupedModels {
57 let lm_registry = LanguageModelRegistry::global(cx).read(cx);
58 let providers = lm_registry.visible_providers();
59
60 let mut favorites_index = FavoritesIndex::default();
61
62 for sel in &AgentSettings::get_global(cx).favorite_models {
63 favorites_index
64 .entry(sel.provider.0.clone().into())
65 .or_default()
66 .insert(sel.model.clone().into());
67 }
68
69 let recommended = providers
70 .iter()
71 .flat_map(|provider| {
72 provider
73 .recommended_models(cx)
74 .into_iter()
75 .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
76 })
77 .collect();
78
79 let all = providers
80 .iter()
81 .flat_map(|provider| {
82 provider
83 .provided_models(cx)
84 .into_iter()
85 .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
86 })
87 .collect();
88
89 GroupedModels::new(all, recommended)
90}
91
92type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
93
94#[derive(Clone)]
95struct ModelInfo {
96 model: Arc<dyn LanguageModel>,
97 icon: IconOrSvg,
98 is_favorite: bool,
99}
100
101impl ModelInfo {
102 fn new(
103 provider: &dyn LanguageModelProvider,
104 model: Arc<dyn LanguageModel>,
105 favorites_index: &FavoritesIndex,
106 ) -> Self {
107 let is_favorite = favorites_index
108 .get(&provider.id())
109 .map_or(false, |set| set.contains(&model.id()));
110
111 Self {
112 model,
113 icon: provider.icon(),
114 is_favorite,
115 }
116 }
117}
118
119pub struct LanguageModelPickerDelegate {
120 on_model_changed: OnModelChanged,
121 get_active_model: GetActiveModel,
122 on_toggle_favorite: OnToggleFavorite,
123 all_models: Arc<GroupedModels>,
124 filtered_entries: Vec<LanguageModelPickerEntry>,
125 selected_index: usize,
126 _authenticate_all_providers_task: Task<()>,
127 _subscriptions: Vec<Subscription>,
128 popover_styles: bool,
129 focus_handle: FocusHandle,
130}
131
132impl LanguageModelPickerDelegate {
133 fn new(
134 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
135 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
136 on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
137 popover_styles: bool,
138 focus_handle: FocusHandle,
139 window: &mut Window,
140 cx: &mut Context<Picker<Self>>,
141 ) -> Self {
142 let on_model_changed = Arc::new(on_model_changed);
143 let models = all_models(cx);
144 let entries = models.entries();
145
146 Self {
147 on_model_changed,
148 all_models: Arc::new(models),
149 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
150 filtered_entries: entries,
151 get_active_model: Arc::new(get_active_model),
152 on_toggle_favorite: Arc::new(on_toggle_favorite),
153 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
154 _subscriptions: vec![cx.subscribe_in(
155 &LanguageModelRegistry::global(cx),
156 window,
157 |picker, _, event, window, cx| {
158 match event {
159 language_model::Event::ProviderStateChanged(_)
160 | language_model::Event::AddedProvider(_)
161 | language_model::Event::RemovedProvider(_) => {
162 let query = picker.query(cx);
163 picker.delegate.all_models = Arc::new(all_models(cx));
164 // Update matches will automatically drop the previous task
165 // if we get a provider event again
166 picker.update_matches(query, window, cx)
167 }
168 _ => {}
169 }
170 },
171 )],
172 popover_styles,
173 focus_handle,
174 }
175 }
176
177 fn get_active_model_index(
178 entries: &[LanguageModelPickerEntry],
179 active_model: Option<ConfiguredModel>,
180 ) -> usize {
181 entries
182 .iter()
183 .position(|entry| {
184 if let LanguageModelPickerEntry::Model(model) = entry {
185 active_model
186 .as_ref()
187 .map(|active_model| {
188 active_model.model.id() == model.model.id()
189 && active_model.provider.id() == model.model.provider_id()
190 })
191 .unwrap_or_default()
192 } else {
193 false
194 }
195 })
196 .unwrap_or(0)
197 }
198
199 /// Authenticates all providers in the [`LanguageModelRegistry`].
200 ///
201 /// We do this so that we can populate the language selector with all of the
202 /// models from the configured providers.
203 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
204 let authenticate_all_providers = LanguageModelRegistry::global(cx)
205 .read(cx)
206 .visible_providers()
207 .iter()
208 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
209 .collect::<Vec<_>>();
210
211 cx.spawn(async move |_cx| {
212 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
213 if let Err(err) = authenticate_task.await {
214 if matches!(err, AuthenticateError::CredentialsNotFound) {
215 // Since we're authenticating these providers in the
216 // background for the purposes of populating the
217 // language selector, we don't care about providers
218 // where the credentials are not found.
219 } else {
220 // Some providers have noisy failure states that we
221 // don't want to spam the logs with every time the
222 // language model selector is initialized.
223 //
224 // Ideally these should have more clear failure modes
225 // that we know are safe to ignore here, like what we do
226 // with `CredentialsNotFound` above.
227 match provider_id.0.as_ref() {
228 "lmstudio" | "ollama" => {
229 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
230 //
231 // These fail noisily, so we don't log them.
232 }
233 "copilot_chat" => {
234 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
235 }
236 _ => {
237 log::error!(
238 "Failed to authenticate provider: {}: {err:#}",
239 provider_name.0
240 );
241 }
242 }
243 }
244 }
245 }
246 })
247 }
248
249 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
250 (self.get_active_model)(cx)
251 }
252
253 pub fn favorites_count(&self) -> usize {
254 self.all_models.favorites.len()
255 }
256
257 pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
258 if self.all_models.favorites.is_empty() {
259 return;
260 }
261
262 let active_model = (self.get_active_model)(cx);
263 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
264 let active_model_id = active_model.as_ref().map(|m| m.model.id());
265
266 let current_index = self
267 .all_models
268 .favorites
269 .iter()
270 .position(|info| {
271 Some(info.model.provider_id()) == active_provider_id
272 && Some(info.model.id()) == active_model_id
273 })
274 .unwrap_or(usize::MAX);
275
276 let next_index = if current_index == usize::MAX {
277 0
278 } else {
279 (current_index + 1) % self.all_models.favorites.len()
280 };
281
282 let next_model = self.all_models.favorites[next_index].model.clone();
283
284 (self.on_model_changed)(next_model, cx);
285
286 // Align the picker selection with the newly-active model
287 let new_index =
288 Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
289 self.set_selected_index(new_index, window, cx);
290 }
291}
292
293struct GroupedModels {
294 favorites: Vec<ModelInfo>,
295 recommended: Vec<ModelInfo>,
296 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
297}
298
299impl GroupedModels {
300 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
301 let favorites = all
302 .iter()
303 .filter(|info| info.is_favorite)
304 .cloned()
305 .collect();
306
307 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
308 for model in all {
309 let provider = model.model.provider_id();
310 if let Some(models) = all_by_provider.get_mut(&provider) {
311 models.push(model);
312 } else {
313 all_by_provider.insert(provider, vec![model]);
314 }
315 }
316
317 Self {
318 favorites,
319 recommended,
320 all: all_by_provider,
321 }
322 }
323
324 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
325 let mut entries = Vec::new();
326
327 if !self.favorites.is_empty() {
328 entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
329 for info in &self.favorites {
330 entries.push(LanguageModelPickerEntry::Model(info.clone()));
331 }
332 }
333
334 if !self.recommended.is_empty() {
335 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
336 for info in &self.recommended {
337 entries.push(LanguageModelPickerEntry::Model(info.clone()));
338 }
339 }
340
341 for models in self.all.values() {
342 if models.is_empty() {
343 continue;
344 }
345 entries.push(LanguageModelPickerEntry::Separator(
346 models[0].model.provider_name().0,
347 ));
348 for info in models {
349 entries.push(LanguageModelPickerEntry::Model(info.clone()));
350 }
351 }
352
353 entries
354 }
355}
356
357enum LanguageModelPickerEntry {
358 Model(ModelInfo),
359 Separator(SharedString),
360}
361
362struct ModelMatcher {
363 models: Vec<ModelInfo>,
364 bg_executor: BackgroundExecutor,
365 candidates: Vec<StringMatchCandidate>,
366}
367
368impl ModelMatcher {
369 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
370 let candidates = Self::make_match_candidates(&models);
371 Self {
372 models,
373 bg_executor,
374 candidates,
375 }
376 }
377
378 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
379 let mut matches = self.bg_executor.block(match_strings(
380 &self.candidates,
381 query,
382 false,
383 true,
384 100,
385 &Default::default(),
386 self.bg_executor.clone(),
387 ));
388
389 let sorting_key = |mat: &StringMatch| {
390 let candidate = &self.candidates[mat.candidate_id];
391 (Reverse(OrderedFloat(mat.score)), candidate.id)
392 };
393 matches.sort_unstable_by_key(sorting_key);
394
395 let matched_models: Vec<_> = matches
396 .into_iter()
397 .map(|mat| self.models[mat.candidate_id].clone())
398 .collect();
399
400 matched_models
401 }
402
403 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
404 self.models
405 .iter()
406 .filter(|m| {
407 m.model
408 .name()
409 .0
410 .to_lowercase()
411 .contains(&query.to_lowercase())
412 })
413 .cloned()
414 .collect::<Vec<_>>()
415 }
416
417 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
418 model_infos
419 .iter()
420 .enumerate()
421 .map(|(index, model)| {
422 StringMatchCandidate::new(
423 index,
424 &format!(
425 "{}/{}",
426 &model.model.provider_name().0,
427 &model.model.name().0
428 ),
429 )
430 })
431 .collect::<Vec<_>>()
432 }
433}
434
435impl PickerDelegate for LanguageModelPickerDelegate {
436 type ListItem = AnyElement;
437
438 fn match_count(&self) -> usize {
439 self.filtered_entries.len()
440 }
441
442 fn selected_index(&self) -> usize {
443 self.selected_index
444 }
445
446 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
447 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
448 cx.notify();
449 }
450
451 fn can_select(
452 &mut self,
453 ix: usize,
454 _window: &mut Window,
455 _cx: &mut Context<Picker<Self>>,
456 ) -> bool {
457 match self.filtered_entries.get(ix) {
458 Some(LanguageModelPickerEntry::Model(_)) => true,
459 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
460 }
461 }
462
463 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
464 "Select a model…".into()
465 }
466
467 fn update_matches(
468 &mut self,
469 query: String,
470 window: &mut Window,
471 cx: &mut Context<Picker<Self>>,
472 ) -> Task<()> {
473 let all_models = self.all_models.clone();
474 let active_model = (self.get_active_model)(cx);
475 let bg_executor = cx.background_executor();
476
477 let language_model_registry = LanguageModelRegistry::global(cx);
478
479 let configured_providers = language_model_registry
480 .read(cx)
481 .visible_providers()
482 .into_iter()
483 .filter(|provider| provider.is_authenticated(cx))
484 .collect::<Vec<_>>();
485
486 let configured_provider_ids = configured_providers
487 .iter()
488 .map(|provider| provider.id())
489 .collect::<Vec<_>>();
490
491 let recommended_models = all_models
492 .recommended
493 .iter()
494 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
495 .cloned()
496 .collect::<Vec<_>>();
497
498 let available_models = all_models
499 .all
500 .values()
501 .flat_map(|models| models.iter())
502 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
503 .cloned()
504 .collect::<Vec<_>>();
505
506 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
507 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
508
509 let recommended = matcher_rec.exact_search(&query);
510 let all = matcher_all.fuzzy_search(&query);
511
512 let filtered_models = GroupedModels::new(all, recommended);
513
514 cx.spawn_in(window, async move |this, cx| {
515 this.update_in(cx, |this, window, cx| {
516 this.delegate.filtered_entries = filtered_models.entries();
517 // Finds the currently selected model in the list
518 let new_index =
519 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
520 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
521 cx.notify();
522 })
523 .ok();
524 })
525 }
526
527 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
528 if let Some(LanguageModelPickerEntry::Model(model_info)) =
529 self.filtered_entries.get(self.selected_index)
530 {
531 let model = model_info.model.clone();
532 (self.on_model_changed)(model.clone(), cx);
533
534 let current_index = self.selected_index;
535 self.set_selected_index(current_index, window, cx);
536
537 cx.emit(DismissEvent);
538 }
539 }
540
541 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
542 cx.emit(DismissEvent);
543 }
544
545 fn render_match(
546 &self,
547 ix: usize,
548 selected: bool,
549 _: &mut Window,
550 cx: &mut Context<Picker<Self>>,
551 ) -> Option<Self::ListItem> {
552 match self.filtered_entries.get(ix)? {
553 LanguageModelPickerEntry::Separator(title) => {
554 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
555 }
556 LanguageModelPickerEntry::Model(model_info) => {
557 let active_model = (self.get_active_model)(cx);
558 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
559 let active_model_id = active_model.map(|m| m.model.id());
560
561 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
562 && Some(model_info.model.id()) == active_model_id;
563
564 let is_favorite = model_info.is_favorite;
565 let handle_action_click = {
566 let model = model_info.model.clone();
567 let on_toggle_favorite = self.on_toggle_favorite.clone();
568 cx.listener(move |picker, _, window, cx| {
569 on_toggle_favorite(model.clone(), !is_favorite, cx);
570 picker.refresh(window, cx);
571 })
572 };
573
574 Some(
575 ModelSelectorListItem::new(ix, model_info.model.name().0)
576 .map(|this| match &model_info.icon {
577 IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
578 IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
579 })
580 .is_selected(is_selected)
581 .is_focused(selected)
582 .is_favorite(is_favorite)
583 .on_toggle_favorite(handle_action_click)
584 .into_any_element(),
585 )
586 }
587 }
588 }
589
590 fn render_footer(
591 &self,
592 _window: &mut Window,
593 _cx: &mut Context<Picker<Self>>,
594 ) -> Option<gpui::AnyElement> {
595 let focus_handle = self.focus_handle.clone();
596
597 if !self.popover_styles {
598 return None;
599 }
600
601 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
602 }
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608 use futures::{future::BoxFuture, stream::BoxStream};
609 use gpui::{AsyncApp, TestAppContext, http_client};
610 use language_model::{
611 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
612 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
613 LanguageModelRequest, LanguageModelToolChoice,
614 };
615 use ui::IconName;
616
617 #[derive(Clone)]
618 struct TestLanguageModel {
619 name: LanguageModelName,
620 id: LanguageModelId,
621 provider_id: LanguageModelProviderId,
622 provider_name: LanguageModelProviderName,
623 }
624
625 impl TestLanguageModel {
626 fn new(name: &str, provider: &str) -> Self {
627 Self {
628 name: LanguageModelName::from(name.to_string()),
629 id: LanguageModelId::from(name.to_string()),
630 provider_id: LanguageModelProviderId::from(provider.to_string()),
631 provider_name: LanguageModelProviderName::from(provider.to_string()),
632 }
633 }
634 }
635
636 impl LanguageModel for TestLanguageModel {
637 fn id(&self) -> LanguageModelId {
638 self.id.clone()
639 }
640
641 fn name(&self) -> LanguageModelName {
642 self.name.clone()
643 }
644
645 fn provider_id(&self) -> LanguageModelProviderId {
646 self.provider_id.clone()
647 }
648
649 fn provider_name(&self) -> LanguageModelProviderName {
650 self.provider_name.clone()
651 }
652
653 fn supports_tools(&self) -> bool {
654 false
655 }
656
657 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
658 false
659 }
660
661 fn supports_images(&self) -> bool {
662 false
663 }
664
665 fn telemetry_id(&self) -> String {
666 format!("{}/{}", self.provider_id.0, self.name.0)
667 }
668
669 fn max_token_count(&self) -> u64 {
670 1000
671 }
672
673 fn count_tokens(
674 &self,
675 _: LanguageModelRequest,
676 _: &App,
677 ) -> BoxFuture<'static, http_client::Result<u64>> {
678 unimplemented!()
679 }
680
681 fn stream_completion(
682 &self,
683 _: LanguageModelRequest,
684 _: &AsyncApp,
685 ) -> BoxFuture<
686 'static,
687 Result<
688 BoxStream<
689 'static,
690 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
691 >,
692 LanguageModelCompletionError,
693 >,
694 > {
695 unimplemented!()
696 }
697 }
698
699 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
700 create_models_with_favorites(model_specs, vec![])
701 }
702
703 fn create_models_with_favorites(
704 model_specs: Vec<(&str, &str)>,
705 favorites: Vec<(&str, &str)>,
706 ) -> Vec<ModelInfo> {
707 model_specs
708 .into_iter()
709 .map(|(provider, name)| {
710 let is_favorite = favorites
711 .iter()
712 .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
713 ModelInfo {
714 model: Arc::new(TestLanguageModel::new(name, provider)),
715 icon: IconOrSvg::Icon(IconName::Ai),
716 is_favorite,
717 }
718 })
719 .collect()
720 }
721
722 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
723 assert_eq!(
724 result.len(),
725 expected.len(),
726 "Number of models doesn't match"
727 );
728
729 for (i, expected_name) in expected.iter().enumerate() {
730 assert_eq!(
731 result[i].model.telemetry_id(),
732 *expected_name,
733 "Model at position {} doesn't match expected model",
734 i
735 );
736 }
737 }
738
739 #[gpui::test]
740 fn test_exact_match(cx: &mut TestAppContext) {
741 let models = create_models(vec![
742 ("zed", "Claude 3.7 Sonnet"),
743 ("zed", "Claude 3.7 Sonnet Thinking"),
744 ("zed", "gpt-4.1"),
745 ("zed", "gpt-4.1-nano"),
746 ("openai", "gpt-3.5-turbo"),
747 ("openai", "gpt-4.1"),
748 ("openai", "gpt-4.1-nano"),
749 ("ollama", "mistral"),
750 ("ollama", "deepseek"),
751 ]);
752 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
753
754 // The order of models should be maintained, case doesn't matter
755 let results = matcher.exact_search("GPT-4.1");
756 assert_models_eq(
757 results,
758 vec![
759 "zed/gpt-4.1",
760 "zed/gpt-4.1-nano",
761 "openai/gpt-4.1",
762 "openai/gpt-4.1-nano",
763 ],
764 );
765 }
766
767 #[gpui::test]
768 fn test_fuzzy_match(cx: &mut TestAppContext) {
769 let models = create_models(vec![
770 ("zed", "Claude 3.7 Sonnet"),
771 ("zed", "Claude 3.7 Sonnet Thinking"),
772 ("zed", "gpt-4.1"),
773 ("zed", "gpt-4.1-nano"),
774 ("openai", "gpt-3.5-turbo"),
775 ("openai", "gpt-4.1"),
776 ("openai", "gpt-4.1-nano"),
777 ("ollama", "mistral"),
778 ("ollama", "deepseek"),
779 ]);
780 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
781
782 // Results should preserve models order whenever possible.
783 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
784 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
785 // so it should appear first in the results.
786 let results = matcher.fuzzy_search("41");
787 assert_models_eq(
788 results,
789 vec![
790 "zed/gpt-4.1",
791 "openai/gpt-4.1",
792 "zed/gpt-4.1-nano",
793 "openai/gpt-4.1-nano",
794 ],
795 );
796
797 // Model provider should be searchable as well
798 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
799 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
800
801 // Fuzzy search
802 let results = matcher.fuzzy_search("z4n");
803 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
804 }
805
806 #[gpui::test]
807 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
808 let recommended_models = create_models(vec![("zed", "claude")]);
809 let all_models = create_models(vec![
810 ("zed", "claude"), // Should also appear in "other"
811 ("zed", "gemini"),
812 ("copilot", "o3"),
813 ]);
814
815 let grouped_models = GroupedModels::new(all_models, recommended_models);
816
817 let actual_all_models = grouped_models
818 .all
819 .values()
820 .flatten()
821 .cloned()
822 .collect::<Vec<_>>();
823
824 // Recommended models should also appear in "all"
825 assert_models_eq(
826 actual_all_models,
827 vec!["zed/claude", "zed/gemini", "copilot/o3"],
828 );
829 }
830
831 #[gpui::test]
832 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
833 let recommended_models = create_models(vec![("zed", "claude")]);
834 let all_models = create_models(vec![
835 ("zed", "claude"), // Should also appear in "other"
836 ("zed", "gemini"),
837 ("copilot", "claude"), // Different provider, should appear in "other"
838 ]);
839
840 let grouped_models = GroupedModels::new(all_models, recommended_models);
841
842 let actual_all_models = grouped_models
843 .all
844 .values()
845 .flatten()
846 .cloned()
847 .collect::<Vec<_>>();
848
849 // All models should appear in "all" regardless of recommended status
850 assert_models_eq(
851 actual_all_models,
852 vec!["zed/claude", "zed/gemini", "copilot/claude"],
853 );
854 }
855
856 #[gpui::test]
857 fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
858 let recommended_models = create_models(vec![("zed", "claude")]);
859 let all_models = create_models_with_favorites(
860 vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
861 vec![("zed", "gemini")],
862 );
863
864 let grouped_models = GroupedModels::new(all_models, recommended_models);
865 let entries = grouped_models.entries();
866
867 assert!(matches!(
868 entries.first(),
869 Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
870 ));
871
872 assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
873 }
874
875 #[gpui::test]
876 fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
877 let recommended_models = create_models(vec![("zed", "claude")]);
878 let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
879
880 let grouped_models = GroupedModels::new(all_models, recommended_models);
881 let entries = grouped_models.entries();
882
883 assert!(matches!(
884 entries.first(),
885 Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
886 ));
887
888 assert!(grouped_models.favorites.is_empty());
889 }
890
891 #[gpui::test]
892 fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
893 let recommended_models =
894 create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
895 let all_models = create_models_with_favorites(
896 vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
897 vec![("zed", "claude")],
898 );
899
900 let grouped_models = GroupedModels::new(all_models, recommended_models);
901 let entries = grouped_models.entries();
902
903 for entry in &entries {
904 if let LanguageModelPickerEntry::Model(info) = entry {
905 if info.model.telemetry_id() == "zed/claude" {
906 assert!(info.is_favorite, "zed/claude should be a favorite");
907 } else {
908 assert!(
909 !info.is_favorite,
910 "{} should not be a favorite",
911 info.model.telemetry_id()
912 );
913 }
914 }
915 }
916 }
917
918 #[gpui::test]
919 fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
920 let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
921
922 let recommended_models =
923 create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
924
925 let all_models = create_models_with_favorites(
926 vec![
927 ("zed", "claude"),
928 ("zed", "gemini"),
929 ("openai", "gpt-4"),
930 ("openai", "gpt-3.5"),
931 ],
932 favorites,
933 );
934
935 let grouped_models = GroupedModels::new(all_models, recommended_models);
936
937 assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
938 assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
939 assert_models_eq(
940 grouped_models.all.values().flatten().cloned().collect(),
941 vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
942 );
943 }
944}