1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_client_protocol::schema as acp;
5use agent_servers::AgentServer;
6
7use anyhow::Result;
8use collections::{HashSet, IndexMap};
9use fs::Fs;
10use futures::FutureExt;
11use fuzzy::{StringMatchCandidate, match_strings};
12use gpui::{
13 Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
14 WeakEntity,
15};
16use itertools::Itertools;
17use ordered_float::OrderedFloat;
18use picker::{Picker, PickerDelegate};
19use settings::SettingsStore;
20use ui::{DocumentationAside, IntoElement, prelude::*};
21use util::ResultExt;
22use zed_actions::agent::OpenSettings;
23
24use crate::ui::{
25 HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem,
26 documentation_aside_side,
27};
28
29pub type ModelSelector = Picker<ModelPickerDelegate>;
30
31pub fn acp_model_selector(
32 selector: Rc<dyn AgentModelSelector>,
33 agent_server: Rc<dyn AgentServer>,
34 fs: Arc<dyn Fs>,
35 focus_handle: FocusHandle,
36 window: &mut Window,
37 cx: &mut Context<ModelSelector>,
38) -> ModelSelector {
39 let delegate = ModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
40 Picker::list(delegate, window, cx)
41 .show_scrollbar(true)
42 .width(rems(20.))
43 .max_height(Some(rems(20.).into()))
44}
45
46enum ModelPickerEntry {
47 Separator(SharedString),
48 Model(AgentModelInfo, bool),
49}
50
51pub struct ModelPickerDelegate {
52 selector: Rc<dyn AgentModelSelector>,
53 agent_server: Rc<dyn AgentServer>,
54 fs: Arc<dyn Fs>,
55 filtered_entries: Vec<ModelPickerEntry>,
56 models: Option<AgentModelList>,
57 selected_index: usize,
58 selected_description: Option<(usize, SharedString, bool)>,
59 selected_model: Option<AgentModelInfo>,
60 favorites: HashSet<acp::ModelId>,
61 _refresh_models_task: Task<()>,
62 _settings_subscription: Subscription,
63 focus_handle: FocusHandle,
64}
65
66impl ModelPickerDelegate {
67 fn new(
68 selector: Rc<dyn AgentModelSelector>,
69 agent_server: Rc<dyn AgentServer>,
70 fs: Arc<dyn Fs>,
71 focus_handle: FocusHandle,
72 window: &mut Window,
73 cx: &mut Context<ModelSelector>,
74 ) -> Self {
75 let rx = selector.watch(cx);
76 let refresh_models_task = {
77 cx.spawn_in(window, {
78 async move |this, cx| {
79 async fn refresh(
80 this: &WeakEntity<Picker<ModelPickerDelegate>>,
81 cx: &mut AsyncWindowContext,
82 ) -> Result<()> {
83 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
84 (
85 this.delegate.selector.list_models(cx),
86 this.delegate.selector.selected_model(cx),
87 )
88 })?;
89
90 let (models, selected_model) =
91 futures::join!(models_task, selected_model_task);
92
93 this.update_in(cx, |this, window, cx| {
94 this.delegate.models = models.ok();
95 this.delegate.selected_model = selected_model.ok();
96 this.refresh(window, cx)
97 })
98 }
99
100 refresh(&this, cx).await.log_err();
101 if let Some(mut rx) = rx {
102 while let Ok(()) = rx.recv().await {
103 refresh(&this, cx).await.log_err();
104 }
105 }
106 }
107 })
108 };
109
110 let agent_server_for_subscription = agent_server.clone();
111 let settings_subscription =
112 cx.observe_global_in::<SettingsStore>(window, move |picker, window, cx| {
113 // Only refresh if the favorites actually changed to avoid redundant work
114 // when other settings are modified (e.g., user editing settings.json)
115 let new_favorites = agent_server_for_subscription.favorite_model_ids(cx);
116 if new_favorites != picker.delegate.favorites {
117 picker.delegate.favorites = new_favorites;
118 picker.refresh(window, cx);
119 }
120 });
121 let favorites = agent_server.favorite_model_ids(cx);
122
123 Self {
124 selector,
125 agent_server,
126 fs,
127 filtered_entries: Vec::new(),
128 models: None,
129 selected_model: None,
130 selected_index: 0,
131 selected_description: None,
132 favorites,
133 _refresh_models_task: refresh_models_task,
134 _settings_subscription: settings_subscription,
135 focus_handle,
136 }
137 }
138
139 pub fn active_model(&self) -> Option<&AgentModelInfo> {
140 self.selected_model.as_ref()
141 }
142
143 pub fn favorites_count(&self) -> usize {
144 self.favorites.len()
145 }
146
147 pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
148 if self.favorites.is_empty() {
149 return;
150 }
151
152 let Some(models) = &self.models else {
153 return;
154 };
155
156 let all_models: Vec<&AgentModelInfo> = match models {
157 AgentModelList::Flat(list) => list.iter().collect(),
158 AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
159 };
160
161 let favorite_models: Vec<_> = all_models
162 .into_iter()
163 .filter(|model| self.favorites.contains(&model.id))
164 .unique_by(|model| &model.id)
165 .collect();
166
167 if favorite_models.is_empty() {
168 return;
169 }
170
171 let current_id = self.selected_model.as_ref().map(|m| &m.id);
172
173 let current_index_in_favorites = current_id
174 .and_then(|id| favorite_models.iter().position(|m| &m.id == id))
175 .unwrap_or(usize::MAX);
176
177 let next_index = if current_index_in_favorites == usize::MAX {
178 0
179 } else {
180 (current_index_in_favorites + 1) % favorite_models.len()
181 };
182
183 let next_model = favorite_models[next_index].clone();
184
185 self.selector
186 .select_model(next_model.id.clone(), cx)
187 .detach_and_log_err(cx);
188
189 self.selected_model = Some(next_model);
190
191 // Keep the picker selection aligned with the newly-selected model
192 if let Some(new_index) = self.filtered_entries.iter().position(|entry| {
193 matches!(entry, ModelPickerEntry::Model(model_info, _) if self.selected_model.as_ref().is_some_and(|selected| model_info.id == selected.id))
194 }) {
195 self.set_selected_index(new_index, window, cx);
196 } else {
197 cx.notify();
198 }
199 }
200}
201
202impl PickerDelegate for ModelPickerDelegate {
203 type ListItem = AnyElement;
204
205 fn match_count(&self) -> usize {
206 self.filtered_entries.len()
207 }
208
209 fn selected_index(&self) -> usize {
210 self.selected_index
211 }
212
213 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
214 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
215 cx.notify();
216 }
217
218 fn can_select(&self, ix: usize, _window: &mut Window, _cx: &mut Context<Picker<Self>>) -> bool {
219 match self.filtered_entries.get(ix) {
220 Some(ModelPickerEntry::Model(_, _)) => true,
221 Some(ModelPickerEntry::Separator(_)) | None => false,
222 }
223 }
224
225 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
226 "Select a model…".into()
227 }
228
229 fn update_matches(
230 &mut self,
231 query: String,
232 window: &mut Window,
233 cx: &mut Context<Picker<Self>>,
234 ) -> Task<()> {
235 let favorites = self.favorites.clone();
236
237 cx.spawn_in(window, async move |this, cx| {
238 let filtered_models = match this
239 .read_with(cx, |this, cx| {
240 this.delegate.models.clone().map(move |models| {
241 fuzzy_search(models, query, cx.background_executor().clone())
242 })
243 })
244 .ok()
245 .flatten()
246 {
247 Some(task) => task.await,
248 None => AgentModelList::Flat(vec![]),
249 };
250
251 this.update_in(cx, |this, window, cx| {
252 this.delegate.filtered_entries =
253 info_list_to_picker_entries(filtered_models, &favorites);
254 // Finds the currently selected model in the list
255 let new_index = this
256 .delegate
257 .selected_model
258 .as_ref()
259 .and_then(|selected| {
260 this.delegate.filtered_entries.iter().position(|entry| {
261 if let ModelPickerEntry::Model(model_info, _) = entry {
262 model_info.id == selected.id
263 } else {
264 false
265 }
266 })
267 })
268 .unwrap_or(0);
269 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
270 cx.notify();
271 })
272 .ok();
273 })
274 }
275
276 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
277 if let Some(ModelPickerEntry::Model(model_info, _)) =
278 self.filtered_entries.get(self.selected_index)
279 {
280 if window.modifiers().secondary() {
281 let default_model = self.agent_server.default_model(cx);
282 let is_default = default_model.as_ref() == Some(&model_info.id);
283
284 self.agent_server.set_default_model(
285 if is_default {
286 None
287 } else {
288 Some(model_info.id.clone())
289 },
290 self.fs.clone(),
291 cx,
292 );
293 }
294
295 self.selector
296 .select_model(model_info.id.clone(), cx)
297 .detach_and_log_err(cx);
298 self.selected_model = Some(model_info.clone());
299 let current_index = self.selected_index;
300 self.set_selected_index(current_index, window, cx);
301
302 cx.emit(DismissEvent);
303 }
304 }
305
306 fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
307 cx.defer_in(window, |picker, window, cx| {
308 picker.set_query("", window, cx);
309 });
310 }
311
312 fn render_match(
313 &self,
314 ix: usize,
315 selected: bool,
316 _: &mut Window,
317 cx: &mut Context<Picker<Self>>,
318 ) -> Option<Self::ListItem> {
319 match self.filtered_entries.get(ix)? {
320 ModelPickerEntry::Separator(title) => {
321 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
322 }
323 ModelPickerEntry::Model(model_info, is_favorite) => {
324 let is_selected = Some(model_info) == self.selected_model.as_ref();
325 let default_model = self.agent_server.default_model(cx);
326 let is_default = default_model.as_ref() == Some(&model_info.id);
327
328 let is_favorite = *is_favorite;
329 let handle_action_click = {
330 let model_id = model_info.id.clone();
331 let fs = self.fs.clone();
332 let agent_server = self.agent_server.clone();
333
334 cx.listener(move |_, _, _, cx| {
335 agent_server.toggle_favorite_model(
336 model_id.clone(),
337 !is_favorite,
338 fs.clone(),
339 cx,
340 );
341 })
342 };
343
344 let model_cost = model_info.cost.clone();
345
346 Some(
347 div()
348 .id(("model-picker-menu-child", ix))
349 .when_some(model_info.description.clone(), |this, description| {
350 this.on_hover(cx.listener(move |menu, hovered, _, cx| {
351 if *hovered {
352 menu.delegate.selected_description =
353 Some((ix, description.clone(), is_default));
354 } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
355 menu.delegate.selected_description = None;
356 }
357 cx.notify();
358 }))
359 })
360 .child(
361 ModelSelectorListItem::new(ix, model_info.name.clone())
362 .map(|this| match &model_info.icon {
363 Some(AgentModelIcon::Path(path)) => this.icon_path(path.clone()),
364 Some(AgentModelIcon::Named(icon)) => this.icon(*icon),
365 None => this,
366 })
367 .is_selected(is_selected)
368 .is_focused(selected)
369 .is_latest(model_info.is_latest)
370 .is_favorite(is_favorite)
371 .on_toggle_favorite(handle_action_click)
372 .cost_info(model_cost)
373 )
374 .into_any_element(),
375 )
376 }
377 }
378 }
379
380 fn documentation_aside(
381 &self,
382 _window: &mut Window,
383 cx: &mut Context<Picker<Self>>,
384 ) -> Option<ui::DocumentationAside> {
385 self.selected_description
386 .as_ref()
387 .map(|(_, description, is_default)| {
388 let description = description.clone();
389 let is_default = *is_default;
390
391 let side = documentation_aside_side(cx);
392
393 DocumentationAside::new(
394 side,
395 Rc::new(move |_| {
396 v_flex()
397 .gap_1()
398 .child(Label::new(description.clone()))
399 .child(HoldForDefault::new(is_default))
400 .into_any_element()
401 }),
402 )
403 })
404 }
405
406 fn documentation_aside_index(&self) -> Option<usize> {
407 self.selected_description.as_ref().map(|(ix, _, _)| *ix)
408 }
409
410 fn render_footer(
411 &self,
412 _window: &mut Window,
413 _cx: &mut Context<Picker<Self>>,
414 ) -> Option<AnyElement> {
415 let focus_handle = self.focus_handle.clone();
416
417 if !self.selector.should_render_footer() {
418 return None;
419 }
420
421 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
422 }
423}
424
425fn info_list_to_picker_entries(
426 model_list: AgentModelList,
427 favorites: &HashSet<acp::ModelId>,
428) -> Vec<ModelPickerEntry> {
429 let mut entries = Vec::new();
430
431 let all_models: Vec<_> = match &model_list {
432 AgentModelList::Flat(list) => list.iter().collect(),
433 AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
434 };
435
436 let favorite_models: Vec<_> = all_models
437 .iter()
438 .filter(|m| favorites.contains(&m.id))
439 .unique_by(|m| &m.id)
440 .collect();
441
442 let has_favorites = !favorite_models.is_empty();
443 if has_favorites {
444 entries.push(ModelPickerEntry::Separator("Favorite".into()));
445 for model in favorite_models {
446 entries.push(ModelPickerEntry::Model((*model).clone(), true));
447 }
448 }
449
450 match model_list {
451 AgentModelList::Flat(list) => {
452 if has_favorites {
453 entries.push(ModelPickerEntry::Separator("All".into()));
454 }
455 for model in list {
456 let is_favorite = favorites.contains(&model.id);
457 entries.push(ModelPickerEntry::Model(model, is_favorite));
458 }
459 }
460 AgentModelList::Grouped(index_map) => {
461 for (group_name, models) in index_map {
462 entries.push(ModelPickerEntry::Separator(group_name.0));
463 for model in models {
464 let is_favorite = favorites.contains(&model.id);
465 entries.push(ModelPickerEntry::Model(model, is_favorite));
466 }
467 }
468 }
469 }
470
471 entries
472}
473
474async fn fuzzy_search(
475 model_list: AgentModelList,
476 query: String,
477 executor: BackgroundExecutor,
478) -> AgentModelList {
479 async fn fuzzy_search_list(
480 model_list: Vec<AgentModelInfo>,
481 query: &str,
482 executor: BackgroundExecutor,
483 ) -> Vec<AgentModelInfo> {
484 let candidates = model_list
485 .iter()
486 .enumerate()
487 .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
488 .collect::<Vec<_>>();
489 let mut matches = match_strings(
490 &candidates,
491 query,
492 false,
493 true,
494 100,
495 &Default::default(),
496 executor,
497 )
498 .await;
499
500 matches.sort_unstable_by_key(|mat| {
501 let candidate = &candidates[mat.candidate_id];
502 (Reverse(OrderedFloat(mat.score)), candidate.id)
503 });
504
505 matches
506 .into_iter()
507 .map(|mat| model_list[mat.candidate_id].clone())
508 .collect()
509 }
510
511 match model_list {
512 AgentModelList::Flat(model_list) => {
513 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
514 }
515 AgentModelList::Grouped(index_map) => {
516 let groups =
517 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
518 fuzzy_search_list(models, &query, executor.clone())
519 .map(|results| (group_name, results))
520 }))
521 .await;
522 AgentModelList::Grouped(IndexMap::from_iter(
523 groups
524 .into_iter()
525 .filter(|(_, results)| !results.is_empty()),
526 ))
527 }
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use gpui::TestAppContext;
534
535 use super::*;
536
537 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
538 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
539 |(group, models)| {
540 (
541 acp_thread::AgentModelGroupName(group.to_string().into()),
542 models
543 .into_iter()
544 .map(|model| acp_thread::AgentModelInfo {
545 id: acp::ModelId::new(model.to_string()),
546 name: model.to_string().into(),
547 description: None,
548 icon: None,
549 is_latest: false,
550 cost: None,
551 })
552 .collect::<Vec<_>>(),
553 )
554 },
555 )))
556 }
557
558 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
559 let AgentModelList::Grouped(groups) = result else {
560 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
561 };
562
563 assert_eq!(
564 groups.len(),
565 expected.len(),
566 "Number of groups doesn't match"
567 );
568
569 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
570 let (actual_group, actual_models) = groups.get_index(i).unwrap();
571 assert_eq!(
572 actual_group.0.as_ref(),
573 *expected_group,
574 "Group at position {} doesn't match expected group",
575 i
576 );
577 assert_eq!(
578 actual_models.len(),
579 expected_models.len(),
580 "Number of models in group {} doesn't match",
581 expected_group
582 );
583
584 for (j, expected_model_name) in expected_models.iter().enumerate() {
585 assert_eq!(
586 actual_models[j].name, *expected_model_name,
587 "Model at position {} in group {} doesn't match expected model",
588 j, expected_group
589 );
590 }
591 }
592 }
593
594 fn create_favorites(models: Vec<&str>) -> HashSet<acp::ModelId> {
595 models
596 .into_iter()
597 .map(|m| acp::ModelId::new(m.to_string()))
598 .collect()
599 }
600
601 fn get_entry_model_ids(entries: &[ModelPickerEntry]) -> Vec<&str> {
602 entries
603 .iter()
604 .filter_map(|entry| match entry {
605 ModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()),
606 _ => None,
607 })
608 .collect()
609 }
610
611 fn get_entry_labels(entries: &[ModelPickerEntry]) -> Vec<&str> {
612 entries
613 .iter()
614 .map(|entry| match entry {
615 ModelPickerEntry::Model(info, _) => info.id.0.as_ref(),
616 ModelPickerEntry::Separator(s) => &s,
617 })
618 .collect()
619 }
620
621 #[gpui::test]
622 async fn test_fuzzy_match(cx: &mut TestAppContext) {
623 let models = create_model_list(vec![
624 (
625 "zed",
626 vec![
627 "Claude 3.7 Sonnet",
628 "Claude 3.7 Sonnet Thinking",
629 "gpt-5",
630 "gpt-5-mini",
631 ],
632 ),
633 ("openai", vec!["gpt-3.5-turbo", "gpt-5", "gpt-5-mini"]),
634 ("ollama", vec!["mistral", "deepseek"]),
635 ]);
636
637 // Results should preserve models order whenever possible.
638 // In the case below, `zed/gpt-5-mini` and `openai/gpt-5-mini` have identical
639 // similarity scores, but `zed/gpt-5-mini` was higher in the models list,
640 // so it should appear first in the results.
641 let results = fuzzy_search(models.clone(), "mini".into(), cx.executor()).await;
642 assert_models_eq(
643 results,
644 vec![("zed", vec!["gpt-5-mini"]), ("openai", vec!["gpt-5-mini"])],
645 );
646
647 // Fuzzy search - test with specific model name
648 let results = fuzzy_search(models.clone(), "mistral".into(), cx.executor()).await;
649 assert_models_eq(results, vec![("ollama", vec!["mistral"])]);
650 }
651
652 #[gpui::test]
653 fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
654 let models = create_model_list(vec![
655 ("zed", vec!["zed/claude", "zed/gemini"]),
656 ("openai", vec!["openai/gpt-5"]),
657 ]);
658 let favorites = create_favorites(vec!["zed/gemini"]);
659
660 let entries = info_list_to_picker_entries(models, &favorites);
661
662 assert!(matches!(
663 entries.first(),
664 Some(ModelPickerEntry::Separator(s)) if s == "Favorite"
665 ));
666
667 let model_ids = get_entry_model_ids(&entries);
668 assert_eq!(model_ids[0], "zed/gemini");
669 }
670
671 #[gpui::test]
672 fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
673 let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]);
674 let favorites = create_favorites(vec![]);
675
676 let entries = info_list_to_picker_entries(models, &favorites);
677
678 assert!(matches!(
679 entries.first(),
680 Some(ModelPickerEntry::Separator(s)) if s == "zed"
681 ));
682 }
683
684 #[gpui::test]
685 fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
686 let models = create_model_list(vec![
687 ("zed", vec!["zed/claude", "zed/gemini"]),
688 ("openai", vec!["openai/gpt-5"]),
689 ]);
690 let favorites = create_favorites(vec!["zed/claude"]);
691
692 let entries = info_list_to_picker_entries(models, &favorites);
693
694 for entry in &entries {
695 if let ModelPickerEntry::Model(info, is_favorite) = entry {
696 if info.id.0.as_ref() == "zed/claude" {
697 assert!(is_favorite, "zed/claude should be a favorite");
698 } else {
699 assert!(!is_favorite, "{} should not be a favorite", info.id.0);
700 }
701 }
702 }
703 }
704
705 #[gpui::test]
706 fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) {
707 let models = create_model_list(vec![
708 ("zed", vec!["zed/claude", "zed/gemini"]),
709 ("openai", vec!["openai/gpt-5", "openai/gpt-4"]),
710 ]);
711 let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]);
712
713 let entries = info_list_to_picker_entries(models, &favorites);
714 let model_ids = get_entry_model_ids(&entries);
715
716 assert_eq!(model_ids[0], "zed/gemini");
717 assert_eq!(model_ids[1], "openai/gpt-5");
718
719 assert!(model_ids[2..].contains(&"zed/gemini"));
720 assert!(model_ids[2..].contains(&"openai/gpt-5"));
721 }
722
723 #[gpui::test]
724 fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) {
725 let models = create_model_list(vec![
726 ("Recommended", vec!["zed/claude", "anthropic/claude"]),
727 ("Zed", vec!["zed/claude", "zed/gpt-5"]),
728 ("Antropic", vec!["anthropic/claude"]),
729 ("OpenAI", vec!["openai/gpt-5"]),
730 ]);
731
732 let favorites = create_favorites(vec!["zed/claude"]);
733
734 let entries = info_list_to_picker_entries(models, &favorites);
735 let labels = get_entry_labels(&entries);
736
737 assert_eq!(
738 labels,
739 vec![
740 "Favorite",
741 "zed/claude",
742 "Recommended",
743 "zed/claude",
744 "anthropic/claude",
745 "Zed",
746 "zed/claude",
747 "zed/gpt-5",
748 "Antropic",
749 "anthropic/claude",
750 "OpenAI",
751 "openai/gpt-5"
752 ]
753 );
754 }
755
756 #[gpui::test]
757 fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) {
758 let models = AgentModelList::Flat(vec![
759 acp_thread::AgentModelInfo {
760 id: acp::ModelId::new("zed/claude".to_string()),
761 name: "Claude".into(),
762 description: None,
763 icon: None,
764 is_latest: false,
765 cost: None,
766 },
767 acp_thread::AgentModelInfo {
768 id: acp::ModelId::new("zed/gemini".to_string()),
769 name: "Gemini".into(),
770 description: None,
771 icon: None,
772 is_latest: false,
773 cost: None,
774 },
775 ]);
776 let favorites = create_favorites(vec!["zed/gemini"]);
777
778 let entries = info_list_to_picker_entries(models, &favorites);
779
780 assert!(matches!(
781 entries.first(),
782 Some(ModelPickerEntry::Separator(s)) if s == "Favorite"
783 ));
784
785 assert!(entries.iter().any(|e| matches!(
786 e,
787 ModelPickerEntry::Separator(s) if s == "All"
788 )));
789 }
790
791 #[gpui::test]
792 fn test_favorites_count_returns_correct_count(_cx: &mut TestAppContext) {
793 let empty_favorites: HashSet<acp::ModelId> = HashSet::default();
794 assert_eq!(empty_favorites.len(), 0);
795
796 let one_favorite = create_favorites(vec!["model-a"]);
797 assert_eq!(one_favorite.len(), 1);
798
799 let multiple_favorites = create_favorites(vec!["model-a", "model-b", "model-c"]);
800 assert_eq!(multiple_favorites.len(), 3);
801
802 let with_duplicates = create_favorites(vec!["model-a", "model-a", "model-b"]);
803 assert_eq!(with_duplicates.len(), 2);
804 }
805
806 #[gpui::test]
807 fn test_is_favorite_flag_set_correctly_in_entries(_cx: &mut TestAppContext) {
808 let models = AgentModelList::Flat(vec![
809 acp_thread::AgentModelInfo {
810 id: acp::ModelId::new("favorite-model".to_string()),
811 name: "Favorite".into(),
812 description: None,
813 icon: None,
814 is_latest: false,
815 cost: None,
816 },
817 acp_thread::AgentModelInfo {
818 id: acp::ModelId::new("regular-model".to_string()),
819 name: "Regular".into(),
820 description: None,
821 icon: None,
822 is_latest: false,
823 cost: None,
824 },
825 ]);
826 let favorites = create_favorites(vec!["favorite-model"]);
827
828 let entries = info_list_to_picker_entries(models, &favorites);
829
830 for entry in &entries {
831 if let ModelPickerEntry::Model(info, is_favorite) = entry {
832 if info.id.0.as_ref() == "favorite-model" {
833 assert!(*is_favorite, "favorite-model should have is_favorite=true");
834 } else if info.id.0.as_ref() == "regular-model" {
835 assert!(!*is_favorite, "regular-model should have is_favorite=false");
836 }
837 }
838 }
839 }
840}