model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_client_protocol::ModelId;
  5use agent_servers::AgentServer;
  6use anyhow::Result;
  7use collections::{HashSet, IndexMap};
  8use fs::Fs;
  9use futures::FutureExt;
 10use fuzzy::{StringMatchCandidate, match_strings};
 11use gpui::{
 12    Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
 13    WeakEntity,
 14};
 15use itertools::Itertools;
 16use ordered_float::OrderedFloat;
 17use picker::{Picker, PickerDelegate};
 18use settings::SettingsStore;
 19use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, prelude::*};
 20use util::ResultExt;
 21use zed_actions::agent::OpenSettings;
 22
 23use crate::ui::{HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 24
 25pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 26
 27pub fn acp_model_selector(
 28    selector: Rc<dyn AgentModelSelector>,
 29    agent_server: Rc<dyn AgentServer>,
 30    fs: Arc<dyn Fs>,
 31    focus_handle: FocusHandle,
 32    window: &mut Window,
 33    cx: &mut Context<AcpModelSelector>,
 34) -> AcpModelSelector {
 35    let delegate =
 36        AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
 37    Picker::list(delegate, window, cx)
 38        .show_scrollbar(true)
 39        .width(rems(20.))
 40        .max_height(Some(rems(20.).into()))
 41}
 42
 43enum AcpModelPickerEntry {
 44    Separator(SharedString),
 45    Model(AgentModelInfo, bool),
 46}
 47
 48pub struct AcpModelPickerDelegate {
 49    selector: Rc<dyn AgentModelSelector>,
 50    agent_server: Rc<dyn AgentServer>,
 51    fs: Arc<dyn Fs>,
 52    filtered_entries: Vec<AcpModelPickerEntry>,
 53    models: Option<AgentModelList>,
 54    selected_index: usize,
 55    selected_description: Option<(usize, SharedString, bool)>,
 56    selected_model: Option<AgentModelInfo>,
 57    favorites: HashSet<ModelId>,
 58    _refresh_models_task: Task<()>,
 59    _settings_subscription: Subscription,
 60    focus_handle: FocusHandle,
 61}
 62
 63impl AcpModelPickerDelegate {
 64    fn new(
 65        selector: Rc<dyn AgentModelSelector>,
 66        agent_server: Rc<dyn AgentServer>,
 67        fs: Arc<dyn Fs>,
 68        focus_handle: FocusHandle,
 69        window: &mut Window,
 70        cx: &mut Context<AcpModelSelector>,
 71    ) -> Self {
 72        let rx = selector.watch(cx);
 73        let refresh_models_task = {
 74            cx.spawn_in(window, {
 75                async move |this, cx| {
 76                    async fn refresh(
 77                        this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 78                        cx: &mut AsyncWindowContext,
 79                    ) -> Result<()> {
 80                        let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 81                            (
 82                                this.delegate.selector.list_models(cx),
 83                                this.delegate.selector.selected_model(cx),
 84                            )
 85                        })?;
 86
 87                        let (models, selected_model) =
 88                            futures::join!(models_task, selected_model_task);
 89
 90                        this.update_in(cx, |this, window, cx| {
 91                            this.delegate.models = models.ok();
 92                            this.delegate.selected_model = selected_model.ok();
 93                            this.refresh(window, cx)
 94                        })
 95                    }
 96
 97                    refresh(&this, cx).await.log_err();
 98                    if let Some(mut rx) = rx {
 99                        while let Ok(()) = rx.recv().await {
100                            refresh(&this, cx).await.log_err();
101                        }
102                    }
103                }
104            })
105        };
106
107        let agent_server_for_subscription = agent_server.clone();
108        let settings_subscription =
109            cx.observe_global_in::<SettingsStore>(window, move |picker, window, cx| {
110                // Only refresh if the favorites actually changed to avoid redundant work
111                // when other settings are modified (e.g., user editing settings.json)
112                let new_favorites = agent_server_for_subscription.favorite_model_ids(cx);
113                if new_favorites != picker.delegate.favorites {
114                    picker.delegate.favorites = new_favorites;
115                    picker.refresh(window, cx);
116                }
117            });
118        let favorites = agent_server.favorite_model_ids(cx);
119
120        Self {
121            selector,
122            agent_server,
123            fs,
124            filtered_entries: Vec::new(),
125            models: None,
126            selected_model: None,
127            selected_index: 0,
128            selected_description: None,
129            favorites,
130            _refresh_models_task: refresh_models_task,
131            _settings_subscription: settings_subscription,
132            focus_handle,
133        }
134    }
135
136    pub fn active_model(&self) -> Option<&AgentModelInfo> {
137        self.selected_model.as_ref()
138    }
139
140    pub fn favorites_count(&self) -> usize {
141        self.favorites.len()
142    }
143
144    pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
145        if self.favorites.is_empty() {
146            return;
147        }
148
149        let Some(models) = &self.models else {
150            return;
151        };
152
153        let all_models: Vec<&AgentModelInfo> = match models {
154            AgentModelList::Flat(list) => list.iter().collect(),
155            AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
156        };
157
158        let favorite_models: Vec<_> = all_models
159            .into_iter()
160            .filter(|model| self.favorites.contains(&model.id))
161            .unique_by(|model| &model.id)
162            .collect();
163
164        if favorite_models.is_empty() {
165            return;
166        }
167
168        let current_id = self.selected_model.as_ref().map(|m| &m.id);
169
170        let current_index_in_favorites = current_id
171            .and_then(|id| favorite_models.iter().position(|m| &m.id == id))
172            .unwrap_or(usize::MAX);
173
174        let next_index = if current_index_in_favorites == usize::MAX {
175            0
176        } else {
177            (current_index_in_favorites + 1) % favorite_models.len()
178        };
179
180        let next_model = favorite_models[next_index].clone();
181
182        self.selector
183            .select_model(next_model.id.clone(), cx)
184            .detach_and_log_err(cx);
185
186        self.selected_model = Some(next_model);
187
188        // Keep the picker selection aligned with the newly-selected model
189        if let Some(new_index) = self.filtered_entries.iter().position(|entry| {
190            matches!(entry, AcpModelPickerEntry::Model(model_info, _) if self.selected_model.as_ref().is_some_and(|selected| model_info.id == selected.id))
191        }) {
192            self.set_selected_index(new_index, window, cx);
193        } else {
194            cx.notify();
195        }
196    }
197}
198
199impl PickerDelegate for AcpModelPickerDelegate {
200    type ListItem = AnyElement;
201
202    fn match_count(&self) -> usize {
203        self.filtered_entries.len()
204    }
205
206    fn selected_index(&self) -> usize {
207        self.selected_index
208    }
209
210    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
211        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
212        cx.notify();
213    }
214
215    fn can_select(
216        &mut self,
217        ix: usize,
218        _window: &mut Window,
219        _cx: &mut Context<Picker<Self>>,
220    ) -> bool {
221        match self.filtered_entries.get(ix) {
222            Some(AcpModelPickerEntry::Model(_, _)) => true,
223            Some(AcpModelPickerEntry::Separator(_)) | None => false,
224        }
225    }
226
227    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
228        "Select a model…".into()
229    }
230
231    fn update_matches(
232        &mut self,
233        query: String,
234        window: &mut Window,
235        cx: &mut Context<Picker<Self>>,
236    ) -> Task<()> {
237        let favorites = self.favorites.clone();
238
239        cx.spawn_in(window, async move |this, cx| {
240            let filtered_models = match this
241                .read_with(cx, |this, cx| {
242                    this.delegate.models.clone().map(move |models| {
243                        fuzzy_search(models, query, cx.background_executor().clone())
244                    })
245                })
246                .ok()
247                .flatten()
248            {
249                Some(task) => task.await,
250                None => AgentModelList::Flat(vec![]),
251            };
252
253            this.update_in(cx, |this, window, cx| {
254                this.delegate.filtered_entries =
255                    info_list_to_picker_entries(filtered_models, &favorites);
256                // Finds the currently selected model in the list
257                let new_index = this
258                    .delegate
259                    .selected_model
260                    .as_ref()
261                    .and_then(|selected| {
262                        this.delegate.filtered_entries.iter().position(|entry| {
263                            if let AcpModelPickerEntry::Model(model_info, _) = entry {
264                                model_info.id == selected.id
265                            } else {
266                                false
267                            }
268                        })
269                    })
270                    .unwrap_or(0);
271                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
272                cx.notify();
273            })
274            .ok();
275        })
276    }
277
278    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
279        if let Some(AcpModelPickerEntry::Model(model_info, _)) =
280            self.filtered_entries.get(self.selected_index)
281        {
282            if window.modifiers().secondary() {
283                let default_model = self.agent_server.default_model(cx);
284                let is_default = default_model.as_ref() == Some(&model_info.id);
285
286                self.agent_server.set_default_model(
287                    if is_default {
288                        None
289                    } else {
290                        Some(model_info.id.clone())
291                    },
292                    self.fs.clone(),
293                    cx,
294                );
295            }
296
297            self.selector
298                .select_model(model_info.id.clone(), cx)
299                .detach_and_log_err(cx);
300            self.selected_model = Some(model_info.clone());
301            let current_index = self.selected_index;
302            self.set_selected_index(current_index, window, cx);
303
304            cx.emit(DismissEvent);
305        }
306    }
307
308    fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
309        cx.defer_in(window, |picker, window, cx| {
310            picker.set_query("", window, cx);
311        });
312    }
313
314    fn render_match(
315        &self,
316        ix: usize,
317        selected: bool,
318        _: &mut Window,
319        cx: &mut Context<Picker<Self>>,
320    ) -> Option<Self::ListItem> {
321        match self.filtered_entries.get(ix)? {
322            AcpModelPickerEntry::Separator(title) => {
323                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
324            }
325            AcpModelPickerEntry::Model(model_info, is_favorite) => {
326                let is_selected = Some(model_info) == self.selected_model.as_ref();
327                let default_model = self.agent_server.default_model(cx);
328                let is_default = default_model.as_ref() == Some(&model_info.id);
329
330                let is_favorite = *is_favorite;
331                let handle_action_click = {
332                    let model_id = model_info.id.clone();
333                    let fs = self.fs.clone();
334                    let agent_server = self.agent_server.clone();
335
336                    cx.listener(move |_, _, _, cx| {
337                        agent_server.toggle_favorite_model(
338                            model_id.clone(),
339                            !is_favorite,
340                            fs.clone(),
341                            cx,
342                        );
343                    })
344                };
345
346                Some(
347                    div()
348                        .id(("model-picker-menu-child", ix))
349                        .when_some(model_info.description.clone(), |this, description| {
350                            this.on_hover(cx.listener(move |menu, hovered, _, cx| {
351                                if *hovered {
352                                    menu.delegate.selected_description =
353                                        Some((ix, description.clone(), is_default));
354                                } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
355                                    menu.delegate.selected_description = None;
356                                }
357                                cx.notify();
358                            }))
359                        })
360                        .child(
361                            ModelSelectorListItem::new(ix, model_info.name.clone())
362                                .map(|this| match &model_info.icon {
363                                    Some(AgentModelIcon::Path(path)) => this.icon_path(path.clone()),
364                                    Some(AgentModelIcon::Named(icon)) => this.icon(*icon),
365                                    None => this,
366                                })
367                                .is_selected(is_selected)
368                                .is_focused(selected)
369                                .is_favorite(is_favorite)
370                                .on_toggle_favorite(handle_action_click),
371                        )
372                        .into_any_element(),
373                )
374            }
375        }
376    }
377
378    fn documentation_aside(
379        &self,
380        _window: &mut Window,
381        _cx: &mut Context<Picker<Self>>,
382    ) -> Option<ui::DocumentationAside> {
383        self.selected_description
384            .as_ref()
385            .map(|(_, description, is_default)| {
386                let description = description.clone();
387                let is_default = *is_default;
388
389                DocumentationAside::new(
390                    DocumentationSide::Left,
391                    DocumentationEdge::Top,
392                    Rc::new(move |_| {
393                        v_flex()
394                            .gap_1()
395                            .child(Label::new(description.clone()))
396                            .child(HoldForDefault::new(is_default))
397                            .into_any_element()
398                    }),
399                )
400            })
401    }
402
403    fn render_footer(
404        &self,
405        _window: &mut Window,
406        _cx: &mut Context<Picker<Self>>,
407    ) -> Option<AnyElement> {
408        let focus_handle = self.focus_handle.clone();
409
410        if !self.selector.should_render_footer() {
411            return None;
412        }
413
414        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
415    }
416}
417
418fn info_list_to_picker_entries(
419    model_list: AgentModelList,
420    favorites: &HashSet<ModelId>,
421) -> Vec<AcpModelPickerEntry> {
422    let mut entries = Vec::new();
423
424    let all_models: Vec<_> = match &model_list {
425        AgentModelList::Flat(list) => list.iter().collect(),
426        AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
427    };
428
429    let favorite_models: Vec<_> = all_models
430        .iter()
431        .filter(|m| favorites.contains(&m.id))
432        .unique_by(|m| &m.id)
433        .collect();
434
435    let has_favorites = !favorite_models.is_empty();
436    if has_favorites {
437        entries.push(AcpModelPickerEntry::Separator("Favorite".into()));
438        for model in favorite_models {
439            entries.push(AcpModelPickerEntry::Model((*model).clone(), true));
440        }
441    }
442
443    match model_list {
444        AgentModelList::Flat(list) => {
445            if has_favorites {
446                entries.push(AcpModelPickerEntry::Separator("All".into()));
447            }
448            for model in list {
449                let is_favorite = favorites.contains(&model.id);
450                entries.push(AcpModelPickerEntry::Model(model, is_favorite));
451            }
452        }
453        AgentModelList::Grouped(index_map) => {
454            for (group_name, models) in index_map {
455                entries.push(AcpModelPickerEntry::Separator(group_name.0));
456                for model in models {
457                    let is_favorite = favorites.contains(&model.id);
458                    entries.push(AcpModelPickerEntry::Model(model, is_favorite));
459                }
460            }
461        }
462    }
463
464    entries
465}
466
467async fn fuzzy_search(
468    model_list: AgentModelList,
469    query: String,
470    executor: BackgroundExecutor,
471) -> AgentModelList {
472    async fn fuzzy_search_list(
473        model_list: Vec<AgentModelInfo>,
474        query: &str,
475        executor: BackgroundExecutor,
476    ) -> Vec<AgentModelInfo> {
477        let candidates = model_list
478            .iter()
479            .enumerate()
480            .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
481            .collect::<Vec<_>>();
482        let mut matches = match_strings(
483            &candidates,
484            query,
485            false,
486            true,
487            100,
488            &Default::default(),
489            executor,
490        )
491        .await;
492
493        matches.sort_unstable_by_key(|mat| {
494            let candidate = &candidates[mat.candidate_id];
495            (Reverse(OrderedFloat(mat.score)), candidate.id)
496        });
497
498        matches
499            .into_iter()
500            .map(|mat| model_list[mat.candidate_id].clone())
501            .collect()
502    }
503
504    match model_list {
505        AgentModelList::Flat(model_list) => {
506            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
507        }
508        AgentModelList::Grouped(index_map) => {
509            let groups =
510                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
511                    fuzzy_search_list(models, &query, executor.clone())
512                        .map(|results| (group_name, results))
513                }))
514                .await;
515            AgentModelList::Grouped(IndexMap::from_iter(
516                groups
517                    .into_iter()
518                    .filter(|(_, results)| !results.is_empty()),
519            ))
520        }
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use agent_client_protocol as acp;
527    use gpui::TestAppContext;
528
529    use super::*;
530
531    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
532        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
533            |(group, models)| {
534                (
535                    acp_thread::AgentModelGroupName(group.to_string().into()),
536                    models
537                        .into_iter()
538                        .map(|model| acp_thread::AgentModelInfo {
539                            id: acp::ModelId::new(model.to_string()),
540                            name: model.to_string().into(),
541                            description: None,
542                            icon: None,
543                        })
544                        .collect::<Vec<_>>(),
545                )
546            },
547        )))
548    }
549
550    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
551        let AgentModelList::Grouped(groups) = result else {
552            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
553        };
554
555        assert_eq!(
556            groups.len(),
557            expected.len(),
558            "Number of groups doesn't match"
559        );
560
561        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
562            let (actual_group, actual_models) = groups.get_index(i).unwrap();
563            assert_eq!(
564                actual_group.0.as_ref(),
565                *expected_group,
566                "Group at position {} doesn't match expected group",
567                i
568            );
569            assert_eq!(
570                actual_models.len(),
571                expected_models.len(),
572                "Number of models in group {} doesn't match",
573                expected_group
574            );
575
576            for (j, expected_model_name) in expected_models.iter().enumerate() {
577                assert_eq!(
578                    actual_models[j].name, *expected_model_name,
579                    "Model at position {} in group {} doesn't match expected model",
580                    j, expected_group
581                );
582            }
583        }
584    }
585
586    fn create_favorites(models: Vec<&str>) -> HashSet<ModelId> {
587        models
588            .into_iter()
589            .map(|m| ModelId::new(m.to_string()))
590            .collect()
591    }
592
593    fn get_entry_model_ids(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
594        entries
595            .iter()
596            .filter_map(|entry| match entry {
597                AcpModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()),
598                _ => None,
599            })
600            .collect()
601    }
602
603    fn get_entry_labels(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
604        entries
605            .iter()
606            .map(|entry| match entry {
607                AcpModelPickerEntry::Model(info, _) => info.id.0.as_ref(),
608                AcpModelPickerEntry::Separator(s) => &s,
609            })
610            .collect()
611    }
612
613    #[gpui::test]
614    async fn test_fuzzy_match(cx: &mut TestAppContext) {
615        let models = create_model_list(vec![
616            (
617                "zed",
618                vec![
619                    "Claude 3.7 Sonnet",
620                    "Claude 3.7 Sonnet Thinking",
621                    "gpt-4.1",
622                    "gpt-4.1-nano",
623                ],
624            ),
625            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
626            ("ollama", vec!["mistral", "deepseek"]),
627        ]);
628
629        // Results should preserve models order whenever possible.
630        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
631        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
632        // so it should appear first in the results.
633        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
634        assert_models_eq(
635            results,
636            vec![
637                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
638                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
639            ],
640        );
641
642        // Fuzzy search
643        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
644        assert_models_eq(
645            results,
646            vec![
647                ("zed", vec!["gpt-4.1-nano"]),
648                ("openai", vec!["gpt-4.1-nano"]),
649            ],
650        );
651    }
652
653    #[gpui::test]
654    fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
655        let models = create_model_list(vec![
656            ("zed", vec!["zed/claude", "zed/gemini"]),
657            ("openai", vec!["openai/gpt-5"]),
658        ]);
659        let favorites = create_favorites(vec!["zed/gemini"]);
660
661        let entries = info_list_to_picker_entries(models, &favorites);
662
663        assert!(matches!(
664            entries.first(),
665            Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
666        ));
667
668        let model_ids = get_entry_model_ids(&entries);
669        assert_eq!(model_ids[0], "zed/gemini");
670    }
671
672    #[gpui::test]
673    fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
674        let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]);
675        let favorites = create_favorites(vec![]);
676
677        let entries = info_list_to_picker_entries(models, &favorites);
678
679        assert!(matches!(
680            entries.first(),
681            Some(AcpModelPickerEntry::Separator(s)) if s == "zed"
682        ));
683    }
684
685    #[gpui::test]
686    fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
687        let models = create_model_list(vec![
688            ("zed", vec!["zed/claude", "zed/gemini"]),
689            ("openai", vec!["openai/gpt-5"]),
690        ]);
691        let favorites = create_favorites(vec!["zed/claude"]);
692
693        let entries = info_list_to_picker_entries(models, &favorites);
694
695        for entry in &entries {
696            if let AcpModelPickerEntry::Model(info, is_favorite) = entry {
697                if info.id.0.as_ref() == "zed/claude" {
698                    assert!(is_favorite, "zed/claude should be a favorite");
699                } else {
700                    assert!(!is_favorite, "{} should not be a favorite", info.id.0);
701                }
702            }
703        }
704    }
705
706    #[gpui::test]
707    fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) {
708        let models = create_model_list(vec![
709            ("zed", vec!["zed/claude", "zed/gemini"]),
710            ("openai", vec!["openai/gpt-5", "openai/gpt-4"]),
711        ]);
712        let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]);
713
714        let entries = info_list_to_picker_entries(models, &favorites);
715        let model_ids = get_entry_model_ids(&entries);
716
717        assert_eq!(model_ids[0], "zed/gemini");
718        assert_eq!(model_ids[1], "openai/gpt-5");
719
720        assert!(model_ids[2..].contains(&"zed/gemini"));
721        assert!(model_ids[2..].contains(&"openai/gpt-5"));
722    }
723
724    #[gpui::test]
725    fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) {
726        let models = create_model_list(vec![
727            ("Recommended", vec!["zed/claude", "anthropic/claude"]),
728            ("Zed", vec!["zed/claude", "zed/gpt-5"]),
729            ("Antropic", vec!["anthropic/claude"]),
730            ("OpenAI", vec!["openai/gpt-5"]),
731        ]);
732
733        let favorites = create_favorites(vec!["zed/claude"]);
734
735        let entries = info_list_to_picker_entries(models, &favorites);
736        let labels = get_entry_labels(&entries);
737
738        assert_eq!(
739            labels,
740            vec![
741                "Favorite",
742                "zed/claude",
743                "Recommended",
744                "zed/claude",
745                "anthropic/claude",
746                "Zed",
747                "zed/claude",
748                "zed/gpt-5",
749                "Antropic",
750                "anthropic/claude",
751                "OpenAI",
752                "openai/gpt-5"
753            ]
754        );
755    }
756
757    #[gpui::test]
758    fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) {
759        let models = AgentModelList::Flat(vec![
760            acp_thread::AgentModelInfo {
761                id: acp::ModelId::new("zed/claude".to_string()),
762                name: "Claude".into(),
763                description: None,
764                icon: None,
765            },
766            acp_thread::AgentModelInfo {
767                id: acp::ModelId::new("zed/gemini".to_string()),
768                name: "Gemini".into(),
769                description: None,
770                icon: None,
771            },
772        ]);
773        let favorites = create_favorites(vec!["zed/gemini"]);
774
775        let entries = info_list_to_picker_entries(models, &favorites);
776
777        assert!(matches!(
778            entries.first(),
779            Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
780        ));
781
782        assert!(entries.iter().any(|e| matches!(
783            e,
784            AcpModelPickerEntry::Separator(s) if s == "All"
785        )));
786    }
787
788    #[gpui::test]
789    fn test_favorites_count_returns_correct_count(_cx: &mut TestAppContext) {
790        let empty_favorites: HashSet<ModelId> = HashSet::default();
791        assert_eq!(empty_favorites.len(), 0);
792
793        let one_favorite = create_favorites(vec!["model-a"]);
794        assert_eq!(one_favorite.len(), 1);
795
796        let multiple_favorites = create_favorites(vec!["model-a", "model-b", "model-c"]);
797        assert_eq!(multiple_favorites.len(), 3);
798
799        let with_duplicates = create_favorites(vec!["model-a", "model-a", "model-b"]);
800        assert_eq!(with_duplicates.len(), 2);
801    }
802
803    #[gpui::test]
804    fn test_is_favorite_flag_set_correctly_in_entries(_cx: &mut TestAppContext) {
805        let models = AgentModelList::Flat(vec![
806            acp_thread::AgentModelInfo {
807                id: acp::ModelId::new("favorite-model".to_string()),
808                name: "Favorite".into(),
809                description: None,
810                icon: None,
811            },
812            acp_thread::AgentModelInfo {
813                id: acp::ModelId::new("regular-model".to_string()),
814                name: "Regular".into(),
815                description: None,
816                icon: None,
817            },
818        ]);
819        let favorites = create_favorites(vec!["favorite-model"]);
820
821        let entries = info_list_to_picker_entries(models, &favorites);
822
823        for entry in &entries {
824            if let AcpModelPickerEntry::Model(info, is_favorite) = entry {
825                if info.id.0.as_ref() == "favorite-model" {
826                    assert!(*is_favorite, "favorite-model should have is_favorite=true");
827                } else if info.id.0.as_ref() == "regular-model" {
828                    assert!(!*is_favorite, "regular-model should have is_favorite=false");
829                }
830            }
831        }
832    }
833}