1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_client_protocol::schema as acp;
5use agent_servers::AgentServer;
6use agent_settings::AgentSettings;
7use anyhow::Result;
8use collections::{HashSet, IndexMap};
9use fs::Fs;
10use futures::FutureExt;
11use fuzzy::{StringMatchCandidate, match_strings};
12use gpui::{
13 Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
14 WeakEntity,
15};
16use itertools::Itertools;
17use ordered_float::OrderedFloat;
18use picker::{Picker, PickerDelegate};
19use settings::{Settings, SettingsStore};
20use ui::{DocumentationAside, DocumentationSide, IntoElement, prelude::*};
21use util::ResultExt;
22use zed_actions::agent::OpenSettings;
23
24use crate::ui::{HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
25
26pub type ModelSelector = Picker<ModelPickerDelegate>;
27
28pub fn acp_model_selector(
29 selector: Rc<dyn AgentModelSelector>,
30 agent_server: Rc<dyn AgentServer>,
31 fs: Arc<dyn Fs>,
32 focus_handle: FocusHandle,
33 window: &mut Window,
34 cx: &mut Context<ModelSelector>,
35) -> ModelSelector {
36 let delegate = ModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
37 Picker::list(delegate, window, cx)
38 .show_scrollbar(true)
39 .width(rems(20.))
40 .max_height(Some(rems(20.).into()))
41}
42
43enum ModelPickerEntry {
44 Separator(SharedString),
45 Model(AgentModelInfo, bool),
46}
47
48pub struct ModelPickerDelegate {
49 selector: Rc<dyn AgentModelSelector>,
50 agent_server: Rc<dyn AgentServer>,
51 fs: Arc<dyn Fs>,
52 filtered_entries: Vec<ModelPickerEntry>,
53 models: Option<AgentModelList>,
54 selected_index: usize,
55 selected_description: Option<(usize, SharedString, bool)>,
56 selected_model: Option<AgentModelInfo>,
57 favorites: HashSet<acp::ModelId>,
58 _refresh_models_task: Task<()>,
59 _settings_subscription: Subscription,
60 focus_handle: FocusHandle,
61}
62
63impl ModelPickerDelegate {
64 fn new(
65 selector: Rc<dyn AgentModelSelector>,
66 agent_server: Rc<dyn AgentServer>,
67 fs: Arc<dyn Fs>,
68 focus_handle: FocusHandle,
69 window: &mut Window,
70 cx: &mut Context<ModelSelector>,
71 ) -> Self {
72 let rx = selector.watch(cx);
73 let refresh_models_task = {
74 cx.spawn_in(window, {
75 async move |this, cx| {
76 async fn refresh(
77 this: &WeakEntity<Picker<ModelPickerDelegate>>,
78 cx: &mut AsyncWindowContext,
79 ) -> Result<()> {
80 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
81 (
82 this.delegate.selector.list_models(cx),
83 this.delegate.selector.selected_model(cx),
84 )
85 })?;
86
87 let (models, selected_model) =
88 futures::join!(models_task, selected_model_task);
89
90 this.update_in(cx, |this, window, cx| {
91 this.delegate.models = models.ok();
92 this.delegate.selected_model = selected_model.ok();
93 this.refresh(window, cx)
94 })
95 }
96
97 refresh(&this, cx).await.log_err();
98 if let Some(mut rx) = rx {
99 while let Ok(()) = rx.recv().await {
100 refresh(&this, cx).await.log_err();
101 }
102 }
103 }
104 })
105 };
106
107 let agent_server_for_subscription = agent_server.clone();
108 let settings_subscription =
109 cx.observe_global_in::<SettingsStore>(window, move |picker, window, cx| {
110 // Only refresh if the favorites actually changed to avoid redundant work
111 // when other settings are modified (e.g., user editing settings.json)
112 let new_favorites = agent_server_for_subscription.favorite_model_ids(cx);
113 if new_favorites != picker.delegate.favorites {
114 picker.delegate.favorites = new_favorites;
115 picker.refresh(window, cx);
116 }
117 });
118 let favorites = agent_server.favorite_model_ids(cx);
119
120 Self {
121 selector,
122 agent_server,
123 fs,
124 filtered_entries: Vec::new(),
125 models: None,
126 selected_model: None,
127 selected_index: 0,
128 selected_description: None,
129 favorites,
130 _refresh_models_task: refresh_models_task,
131 _settings_subscription: settings_subscription,
132 focus_handle,
133 }
134 }
135
136 pub fn active_model(&self) -> Option<&AgentModelInfo> {
137 self.selected_model.as_ref()
138 }
139
140 pub fn favorites_count(&self) -> usize {
141 self.favorites.len()
142 }
143
144 pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
145 if self.favorites.is_empty() {
146 return;
147 }
148
149 let Some(models) = &self.models else {
150 return;
151 };
152
153 let all_models: Vec<&AgentModelInfo> = match models {
154 AgentModelList::Flat(list) => list.iter().collect(),
155 AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
156 };
157
158 let favorite_models: Vec<_> = all_models
159 .into_iter()
160 .filter(|model| self.favorites.contains(&model.id))
161 .unique_by(|model| &model.id)
162 .collect();
163
164 if favorite_models.is_empty() {
165 return;
166 }
167
168 let current_id = self.selected_model.as_ref().map(|m| &m.id);
169
170 let current_index_in_favorites = current_id
171 .and_then(|id| favorite_models.iter().position(|m| &m.id == id))
172 .unwrap_or(usize::MAX);
173
174 let next_index = if current_index_in_favorites == usize::MAX {
175 0
176 } else {
177 (current_index_in_favorites + 1) % favorite_models.len()
178 };
179
180 let next_model = favorite_models[next_index].clone();
181
182 self.selector
183 .select_model(next_model.id.clone(), cx)
184 .detach_and_log_err(cx);
185
186 self.selected_model = Some(next_model);
187
188 // Keep the picker selection aligned with the newly-selected model
189 if let Some(new_index) = self.filtered_entries.iter().position(|entry| {
190 matches!(entry, ModelPickerEntry::Model(model_info, _) if self.selected_model.as_ref().is_some_and(|selected| model_info.id == selected.id))
191 }) {
192 self.set_selected_index(new_index, window, cx);
193 } else {
194 cx.notify();
195 }
196 }
197}
198
199impl PickerDelegate for ModelPickerDelegate {
200 type ListItem = AnyElement;
201
202 fn match_count(&self) -> usize {
203 self.filtered_entries.len()
204 }
205
206 fn selected_index(&self) -> usize {
207 self.selected_index
208 }
209
210 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
211 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
212 cx.notify();
213 }
214
215 fn can_select(&self, ix: usize, _window: &mut Window, _cx: &mut Context<Picker<Self>>) -> bool {
216 match self.filtered_entries.get(ix) {
217 Some(ModelPickerEntry::Model(_, _)) => true,
218 Some(ModelPickerEntry::Separator(_)) | None => false,
219 }
220 }
221
222 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
223 "Select a model…".into()
224 }
225
226 fn update_matches(
227 &mut self,
228 query: String,
229 window: &mut Window,
230 cx: &mut Context<Picker<Self>>,
231 ) -> Task<()> {
232 let favorites = self.favorites.clone();
233
234 cx.spawn_in(window, async move |this, cx| {
235 let filtered_models = match this
236 .read_with(cx, |this, cx| {
237 this.delegate.models.clone().map(move |models| {
238 fuzzy_search(models, query, cx.background_executor().clone())
239 })
240 })
241 .ok()
242 .flatten()
243 {
244 Some(task) => task.await,
245 None => AgentModelList::Flat(vec![]),
246 };
247
248 this.update_in(cx, |this, window, cx| {
249 this.delegate.filtered_entries =
250 info_list_to_picker_entries(filtered_models, &favorites);
251 // Finds the currently selected model in the list
252 let new_index = this
253 .delegate
254 .selected_model
255 .as_ref()
256 .and_then(|selected| {
257 this.delegate.filtered_entries.iter().position(|entry| {
258 if let ModelPickerEntry::Model(model_info, _) = entry {
259 model_info.id == selected.id
260 } else {
261 false
262 }
263 })
264 })
265 .unwrap_or(0);
266 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
267 cx.notify();
268 })
269 .ok();
270 })
271 }
272
273 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
274 if let Some(ModelPickerEntry::Model(model_info, _)) =
275 self.filtered_entries.get(self.selected_index)
276 {
277 if window.modifiers().secondary() {
278 let default_model = self.agent_server.default_model(cx);
279 let is_default = default_model.as_ref() == Some(&model_info.id);
280
281 self.agent_server.set_default_model(
282 if is_default {
283 None
284 } else {
285 Some(model_info.id.clone())
286 },
287 self.fs.clone(),
288 cx,
289 );
290 }
291
292 self.selector
293 .select_model(model_info.id.clone(), cx)
294 .detach_and_log_err(cx);
295 self.selected_model = Some(model_info.clone());
296 let current_index = self.selected_index;
297 self.set_selected_index(current_index, window, cx);
298
299 cx.emit(DismissEvent);
300 }
301 }
302
303 fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
304 cx.defer_in(window, |picker, window, cx| {
305 picker.set_query("", window, cx);
306 });
307 }
308
309 fn render_match(
310 &self,
311 ix: usize,
312 selected: bool,
313 _: &mut Window,
314 cx: &mut Context<Picker<Self>>,
315 ) -> Option<Self::ListItem> {
316 match self.filtered_entries.get(ix)? {
317 ModelPickerEntry::Separator(title) => {
318 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
319 }
320 ModelPickerEntry::Model(model_info, is_favorite) => {
321 let is_selected = Some(model_info) == self.selected_model.as_ref();
322 let default_model = self.agent_server.default_model(cx);
323 let is_default = default_model.as_ref() == Some(&model_info.id);
324
325 let is_favorite = *is_favorite;
326 let handle_action_click = {
327 let model_id = model_info.id.clone();
328 let fs = self.fs.clone();
329 let agent_server = self.agent_server.clone();
330
331 cx.listener(move |_, _, _, cx| {
332 agent_server.toggle_favorite_model(
333 model_id.clone(),
334 !is_favorite,
335 fs.clone(),
336 cx,
337 );
338 })
339 };
340
341 let model_cost = model_info.cost.clone();
342
343 Some(
344 div()
345 .id(("model-picker-menu-child", ix))
346 .when_some(model_info.description.clone(), |this, description| {
347 this.on_hover(cx.listener(move |menu, hovered, _, cx| {
348 if *hovered {
349 menu.delegate.selected_description =
350 Some((ix, description.clone(), is_default));
351 } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
352 menu.delegate.selected_description = None;
353 }
354 cx.notify();
355 }))
356 })
357 .child(
358 ModelSelectorListItem::new(ix, model_info.name.clone())
359 .map(|this| match &model_info.icon {
360 Some(AgentModelIcon::Path(path)) => this.icon_path(path.clone()),
361 Some(AgentModelIcon::Named(icon)) => this.icon(*icon),
362 None => this,
363 })
364 .is_selected(is_selected)
365 .is_focused(selected)
366 .is_latest(model_info.is_latest)
367 .is_favorite(is_favorite)
368 .on_toggle_favorite(handle_action_click)
369 .cost_info(model_cost)
370 )
371 .into_any_element(),
372 )
373 }
374 }
375 }
376
377 fn documentation_aside(
378 &self,
379 _window: &mut Window,
380 cx: &mut Context<Picker<Self>>,
381 ) -> Option<ui::DocumentationAside> {
382 self.selected_description
383 .as_ref()
384 .map(|(_, description, is_default)| {
385 let description = description.clone();
386 let is_default = *is_default;
387
388 let settings = AgentSettings::get_global(cx);
389 let side = match settings.dock {
390 settings::DockPosition::Left => DocumentationSide::Right,
391 settings::DockPosition::Bottom | settings::DockPosition::Right => {
392 DocumentationSide::Left
393 }
394 };
395
396 DocumentationAside::new(
397 side,
398 Rc::new(move |_| {
399 v_flex()
400 .gap_1()
401 .child(Label::new(description.clone()))
402 .child(HoldForDefault::new(is_default))
403 .into_any_element()
404 }),
405 )
406 })
407 }
408
409 fn documentation_aside_index(&self) -> Option<usize> {
410 self.selected_description.as_ref().map(|(ix, _, _)| *ix)
411 }
412
413 fn render_footer(
414 &self,
415 _window: &mut Window,
416 _cx: &mut Context<Picker<Self>>,
417 ) -> Option<AnyElement> {
418 let focus_handle = self.focus_handle.clone();
419
420 if !self.selector.should_render_footer() {
421 return None;
422 }
423
424 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
425 }
426}
427
428fn info_list_to_picker_entries(
429 model_list: AgentModelList,
430 favorites: &HashSet<acp::ModelId>,
431) -> Vec<ModelPickerEntry> {
432 let mut entries = Vec::new();
433
434 let all_models: Vec<_> = match &model_list {
435 AgentModelList::Flat(list) => list.iter().collect(),
436 AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
437 };
438
439 let favorite_models: Vec<_> = all_models
440 .iter()
441 .filter(|m| favorites.contains(&m.id))
442 .unique_by(|m| &m.id)
443 .collect();
444
445 let has_favorites = !favorite_models.is_empty();
446 if has_favorites {
447 entries.push(ModelPickerEntry::Separator("Favorite".into()));
448 for model in favorite_models {
449 entries.push(ModelPickerEntry::Model((*model).clone(), true));
450 }
451 }
452
453 match model_list {
454 AgentModelList::Flat(list) => {
455 if has_favorites {
456 entries.push(ModelPickerEntry::Separator("All".into()));
457 }
458 for model in list {
459 let is_favorite = favorites.contains(&model.id);
460 entries.push(ModelPickerEntry::Model(model, is_favorite));
461 }
462 }
463 AgentModelList::Grouped(index_map) => {
464 for (group_name, models) in index_map {
465 entries.push(ModelPickerEntry::Separator(group_name.0));
466 for model in models {
467 let is_favorite = favorites.contains(&model.id);
468 entries.push(ModelPickerEntry::Model(model, is_favorite));
469 }
470 }
471 }
472 }
473
474 entries
475}
476
477async fn fuzzy_search(
478 model_list: AgentModelList,
479 query: String,
480 executor: BackgroundExecutor,
481) -> AgentModelList {
482 async fn fuzzy_search_list(
483 model_list: Vec<AgentModelInfo>,
484 query: &str,
485 executor: BackgroundExecutor,
486 ) -> Vec<AgentModelInfo> {
487 let candidates = model_list
488 .iter()
489 .enumerate()
490 .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
491 .collect::<Vec<_>>();
492 let mut matches = match_strings(
493 &candidates,
494 query,
495 false,
496 true,
497 100,
498 &Default::default(),
499 executor,
500 )
501 .await;
502
503 matches.sort_unstable_by_key(|mat| {
504 let candidate = &candidates[mat.candidate_id];
505 (Reverse(OrderedFloat(mat.score)), candidate.id)
506 });
507
508 matches
509 .into_iter()
510 .map(|mat| model_list[mat.candidate_id].clone())
511 .collect()
512 }
513
514 match model_list {
515 AgentModelList::Flat(model_list) => {
516 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
517 }
518 AgentModelList::Grouped(index_map) => {
519 let groups =
520 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
521 fuzzy_search_list(models, &query, executor.clone())
522 .map(|results| (group_name, results))
523 }))
524 .await;
525 AgentModelList::Grouped(IndexMap::from_iter(
526 groups
527 .into_iter()
528 .filter(|(_, results)| !results.is_empty()),
529 ))
530 }
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use gpui::TestAppContext;
537
538 use super::*;
539
540 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
541 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
542 |(group, models)| {
543 (
544 acp_thread::AgentModelGroupName(group.to_string().into()),
545 models
546 .into_iter()
547 .map(|model| acp_thread::AgentModelInfo {
548 id: acp::ModelId::new(model.to_string()),
549 name: model.to_string().into(),
550 description: None,
551 icon: None,
552 is_latest: false,
553 cost: None,
554 })
555 .collect::<Vec<_>>(),
556 )
557 },
558 )))
559 }
560
561 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
562 let AgentModelList::Grouped(groups) = result else {
563 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
564 };
565
566 assert_eq!(
567 groups.len(),
568 expected.len(),
569 "Number of groups doesn't match"
570 );
571
572 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
573 let (actual_group, actual_models) = groups.get_index(i).unwrap();
574 assert_eq!(
575 actual_group.0.as_ref(),
576 *expected_group,
577 "Group at position {} doesn't match expected group",
578 i
579 );
580 assert_eq!(
581 actual_models.len(),
582 expected_models.len(),
583 "Number of models in group {} doesn't match",
584 expected_group
585 );
586
587 for (j, expected_model_name) in expected_models.iter().enumerate() {
588 assert_eq!(
589 actual_models[j].name, *expected_model_name,
590 "Model at position {} in group {} doesn't match expected model",
591 j, expected_group
592 );
593 }
594 }
595 }
596
597 fn create_favorites(models: Vec<&str>) -> HashSet<acp::ModelId> {
598 models
599 .into_iter()
600 .map(|m| acp::ModelId::new(m.to_string()))
601 .collect()
602 }
603
604 fn get_entry_model_ids(entries: &[ModelPickerEntry]) -> Vec<&str> {
605 entries
606 .iter()
607 .filter_map(|entry| match entry {
608 ModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()),
609 _ => None,
610 })
611 .collect()
612 }
613
614 fn get_entry_labels(entries: &[ModelPickerEntry]) -> Vec<&str> {
615 entries
616 .iter()
617 .map(|entry| match entry {
618 ModelPickerEntry::Model(info, _) => info.id.0.as_ref(),
619 ModelPickerEntry::Separator(s) => &s,
620 })
621 .collect()
622 }
623
624 #[gpui::test]
625 async fn test_fuzzy_match(cx: &mut TestAppContext) {
626 let models = create_model_list(vec![
627 (
628 "zed",
629 vec![
630 "Claude 3.7 Sonnet",
631 "Claude 3.7 Sonnet Thinking",
632 "gpt-5",
633 "gpt-5-mini",
634 ],
635 ),
636 ("openai", vec!["gpt-3.5-turbo", "gpt-5", "gpt-5-mini"]),
637 ("ollama", vec!["mistral", "deepseek"]),
638 ]);
639
640 // Results should preserve models order whenever possible.
641 // In the case below, `zed/gpt-5-mini` and `openai/gpt-5-mini` have identical
642 // similarity scores, but `zed/gpt-5-mini` was higher in the models list,
643 // so it should appear first in the results.
644 let results = fuzzy_search(models.clone(), "mini".into(), cx.executor()).await;
645 assert_models_eq(
646 results,
647 vec![("zed", vec!["gpt-5-mini"]), ("openai", vec!["gpt-5-mini"])],
648 );
649
650 // Fuzzy search - test with specific model name
651 let results = fuzzy_search(models.clone(), "mistral".into(), cx.executor()).await;
652 assert_models_eq(results, vec![("ollama", vec!["mistral"])]);
653 }
654
655 #[gpui::test]
656 fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
657 let models = create_model_list(vec![
658 ("zed", vec!["zed/claude", "zed/gemini"]),
659 ("openai", vec!["openai/gpt-5"]),
660 ]);
661 let favorites = create_favorites(vec!["zed/gemini"]);
662
663 let entries = info_list_to_picker_entries(models, &favorites);
664
665 assert!(matches!(
666 entries.first(),
667 Some(ModelPickerEntry::Separator(s)) if s == "Favorite"
668 ));
669
670 let model_ids = get_entry_model_ids(&entries);
671 assert_eq!(model_ids[0], "zed/gemini");
672 }
673
674 #[gpui::test]
675 fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
676 let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]);
677 let favorites = create_favorites(vec![]);
678
679 let entries = info_list_to_picker_entries(models, &favorites);
680
681 assert!(matches!(
682 entries.first(),
683 Some(ModelPickerEntry::Separator(s)) if s == "zed"
684 ));
685 }
686
687 #[gpui::test]
688 fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
689 let models = create_model_list(vec![
690 ("zed", vec!["zed/claude", "zed/gemini"]),
691 ("openai", vec!["openai/gpt-5"]),
692 ]);
693 let favorites = create_favorites(vec!["zed/claude"]);
694
695 let entries = info_list_to_picker_entries(models, &favorites);
696
697 for entry in &entries {
698 if let ModelPickerEntry::Model(info, is_favorite) = entry {
699 if info.id.0.as_ref() == "zed/claude" {
700 assert!(is_favorite, "zed/claude should be a favorite");
701 } else {
702 assert!(!is_favorite, "{} should not be a favorite", info.id.0);
703 }
704 }
705 }
706 }
707
708 #[gpui::test]
709 fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) {
710 let models = create_model_list(vec![
711 ("zed", vec!["zed/claude", "zed/gemini"]),
712 ("openai", vec!["openai/gpt-5", "openai/gpt-4"]),
713 ]);
714 let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]);
715
716 let entries = info_list_to_picker_entries(models, &favorites);
717 let model_ids = get_entry_model_ids(&entries);
718
719 assert_eq!(model_ids[0], "zed/gemini");
720 assert_eq!(model_ids[1], "openai/gpt-5");
721
722 assert!(model_ids[2..].contains(&"zed/gemini"));
723 assert!(model_ids[2..].contains(&"openai/gpt-5"));
724 }
725
726 #[gpui::test]
727 fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) {
728 let models = create_model_list(vec![
729 ("Recommended", vec!["zed/claude", "anthropic/claude"]),
730 ("Zed", vec!["zed/claude", "zed/gpt-5"]),
731 ("Antropic", vec!["anthropic/claude"]),
732 ("OpenAI", vec!["openai/gpt-5"]),
733 ]);
734
735 let favorites = create_favorites(vec!["zed/claude"]);
736
737 let entries = info_list_to_picker_entries(models, &favorites);
738 let labels = get_entry_labels(&entries);
739
740 assert_eq!(
741 labels,
742 vec![
743 "Favorite",
744 "zed/claude",
745 "Recommended",
746 "zed/claude",
747 "anthropic/claude",
748 "Zed",
749 "zed/claude",
750 "zed/gpt-5",
751 "Antropic",
752 "anthropic/claude",
753 "OpenAI",
754 "openai/gpt-5"
755 ]
756 );
757 }
758
759 #[gpui::test]
760 fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) {
761 let models = AgentModelList::Flat(vec![
762 acp_thread::AgentModelInfo {
763 id: acp::ModelId::new("zed/claude".to_string()),
764 name: "Claude".into(),
765 description: None,
766 icon: None,
767 is_latest: false,
768 cost: None,
769 },
770 acp_thread::AgentModelInfo {
771 id: acp::ModelId::new("zed/gemini".to_string()),
772 name: "Gemini".into(),
773 description: None,
774 icon: None,
775 is_latest: false,
776 cost: None,
777 },
778 ]);
779 let favorites = create_favorites(vec!["zed/gemini"]);
780
781 let entries = info_list_to_picker_entries(models, &favorites);
782
783 assert!(matches!(
784 entries.first(),
785 Some(ModelPickerEntry::Separator(s)) if s == "Favorite"
786 ));
787
788 assert!(entries.iter().any(|e| matches!(
789 e,
790 ModelPickerEntry::Separator(s) if s == "All"
791 )));
792 }
793
794 #[gpui::test]
795 fn test_favorites_count_returns_correct_count(_cx: &mut TestAppContext) {
796 let empty_favorites: HashSet<acp::ModelId> = HashSet::default();
797 assert_eq!(empty_favorites.len(), 0);
798
799 let one_favorite = create_favorites(vec!["model-a"]);
800 assert_eq!(one_favorite.len(), 1);
801
802 let multiple_favorites = create_favorites(vec!["model-a", "model-b", "model-c"]);
803 assert_eq!(multiple_favorites.len(), 3);
804
805 let with_duplicates = create_favorites(vec!["model-a", "model-a", "model-b"]);
806 assert_eq!(with_duplicates.len(), 2);
807 }
808
809 #[gpui::test]
810 fn test_is_favorite_flag_set_correctly_in_entries(_cx: &mut TestAppContext) {
811 let models = AgentModelList::Flat(vec![
812 acp_thread::AgentModelInfo {
813 id: acp::ModelId::new("favorite-model".to_string()),
814 name: "Favorite".into(),
815 description: None,
816 icon: None,
817 is_latest: false,
818 cost: None,
819 },
820 acp_thread::AgentModelInfo {
821 id: acp::ModelId::new("regular-model".to_string()),
822 name: "Regular".into(),
823 description: None,
824 icon: None,
825 is_latest: false,
826 cost: None,
827 },
828 ]);
829 let favorites = create_favorites(vec!["favorite-model"]);
830
831 let entries = info_list_to_picker_entries(models, &favorites);
832
833 for entry in &entries {
834 if let ModelPickerEntry::Model(info, is_favorite) = entry {
835 if info.id.0.as_ref() == "favorite-model" {
836 assert!(*is_favorite, "favorite-model should have is_favorite=true");
837 } else if info.id.0.as_ref() == "regular-model" {
838 assert!(!*is_favorite, "regular-model should have is_favorite=false");
839 }
840 }
841 }
842 }
843}