1use std::{cmp::Reverse, sync::Arc};
2
3use agent_settings::AgentSettings;
4use collections::{HashMap, HashSet, IndexMap};
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{
7 Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, ForegroundExecutor,
8 Subscription, Task,
9};
10use language_model::{
11 ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId, LanguageModelProvider,
12 LanguageModelProviderId, LanguageModelRegistry,
13};
14use ordered_float::OrderedFloat;
15use picker::{Picker, PickerDelegate};
16use settings::Settings;
17use ui::prelude::*;
18use zed_actions::agent::OpenSettings;
19
20use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
21
22type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
23type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
24type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static>;
25
26pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
27
28pub fn language_model_selector(
29 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
30 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
31 on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
32 popover_styles: bool,
33 focus_handle: FocusHandle,
34 window: &mut Window,
35 cx: &mut Context<LanguageModelSelector>,
36) -> LanguageModelSelector {
37 let delegate = LanguageModelPickerDelegate::new(
38 get_active_model,
39 on_model_changed,
40 on_toggle_favorite,
41 popover_styles,
42 focus_handle,
43 window,
44 cx,
45 );
46
47 if popover_styles {
48 Picker::list(delegate, window, cx)
49 .show_scrollbar(true)
50 .width(rems(20.))
51 .max_height(Some(rems(20.).into()))
52 } else {
53 Picker::list(delegate, window, cx).show_scrollbar(true)
54 }
55}
56
57fn all_models(cx: &App) -> GroupedModels {
58 let lm_registry = LanguageModelRegistry::global(cx).read(cx);
59 let providers = lm_registry.visible_providers();
60
61 let mut favorites_index = FavoritesIndex::default();
62
63 for sel in &AgentSettings::get_global(cx).favorite_models {
64 favorites_index
65 .entry(sel.provider.0.clone().into())
66 .or_default()
67 .insert(sel.model.clone().into());
68 }
69
70 let recommended = providers
71 .iter()
72 .flat_map(|provider| {
73 provider
74 .recommended_models(cx)
75 .into_iter()
76 .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
77 })
78 .collect();
79
80 let all = providers
81 .iter()
82 .flat_map(|provider| {
83 provider
84 .provided_models(cx)
85 .into_iter()
86 .map(|model| ModelInfo::new(&**provider, model, &favorites_index))
87 })
88 .collect();
89
90 GroupedModels::new(all, recommended)
91}
92
93type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
94
95#[derive(Clone)]
96struct ModelInfo {
97 model: Arc<dyn LanguageModel>,
98 icon: IconOrSvg,
99 is_favorite: bool,
100}
101
102impl ModelInfo {
103 fn new(
104 provider: &dyn LanguageModelProvider,
105 model: Arc<dyn LanguageModel>,
106 favorites_index: &FavoritesIndex,
107 ) -> Self {
108 let is_favorite = favorites_index
109 .get(&provider.id())
110 .map_or(false, |set| set.contains(&model.id()));
111
112 Self {
113 model,
114 icon: provider.icon(),
115 is_favorite,
116 }
117 }
118}
119
120pub struct LanguageModelPickerDelegate {
121 on_model_changed: OnModelChanged,
122 get_active_model: GetActiveModel,
123 on_toggle_favorite: OnToggleFavorite,
124 all_models: Arc<GroupedModels>,
125 filtered_entries: Vec<LanguageModelPickerEntry>,
126 selected_index: usize,
127 _subscriptions: Vec<Subscription>,
128 popover_styles: bool,
129 focus_handle: FocusHandle,
130}
131
132impl LanguageModelPickerDelegate {
133 fn new(
134 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
135 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
136 on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
137 popover_styles: bool,
138 focus_handle: FocusHandle,
139 window: &mut Window,
140 cx: &mut Context<Picker<Self>>,
141 ) -> Self {
142 let on_model_changed = Arc::new(on_model_changed);
143 let models = all_models(cx);
144 let entries = models.entries();
145
146 Self {
147 on_model_changed,
148 all_models: Arc::new(models),
149 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
150 filtered_entries: entries,
151 get_active_model: Arc::new(get_active_model),
152 on_toggle_favorite: Arc::new(on_toggle_favorite),
153 _subscriptions: vec![cx.subscribe_in(
154 &LanguageModelRegistry::global(cx),
155 window,
156 |picker, _, event, window, cx| {
157 match event {
158 language_model::Event::ProviderStateChanged(_)
159 | language_model::Event::AddedProvider(_)
160 | language_model::Event::RemovedProvider(_) => {
161 let query = picker.query(cx);
162 picker.delegate.all_models = Arc::new(all_models(cx));
163 // Update matches will automatically drop the previous task
164 // if we get a provider event again
165 picker.update_matches(query, window, cx)
166 }
167 _ => {}
168 }
169 },
170 )],
171 popover_styles,
172 focus_handle,
173 }
174 }
175
176 fn get_active_model_index(
177 entries: &[LanguageModelPickerEntry],
178 active_model: Option<ConfiguredModel>,
179 ) -> usize {
180 entries
181 .iter()
182 .position(|entry| {
183 if let LanguageModelPickerEntry::Model(model) = entry {
184 active_model
185 .as_ref()
186 .map(|active_model| {
187 active_model.model.id() == model.model.id()
188 && active_model.provider.id() == model.model.provider_id()
189 })
190 .unwrap_or_default()
191 } else {
192 false
193 }
194 })
195 .unwrap_or(0)
196 }
197
198 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
199 (self.get_active_model)(cx)
200 }
201
202 pub fn favorites_count(&self) -> usize {
203 self.all_models.favorites.len()
204 }
205
206 pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
207 if self.all_models.favorites.is_empty() {
208 return;
209 }
210
211 let active_model = (self.get_active_model)(cx);
212 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
213 let active_model_id = active_model.as_ref().map(|m| m.model.id());
214
215 let current_index = self
216 .all_models
217 .favorites
218 .iter()
219 .position(|info| {
220 Some(info.model.provider_id()) == active_provider_id
221 && Some(info.model.id()) == active_model_id
222 })
223 .unwrap_or(usize::MAX);
224
225 let next_index = if current_index == usize::MAX {
226 0
227 } else {
228 (current_index + 1) % self.all_models.favorites.len()
229 };
230
231 let next_model = self.all_models.favorites[next_index].model.clone();
232
233 (self.on_model_changed)(next_model, cx);
234
235 // Align the picker selection with the newly-active model
236 let new_index =
237 Self::get_active_model_index(&self.filtered_entries, (self.get_active_model)(cx));
238 self.set_selected_index(new_index, window, cx);
239 }
240}
241
242struct GroupedModels {
243 favorites: Vec<ModelInfo>,
244 recommended: Vec<ModelInfo>,
245 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
246}
247
248impl GroupedModels {
249 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
250 let favorites = all
251 .iter()
252 .filter(|info| info.is_favorite)
253 .cloned()
254 .collect();
255
256 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
257 for model in all {
258 let provider = model.model.provider_id();
259 if let Some(models) = all_by_provider.get_mut(&provider) {
260 models.push(model);
261 } else {
262 all_by_provider.insert(provider, vec![model]);
263 }
264 }
265
266 Self {
267 favorites,
268 recommended,
269 all: all_by_provider,
270 }
271 }
272
273 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
274 let mut entries = Vec::new();
275
276 if !self.favorites.is_empty() {
277 entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
278 for info in &self.favorites {
279 entries.push(LanguageModelPickerEntry::Model(info.clone()));
280 }
281 }
282
283 if !self.recommended.is_empty() {
284 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
285 for info in &self.recommended {
286 entries.push(LanguageModelPickerEntry::Model(info.clone()));
287 }
288 }
289
290 for models in self.all.values() {
291 if models.is_empty() {
292 continue;
293 }
294 entries.push(LanguageModelPickerEntry::Separator(
295 models[0].model.provider_name().0,
296 ));
297 for info in models {
298 entries.push(LanguageModelPickerEntry::Model(info.clone()));
299 }
300 }
301
302 entries
303 }
304}
305
306enum LanguageModelPickerEntry {
307 Model(ModelInfo),
308 Separator(SharedString),
309}
310
311struct ModelMatcher {
312 models: Vec<ModelInfo>,
313 fg_executor: ForegroundExecutor,
314 bg_executor: BackgroundExecutor,
315 candidates: Vec<StringMatchCandidate>,
316}
317
318impl ModelMatcher {
319 fn new(
320 models: Vec<ModelInfo>,
321 fg_executor: ForegroundExecutor,
322 bg_executor: BackgroundExecutor,
323 ) -> ModelMatcher {
324 let candidates = Self::make_match_candidates(&models);
325 Self {
326 models,
327 fg_executor,
328 bg_executor,
329 candidates,
330 }
331 }
332
333 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
334 let mut matches = self.fg_executor.block_on(match_strings(
335 &self.candidates,
336 query,
337 false,
338 true,
339 100,
340 &Default::default(),
341 self.bg_executor.clone(),
342 ));
343
344 let sorting_key = |mat: &StringMatch| {
345 let candidate = &self.candidates[mat.candidate_id];
346 (Reverse(OrderedFloat(mat.score)), candidate.id)
347 };
348 matches.sort_unstable_by_key(sorting_key);
349
350 let matched_models: Vec<_> = matches
351 .into_iter()
352 .map(|mat| self.models[mat.candidate_id].clone())
353 .collect();
354
355 matched_models
356 }
357
358 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
359 self.models
360 .iter()
361 .filter(|m| {
362 m.model
363 .name()
364 .0
365 .to_lowercase()
366 .contains(&query.to_lowercase())
367 })
368 .cloned()
369 .collect::<Vec<_>>()
370 }
371
372 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
373 model_infos
374 .iter()
375 .enumerate()
376 .map(|(index, model)| {
377 StringMatchCandidate::new(
378 index,
379 &format!(
380 "{}/{}",
381 &model.model.provider_name().0,
382 &model.model.name().0
383 ),
384 )
385 })
386 .collect::<Vec<_>>()
387 }
388}
389
390impl PickerDelegate for LanguageModelPickerDelegate {
391 type ListItem = AnyElement;
392
393 fn match_count(&self) -> usize {
394 self.filtered_entries.len()
395 }
396
397 fn selected_index(&self) -> usize {
398 self.selected_index
399 }
400
401 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
402 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
403 cx.notify();
404 }
405
406 fn can_select(&self, ix: usize, _window: &mut Window, _cx: &mut Context<Picker<Self>>) -> bool {
407 match self.filtered_entries.get(ix) {
408 Some(LanguageModelPickerEntry::Model(_)) => true,
409 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
410 }
411 }
412
413 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
414 "Select a model…".into()
415 }
416
417 fn update_matches(
418 &mut self,
419 query: String,
420 window: &mut Window,
421 cx: &mut Context<Picker<Self>>,
422 ) -> Task<()> {
423 let all_models = self.all_models.clone();
424 let active_model = (self.get_active_model)(cx);
425 let fg_executor = cx.foreground_executor();
426 let bg_executor = cx.background_executor();
427
428 let language_model_registry = LanguageModelRegistry::global(cx);
429
430 let configured_providers = language_model_registry
431 .read(cx)
432 .visible_providers()
433 .into_iter()
434 .filter(|provider| provider.is_authenticated(cx))
435 .collect::<Vec<_>>();
436
437 let configured_provider_ids = configured_providers
438 .iter()
439 .map(|provider| provider.id())
440 .collect::<Vec<_>>();
441
442 let recommended_models = all_models
443 .recommended
444 .iter()
445 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
446 .cloned()
447 .collect::<Vec<_>>();
448
449 let available_models = all_models
450 .all
451 .values()
452 .flat_map(|models| models.iter())
453 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
454 .cloned()
455 .collect::<Vec<_>>();
456
457 let matcher_rec =
458 ModelMatcher::new(recommended_models, fg_executor.clone(), bg_executor.clone());
459 let matcher_all =
460 ModelMatcher::new(available_models, fg_executor.clone(), bg_executor.clone());
461
462 let recommended = matcher_rec.exact_search(&query);
463 let all = matcher_all.fuzzy_search(&query);
464
465 let filtered_models = GroupedModels::new(all, recommended);
466
467 cx.spawn_in(window, async move |this, cx| {
468 this.update_in(cx, |this, window, cx| {
469 this.delegate.filtered_entries = filtered_models.entries();
470 // Finds the currently selected model in the list
471 let new_index =
472 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
473 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
474 cx.notify();
475 })
476 .ok();
477 })
478 }
479
480 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
481 if let Some(LanguageModelPickerEntry::Model(model_info)) =
482 self.filtered_entries.get(self.selected_index)
483 {
484 let model = model_info.model.clone();
485 (self.on_model_changed)(model.clone(), cx);
486
487 let current_index = self.selected_index;
488 self.set_selected_index(current_index, window, cx);
489
490 cx.emit(DismissEvent);
491 }
492 }
493
494 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
495 cx.emit(DismissEvent);
496 }
497
498 fn render_match(
499 &self,
500 ix: usize,
501 selected: bool,
502 _: &mut Window,
503 cx: &mut Context<Picker<Self>>,
504 ) -> Option<Self::ListItem> {
505 match self.filtered_entries.get(ix)? {
506 LanguageModelPickerEntry::Separator(title) => {
507 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
508 }
509 LanguageModelPickerEntry::Model(model_info) => {
510 let active_model = (self.get_active_model)(cx);
511 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
512 let active_model_id = active_model.map(|m| m.model.id());
513
514 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
515 && Some(model_info.model.id()) == active_model_id;
516
517 let model_cost = model_info
518 .model
519 .model_cost_info()
520 .map(|cost| cost.to_shared_string());
521
522 let is_favorite = model_info.is_favorite;
523 let handle_action_click = {
524 let model = model_info.model.clone();
525 let on_toggle_favorite = self.on_toggle_favorite.clone();
526 cx.listener(move |picker, _, window, cx| {
527 on_toggle_favorite(model.clone(), !is_favorite, cx);
528 picker.refresh(window, cx);
529 })
530 };
531
532 Some(
533 ModelSelectorListItem::new(ix, model_info.model.name().0)
534 .map(|this| match &model_info.icon {
535 IconOrSvg::Icon(icon_name) => this.icon(*icon_name),
536 IconOrSvg::Svg(icon_path) => this.icon_path(icon_path.clone()),
537 })
538 .is_selected(is_selected)
539 .is_focused(selected)
540 .is_latest(model_info.model.is_latest())
541 .is_favorite(is_favorite)
542 .cost_info(model_cost)
543 .on_toggle_favorite(handle_action_click)
544 .into_any_element(),
545 )
546 }
547 }
548 }
549
550 fn render_footer(
551 &self,
552 _window: &mut Window,
553 _cx: &mut Context<Picker<Self>>,
554 ) -> Option<gpui::AnyElement> {
555 let focus_handle = self.focus_handle.clone();
556
557 if !self.popover_styles {
558 return None;
559 }
560
561 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use futures::{future::BoxFuture, stream::BoxStream};
569 use gpui::{AsyncApp, TestAppContext};
570 use language_model::{
571 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
572 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
573 LanguageModelRequest, LanguageModelToolChoice,
574 };
575 use ui::IconName;
576
577 #[derive(Clone)]
578 struct TestLanguageModel {
579 name: LanguageModelName,
580 id: LanguageModelId,
581 provider_id: LanguageModelProviderId,
582 provider_name: LanguageModelProviderName,
583 }
584
585 impl TestLanguageModel {
586 fn new(name: &str, provider: &str) -> Self {
587 Self {
588 name: LanguageModelName::from(name.to_string()),
589 id: LanguageModelId::from(name.to_string()),
590 provider_id: LanguageModelProviderId::from(provider.to_string()),
591 provider_name: LanguageModelProviderName::from(provider.to_string()),
592 }
593 }
594 }
595
596 impl LanguageModel for TestLanguageModel {
597 fn id(&self) -> LanguageModelId {
598 self.id.clone()
599 }
600
601 fn name(&self) -> LanguageModelName {
602 self.name.clone()
603 }
604
605 fn provider_id(&self) -> LanguageModelProviderId {
606 self.provider_id.clone()
607 }
608
609 fn provider_name(&self) -> LanguageModelProviderName {
610 self.provider_name.clone()
611 }
612
613 fn supports_tools(&self) -> bool {
614 false
615 }
616
617 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
618 false
619 }
620
621 fn supports_images(&self) -> bool {
622 false
623 }
624
625 fn telemetry_id(&self) -> String {
626 format!("{}/{}", self.provider_id.0, self.name.0)
627 }
628
629 fn max_token_count(&self) -> u64 {
630 1000
631 }
632
633 fn stream_completion(
634 &self,
635 _: LanguageModelRequest,
636 _: &AsyncApp,
637 ) -> BoxFuture<
638 'static,
639 Result<
640 BoxStream<
641 'static,
642 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
643 >,
644 LanguageModelCompletionError,
645 >,
646 > {
647 unimplemented!()
648 }
649 }
650
651 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
652 create_models_with_favorites(model_specs, vec![])
653 }
654
655 fn create_models_with_favorites(
656 model_specs: Vec<(&str, &str)>,
657 favorites: Vec<(&str, &str)>,
658 ) -> Vec<ModelInfo> {
659 model_specs
660 .into_iter()
661 .map(|(provider, name)| {
662 let is_favorite = favorites
663 .iter()
664 .any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
665 ModelInfo {
666 model: Arc::new(TestLanguageModel::new(name, provider)),
667 icon: IconOrSvg::Icon(IconName::ZedAgent),
668 is_favorite,
669 }
670 })
671 .collect()
672 }
673
674 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
675 assert_eq!(
676 result.len(),
677 expected.len(),
678 "Number of models doesn't match"
679 );
680
681 for (i, expected_name) in expected.iter().enumerate() {
682 assert_eq!(
683 result[i].model.telemetry_id(),
684 *expected_name,
685 "Model at position {} doesn't match expected model",
686 i
687 );
688 }
689 }
690
691 #[gpui::test]
692 fn test_exact_match(cx: &mut TestAppContext) {
693 let models = create_models(vec![
694 ("zed", "Claude 3.7 Sonnet"),
695 ("zed", "Claude 3.7 Sonnet Thinking"),
696 ("zed", "gpt-5"),
697 ("zed", "gpt-5-mini"),
698 ("openai", "gpt-3.5-turbo"),
699 ("openai", "gpt-5"),
700 ("openai", "gpt-5-mini"),
701 ("ollama", "mistral"),
702 ("ollama", "deepseek"),
703 ]);
704 let matcher = ModelMatcher::new(
705 models,
706 cx.foreground_executor().clone(),
707 cx.background_executor.clone(),
708 );
709
710 // The order of models should be maintained, case doesn't matter
711 let results = matcher.exact_search("GPT-5");
712 assert_models_eq(
713 results,
714 vec![
715 "zed/gpt-5",
716 "zed/gpt-5-mini",
717 "openai/gpt-5",
718 "openai/gpt-5-mini",
719 ],
720 );
721 }
722
723 #[gpui::test]
724 fn test_fuzzy_match(cx: &mut TestAppContext) {
725 let models = create_models(vec![
726 ("zed", "Claude 3.7 Sonnet"),
727 ("zed", "Claude 3.7 Sonnet Thinking"),
728 ("zed", "gpt-5"),
729 ("zed", "gpt-5-mini"),
730 ("openai", "gpt-3.5-turbo"),
731 ("openai", "gpt-5"),
732 ("openai", "gpt-5-mini"),
733 ("ollama", "mistral"),
734 ("ollama", "deepseek"),
735 ]);
736 let matcher = ModelMatcher::new(
737 models,
738 cx.foreground_executor().clone(),
739 cx.background_executor.clone(),
740 );
741
742 // Results should preserve models order whenever possible.
743 // In the case below, `zed/gpt-5-mini` and `openai/gpt-5-mini` have identical
744 // similarity scores, but `zed/gpt-5-mini` was higher in the models list,
745 // so it should appear first in the results.
746 let results = matcher.fuzzy_search("mini");
747 assert_models_eq(results, vec!["zed/gpt-5-mini", "openai/gpt-5-mini"]);
748
749 // Model provider should be searchable as well
750 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
751 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
752
753 // Fuzzy search - search for Claude to get the Thinking variant
754 let results = matcher.fuzzy_search("thinking");
755 assert_models_eq(results, vec!["zed/Claude 3.7 Sonnet Thinking"]);
756 }
757
758 #[gpui::test]
759 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
760 let recommended_models = create_models(vec![("zed", "claude")]);
761 let all_models = create_models(vec![
762 ("zed", "claude"), // Should also appear in "other"
763 ("zed", "gemini"),
764 ("copilot", "o3"),
765 ]);
766
767 let grouped_models = GroupedModels::new(all_models, recommended_models);
768
769 let actual_all_models = grouped_models
770 .all
771 .values()
772 .flatten()
773 .cloned()
774 .collect::<Vec<_>>();
775
776 // Recommended models should also appear in "all"
777 assert_models_eq(
778 actual_all_models,
779 vec!["zed/claude", "zed/gemini", "copilot/o3"],
780 );
781 }
782
783 #[gpui::test]
784 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
785 let recommended_models = create_models(vec![("zed", "claude")]);
786 let all_models = create_models(vec![
787 ("zed", "claude"), // Should also appear in "other"
788 ("zed", "gemini"),
789 ("copilot", "claude"), // Different provider, should appear in "other"
790 ]);
791
792 let grouped_models = GroupedModels::new(all_models, recommended_models);
793
794 let actual_all_models = grouped_models
795 .all
796 .values()
797 .flatten()
798 .cloned()
799 .collect::<Vec<_>>();
800
801 // All models should appear in "all" regardless of recommended status
802 assert_models_eq(
803 actual_all_models,
804 vec!["zed/claude", "zed/gemini", "copilot/claude"],
805 );
806 }
807
808 #[gpui::test]
809 fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
810 let recommended_models = create_models(vec![("zed", "claude")]);
811 let all_models = create_models_with_favorites(
812 vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
813 vec![("zed", "gemini")],
814 );
815
816 let grouped_models = GroupedModels::new(all_models, recommended_models);
817 let entries = grouped_models.entries();
818
819 assert!(matches!(
820 entries.first(),
821 Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
822 ));
823
824 assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
825 }
826
827 #[gpui::test]
828 fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
829 let recommended_models = create_models(vec![("zed", "claude")]);
830 let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
831
832 let grouped_models = GroupedModels::new(all_models, recommended_models);
833 let entries = grouped_models.entries();
834
835 assert!(matches!(
836 entries.first(),
837 Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
838 ));
839
840 assert!(grouped_models.favorites.is_empty());
841 }
842
843 #[gpui::test]
844 fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
845 let recommended_models =
846 create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
847 let all_models = create_models_with_favorites(
848 vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
849 vec![("zed", "claude")],
850 );
851
852 let grouped_models = GroupedModels::new(all_models, recommended_models);
853 let entries = grouped_models.entries();
854
855 for entry in &entries {
856 if let LanguageModelPickerEntry::Model(info) = entry {
857 if info.model.telemetry_id() == "zed/claude" {
858 assert!(info.is_favorite, "zed/claude should be a favorite");
859 } else {
860 assert!(
861 !info.is_favorite,
862 "{} should not be a favorite",
863 info.model.telemetry_id()
864 );
865 }
866 }
867 }
868 }
869
870 #[gpui::test]
871 fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
872 let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
873
874 let recommended_models =
875 create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
876
877 let all_models = create_models_with_favorites(
878 vec![
879 ("zed", "claude"),
880 ("zed", "gemini"),
881 ("openai", "gpt-4"),
882 ("openai", "gpt-3.5"),
883 ],
884 favorites,
885 );
886
887 let grouped_models = GroupedModels::new(all_models, recommended_models);
888
889 assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
890 assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
891 assert_models_eq(
892 grouped_models.all.values().flatten().cloned().collect(),
893 vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
894 );
895 }
896}