1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_client_protocol::ModelId;
5use agent_servers::AgentServer;
6use anyhow::Result;
7use collections::{HashSet, IndexMap};
8use fs::Fs;
9use futures::FutureExt;
10use fuzzy::{StringMatchCandidate, match_strings};
11use gpui::{
12 Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
13 WeakEntity,
14};
15use itertools::Itertools;
16use ordered_float::OrderedFloat;
17use picker::{Picker, PickerDelegate};
18use settings::SettingsStore;
19use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, prelude::*};
20use util::ResultExt;
21use zed_actions::agent::OpenSettings;
22
23use crate::ui::{HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
24
25pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
26
27pub fn acp_model_selector(
28 selector: Rc<dyn AgentModelSelector>,
29 agent_server: Rc<dyn AgentServer>,
30 fs: Arc<dyn Fs>,
31 focus_handle: FocusHandle,
32 window: &mut Window,
33 cx: &mut Context<AcpModelSelector>,
34) -> AcpModelSelector {
35 let delegate =
36 AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
37 Picker::list(delegate, window, cx)
38 .show_scrollbar(true)
39 .width(rems(20.))
40 .max_height(Some(rems(20.).into()))
41}
42
43enum AcpModelPickerEntry {
44 Separator(SharedString),
45 Model(AgentModelInfo, bool),
46}
47
48pub struct AcpModelPickerDelegate {
49 selector: Rc<dyn AgentModelSelector>,
50 agent_server: Rc<dyn AgentServer>,
51 fs: Arc<dyn Fs>,
52 filtered_entries: Vec<AcpModelPickerEntry>,
53 models: Option<AgentModelList>,
54 selected_index: usize,
55 selected_description: Option<(usize, SharedString, bool)>,
56 selected_model: Option<AgentModelInfo>,
57 favorites: HashSet<ModelId>,
58 _refresh_models_task: Task<()>,
59 _settings_subscription: Subscription,
60 focus_handle: FocusHandle,
61}
62
63impl AcpModelPickerDelegate {
64 fn new(
65 selector: Rc<dyn AgentModelSelector>,
66 agent_server: Rc<dyn AgentServer>,
67 fs: Arc<dyn Fs>,
68 focus_handle: FocusHandle,
69 window: &mut Window,
70 cx: &mut Context<AcpModelSelector>,
71 ) -> Self {
72 let rx = selector.watch(cx);
73 let refresh_models_task = {
74 cx.spawn_in(window, {
75 async move |this, cx| {
76 async fn refresh(
77 this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
78 cx: &mut AsyncWindowContext,
79 ) -> Result<()> {
80 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
81 (
82 this.delegate.selector.list_models(cx),
83 this.delegate.selector.selected_model(cx),
84 )
85 })?;
86
87 let (models, selected_model) =
88 futures::join!(models_task, selected_model_task);
89
90 this.update_in(cx, |this, window, cx| {
91 this.delegate.models = models.ok();
92 this.delegate.selected_model = selected_model.ok();
93 this.refresh(window, cx)
94 })
95 }
96
97 refresh(&this, cx).await.log_err();
98 if let Some(mut rx) = rx {
99 while let Ok(()) = rx.recv().await {
100 refresh(&this, cx).await.log_err();
101 }
102 }
103 }
104 })
105 };
106
107 let agent_server_for_subscription = agent_server.clone();
108 let settings_subscription =
109 cx.observe_global_in::<SettingsStore>(window, move |picker, window, cx| {
110 // Only refresh if the favorites actually changed to avoid redundant work
111 // when other settings are modified (e.g., user editing settings.json)
112 let new_favorites = agent_server_for_subscription.favorite_model_ids(cx);
113 if new_favorites != picker.delegate.favorites {
114 picker.delegate.favorites = new_favorites;
115 picker.refresh(window, cx);
116 }
117 });
118 let favorites = agent_server.favorite_model_ids(cx);
119
120 Self {
121 selector,
122 agent_server,
123 fs,
124 filtered_entries: Vec::new(),
125 models: None,
126 selected_model: None,
127 selected_index: 0,
128 selected_description: None,
129 favorites,
130 _refresh_models_task: refresh_models_task,
131 _settings_subscription: settings_subscription,
132 focus_handle,
133 }
134 }
135
136 pub fn active_model(&self) -> Option<&AgentModelInfo> {
137 self.selected_model.as_ref()
138 }
139
140 pub fn favorites_count(&self) -> usize {
141 self.favorites.len()
142 }
143
144 pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
145 if self.favorites.is_empty() {
146 return;
147 }
148
149 let Some(models) = &self.models else {
150 return;
151 };
152
153 let all_models: Vec<&AgentModelInfo> = match models {
154 AgentModelList::Flat(list) => list.iter().collect(),
155 AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
156 };
157
158 let favorite_models: Vec<_> = all_models
159 .into_iter()
160 .filter(|model| self.favorites.contains(&model.id))
161 .unique_by(|model| &model.id)
162 .collect();
163
164 if favorite_models.is_empty() {
165 return;
166 }
167
168 let current_id = self.selected_model.as_ref().map(|m| &m.id);
169
170 let current_index_in_favorites = current_id
171 .and_then(|id| favorite_models.iter().position(|m| &m.id == id))
172 .unwrap_or(usize::MAX);
173
174 let next_index = if current_index_in_favorites == usize::MAX {
175 0
176 } else {
177 (current_index_in_favorites + 1) % favorite_models.len()
178 };
179
180 let next_model = favorite_models[next_index].clone();
181
182 self.selector
183 .select_model(next_model.id.clone(), cx)
184 .detach_and_log_err(cx);
185
186 self.selected_model = Some(next_model);
187
188 // Keep the picker selection aligned with the newly-selected model
189 if let Some(new_index) = self.filtered_entries.iter().position(|entry| {
190 matches!(entry, AcpModelPickerEntry::Model(model_info, _) if self.selected_model.as_ref().is_some_and(|selected| model_info.id == selected.id))
191 }) {
192 self.set_selected_index(new_index, window, cx);
193 } else {
194 cx.notify();
195 }
196 }
197}
198
199impl PickerDelegate for AcpModelPickerDelegate {
200 type ListItem = AnyElement;
201
202 fn match_count(&self) -> usize {
203 self.filtered_entries.len()
204 }
205
206 fn selected_index(&self) -> usize {
207 self.selected_index
208 }
209
210 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
211 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
212 cx.notify();
213 }
214
215 fn can_select(
216 &mut self,
217 ix: usize,
218 _window: &mut Window,
219 _cx: &mut Context<Picker<Self>>,
220 ) -> bool {
221 match self.filtered_entries.get(ix) {
222 Some(AcpModelPickerEntry::Model(_, _)) => true,
223 Some(AcpModelPickerEntry::Separator(_)) | None => false,
224 }
225 }
226
227 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
228 "Select a model…".into()
229 }
230
231 fn update_matches(
232 &mut self,
233 query: String,
234 window: &mut Window,
235 cx: &mut Context<Picker<Self>>,
236 ) -> Task<()> {
237 let favorites = self.favorites.clone();
238
239 cx.spawn_in(window, async move |this, cx| {
240 let filtered_models = match this
241 .read_with(cx, |this, cx| {
242 this.delegate.models.clone().map(move |models| {
243 fuzzy_search(models, query, cx.background_executor().clone())
244 })
245 })
246 .ok()
247 .flatten()
248 {
249 Some(task) => task.await,
250 None => AgentModelList::Flat(vec![]),
251 };
252
253 this.update_in(cx, |this, window, cx| {
254 this.delegate.filtered_entries =
255 info_list_to_picker_entries(filtered_models, &favorites);
256 // Finds the currently selected model in the list
257 let new_index = this
258 .delegate
259 .selected_model
260 .as_ref()
261 .and_then(|selected| {
262 this.delegate.filtered_entries.iter().position(|entry| {
263 if let AcpModelPickerEntry::Model(model_info, _) = entry {
264 model_info.id == selected.id
265 } else {
266 false
267 }
268 })
269 })
270 .unwrap_or(0);
271 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
272 cx.notify();
273 })
274 .ok();
275 })
276 }
277
278 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
279 if let Some(AcpModelPickerEntry::Model(model_info, _)) =
280 self.filtered_entries.get(self.selected_index)
281 {
282 if window.modifiers().secondary() {
283 let default_model = self.agent_server.default_model(cx);
284 let is_default = default_model.as_ref() == Some(&model_info.id);
285
286 self.agent_server.set_default_model(
287 if is_default {
288 None
289 } else {
290 Some(model_info.id.clone())
291 },
292 self.fs.clone(),
293 cx,
294 );
295 }
296
297 self.selector
298 .select_model(model_info.id.clone(), cx)
299 .detach_and_log_err(cx);
300 self.selected_model = Some(model_info.clone());
301 let current_index = self.selected_index;
302 self.set_selected_index(current_index, window, cx);
303
304 cx.emit(DismissEvent);
305 }
306 }
307
308 fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
309 cx.defer_in(window, |picker, window, cx| {
310 picker.set_query("", window, cx);
311 });
312 }
313
314 fn render_match(
315 &self,
316 ix: usize,
317 selected: bool,
318 _: &mut Window,
319 cx: &mut Context<Picker<Self>>,
320 ) -> Option<Self::ListItem> {
321 match self.filtered_entries.get(ix)? {
322 AcpModelPickerEntry::Separator(title) => {
323 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
324 }
325 AcpModelPickerEntry::Model(model_info, is_favorite) => {
326 let is_selected = Some(model_info) == self.selected_model.as_ref();
327 let default_model = self.agent_server.default_model(cx);
328 let is_default = default_model.as_ref() == Some(&model_info.id);
329
330 let is_favorite = *is_favorite;
331 let handle_action_click = {
332 let model_id = model_info.id.clone();
333 let fs = self.fs.clone();
334 let agent_server = self.agent_server.clone();
335
336 cx.listener(move |_, _, _, cx| {
337 agent_server.toggle_favorite_model(
338 model_id.clone(),
339 !is_favorite,
340 fs.clone(),
341 cx,
342 );
343 })
344 };
345
346 Some(
347 div()
348 .id(("model-picker-menu-child", ix))
349 .when_some(model_info.description.clone(), |this, description| {
350 this.on_hover(cx.listener(move |menu, hovered, _, cx| {
351 if *hovered {
352 menu.delegate.selected_description =
353 Some((ix, description.clone(), is_default));
354 } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
355 menu.delegate.selected_description = None;
356 }
357 cx.notify();
358 }))
359 })
360 .child(
361 ModelSelectorListItem::new(ix, model_info.name.clone())
362 .map(|this| match &model_info.icon {
363 Some(AgentModelIcon::Path(path)) => this.icon_path(path.clone()),
364 Some(AgentModelIcon::Named(icon)) => this.icon(*icon),
365 None => this,
366 })
367 .is_selected(is_selected)
368 .is_focused(selected)
369 .is_favorite(is_favorite)
370 .on_toggle_favorite(handle_action_click),
371 )
372 .into_any_element(),
373 )
374 }
375 }
376 }
377
378 fn documentation_aside(
379 &self,
380 _window: &mut Window,
381 _cx: &mut Context<Picker<Self>>,
382 ) -> Option<ui::DocumentationAside> {
383 self.selected_description
384 .as_ref()
385 .map(|(_, description, is_default)| {
386 let description = description.clone();
387 let is_default = *is_default;
388
389 DocumentationAside::new(
390 DocumentationSide::Left,
391 DocumentationEdge::Top,
392 Rc::new(move |_| {
393 v_flex()
394 .gap_1()
395 .child(Label::new(description.clone()))
396 .child(HoldForDefault::new(is_default))
397 .into_any_element()
398 }),
399 )
400 })
401 }
402
403 fn render_footer(
404 &self,
405 _window: &mut Window,
406 _cx: &mut Context<Picker<Self>>,
407 ) -> Option<AnyElement> {
408 let focus_handle = self.focus_handle.clone();
409
410 if !self.selector.should_render_footer() {
411 return None;
412 }
413
414 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
415 }
416}
417
418fn info_list_to_picker_entries(
419 model_list: AgentModelList,
420 favorites: &HashSet<ModelId>,
421) -> Vec<AcpModelPickerEntry> {
422 let mut entries = Vec::new();
423
424 let all_models: Vec<_> = match &model_list {
425 AgentModelList::Flat(list) => list.iter().collect(),
426 AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
427 };
428
429 let favorite_models: Vec<_> = all_models
430 .iter()
431 .filter(|m| favorites.contains(&m.id))
432 .unique_by(|m| &m.id)
433 .collect();
434
435 let has_favorites = !favorite_models.is_empty();
436 if has_favorites {
437 entries.push(AcpModelPickerEntry::Separator("Favorite".into()));
438 for model in favorite_models {
439 entries.push(AcpModelPickerEntry::Model((*model).clone(), true));
440 }
441 }
442
443 match model_list {
444 AgentModelList::Flat(list) => {
445 if has_favorites {
446 entries.push(AcpModelPickerEntry::Separator("All".into()));
447 }
448 for model in list {
449 let is_favorite = favorites.contains(&model.id);
450 entries.push(AcpModelPickerEntry::Model(model, is_favorite));
451 }
452 }
453 AgentModelList::Grouped(index_map) => {
454 for (group_name, models) in index_map {
455 entries.push(AcpModelPickerEntry::Separator(group_name.0));
456 for model in models {
457 let is_favorite = favorites.contains(&model.id);
458 entries.push(AcpModelPickerEntry::Model(model, is_favorite));
459 }
460 }
461 }
462 }
463
464 entries
465}
466
467async fn fuzzy_search(
468 model_list: AgentModelList,
469 query: String,
470 executor: BackgroundExecutor,
471) -> AgentModelList {
472 async fn fuzzy_search_list(
473 model_list: Vec<AgentModelInfo>,
474 query: &str,
475 executor: BackgroundExecutor,
476 ) -> Vec<AgentModelInfo> {
477 let candidates = model_list
478 .iter()
479 .enumerate()
480 .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
481 .collect::<Vec<_>>();
482 let mut matches = match_strings(
483 &candidates,
484 query,
485 false,
486 true,
487 100,
488 &Default::default(),
489 executor,
490 )
491 .await;
492
493 matches.sort_unstable_by_key(|mat| {
494 let candidate = &candidates[mat.candidate_id];
495 (Reverse(OrderedFloat(mat.score)), candidate.id)
496 });
497
498 matches
499 .into_iter()
500 .map(|mat| model_list[mat.candidate_id].clone())
501 .collect()
502 }
503
504 match model_list {
505 AgentModelList::Flat(model_list) => {
506 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
507 }
508 AgentModelList::Grouped(index_map) => {
509 let groups =
510 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
511 fuzzy_search_list(models, &query, executor.clone())
512 .map(|results| (group_name, results))
513 }))
514 .await;
515 AgentModelList::Grouped(IndexMap::from_iter(
516 groups
517 .into_iter()
518 .filter(|(_, results)| !results.is_empty()),
519 ))
520 }
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use agent_client_protocol as acp;
527 use gpui::TestAppContext;
528
529 use super::*;
530
531 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
532 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
533 |(group, models)| {
534 (
535 acp_thread::AgentModelGroupName(group.to_string().into()),
536 models
537 .into_iter()
538 .map(|model| acp_thread::AgentModelInfo {
539 id: acp::ModelId::new(model.to_string()),
540 name: model.to_string().into(),
541 description: None,
542 icon: None,
543 })
544 .collect::<Vec<_>>(),
545 )
546 },
547 )))
548 }
549
550 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
551 let AgentModelList::Grouped(groups) = result else {
552 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
553 };
554
555 assert_eq!(
556 groups.len(),
557 expected.len(),
558 "Number of groups doesn't match"
559 );
560
561 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
562 let (actual_group, actual_models) = groups.get_index(i).unwrap();
563 assert_eq!(
564 actual_group.0.as_ref(),
565 *expected_group,
566 "Group at position {} doesn't match expected group",
567 i
568 );
569 assert_eq!(
570 actual_models.len(),
571 expected_models.len(),
572 "Number of models in group {} doesn't match",
573 expected_group
574 );
575
576 for (j, expected_model_name) in expected_models.iter().enumerate() {
577 assert_eq!(
578 actual_models[j].name, *expected_model_name,
579 "Model at position {} in group {} doesn't match expected model",
580 j, expected_group
581 );
582 }
583 }
584 }
585
586 fn create_favorites(models: Vec<&str>) -> HashSet<ModelId> {
587 models
588 .into_iter()
589 .map(|m| ModelId::new(m.to_string()))
590 .collect()
591 }
592
593 fn get_entry_model_ids(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
594 entries
595 .iter()
596 .filter_map(|entry| match entry {
597 AcpModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()),
598 _ => None,
599 })
600 .collect()
601 }
602
603 fn get_entry_labels(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
604 entries
605 .iter()
606 .map(|entry| match entry {
607 AcpModelPickerEntry::Model(info, _) => info.id.0.as_ref(),
608 AcpModelPickerEntry::Separator(s) => &s,
609 })
610 .collect()
611 }
612
613 #[gpui::test]
614 async fn test_fuzzy_match(cx: &mut TestAppContext) {
615 let models = create_model_list(vec![
616 (
617 "zed",
618 vec![
619 "Claude 3.7 Sonnet",
620 "Claude 3.7 Sonnet Thinking",
621 "gpt-4.1",
622 "gpt-4.1-nano",
623 ],
624 ),
625 ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
626 ("ollama", vec!["mistral", "deepseek"]),
627 ]);
628
629 // Results should preserve models order whenever possible.
630 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
631 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
632 // so it should appear first in the results.
633 let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
634 assert_models_eq(
635 results,
636 vec![
637 ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
638 ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
639 ],
640 );
641
642 // Fuzzy search
643 let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
644 assert_models_eq(
645 results,
646 vec![
647 ("zed", vec!["gpt-4.1-nano"]),
648 ("openai", vec!["gpt-4.1-nano"]),
649 ],
650 );
651 }
652
653 #[gpui::test]
654 fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
655 let models = create_model_list(vec![
656 ("zed", vec!["zed/claude", "zed/gemini"]),
657 ("openai", vec!["openai/gpt-5"]),
658 ]);
659 let favorites = create_favorites(vec!["zed/gemini"]);
660
661 let entries = info_list_to_picker_entries(models, &favorites);
662
663 assert!(matches!(
664 entries.first(),
665 Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
666 ));
667
668 let model_ids = get_entry_model_ids(&entries);
669 assert_eq!(model_ids[0], "zed/gemini");
670 }
671
672 #[gpui::test]
673 fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
674 let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]);
675 let favorites = create_favorites(vec![]);
676
677 let entries = info_list_to_picker_entries(models, &favorites);
678
679 assert!(matches!(
680 entries.first(),
681 Some(AcpModelPickerEntry::Separator(s)) if s == "zed"
682 ));
683 }
684
685 #[gpui::test]
686 fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
687 let models = create_model_list(vec![
688 ("zed", vec!["zed/claude", "zed/gemini"]),
689 ("openai", vec!["openai/gpt-5"]),
690 ]);
691 let favorites = create_favorites(vec!["zed/claude"]);
692
693 let entries = info_list_to_picker_entries(models, &favorites);
694
695 for entry in &entries {
696 if let AcpModelPickerEntry::Model(info, is_favorite) = entry {
697 if info.id.0.as_ref() == "zed/claude" {
698 assert!(is_favorite, "zed/claude should be a favorite");
699 } else {
700 assert!(!is_favorite, "{} should not be a favorite", info.id.0);
701 }
702 }
703 }
704 }
705
706 #[gpui::test]
707 fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) {
708 let models = create_model_list(vec![
709 ("zed", vec!["zed/claude", "zed/gemini"]),
710 ("openai", vec!["openai/gpt-5", "openai/gpt-4"]),
711 ]);
712 let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]);
713
714 let entries = info_list_to_picker_entries(models, &favorites);
715 let model_ids = get_entry_model_ids(&entries);
716
717 assert_eq!(model_ids[0], "zed/gemini");
718 assert_eq!(model_ids[1], "openai/gpt-5");
719
720 assert!(model_ids[2..].contains(&"zed/gemini"));
721 assert!(model_ids[2..].contains(&"openai/gpt-5"));
722 }
723
724 #[gpui::test]
725 fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) {
726 let models = create_model_list(vec![
727 ("Recommended", vec!["zed/claude", "anthropic/claude"]),
728 ("Zed", vec!["zed/claude", "zed/gpt-5"]),
729 ("Antropic", vec!["anthropic/claude"]),
730 ("OpenAI", vec!["openai/gpt-5"]),
731 ]);
732
733 let favorites = create_favorites(vec!["zed/claude"]);
734
735 let entries = info_list_to_picker_entries(models, &favorites);
736 let labels = get_entry_labels(&entries);
737
738 assert_eq!(
739 labels,
740 vec![
741 "Favorite",
742 "zed/claude",
743 "Recommended",
744 "zed/claude",
745 "anthropic/claude",
746 "Zed",
747 "zed/claude",
748 "zed/gpt-5",
749 "Antropic",
750 "anthropic/claude",
751 "OpenAI",
752 "openai/gpt-5"
753 ]
754 );
755 }
756
757 #[gpui::test]
758 fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) {
759 let models = AgentModelList::Flat(vec![
760 acp_thread::AgentModelInfo {
761 id: acp::ModelId::new("zed/claude".to_string()),
762 name: "Claude".into(),
763 description: None,
764 icon: None,
765 },
766 acp_thread::AgentModelInfo {
767 id: acp::ModelId::new("zed/gemini".to_string()),
768 name: "Gemini".into(),
769 description: None,
770 icon: None,
771 },
772 ]);
773 let favorites = create_favorites(vec!["zed/gemini"]);
774
775 let entries = info_list_to_picker_entries(models, &favorites);
776
777 assert!(matches!(
778 entries.first(),
779 Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
780 ));
781
782 assert!(entries.iter().any(|e| matches!(
783 e,
784 AcpModelPickerEntry::Separator(s) if s == "All"
785 )));
786 }
787
788 #[gpui::test]
789 fn test_favorites_count_returns_correct_count(_cx: &mut TestAppContext) {
790 let empty_favorites: HashSet<ModelId> = HashSet::default();
791 assert_eq!(empty_favorites.len(), 0);
792
793 let one_favorite = create_favorites(vec!["model-a"]);
794 assert_eq!(one_favorite.len(), 1);
795
796 let multiple_favorites = create_favorites(vec!["model-a", "model-b", "model-c"]);
797 assert_eq!(multiple_favorites.len(), 3);
798
799 let with_duplicates = create_favorites(vec!["model-a", "model-a", "model-b"]);
800 assert_eq!(with_duplicates.len(), 2);
801 }
802
803 #[gpui::test]
804 fn test_is_favorite_flag_set_correctly_in_entries(_cx: &mut TestAppContext) {
805 let models = AgentModelList::Flat(vec![
806 acp_thread::AgentModelInfo {
807 id: acp::ModelId::new("favorite-model".to_string()),
808 name: "Favorite".into(),
809 description: None,
810 icon: None,
811 },
812 acp_thread::AgentModelInfo {
813 id: acp::ModelId::new("regular-model".to_string()),
814 name: "Regular".into(),
815 description: None,
816 icon: None,
817 },
818 ]);
819 let favorites = create_favorites(vec!["favorite-model"]);
820
821 let entries = info_list_to_picker_entries(models, &favorites);
822
823 for entry in &entries {
824 if let AcpModelPickerEntry::Model(info, is_favorite) = entry {
825 if info.id.0.as_ref() == "favorite-model" {
826 assert!(*is_favorite, "favorite-model should have is_favorite=true");
827 } else if info.id.0.as_ref() == "regular-model" {
828 assert!(!*is_favorite, "regular-model should have is_favorite=false");
829 }
830 }
831 }
832 }
833}