model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_client_protocol::ModelId;
  5use agent_servers::AgentServer;
  6use agent_settings::AgentSettings;
  7use anyhow::Result;
  8use collections::{HashSet, IndexMap};
  9use fs::Fs;
 10use futures::FutureExt;
 11use fuzzy::{StringMatchCandidate, match_strings};
 12use gpui::{
 13    Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
 14    WeakEntity,
 15};
 16use itertools::Itertools;
 17use ordered_float::OrderedFloat;
 18use picker::{Picker, PickerDelegate};
 19use settings::{Settings, SettingsStore};
 20use ui::{DocumentationAside, DocumentationSide, IntoElement, prelude::*};
 21use util::ResultExt;
 22use zed_actions::agent::OpenSettings;
 23
 24use crate::ui::{HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 25
 26pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 27
 28pub fn acp_model_selector(
 29    selector: Rc<dyn AgentModelSelector>,
 30    agent_server: Rc<dyn AgentServer>,
 31    fs: Arc<dyn Fs>,
 32    focus_handle: FocusHandle,
 33    window: &mut Window,
 34    cx: &mut Context<AcpModelSelector>,
 35) -> AcpModelSelector {
 36    let delegate =
 37        AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
 38    Picker::list(delegate, window, cx)
 39        .show_scrollbar(true)
 40        .width(rems(20.))
 41        .max_height(Some(rems(20.).into()))
 42}
 43
 44enum AcpModelPickerEntry {
 45    Separator(SharedString),
 46    Model(AgentModelInfo, bool),
 47}
 48
 49pub struct AcpModelPickerDelegate {
 50    selector: Rc<dyn AgentModelSelector>,
 51    agent_server: Rc<dyn AgentServer>,
 52    fs: Arc<dyn Fs>,
 53    filtered_entries: Vec<AcpModelPickerEntry>,
 54    models: Option<AgentModelList>,
 55    selected_index: usize,
 56    selected_description: Option<(usize, SharedString, bool)>,
 57    selected_model: Option<AgentModelInfo>,
 58    favorites: HashSet<ModelId>,
 59    _refresh_models_task: Task<()>,
 60    _settings_subscription: Subscription,
 61    focus_handle: FocusHandle,
 62}
 63
 64impl AcpModelPickerDelegate {
 65    fn new(
 66        selector: Rc<dyn AgentModelSelector>,
 67        agent_server: Rc<dyn AgentServer>,
 68        fs: Arc<dyn Fs>,
 69        focus_handle: FocusHandle,
 70        window: &mut Window,
 71        cx: &mut Context<AcpModelSelector>,
 72    ) -> Self {
 73        let rx = selector.watch(cx);
 74        let refresh_models_task = {
 75            cx.spawn_in(window, {
 76                async move |this, cx| {
 77                    async fn refresh(
 78                        this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 79                        cx: &mut AsyncWindowContext,
 80                    ) -> Result<()> {
 81                        let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 82                            (
 83                                this.delegate.selector.list_models(cx),
 84                                this.delegate.selector.selected_model(cx),
 85                            )
 86                        })?;
 87
 88                        let (models, selected_model) =
 89                            futures::join!(models_task, selected_model_task);
 90
 91                        this.update_in(cx, |this, window, cx| {
 92                            this.delegate.models = models.ok();
 93                            this.delegate.selected_model = selected_model.ok();
 94                            this.refresh(window, cx)
 95                        })
 96                    }
 97
 98                    refresh(&this, cx).await.log_err();
 99                    if let Some(mut rx) = rx {
100                        while let Ok(()) = rx.recv().await {
101                            refresh(&this, cx).await.log_err();
102                        }
103                    }
104                }
105            })
106        };
107
108        let agent_server_for_subscription = agent_server.clone();
109        let settings_subscription =
110            cx.observe_global_in::<SettingsStore>(window, move |picker, window, cx| {
111                // Only refresh if the favorites actually changed to avoid redundant work
112                // when other settings are modified (e.g., user editing settings.json)
113                let new_favorites = agent_server_for_subscription.favorite_model_ids(cx);
114                if new_favorites != picker.delegate.favorites {
115                    picker.delegate.favorites = new_favorites;
116                    picker.refresh(window, cx);
117                }
118            });
119        let favorites = agent_server.favorite_model_ids(cx);
120
121        Self {
122            selector,
123            agent_server,
124            fs,
125            filtered_entries: Vec::new(),
126            models: None,
127            selected_model: None,
128            selected_index: 0,
129            selected_description: None,
130            favorites,
131            _refresh_models_task: refresh_models_task,
132            _settings_subscription: settings_subscription,
133            focus_handle,
134        }
135    }
136
137    pub fn active_model(&self) -> Option<&AgentModelInfo> {
138        self.selected_model.as_ref()
139    }
140
141    pub fn favorites_count(&self) -> usize {
142        self.favorites.len()
143    }
144
145    pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
146        if self.favorites.is_empty() {
147            return;
148        }
149
150        let Some(models) = &self.models else {
151            return;
152        };
153
154        let all_models: Vec<&AgentModelInfo> = match models {
155            AgentModelList::Flat(list) => list.iter().collect(),
156            AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
157        };
158
159        let favorite_models: Vec<_> = all_models
160            .into_iter()
161            .filter(|model| self.favorites.contains(&model.id))
162            .unique_by(|model| &model.id)
163            .collect();
164
165        if favorite_models.is_empty() {
166            return;
167        }
168
169        let current_id = self.selected_model.as_ref().map(|m| &m.id);
170
171        let current_index_in_favorites = current_id
172            .and_then(|id| favorite_models.iter().position(|m| &m.id == id))
173            .unwrap_or(usize::MAX);
174
175        let next_index = if current_index_in_favorites == usize::MAX {
176            0
177        } else {
178            (current_index_in_favorites + 1) % favorite_models.len()
179        };
180
181        let next_model = favorite_models[next_index].clone();
182
183        self.selector
184            .select_model(next_model.id.clone(), cx)
185            .detach_and_log_err(cx);
186
187        self.selected_model = Some(next_model);
188
189        // Keep the picker selection aligned with the newly-selected model
190        if let Some(new_index) = self.filtered_entries.iter().position(|entry| {
191            matches!(entry, AcpModelPickerEntry::Model(model_info, _) if self.selected_model.as_ref().is_some_and(|selected| model_info.id == selected.id))
192        }) {
193            self.set_selected_index(new_index, window, cx);
194        } else {
195            cx.notify();
196        }
197    }
198}
199
200impl PickerDelegate for AcpModelPickerDelegate {
201    type ListItem = AnyElement;
202
203    fn match_count(&self) -> usize {
204        self.filtered_entries.len()
205    }
206
207    fn selected_index(&self) -> usize {
208        self.selected_index
209    }
210
211    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
212        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
213        cx.notify();
214    }
215
216    fn can_select(
217        &mut self,
218        ix: usize,
219        _window: &mut Window,
220        _cx: &mut Context<Picker<Self>>,
221    ) -> bool {
222        match self.filtered_entries.get(ix) {
223            Some(AcpModelPickerEntry::Model(_, _)) => true,
224            Some(AcpModelPickerEntry::Separator(_)) | None => false,
225        }
226    }
227
228    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
229        "Select a model…".into()
230    }
231
232    fn update_matches(
233        &mut self,
234        query: String,
235        window: &mut Window,
236        cx: &mut Context<Picker<Self>>,
237    ) -> Task<()> {
238        let favorites = self.favorites.clone();
239
240        cx.spawn_in(window, async move |this, cx| {
241            let filtered_models = match this
242                .read_with(cx, |this, cx| {
243                    this.delegate.models.clone().map(move |models| {
244                        fuzzy_search(models, query, cx.background_executor().clone())
245                    })
246                })
247                .ok()
248                .flatten()
249            {
250                Some(task) => task.await,
251                None => AgentModelList::Flat(vec![]),
252            };
253
254            this.update_in(cx, |this, window, cx| {
255                this.delegate.filtered_entries =
256                    info_list_to_picker_entries(filtered_models, &favorites);
257                // Finds the currently selected model in the list
258                let new_index = this
259                    .delegate
260                    .selected_model
261                    .as_ref()
262                    .and_then(|selected| {
263                        this.delegate.filtered_entries.iter().position(|entry| {
264                            if let AcpModelPickerEntry::Model(model_info, _) = entry {
265                                model_info.id == selected.id
266                            } else {
267                                false
268                            }
269                        })
270                    })
271                    .unwrap_or(0);
272                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
273                cx.notify();
274            })
275            .ok();
276        })
277    }
278
279    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
280        if let Some(AcpModelPickerEntry::Model(model_info, _)) =
281            self.filtered_entries.get(self.selected_index)
282        {
283            if window.modifiers().secondary() {
284                let default_model = self.agent_server.default_model(cx);
285                let is_default = default_model.as_ref() == Some(&model_info.id);
286
287                self.agent_server.set_default_model(
288                    if is_default {
289                        None
290                    } else {
291                        Some(model_info.id.clone())
292                    },
293                    self.fs.clone(),
294                    cx,
295                );
296            }
297
298            self.selector
299                .select_model(model_info.id.clone(), cx)
300                .detach_and_log_err(cx);
301            self.selected_model = Some(model_info.clone());
302            let current_index = self.selected_index;
303            self.set_selected_index(current_index, window, cx);
304
305            cx.emit(DismissEvent);
306        }
307    }
308
309    fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
310        cx.defer_in(window, |picker, window, cx| {
311            picker.set_query("", window, cx);
312        });
313    }
314
315    fn render_match(
316        &self,
317        ix: usize,
318        selected: bool,
319        _: &mut Window,
320        cx: &mut Context<Picker<Self>>,
321    ) -> Option<Self::ListItem> {
322        match self.filtered_entries.get(ix)? {
323            AcpModelPickerEntry::Separator(title) => {
324                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
325            }
326            AcpModelPickerEntry::Model(model_info, is_favorite) => {
327                let is_selected = Some(model_info) == self.selected_model.as_ref();
328                let default_model = self.agent_server.default_model(cx);
329                let is_default = default_model.as_ref() == Some(&model_info.id);
330
331                let is_favorite = *is_favorite;
332                let handle_action_click = {
333                    let model_id = model_info.id.clone();
334                    let fs = self.fs.clone();
335                    let agent_server = self.agent_server.clone();
336
337                    cx.listener(move |_, _, _, cx| {
338                        agent_server.toggle_favorite_model(
339                            model_id.clone(),
340                            !is_favorite,
341                            fs.clone(),
342                            cx,
343                        );
344                    })
345                };
346
347                Some(
348                    div()
349                        .id(("model-picker-menu-child", ix))
350                        .when_some(model_info.description.clone(), |this, description| {
351                            this.on_hover(cx.listener(move |menu, hovered, _, cx| {
352                                if *hovered {
353                                    menu.delegate.selected_description =
354                                        Some((ix, description.clone(), is_default));
355                                } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
356                                    menu.delegate.selected_description = None;
357                                }
358                                cx.notify();
359                            }))
360                        })
361                        .child(
362                            ModelSelectorListItem::new(ix, model_info.name.clone())
363                                .map(|this| match &model_info.icon {
364                                    Some(AgentModelIcon::Path(path)) => this.icon_path(path.clone()),
365                                    Some(AgentModelIcon::Named(icon)) => this.icon(*icon),
366                                    None => this,
367                                })
368                                .is_selected(is_selected)
369                                .is_focused(selected)
370                                .is_latest(model_info.is_latest)
371                                .is_favorite(is_favorite)
372                                .on_toggle_favorite(handle_action_click),
373                        )
374                        .into_any_element(),
375                )
376            }
377        }
378    }
379
380    fn documentation_aside(
381        &self,
382        _window: &mut Window,
383        cx: &mut Context<Picker<Self>>,
384    ) -> Option<ui::DocumentationAside> {
385        self.selected_description
386            .as_ref()
387            .map(|(_, description, is_default)| {
388                let description = description.clone();
389                let is_default = *is_default;
390
391                let settings = AgentSettings::get_global(cx);
392                let side = match settings.dock {
393                    settings::DockPosition::Left => DocumentationSide::Right,
394                    settings::DockPosition::Bottom | settings::DockPosition::Right => {
395                        DocumentationSide::Left
396                    }
397                };
398
399                DocumentationAside::new(
400                    side,
401                    Rc::new(move |_| {
402                        v_flex()
403                            .gap_1()
404                            .child(Label::new(description.clone()))
405                            .child(HoldForDefault::new(is_default))
406                            .into_any_element()
407                    }),
408                )
409            })
410    }
411
412    fn documentation_aside_index(&self) -> Option<usize> {
413        self.selected_description.as_ref().map(|(ix, _, _)| *ix)
414    }
415
416    fn render_footer(
417        &self,
418        _window: &mut Window,
419        _cx: &mut Context<Picker<Self>>,
420    ) -> Option<AnyElement> {
421        let focus_handle = self.focus_handle.clone();
422
423        if !self.selector.should_render_footer() {
424            return None;
425        }
426
427        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
428    }
429}
430
431fn info_list_to_picker_entries(
432    model_list: AgentModelList,
433    favorites: &HashSet<ModelId>,
434) -> Vec<AcpModelPickerEntry> {
435    let mut entries = Vec::new();
436
437    let all_models: Vec<_> = match &model_list {
438        AgentModelList::Flat(list) => list.iter().collect(),
439        AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
440    };
441
442    let favorite_models: Vec<_> = all_models
443        .iter()
444        .filter(|m| favorites.contains(&m.id))
445        .unique_by(|m| &m.id)
446        .collect();
447
448    let has_favorites = !favorite_models.is_empty();
449    if has_favorites {
450        entries.push(AcpModelPickerEntry::Separator("Favorite".into()));
451        for model in favorite_models {
452            entries.push(AcpModelPickerEntry::Model((*model).clone(), true));
453        }
454    }
455
456    match model_list {
457        AgentModelList::Flat(list) => {
458            if has_favorites {
459                entries.push(AcpModelPickerEntry::Separator("All".into()));
460            }
461            for model in list {
462                let is_favorite = favorites.contains(&model.id);
463                entries.push(AcpModelPickerEntry::Model(model, is_favorite));
464            }
465        }
466        AgentModelList::Grouped(index_map) => {
467            for (group_name, models) in index_map {
468                entries.push(AcpModelPickerEntry::Separator(group_name.0));
469                for model in models {
470                    let is_favorite = favorites.contains(&model.id);
471                    entries.push(AcpModelPickerEntry::Model(model, is_favorite));
472                }
473            }
474        }
475    }
476
477    entries
478}
479
480async fn fuzzy_search(
481    model_list: AgentModelList,
482    query: String,
483    executor: BackgroundExecutor,
484) -> AgentModelList {
485    async fn fuzzy_search_list(
486        model_list: Vec<AgentModelInfo>,
487        query: &str,
488        executor: BackgroundExecutor,
489    ) -> Vec<AgentModelInfo> {
490        let candidates = model_list
491            .iter()
492            .enumerate()
493            .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
494            .collect::<Vec<_>>();
495        let mut matches = match_strings(
496            &candidates,
497            query,
498            false,
499            true,
500            100,
501            &Default::default(),
502            executor,
503        )
504        .await;
505
506        matches.sort_unstable_by_key(|mat| {
507            let candidate = &candidates[mat.candidate_id];
508            (Reverse(OrderedFloat(mat.score)), candidate.id)
509        });
510
511        matches
512            .into_iter()
513            .map(|mat| model_list[mat.candidate_id].clone())
514            .collect()
515    }
516
517    match model_list {
518        AgentModelList::Flat(model_list) => {
519            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
520        }
521        AgentModelList::Grouped(index_map) => {
522            let groups =
523                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
524                    fuzzy_search_list(models, &query, executor.clone())
525                        .map(|results| (group_name, results))
526                }))
527                .await;
528            AgentModelList::Grouped(IndexMap::from_iter(
529                groups
530                    .into_iter()
531                    .filter(|(_, results)| !results.is_empty()),
532            ))
533        }
534    }
535}
536
537#[cfg(test)]
538mod tests {
539    use agent_client_protocol as acp;
540    use gpui::TestAppContext;
541
542    use super::*;
543
544    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
545        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
546            |(group, models)| {
547                (
548                    acp_thread::AgentModelGroupName(group.to_string().into()),
549                    models
550                        .into_iter()
551                        .map(|model| acp_thread::AgentModelInfo {
552                            id: acp::ModelId::new(model.to_string()),
553                            name: model.to_string().into(),
554                            description: None,
555                            icon: None,
556                            is_latest: false,
557                        })
558                        .collect::<Vec<_>>(),
559                )
560            },
561        )))
562    }
563
564    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
565        let AgentModelList::Grouped(groups) = result else {
566            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
567        };
568
569        assert_eq!(
570            groups.len(),
571            expected.len(),
572            "Number of groups doesn't match"
573        );
574
575        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
576            let (actual_group, actual_models) = groups.get_index(i).unwrap();
577            assert_eq!(
578                actual_group.0.as_ref(),
579                *expected_group,
580                "Group at position {} doesn't match expected group",
581                i
582            );
583            assert_eq!(
584                actual_models.len(),
585                expected_models.len(),
586                "Number of models in group {} doesn't match",
587                expected_group
588            );
589
590            for (j, expected_model_name) in expected_models.iter().enumerate() {
591                assert_eq!(
592                    actual_models[j].name, *expected_model_name,
593                    "Model at position {} in group {} doesn't match expected model",
594                    j, expected_group
595                );
596            }
597        }
598    }
599
600    fn create_favorites(models: Vec<&str>) -> HashSet<ModelId> {
601        models
602            .into_iter()
603            .map(|m| ModelId::new(m.to_string()))
604            .collect()
605    }
606
607    fn get_entry_model_ids(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
608        entries
609            .iter()
610            .filter_map(|entry| match entry {
611                AcpModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()),
612                _ => None,
613            })
614            .collect()
615    }
616
617    fn get_entry_labels(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
618        entries
619            .iter()
620            .map(|entry| match entry {
621                AcpModelPickerEntry::Model(info, _) => info.id.0.as_ref(),
622                AcpModelPickerEntry::Separator(s) => &s,
623            })
624            .collect()
625    }
626
627    #[gpui::test]
628    async fn test_fuzzy_match(cx: &mut TestAppContext) {
629        let models = create_model_list(vec![
630            (
631                "zed",
632                vec![
633                    "Claude 3.7 Sonnet",
634                    "Claude 3.7 Sonnet Thinking",
635                    "gpt-4.1",
636                    "gpt-4.1-nano",
637                ],
638            ),
639            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
640            ("ollama", vec!["mistral", "deepseek"]),
641        ]);
642
643        // Results should preserve models order whenever possible.
644        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
645        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
646        // so it should appear first in the results.
647        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
648        assert_models_eq(
649            results,
650            vec![
651                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
652                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
653            ],
654        );
655
656        // Fuzzy search
657        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
658        assert_models_eq(
659            results,
660            vec![
661                ("zed", vec!["gpt-4.1-nano"]),
662                ("openai", vec!["gpt-4.1-nano"]),
663            ],
664        );
665    }
666
667    #[gpui::test]
668    fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
669        let models = create_model_list(vec![
670            ("zed", vec!["zed/claude", "zed/gemini"]),
671            ("openai", vec!["openai/gpt-5"]),
672        ]);
673        let favorites = create_favorites(vec!["zed/gemini"]);
674
675        let entries = info_list_to_picker_entries(models, &favorites);
676
677        assert!(matches!(
678            entries.first(),
679            Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
680        ));
681
682        let model_ids = get_entry_model_ids(&entries);
683        assert_eq!(model_ids[0], "zed/gemini");
684    }
685
686    #[gpui::test]
687    fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
688        let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]);
689        let favorites = create_favorites(vec![]);
690
691        let entries = info_list_to_picker_entries(models, &favorites);
692
693        assert!(matches!(
694            entries.first(),
695            Some(AcpModelPickerEntry::Separator(s)) if s == "zed"
696        ));
697    }
698
699    #[gpui::test]
700    fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
701        let models = create_model_list(vec![
702            ("zed", vec!["zed/claude", "zed/gemini"]),
703            ("openai", vec!["openai/gpt-5"]),
704        ]);
705        let favorites = create_favorites(vec!["zed/claude"]);
706
707        let entries = info_list_to_picker_entries(models, &favorites);
708
709        for entry in &entries {
710            if let AcpModelPickerEntry::Model(info, is_favorite) = entry {
711                if info.id.0.as_ref() == "zed/claude" {
712                    assert!(is_favorite, "zed/claude should be a favorite");
713                } else {
714                    assert!(!is_favorite, "{} should not be a favorite", info.id.0);
715                }
716            }
717        }
718    }
719
720    #[gpui::test]
721    fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) {
722        let models = create_model_list(vec![
723            ("zed", vec!["zed/claude", "zed/gemini"]),
724            ("openai", vec!["openai/gpt-5", "openai/gpt-4"]),
725        ]);
726        let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]);
727
728        let entries = info_list_to_picker_entries(models, &favorites);
729        let model_ids = get_entry_model_ids(&entries);
730
731        assert_eq!(model_ids[0], "zed/gemini");
732        assert_eq!(model_ids[1], "openai/gpt-5");
733
734        assert!(model_ids[2..].contains(&"zed/gemini"));
735        assert!(model_ids[2..].contains(&"openai/gpt-5"));
736    }
737
738    #[gpui::test]
739    fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) {
740        let models = create_model_list(vec![
741            ("Recommended", vec!["zed/claude", "anthropic/claude"]),
742            ("Zed", vec!["zed/claude", "zed/gpt-5"]),
743            ("Antropic", vec!["anthropic/claude"]),
744            ("OpenAI", vec!["openai/gpt-5"]),
745        ]);
746
747        let favorites = create_favorites(vec!["zed/claude"]);
748
749        let entries = info_list_to_picker_entries(models, &favorites);
750        let labels = get_entry_labels(&entries);
751
752        assert_eq!(
753            labels,
754            vec![
755                "Favorite",
756                "zed/claude",
757                "Recommended",
758                "zed/claude",
759                "anthropic/claude",
760                "Zed",
761                "zed/claude",
762                "zed/gpt-5",
763                "Antropic",
764                "anthropic/claude",
765                "OpenAI",
766                "openai/gpt-5"
767            ]
768        );
769    }
770
771    #[gpui::test]
772    fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) {
773        let models = AgentModelList::Flat(vec![
774            acp_thread::AgentModelInfo {
775                id: acp::ModelId::new("zed/claude".to_string()),
776                name: "Claude".into(),
777                description: None,
778                icon: None,
779                is_latest: false,
780            },
781            acp_thread::AgentModelInfo {
782                id: acp::ModelId::new("zed/gemini".to_string()),
783                name: "Gemini".into(),
784                description: None,
785                icon: None,
786                is_latest: false,
787            },
788        ]);
789        let favorites = create_favorites(vec!["zed/gemini"]);
790
791        let entries = info_list_to_picker_entries(models, &favorites);
792
793        assert!(matches!(
794            entries.first(),
795            Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
796        ));
797
798        assert!(entries.iter().any(|e| matches!(
799            e,
800            AcpModelPickerEntry::Separator(s) if s == "All"
801        )));
802    }
803
804    #[gpui::test]
805    fn test_favorites_count_returns_correct_count(_cx: &mut TestAppContext) {
806        let empty_favorites: HashSet<ModelId> = HashSet::default();
807        assert_eq!(empty_favorites.len(), 0);
808
809        let one_favorite = create_favorites(vec!["model-a"]);
810        assert_eq!(one_favorite.len(), 1);
811
812        let multiple_favorites = create_favorites(vec!["model-a", "model-b", "model-c"]);
813        assert_eq!(multiple_favorites.len(), 3);
814
815        let with_duplicates = create_favorites(vec!["model-a", "model-a", "model-b"]);
816        assert_eq!(with_duplicates.len(), 2);
817    }
818
819    #[gpui::test]
820    fn test_is_favorite_flag_set_correctly_in_entries(_cx: &mut TestAppContext) {
821        let models = AgentModelList::Flat(vec![
822            acp_thread::AgentModelInfo {
823                id: acp::ModelId::new("favorite-model".to_string()),
824                name: "Favorite".into(),
825                description: None,
826                icon: None,
827                is_latest: false,
828            },
829            acp_thread::AgentModelInfo {
830                id: acp::ModelId::new("regular-model".to_string()),
831                name: "Regular".into(),
832                description: None,
833                icon: None,
834                is_latest: false,
835            },
836        ]);
837        let favorites = create_favorites(vec!["favorite-model"]);
838
839        let entries = info_list_to_picker_entries(models, &favorites);
840
841        for entry in &entries {
842            if let AcpModelPickerEntry::Model(info, is_favorite) = entry {
843                if info.id.0.as_ref() == "favorite-model" {
844                    assert!(*is_favorite, "favorite-model should have is_favorite=true");
845                } else if info.id.0.as_ref() == "regular-model" {
846                    assert!(!*is_favorite, "regular-model should have is_favorite=false");
847                }
848            }
849        }
850    }
851}