1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_client_protocol::ModelId;
5use agent_servers::AgentServer;
6
7use anyhow::Result;
8use collections::{HashSet, IndexMap};
9use fs::Fs;
10use futures::FutureExt;
11use fuzzy::{StringMatchCandidate, match_strings};
12use gpui::{
13 Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
14 WeakEntity,
15};
16use itertools::Itertools;
17use ordered_float::OrderedFloat;
18use picker::{Picker, PickerDelegate};
19use settings::SettingsStore;
20use ui::{DocumentationAside, IntoElement, prelude::*};
21use util::ResultExt;
22use zed_actions::agent::OpenSettings;
23
24use crate::ui::{
25 HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem,
26 documentation_aside_side,
27};
28
29pub type ModelSelector = Picker<ModelPickerDelegate>;
30
31pub fn acp_model_selector(
32 selector: Rc<dyn AgentModelSelector>,
33 agent_server: Rc<dyn AgentServer>,
34 fs: Arc<dyn Fs>,
35 focus_handle: FocusHandle,
36 window: &mut Window,
37 cx: &mut Context<ModelSelector>,
38) -> ModelSelector {
39 let delegate = ModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
40 Picker::list(delegate, window, cx)
41 .show_scrollbar(true)
42 .width(rems(20.))
43 .max_height(Some(rems(20.).into()))
44}
45
46enum ModelPickerEntry {
47 Separator(SharedString),
48 Model(AgentModelInfo, bool),
49}
50
51pub struct ModelPickerDelegate {
52 selector: Rc<dyn AgentModelSelector>,
53 agent_server: Rc<dyn AgentServer>,
54 fs: Arc<dyn Fs>,
55 filtered_entries: Vec<ModelPickerEntry>,
56 models: Option<AgentModelList>,
57 selected_index: usize,
58 selected_description: Option<(usize, SharedString, bool)>,
59 selected_model: Option<AgentModelInfo>,
60 favorites: HashSet<ModelId>,
61 _refresh_models_task: Task<()>,
62 _settings_subscription: Subscription,
63 focus_handle: FocusHandle,
64}
65
66impl ModelPickerDelegate {
67 fn new(
68 selector: Rc<dyn AgentModelSelector>,
69 agent_server: Rc<dyn AgentServer>,
70 fs: Arc<dyn Fs>,
71 focus_handle: FocusHandle,
72 window: &mut Window,
73 cx: &mut Context<ModelSelector>,
74 ) -> Self {
75 let rx = selector.watch(cx);
76 let refresh_models_task = {
77 cx.spawn_in(window, {
78 async move |this, cx| {
79 async fn refresh(
80 this: &WeakEntity<Picker<ModelPickerDelegate>>,
81 cx: &mut AsyncWindowContext,
82 ) -> Result<()> {
83 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
84 (
85 this.delegate.selector.list_models(cx),
86 this.delegate.selector.selected_model(cx),
87 )
88 })?;
89
90 let (models, selected_model) =
91 futures::join!(models_task, selected_model_task);
92
93 this.update_in(cx, |this, window, cx| {
94 this.delegate.models = models.ok();
95 this.delegate.selected_model = selected_model.ok();
96 this.refresh(window, cx)
97 })
98 }
99
100 refresh(&this, cx).await.log_err();
101 if let Some(mut rx) = rx {
102 while let Ok(()) = rx.recv().await {
103 refresh(&this, cx).await.log_err();
104 }
105 }
106 }
107 })
108 };
109
110 let agent_server_for_subscription = agent_server.clone();
111 let settings_subscription =
112 cx.observe_global_in::<SettingsStore>(window, move |picker, window, cx| {
113 // Only refresh if the favorites actually changed to avoid redundant work
114 // when other settings are modified (e.g., user editing settings.json)
115 let new_favorites = agent_server_for_subscription.favorite_model_ids(cx);
116 if new_favorites != picker.delegate.favorites {
117 picker.delegate.favorites = new_favorites;
118 picker.refresh(window, cx);
119 }
120 });
121 let favorites = agent_server.favorite_model_ids(cx);
122
123 Self {
124 selector,
125 agent_server,
126 fs,
127 filtered_entries: Vec::new(),
128 models: None,
129 selected_model: None,
130 selected_index: 0,
131 selected_description: None,
132 favorites,
133 _refresh_models_task: refresh_models_task,
134 _settings_subscription: settings_subscription,
135 focus_handle,
136 }
137 }
138
139 pub fn active_model(&self) -> Option<&AgentModelInfo> {
140 self.selected_model.as_ref()
141 }
142
143 pub fn favorites_count(&self) -> usize {
144 self.favorites.len()
145 }
146
147 pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
148 if self.favorites.is_empty() {
149 return;
150 }
151
152 let Some(models) = &self.models else {
153 return;
154 };
155
156 let all_models: Vec<&AgentModelInfo> = match models {
157 AgentModelList::Flat(list) => list.iter().collect(),
158 AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
159 };
160
161 let favorite_models: Vec<_> = all_models
162 .into_iter()
163 .filter(|model| self.favorites.contains(&model.id))
164 .unique_by(|model| &model.id)
165 .collect();
166
167 if favorite_models.is_empty() {
168 return;
169 }
170
171 let current_id = self.selected_model.as_ref().map(|m| &m.id);
172
173 let current_index_in_favorites = current_id
174 .and_then(|id| favorite_models.iter().position(|m| &m.id == id))
175 .unwrap_or(usize::MAX);
176
177 let next_index = if current_index_in_favorites == usize::MAX {
178 0
179 } else {
180 (current_index_in_favorites + 1) % favorite_models.len()
181 };
182
183 let next_model = favorite_models[next_index].clone();
184
185 self.selector
186 .select_model(next_model.id.clone(), cx)
187 .detach_and_log_err(cx);
188
189 self.selected_model = Some(next_model);
190
191 // Keep the picker selection aligned with the newly-selected model
192 if let Some(new_index) = self.filtered_entries.iter().position(|entry| {
193 matches!(entry, ModelPickerEntry::Model(model_info, _) if self.selected_model.as_ref().is_some_and(|selected| model_info.id == selected.id))
194 }) {
195 self.set_selected_index(new_index, window, cx);
196 } else {
197 cx.notify();
198 }
199 }
200}
201
202impl PickerDelegate for ModelPickerDelegate {
203 type ListItem = AnyElement;
204
205 fn match_count(&self) -> usize {
206 self.filtered_entries.len()
207 }
208
209 fn selected_index(&self) -> usize {
210 self.selected_index
211 }
212
213 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
214 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
215 cx.notify();
216 }
217
218 fn can_select(&self, ix: usize, _window: &mut Window, _cx: &mut Context<Picker<Self>>) -> bool {
219 match self.filtered_entries.get(ix) {
220 Some(ModelPickerEntry::Model(_, _)) => true,
221 Some(ModelPickerEntry::Separator(_)) | None => false,
222 }
223 }
224
225 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
226 "Select a model…".into()
227 }
228
229 fn update_matches(
230 &mut self,
231 query: String,
232 window: &mut Window,
233 cx: &mut Context<Picker<Self>>,
234 ) -> Task<()> {
235 let favorites = self.favorites.clone();
236
237 cx.spawn_in(window, async move |this, cx| {
238 let filtered_models = match this
239 .read_with(cx, |this, cx| {
240 this.delegate.models.clone().map(move |models| {
241 fuzzy_search(models, query, cx.background_executor().clone())
242 })
243 })
244 .ok()
245 .flatten()
246 {
247 Some(task) => task.await,
248 None => AgentModelList::Flat(vec![]),
249 };
250
251 this.update_in(cx, |this, window, cx| {
252 this.delegate.filtered_entries =
253 info_list_to_picker_entries(filtered_models, &favorites);
254 // Finds the currently selected model in the list
255 let new_index = this
256 .delegate
257 .selected_model
258 .as_ref()
259 .and_then(|selected| {
260 this.delegate.filtered_entries.iter().position(|entry| {
261 if let ModelPickerEntry::Model(model_info, _) = entry {
262 model_info.id == selected.id
263 } else {
264 false
265 }
266 })
267 })
268 .unwrap_or(0);
269 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
270 cx.notify();
271 })
272 .ok();
273 })
274 }
275
276 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
277 if let Some(ModelPickerEntry::Model(model_info, _)) =
278 self.filtered_entries.get(self.selected_index)
279 {
280 if window.modifiers().secondary() {
281 let default_model = self.agent_server.default_model(cx);
282 let is_default = default_model.as_ref() == Some(&model_info.id);
283
284 self.agent_server.set_default_model(
285 if is_default {
286 None
287 } else {
288 Some(model_info.id.clone())
289 },
290 self.fs.clone(),
291 cx,
292 );
293 }
294
295 self.selector
296 .select_model(model_info.id.clone(), cx)
297 .detach_and_log_err(cx);
298 self.selected_model = Some(model_info.clone());
299 let current_index = self.selected_index;
300 self.set_selected_index(current_index, window, cx);
301
302 cx.emit(DismissEvent);
303 }
304 }
305
306 fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
307 cx.defer_in(window, |picker, window, cx| {
308 picker.set_query("", window, cx);
309 });
310 }
311
312 fn render_match(
313 &self,
314 ix: usize,
315 selected: bool,
316 _: &mut Window,
317 cx: &mut Context<Picker<Self>>,
318 ) -> Option<Self::ListItem> {
319 match self.filtered_entries.get(ix)? {
320 ModelPickerEntry::Separator(title) => {
321 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
322 }
323 ModelPickerEntry::Model(model_info, is_favorite) => {
324 let is_selected = Some(model_info) == self.selected_model.as_ref();
325 let default_model = self.agent_server.default_model(cx);
326 let is_default = default_model.as_ref() == Some(&model_info.id);
327
328 let is_favorite = *is_favorite;
329 let handle_action_click = {
330 let model_id = model_info.id.clone();
331 let fs = self.fs.clone();
332 let agent_server = self.agent_server.clone();
333
334 cx.listener(move |_, _, _, cx| {
335 agent_server.toggle_favorite_model(
336 model_id.clone(),
337 !is_favorite,
338 fs.clone(),
339 cx,
340 );
341 })
342 };
343
344 let model_cost = model_info.cost.clone();
345
346 Some(
347 div()
348 .id(("model-picker-menu-child", ix))
349 .when_some(model_info.description.clone(), |this, description| {
350 this.on_hover(cx.listener(move |menu, hovered, _, cx| {
351 if *hovered {
352 menu.delegate.selected_description =
353 Some((ix, description.clone(), is_default));
354 } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
355 menu.delegate.selected_description = None;
356 }
357 cx.notify();
358 }))
359 })
360 .child(
361 ModelSelectorListItem::new(ix, model_info.name.clone())
362 .map(|this| match &model_info.icon {
363 Some(AgentModelIcon::Path(path)) => this.icon_path(path.clone()),
364 Some(AgentModelIcon::Named(icon)) => this.icon(*icon),
365 None => this,
366 })
367 .is_selected(is_selected)
368 .is_focused(selected)
369 .is_latest(model_info.is_latest)
370 .is_favorite(is_favorite)
371 .on_toggle_favorite(handle_action_click)
372 .cost_info(model_cost)
373 )
374 .into_any_element(),
375 )
376 }
377 }
378 }
379
380 fn documentation_aside(
381 &self,
382 _window: &mut Window,
383 cx: &mut Context<Picker<Self>>,
384 ) -> Option<ui::DocumentationAside> {
385 self.selected_description
386 .as_ref()
387 .map(|(_, description, is_default)| {
388 let description = description.clone();
389 let is_default = *is_default;
390
391 let side = documentation_aside_side(cx);
392
393 DocumentationAside::new(
394 side,
395 Rc::new(move |_| {
396 v_flex()
397 .gap_1()
398 .child(Label::new(description.clone()))
399 .child(HoldForDefault::new(is_default))
400 .into_any_element()
401 }),
402 )
403 })
404 }
405
406 fn documentation_aside_index(&self) -> Option<usize> {
407 self.selected_description.as_ref().map(|(ix, _, _)| *ix)
408 }
409
410 fn render_footer(
411 &self,
412 _window: &mut Window,
413 _cx: &mut Context<Picker<Self>>,
414 ) -> Option<AnyElement> {
415 let focus_handle = self.focus_handle.clone();
416
417 if !self.selector.should_render_footer() {
418 return None;
419 }
420
421 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
422 }
423}
424
425fn info_list_to_picker_entries(
426 model_list: AgentModelList,
427 favorites: &HashSet<ModelId>,
428) -> Vec<ModelPickerEntry> {
429 let mut entries = Vec::new();
430
431 let all_models: Vec<_> = match &model_list {
432 AgentModelList::Flat(list) => list.iter().collect(),
433 AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
434 };
435
436 let favorite_models: Vec<_> = all_models
437 .iter()
438 .filter(|m| favorites.contains(&m.id))
439 .unique_by(|m| &m.id)
440 .collect();
441
442 let has_favorites = !favorite_models.is_empty();
443 if has_favorites {
444 entries.push(ModelPickerEntry::Separator("Favorite".into()));
445 for model in favorite_models {
446 entries.push(ModelPickerEntry::Model((*model).clone(), true));
447 }
448 }
449
450 match model_list {
451 AgentModelList::Flat(list) => {
452 if has_favorites {
453 entries.push(ModelPickerEntry::Separator("All".into()));
454 }
455 for model in list {
456 let is_favorite = favorites.contains(&model.id);
457 entries.push(ModelPickerEntry::Model(model, is_favorite));
458 }
459 }
460 AgentModelList::Grouped(index_map) => {
461 for (group_name, models) in index_map {
462 entries.push(ModelPickerEntry::Separator(group_name.0));
463 for model in models {
464 let is_favorite = favorites.contains(&model.id);
465 entries.push(ModelPickerEntry::Model(model, is_favorite));
466 }
467 }
468 }
469 }
470
471 entries
472}
473
474async fn fuzzy_search(
475 model_list: AgentModelList,
476 query: String,
477 executor: BackgroundExecutor,
478) -> AgentModelList {
479 async fn fuzzy_search_list(
480 model_list: Vec<AgentModelInfo>,
481 query: &str,
482 executor: BackgroundExecutor,
483 ) -> Vec<AgentModelInfo> {
484 let candidates = model_list
485 .iter()
486 .enumerate()
487 .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
488 .collect::<Vec<_>>();
489 let mut matches = match_strings(
490 &candidates,
491 query,
492 false,
493 true,
494 100,
495 &Default::default(),
496 executor,
497 )
498 .await;
499
500 matches.sort_unstable_by_key(|mat| {
501 let candidate = &candidates[mat.candidate_id];
502 (Reverse(OrderedFloat(mat.score)), candidate.id)
503 });
504
505 matches
506 .into_iter()
507 .map(|mat| model_list[mat.candidate_id].clone())
508 .collect()
509 }
510
511 match model_list {
512 AgentModelList::Flat(model_list) => {
513 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
514 }
515 AgentModelList::Grouped(index_map) => {
516 let groups =
517 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
518 fuzzy_search_list(models, &query, executor.clone())
519 .map(|results| (group_name, results))
520 }))
521 .await;
522 AgentModelList::Grouped(IndexMap::from_iter(
523 groups
524 .into_iter()
525 .filter(|(_, results)| !results.is_empty()),
526 ))
527 }
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use agent_client_protocol as acp;
534 use gpui::TestAppContext;
535
536 use super::*;
537
538 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
539 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
540 |(group, models)| {
541 (
542 acp_thread::AgentModelGroupName(group.to_string().into()),
543 models
544 .into_iter()
545 .map(|model| acp_thread::AgentModelInfo {
546 id: acp::ModelId::new(model.to_string()),
547 name: model.to_string().into(),
548 description: None,
549 icon: None,
550 is_latest: false,
551 cost: None,
552 })
553 .collect::<Vec<_>>(),
554 )
555 },
556 )))
557 }
558
559 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
560 let AgentModelList::Grouped(groups) = result else {
561 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
562 };
563
564 assert_eq!(
565 groups.len(),
566 expected.len(),
567 "Number of groups doesn't match"
568 );
569
570 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
571 let (actual_group, actual_models) = groups.get_index(i).unwrap();
572 assert_eq!(
573 actual_group.0.as_ref(),
574 *expected_group,
575 "Group at position {} doesn't match expected group",
576 i
577 );
578 assert_eq!(
579 actual_models.len(),
580 expected_models.len(),
581 "Number of models in group {} doesn't match",
582 expected_group
583 );
584
585 for (j, expected_model_name) in expected_models.iter().enumerate() {
586 assert_eq!(
587 actual_models[j].name, *expected_model_name,
588 "Model at position {} in group {} doesn't match expected model",
589 j, expected_group
590 );
591 }
592 }
593 }
594
595 fn create_favorites(models: Vec<&str>) -> HashSet<ModelId> {
596 models
597 .into_iter()
598 .map(|m| ModelId::new(m.to_string()))
599 .collect()
600 }
601
602 fn get_entry_model_ids(entries: &[ModelPickerEntry]) -> Vec<&str> {
603 entries
604 .iter()
605 .filter_map(|entry| match entry {
606 ModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()),
607 _ => None,
608 })
609 .collect()
610 }
611
612 fn get_entry_labels(entries: &[ModelPickerEntry]) -> Vec<&str> {
613 entries
614 .iter()
615 .map(|entry| match entry {
616 ModelPickerEntry::Model(info, _) => info.id.0.as_ref(),
617 ModelPickerEntry::Separator(s) => &s,
618 })
619 .collect()
620 }
621
622 #[gpui::test]
623 async fn test_fuzzy_match(cx: &mut TestAppContext) {
624 let models = create_model_list(vec![
625 (
626 "zed",
627 vec![
628 "Claude 3.7 Sonnet",
629 "Claude 3.7 Sonnet Thinking",
630 "gpt-5",
631 "gpt-5-mini",
632 ],
633 ),
634 ("openai", vec!["gpt-3.5-turbo", "gpt-5", "gpt-5-mini"]),
635 ("ollama", vec!["mistral", "deepseek"]),
636 ]);
637
638 // Results should preserve models order whenever possible.
639 // In the case below, `zed/gpt-5-mini` and `openai/gpt-5-mini` have identical
640 // similarity scores, but `zed/gpt-5-mini` was higher in the models list,
641 // so it should appear first in the results.
642 let results = fuzzy_search(models.clone(), "mini".into(), cx.executor()).await;
643 assert_models_eq(
644 results,
645 vec![("zed", vec!["gpt-5-mini"]), ("openai", vec!["gpt-5-mini"])],
646 );
647
648 // Fuzzy search - test with specific model name
649 let results = fuzzy_search(models.clone(), "mistral".into(), cx.executor()).await;
650 assert_models_eq(results, vec![("ollama", vec!["mistral"])]);
651 }
652
653 #[gpui::test]
654 fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
655 let models = create_model_list(vec![
656 ("zed", vec!["zed/claude", "zed/gemini"]),
657 ("openai", vec!["openai/gpt-5"]),
658 ]);
659 let favorites = create_favorites(vec!["zed/gemini"]);
660
661 let entries = info_list_to_picker_entries(models, &favorites);
662
663 assert!(matches!(
664 entries.first(),
665 Some(ModelPickerEntry::Separator(s)) if s == "Favorite"
666 ));
667
668 let model_ids = get_entry_model_ids(&entries);
669 assert_eq!(model_ids[0], "zed/gemini");
670 }
671
672 #[gpui::test]
673 fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
674 let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]);
675 let favorites = create_favorites(vec![]);
676
677 let entries = info_list_to_picker_entries(models, &favorites);
678
679 assert!(matches!(
680 entries.first(),
681 Some(ModelPickerEntry::Separator(s)) if s == "zed"
682 ));
683 }
684
685 #[gpui::test]
686 fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
687 let models = create_model_list(vec![
688 ("zed", vec!["zed/claude", "zed/gemini"]),
689 ("openai", vec!["openai/gpt-5"]),
690 ]);
691 let favorites = create_favorites(vec!["zed/claude"]);
692
693 let entries = info_list_to_picker_entries(models, &favorites);
694
695 for entry in &entries {
696 if let ModelPickerEntry::Model(info, is_favorite) = entry {
697 if info.id.0.as_ref() == "zed/claude" {
698 assert!(is_favorite, "zed/claude should be a favorite");
699 } else {
700 assert!(!is_favorite, "{} should not be a favorite", info.id.0);
701 }
702 }
703 }
704 }
705
706 #[gpui::test]
707 fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) {
708 let models = create_model_list(vec![
709 ("zed", vec!["zed/claude", "zed/gemini"]),
710 ("openai", vec!["openai/gpt-5", "openai/gpt-4"]),
711 ]);
712 let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]);
713
714 let entries = info_list_to_picker_entries(models, &favorites);
715 let model_ids = get_entry_model_ids(&entries);
716
717 assert_eq!(model_ids[0], "zed/gemini");
718 assert_eq!(model_ids[1], "openai/gpt-5");
719
720 assert!(model_ids[2..].contains(&"zed/gemini"));
721 assert!(model_ids[2..].contains(&"openai/gpt-5"));
722 }
723
724 #[gpui::test]
725 fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) {
726 let models = create_model_list(vec![
727 ("Recommended", vec!["zed/claude", "anthropic/claude"]),
728 ("Zed", vec!["zed/claude", "zed/gpt-5"]),
729 ("Antropic", vec!["anthropic/claude"]),
730 ("OpenAI", vec!["openai/gpt-5"]),
731 ]);
732
733 let favorites = create_favorites(vec!["zed/claude"]);
734
735 let entries = info_list_to_picker_entries(models, &favorites);
736 let labels = get_entry_labels(&entries);
737
738 assert_eq!(
739 labels,
740 vec![
741 "Favorite",
742 "zed/claude",
743 "Recommended",
744 "zed/claude",
745 "anthropic/claude",
746 "Zed",
747 "zed/claude",
748 "zed/gpt-5",
749 "Antropic",
750 "anthropic/claude",
751 "OpenAI",
752 "openai/gpt-5"
753 ]
754 );
755 }
756
757 #[gpui::test]
758 fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) {
759 let models = AgentModelList::Flat(vec![
760 acp_thread::AgentModelInfo {
761 id: acp::ModelId::new("zed/claude".to_string()),
762 name: "Claude".into(),
763 description: None,
764 icon: None,
765 is_latest: false,
766 cost: None,
767 },
768 acp_thread::AgentModelInfo {
769 id: acp::ModelId::new("zed/gemini".to_string()),
770 name: "Gemini".into(),
771 description: None,
772 icon: None,
773 is_latest: false,
774 cost: None,
775 },
776 ]);
777 let favorites = create_favorites(vec!["zed/gemini"]);
778
779 let entries = info_list_to_picker_entries(models, &favorites);
780
781 assert!(matches!(
782 entries.first(),
783 Some(ModelPickerEntry::Separator(s)) if s == "Favorite"
784 ));
785
786 assert!(entries.iter().any(|e| matches!(
787 e,
788 ModelPickerEntry::Separator(s) if s == "All"
789 )));
790 }
791
792 #[gpui::test]
793 fn test_favorites_count_returns_correct_count(_cx: &mut TestAppContext) {
794 let empty_favorites: HashSet<ModelId> = HashSet::default();
795 assert_eq!(empty_favorites.len(), 0);
796
797 let one_favorite = create_favorites(vec!["model-a"]);
798 assert_eq!(one_favorite.len(), 1);
799
800 let multiple_favorites = create_favorites(vec!["model-a", "model-b", "model-c"]);
801 assert_eq!(multiple_favorites.len(), 3);
802
803 let with_duplicates = create_favorites(vec!["model-a", "model-a", "model-b"]);
804 assert_eq!(with_duplicates.len(), 2);
805 }
806
807 #[gpui::test]
808 fn test_is_favorite_flag_set_correctly_in_entries(_cx: &mut TestAppContext) {
809 let models = AgentModelList::Flat(vec![
810 acp_thread::AgentModelInfo {
811 id: acp::ModelId::new("favorite-model".to_string()),
812 name: "Favorite".into(),
813 description: None,
814 icon: None,
815 is_latest: false,
816 cost: None,
817 },
818 acp_thread::AgentModelInfo {
819 id: acp::ModelId::new("regular-model".to_string()),
820 name: "Regular".into(),
821 description: None,
822 icon: None,
823 is_latest: false,
824 cost: None,
825 },
826 ]);
827 let favorites = create_favorites(vec!["favorite-model"]);
828
829 let entries = info_list_to_picker_entries(models, &favorites);
830
831 for entry in &entries {
832 if let ModelPickerEntry::Model(info, is_favorite) = entry {
833 if info.id.0.as_ref() == "favorite-model" {
834 assert!(*is_favorite, "favorite-model should have is_favorite=true");
835 } else if info.id.0.as_ref() == "regular-model" {
836 assert!(!*is_favorite, "regular-model should have is_favorite=false");
837 }
838 }
839 }
840 }
841}