1use std::{cmp::Reverse, sync::Arc};
2
3use agent_settings::AgentSettings;
4use collections::{HashMap, HashSet, IndexMap};
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{
7 Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, ForegroundExecutor,
8 Subscription, Task,
9};
10use language_model::{
11 ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId, LanguageModelProvider,
12 LanguageModelProviderId, LanguageModelRegistry,
13};
14use ordered_float::OrderedFloat;
15use picker::{Picker, PickerDelegate};
16use settings::Settings;
17use ui::prelude::*;
18use zed_actions::agent::OpenSettings;
19
20use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
21
22type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
23type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
24type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static>;
25
26pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
27
28pub fn language_model_selector(
29 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
30 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
31 on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
32 popover_styles: bool,
33 focus_handle: FocusHandle,
34 window: &mut Window,
35 cx: &mut Context<LanguageModelSelector>,
36) -> LanguageModelSelector {
37 let delegate = LanguageModelPickerDelegate::new(
38 get_active_model,
39 on_model_changed,
40 on_toggle_favorite,
41 popover_styles,
42 focus_handle,
43 window,
44 cx,
45 );
46
47 if popover_styles {
48 Picker::list(delegate, window, cx)
49 .show_scrollbar(true)
50 .width(rems(20.))
51 .max_height(Some(rems(20.).into()))
52 } else {
53 Picker::list(delegate, window, cx).show_scrollbar(true)
54 }
55}
56
57fn all_models(cx: &App) -> GroupedModels {
58 let lm_registry = LanguageModelRegistry::global(cx).read(cx);
59 let providers = lm_registry.visible_providers();
60
61 let mut favorites_index = FavoritesIndex::default();
62
63 for sel in &AgentSettings::get_global(cx).favorite_models {
64 favorites_index
65 .entry(sel.provider.0.clone().into())
66 .or_default()
67 .insert(sel.model.clone().into());
68 }
69
70 let recommended = providers
71 .iter()
72 .flat_map(|provider| {
73 provider
74 .recommended_models(cx)
75 .into_iter()
76 .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
77 })
78 .collect();
79
80 let all = providers
81 .iter()
82 .flat_map(|provider| {
83 provider
84 .provided_models(cx)
85 .into_iter()
86 .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
87 })
88 .collect();
89
90 GroupedModels::new(all, recommended)
91}
92
93type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
94
95#[derive(Clone)]
96struct ModelInfo {
97 model: Arc<dyn LanguageModel>,
98 icon: IconOrSvg,
99 is_favorite: bool,
100}
101
102impl ModelInfo {
103 fn new(
104 provider: &dyn LanguageModelProvider,
105 model: Arc<dyn LanguageModel>,
106 favorites_index: &FavoritesIndex,
107 ) -> Self {
108 let is_favorite = favorites_index
109 .get(&provider.id())
110 .map_or(false, |set| set.contains(&model.id()));
111
112 Self {
113 model,
114 icon: provider.icon(),
115 is_favorite,
116 }
117 }
118}
119
120pub struct LanguageModelPickerDelegate {
121 on_model_changed: OnModelChanged,
122 get_active_model: GetActiveModel,
123 on_toggle_favorite: OnToggleFavorite,
124 all_models: Arc<GroupedModels>,
125 filtered_entries: Vec<LanguageModelPickerEntry>,
126 selected_index: usize,
127 _subscriptions: Vec<Subscription>,
128 popover_styles: bool,
129 focus_handle: FocusHandle,
130}
131
132impl LanguageModelPickerDelegate {
133 fn new(
134 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
135 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
136 on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
137 popover_styles: bool,
138 focus_handle: FocusHandle,
139 window: &mut Window,
140 cx: &mut Context<Picker<Self>>,
141 ) -> Self {
142 let on_model_changed = Arc::new(on_model_changed);
143 let models = all_models(cx);
144 let entries = models.entries();
145
146 Self {
147 on_model_changed,
148 all_models: Arc::new(models),
149 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
150 filtered_entries: entries,
151 get_active_model: Arc::new(get_active_model),
152 on_toggle_favorite: Arc::new(on_toggle_favorite),
153 _subscriptions: vec![cx.subscribe_in(
154 &LanguageModelRegistry::global(cx),
155 window,
156 |picker, _, event, window, cx| {
157 match event {
158 language_model::Event::ProviderStateChanged(_)
159 | language_model::Event::AddedProvider(_)
160 | language_model::Event::RemovedProvider(_) => {
161 let query = picker.query(cx);
162 picker.delegate.all_models = Arc::new(all_models(cx));
163 // Update matches will automatically drop the previous task
164 // if we get a provider event again
165 picker.update_matches(query, window, cx)
166 }
167 _ => {}
168 }
169 },
170 )],
171 popover_styles,
172 focus_handle,
173 }
174 }
175
176 fn get_active_model_index(
177 entries: &[LanguageModelPickerEntry],
178 active_model: Option<ConfiguredModel>,
179 ) -> usize {
180 entries
181 .iter()
182 .position(|entry| {
183 if let LanguageModelPickerEntry::Model(model) = entry {
184 active_model
185 .as_ref()
186 .map(|active_model| {
187 active_model.model.id() == model.model.id()
188 && active_model.provider.id() == model.model.provider_id()
189 })
190 .unwrap_or_default()
191 } else {
192 false
193 }
194 })
195 .unwrap_or(0)
196 }
197
198 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
199 (self.get_active_model)(cx)
200 }
201
202 pub fn favorites_count(&self) -> usize {
203 self.all_models.favorites.len()
204 }
205
206 pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
207 if self.all_models.favorites.is_empty() {
208 return;
209 }
210
211 let active_model = (self.get_active_model)(cx);
212 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
213 let active_model_id = active_model.as_ref().map(|m| m.model.id());
214
215 let current_index = self
216 .all_models
217 .favorites
218 .iter()
219 .position(|info| {
220 Some(info.model.provider_id()) == active_provider_id
221 && Some(info.model.id()) == active_model_id
222 })
223 .unwrap_or(usize::MAX);
224
225 let next_index = if current_index == usize::MAX {
226 0
227 } else {
228 (current_index + 1) % self.all_models.favorites.len()
229 };
230
231 let next_model = self.all_models.favorites[next_index].model.clone();
232
233 (self.on_model_changed)(next_model, cx);
234
235 // Align the picker selection with the newly-active model
236 let new_index =
237 Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
238 self.set_selected_index(new_index, window, cx);
239 }
240}
241
242struct GroupedModels {
243 favorites: Vec<ModelInfo>,
244 recommended: Vec<ModelInfo>,
245 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
246}
247
248impl GroupedModels {
249 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
250 let favorites = all
251 .iter()
252 .filter(|info| info.is_favorite)
253 .cloned()
254 .collect();
255
256 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
257 for model in all {
258 let provider = model.model.provider_id();
259 if let Some(models) = all_by_provider.get_mut(&provider) {
260 models.push(model);
261 } else {
262 all_by_provider.insert(provider, vec![model]);
263 }
264 }
265
266 Self {
267 favorites,
268 recommended,
269 all: all_by_provider,
270 }
271 }
272
273 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
274 let mut entries = Vec::new();
275
276 if !self.favorites.is_empty() {
277 entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
278 for info in &self.favorites {
279 entries.push(LanguageModelPickerEntry::Model(info.clone()));
280 }
281 }
282
283 if !self.recommended.is_empty() {
284 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
285 for info in &self.recommended {
286 entries.push(LanguageModelPickerEntry::Model(info.clone()));
287 }
288 }
289
290 for models in self.all.values() {
291 if models.is_empty() {
292 continue;
293 }
294 entries.push(LanguageModelPickerEntry::Separator(
295 models[0].model.provider_name().0,
296 ));
297 for info in models {
298 entries.push(LanguageModelPickerEntry::Model(info.clone()));
299 }
300 }
301
302 entries
303 }
304}
305
306enum LanguageModelPickerEntry {
307 Model(ModelInfo),
308 Separator(SharedString),
309}
310
311struct ModelMatcher {
312 models: Vec<ModelInfo>,
313 fg_executor: ForegroundExecutor,
314 bg_executor: BackgroundExecutor,
315 candidates: Vec<StringMatchCandidate>,
316}
317
318impl ModelMatcher {
319 fn new(
320 models: Vec<ModelInfo>,
321 fg_executor: ForegroundExecutor,
322 bg_executor: BackgroundExecutor,
323 ) -> ModelMatcher {
324 let candidates = Self::make_match_candidates(&models);
325 Self {
326 models,
327 fg_executor,
328 bg_executor,
329 candidates,
330 }
331 }
332
333 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
334 let mut matches = self.fg_executor.block_on(match_strings(
335 &self.candidates,
336 query,
337 false,
338 true,
339 100,
340 &Default::default(),
341 self.bg_executor.clone(),
342 ));
343
344 let sorting_key = |mat: &StringMatch| {
345 let candidate = &self.candidates[mat.candidate_id];
346 (Reverse(OrderedFloat(mat.score)), candidate.id)
347 };
348 matches.sort_unstable_by_key(sorting_key);
349
350 let matched_models: Vec<_> = matches
351 .into_iter()
352 .map(|mat| self.models[mat.candidate_id].clone())
353 .collect();
354
355 matched_models
356 }
357
358 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
359 self.models
360 .iter()
361 .filter(|m| {
362 m.model
363 .name()
364 .0
365 .to_lowercase()
366 .contains(&query.to_lowercase())
367 })
368 .cloned()
369 .collect::<Vec<_>>()
370 }
371
372 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
373 model_infos
374 .iter()
375 .enumerate()
376 .map(|(index, model)| {
377 StringMatchCandidate::new(
378 index,
379 &format!(
380 "{}/{}",
381 &model.model.provider_name().0,
382 &model.model.name().0
383 ),
384 )
385 })
386 .collect::<Vec<_>>()
387 }
388}
389
390impl PickerDelegate for LanguageModelPickerDelegate {
391 type ListItem = AnyElement;
392
393 fn match_count(&self) -> usize {
394 self.filtered_entries.len()
395 }
396
397 fn selected_index(&self) -> usize {
398 self.selected_index
399 }
400
401 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
402 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
403 cx.notify();
404 }
405
406 fn can_select(&self, ix: usize, _window: &mut Window, _cx: &mut Context<Picker<Self>>) -> bool {
407 match self.filtered_entries.get(ix) {
408 Some(LanguageModelPickerEntry::Model(_)) => true,
409 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
410 }
411 }
412
413 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
414 "Select a model…".into()
415 }
416
417 fn update_matches(
418 &mut self,
419 query: String,
420 window: &mut Window,
421 cx: &mut Context<Picker<Self>>,
422 ) -> Task<()> {
423 let all_models = self.all_models.clone();
424 let active_model = (self.get_active_model)(cx);
425 let fg_executor = cx.foreground_executor();
426 let bg_executor = cx.background_executor();
427
428 let language_model_registry = LanguageModelRegistry::global(cx);
429
430 let configured_providers = language_model_registry
431 .read(cx)
432 .visible_providers()
433 .into_iter()
434 .filter(|provider| provider.is_authenticated(cx))
435 .collect::<Vec<_>>();
436
437 let configured_provider_ids = configured_providers
438 .iter()
439 .map(|provider| provider.id())
440 .collect::<Vec<_>>();
441
442 let recommended_models = all_models
443 .recommended
444 .iter()
445 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
446 .cloned()
447 .collect::<Vec<_>>();
448
449 let available_models = all_models
450 .all
451 .values()
452 .flat_map(|models| models.iter())
453 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
454 .cloned()
455 .collect::<Vec<_>>();
456
457 let matcher_rec =
458 ModelMatcher::new(recommended_models, fg_executor.clone(), bg_executor.clone());
459 let matcher_all =
460 ModelMatcher::new(available_models, fg_executor.clone(), bg_executor.clone());
461
462 let recommended = matcher_rec.exact_search(&query);
463 let all = matcher_all.fuzzy_search(&query);
464
465 let filtered_models = GroupedModels::new(all, recommended);
466
467 cx.spawn_in(window, async move |this, cx| {
468 this.update_in(cx, |this, window, cx| {
469 this.delegate.filtered_entries = filtered_models.entries();
470 // Finds the currently selected model in the list
471 let new_index =
472 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
473 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
474 cx.notify();
475 })
476 .ok();
477 })
478 }
479
480 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
481 if let Some(LanguageModelPickerEntry::Model(model_info)) =
482 self.filtered_entries.get(self.selected_index)
483 {
484 let model = model_info.model.clone();
485 (self.on_model_changed)(model.clone(), cx);
486
487 let current_index = self.selected_index;
488 self.set_selected_index(current_index, window, cx);
489
490 cx.emit(DismissEvent);
491 }
492 }
493
494 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
495 cx.emit(DismissEvent);
496 }
497
498 fn render_match(
499 &self,
500 ix: usize,
501 selected: bool,
502 _: &mut Window,
503 cx: &mut Context<Picker<Self>>,
504 ) -> Option<Self::ListItem> {
505 match self.filtered_entries.get(ix)? {
506 LanguageModelPickerEntry::Separator(title) => {
507 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
508 }
509 LanguageModelPickerEntry::Model(model_info) => {
510 let active_model = (self.get_active_model)(cx);
511 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
512 let active_model_id = active_model.map(|m| m.model.id());
513
514 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
515 && Some(model_info.model.id()) == active_model_id;
516
517 let model_cost = model_info
518 .model
519 .model_cost_info()
520 .map(|cost| cost.to_shared_string());
521
522 let is_favorite = model_info.is_favorite;
523 let handle_action_click = {
524 let model = model_info.model.clone();
525 let on_toggle_favorite = self.on_toggle_favorite.clone();
526 cx.listener(move |picker, _, window, cx| {
527 on_toggle_favorite(model.clone(), !is_favorite, cx);
528 picker.refresh(window, cx);
529 })
530 };
531
532 Some(
533 ModelSelectorListItem::new(ix, model_info.model.name().0)
534 .map(|this| match &model_info.icon {
535 IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
536 IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
537 })
538 .is_selected(is_selected)
539 .is_focused(selected)
540 .is_latest(model_info.model.is_latest())
541 .is_favorite(is_favorite)
542 .cost_info(model_cost)
543 .on_toggle_favorite(handle_action_click)
544 .into_any_element(),
545 )
546 }
547 }
548 }
549
550 fn render_footer(
551 &self,
552 _window: &mut Window,
553 _cx: &mut Context<Picker<Self>>,
554 ) -> Option<gpui::AnyElement> {
555 let focus_handle = self.focus_handle.clone();
556
557 if !self.popover_styles {
558 return None;
559 }
560
561 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use futures::{future::BoxFuture, stream::BoxStream};
569 use gpui::{AsyncApp, TestAppContext, http_client};
570 use language_model::{
571 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
572 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
573 LanguageModelRequest, LanguageModelToolChoice,
574 };
575 use ui::IconName;
576
577 #[derive(Clone)]
578 struct TestLanguageModel {
579 name: LanguageModelName,
580 id: LanguageModelId,
581 provider_id: LanguageModelProviderId,
582 provider_name: LanguageModelProviderName,
583 }
584
585 impl TestLanguageModel {
586 fn new(name: &str, provider: &str) -> Self {
587 Self {
588 name: LanguageModelName::from(name.to_string()),
589 id: LanguageModelId::from(name.to_string()),
590 provider_id: LanguageModelProviderId::from(provider.to_string()),
591 provider_name: LanguageModelProviderName::from(provider.to_string()),
592 }
593 }
594 }
595
596 impl LanguageModel for TestLanguageModel {
597 fn id(&self) -> LanguageModelId {
598 self.id.clone()
599 }
600
601 fn name(&self) -> LanguageModelName {
602 self.name.clone()
603 }
604
605 fn provider_id(&self) -> LanguageModelProviderId {
606 self.provider_id.clone()
607 }
608
609 fn provider_name(&self) -> LanguageModelProviderName {
610 self.provider_name.clone()
611 }
612
613 fn supports_tools(&self) -> bool {
614 false
615 }
616
617 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
618 false
619 }
620
621 fn supports_images(&self) -> bool {
622 false
623 }
624
625 fn telemetry_id(&self) -> String {
626 format!("{}/{}", self.provider_id.0, self.name.0)
627 }
628
629 fn max_token_count(&self) -> u64 {
630 1000
631 }
632
633 fn count_tokens(
634 &self,
635 _: LanguageModelRequest,
636 _: &App,
637 ) -> BoxFuture<'static, http_client::Result<u64>> {
638 unimplemented!()
639 }
640
641 fn stream_completion(
642 &self,
643 _: LanguageModelRequest,
644 _: &AsyncApp,
645 ) -> BoxFuture<
646 'static,
647 Result<
648 BoxStream<
649 'static,
650 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
651 >,
652 LanguageModelCompletionError,
653 >,
654 > {
655 unimplemented!()
656 }
657 }
658
659 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
660 create_models_with_favorites(model_specs, vec![])
661 }
662
663 fn create_models_with_favorites(
664 model_specs: Vec<(&str, &str)>,
665 favorites: Vec<(&str, &str)>,
666 ) -> Vec<ModelInfo> {
667 model_specs
668 .into_iter()
669 .map(|(provider, name)| {
670 let is_favorite = favorites
671 .iter()
672 .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
673 ModelInfo {
674 model: Arc::new(TestLanguageModel::new(name, provider)),
675 icon: IconOrSvg::Icon(IconName::ZedAgent),
676 is_favorite,
677 }
678 })
679 .collect()
680 }
681
682 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
683 assert_eq!(
684 result.len(),
685 expected.len(),
686 "Number of models doesn't match"
687 );
688
689 for (i, expected_name) in expected.iter().enumerate() {
690 assert_eq!(
691 result[i].model.telemetry_id(),
692 *expected_name,
693 "Model at position {} doesn't match expected model",
694 i
695 );
696 }
697 }
698
699 #[gpui::test]
700 fn test_exact_match(cx: &mut TestAppContext) {
701 let models = create_models(vec![
702 ("zed", "Claude 3.7 Sonnet"),
703 ("zed", "Claude 3.7 Sonnet Thinking"),
704 ("zed", "gpt-5"),
705 ("zed", "gpt-5-mini"),
706 ("openai", "gpt-3.5-turbo"),
707 ("openai", "gpt-5"),
708 ("openai", "gpt-5-mini"),
709 ("ollama", "mistral"),
710 ("ollama", "deepseek"),
711 ]);
712 let matcher = ModelMatcher::new(
713 models,
714 cx.foreground_executor().clone(),
715 cx.background_executor.clone(),
716 );
717
718 // The order of models should be maintained, case doesn't matter
719 let results = matcher.exact_search("GPT-5");
720 assert_models_eq(
721 results,
722 vec![
723 "zed/gpt-5",
724 "zed/gpt-5-mini",
725 "openai/gpt-5",
726 "openai/gpt-5-mini",
727 ],
728 );
729 }
730
731 #[gpui::test]
732 fn test_fuzzy_match(cx: &mut TestAppContext) {
733 let models = create_models(vec![
734 ("zed", "Claude 3.7 Sonnet"),
735 ("zed", "Claude 3.7 Sonnet Thinking"),
736 ("zed", "gpt-5"),
737 ("zed", "gpt-5-mini"),
738 ("openai", "gpt-3.5-turbo"),
739 ("openai", "gpt-5"),
740 ("openai", "gpt-5-mini"),
741 ("ollama", "mistral"),
742 ("ollama", "deepseek"),
743 ]);
744 let matcher = ModelMatcher::new(
745 models,
746 cx.foreground_executor().clone(),
747 cx.background_executor.clone(),
748 );
749
750 // Results should preserve models order whenever possible.
751 // In the case below, `zed/gpt-5-mini` and `openai/gpt-5-mini` have identical
752 // similarity scores, but `zed/gpt-5-mini` was higher in the models list,
753 // so it should appear first in the results.
754 let results = matcher.fuzzy_search("mini");
755 assert_models_eq(results, vec!["zed/gpt-5-mini", "openai/gpt-5-mini"]);
756
757 // Model provider should be searchable as well
758 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
759 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
760
761 // Fuzzy search - search for Claude to get the Thinking variant
762 let results = matcher.fuzzy_search("thinking");
763 assert_models_eq(results, vec!["zed/Claude 3.7 Sonnet Thinking"]);
764 }
765
766 #[gpui::test]
767 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
768 let recommended_models = create_models(vec![("zed", "claude")]);
769 let all_models = create_models(vec![
770 ("zed", "claude"), // Should also appear in "other"
771 ("zed", "gemini"),
772 ("copilot", "o3"),
773 ]);
774
775 let grouped_models = GroupedModels::new(all_models, recommended_models);
776
777 let actual_all_models = grouped_models
778 .all
779 .values()
780 .flatten()
781 .cloned()
782 .collect::<Vec<_>>();
783
784 // Recommended models should also appear in "all"
785 assert_models_eq(
786 actual_all_models,
787 vec!["zed/claude", "zed/gemini", "copilot/o3"],
788 );
789 }
790
791 #[gpui::test]
792 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
793 let recommended_models = create_models(vec![("zed", "claude")]);
794 let all_models = create_models(vec![
795 ("zed", "claude"), // Should also appear in "other"
796 ("zed", "gemini"),
797 ("copilot", "claude"), // Different provider, should appear in "other"
798 ]);
799
800 let grouped_models = GroupedModels::new(all_models, recommended_models);
801
802 let actual_all_models = grouped_models
803 .all
804 .values()
805 .flatten()
806 .cloned()
807 .collect::<Vec<_>>();
808
809 // All models should appear in "all" regardless of recommended status
810 assert_models_eq(
811 actual_all_models,
812 vec!["zed/claude", "zed/gemini", "copilot/claude"],
813 );
814 }
815
816 #[gpui::test]
817 fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
818 let recommended_models = create_models(vec![("zed", "claude")]);
819 let all_models = create_models_with_favorites(
820 vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
821 vec![("zed", "gemini")],
822 );
823
824 let grouped_models = GroupedModels::new(all_models, recommended_models);
825 let entries = grouped_models.entries();
826
827 assert!(matches!(
828 entries.first(),
829 Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
830 ));
831
832 assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
833 }
834
835 #[gpui::test]
836 fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
837 let recommended_models = create_models(vec![("zed", "claude")]);
838 let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
839
840 let grouped_models = GroupedModels::new(all_models, recommended_models);
841 let entries = grouped_models.entries();
842
843 assert!(matches!(
844 entries.first(),
845 Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
846 ));
847
848 assert!(grouped_models.favorites.is_empty());
849 }
850
851 #[gpui::test]
852 fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
853 let recommended_models =
854 create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
855 let all_models = create_models_with_favorites(
856 vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
857 vec![("zed", "claude")],
858 );
859
860 let grouped_models = GroupedModels::new(all_models, recommended_models);
861 let entries = grouped_models.entries();
862
863 for entry in &entries {
864 if let LanguageModelPickerEntry::Model(info) = entry {
865 if info.model.telemetry_id() == "zed/claude" {
866 assert!(info.is_favorite, "zed/claude should be a favorite");
867 } else {
868 assert!(
869 !info.is_favorite,
870 "{} should not be a favorite",
871 info.model.telemetry_id()
872 );
873 }
874 }
875 }
876 }
877
878 #[gpui::test]
879 fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
880 let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
881
882 let recommended_models =
883 create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
884
885 let all_models = create_models_with_favorites(
886 vec![
887 ("zed", "claude"),
888 ("zed", "gemini"),
889 ("openai", "gpt-4"),
890 ("openai", "gpt-3.5"),
891 ],
892 favorites,
893 );
894
895 let grouped_models = GroupedModels::new(all_models, recommended_models);
896
897 assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
898 assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
899 assert_models_eq(
900 grouped_models.all.values().flatten().cloned().collect(),
901 vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
902 );
903 }
904}