model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_client_protocol::schema as acp;
  5use agent_servers::AgentServer;
  6use agent_settings::AgentSettings;
  7use anyhow::Result;
  8use collections::{HashSet, IndexMap};
  9use fs::Fs;
 10use futures::FutureExt;
 11use fuzzy::{StringMatchCandidate, match_strings};
 12use gpui::{
 13    Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
 14    WeakEntity,
 15};
 16use itertools::Itertools;
 17use ordered_float::OrderedFloat;
 18use picker::{Picker, PickerDelegate};
 19use settings::{Settings, SettingsStore};
 20use ui::{DocumentationAside, DocumentationSide, IntoElement, prelude::*};
 21use util::ResultExt;
 22use zed_actions::agent::OpenSettings;
 23
 24use crate::ui::{HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 25
 26pub type ModelSelector = Picker<ModelPickerDelegate>;
 27
 28pub fn acp_model_selector(
 29    selector: Rc<dyn AgentModelSelector>,
 30    agent_server: Rc<dyn AgentServer>,
 31    fs: Arc<dyn Fs>,
 32    focus_handle: FocusHandle,
 33    window: &mut Window,
 34    cx: &mut Context<ModelSelector>,
 35) -> ModelSelector {
 36    let delegate = ModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
 37    Picker::list(delegate, window, cx)
 38        .show_scrollbar(true)
 39        .width(rems(20.))
 40        .max_height(Some(rems(20.).into()))
 41}
 42
 43enum ModelPickerEntry {
 44    Separator(SharedString),
 45    Model(AgentModelInfo, bool),
 46}
 47
 48pub struct ModelPickerDelegate {
 49    selector: Rc<dyn AgentModelSelector>,
 50    agent_server: Rc<dyn AgentServer>,
 51    fs: Arc<dyn Fs>,
 52    filtered_entries: Vec<ModelPickerEntry>,
 53    models: Option<AgentModelList>,
 54    selected_index: usize,
 55    selected_description: Option<(usize, SharedString, bool)>,
 56    selected_model: Option<AgentModelInfo>,
 57    favorites: HashSet<acp::ModelId>,
 58    _refresh_models_task: Task<()>,
 59    _settings_subscription: Subscription,
 60    focus_handle: FocusHandle,
 61}
 62
 63impl ModelPickerDelegate {
 64    fn new(
 65        selector: Rc<dyn AgentModelSelector>,
 66        agent_server: Rc<dyn AgentServer>,
 67        fs: Arc<dyn Fs>,
 68        focus_handle: FocusHandle,
 69        window: &mut Window,
 70        cx: &mut Context<ModelSelector>,
 71    ) -> Self {
 72        let rx = selector.watch(cx);
 73        let refresh_models_task = {
 74            cx.spawn_in(window, {
 75                async move |this, cx| {
 76                    async fn refresh(
 77                        this: &WeakEntity<Picker<ModelPickerDelegate>>,
 78                        cx: &mut AsyncWindowContext,
 79                    ) -> Result<()> {
 80                        let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 81                            (
 82                                this.delegate.selector.list_models(cx),
 83                                this.delegate.selector.selected_model(cx),
 84                            )
 85                        })?;
 86
 87                        let (models, selected_model) =
 88                            futures::join!(models_task, selected_model_task);
 89
 90                        this.update_in(cx, |this, window, cx| {
 91                            this.delegate.models = models.ok();
 92                            this.delegate.selected_model = selected_model.ok();
 93                            this.refresh(window, cx)
 94                        })
 95                    }
 96
 97                    refresh(&this, cx).await.log_err();
 98                    if let Some(mut rx) = rx {
 99                        while let Ok(()) = rx.recv().await {
100                            refresh(&this, cx).await.log_err();
101                        }
102                    }
103                }
104            })
105        };
106
107        let agent_server_for_subscription = agent_server.clone();
108        let settings_subscription =
109            cx.observe_global_in::<SettingsStore>(window, move |picker, window, cx| {
110                // Only refresh if the favorites actually changed to avoid redundant work
111                // when other settings are modified (e.g., user editing settings.json)
112                let new_favorites = agent_server_for_subscription.favorite_model_ids(cx);
113                if new_favorites != picker.delegate.favorites {
114                    picker.delegate.favorites = new_favorites;
115                    picker.refresh(window, cx);
116                }
117            });
118        let favorites = agent_server.favorite_model_ids(cx);
119
120        Self {
121            selector,
122            agent_server,
123            fs,
124            filtered_entries: Vec::new(),
125            models: None,
126            selected_model: None,
127            selected_index: 0,
128            selected_description: None,
129            favorites,
130            _refresh_models_task: refresh_models_task,
131            _settings_subscription: settings_subscription,
132            focus_handle,
133        }
134    }
135
136    pub fn active_model(&self) -> Option<&AgentModelInfo> {
137        self.selected_model.as_ref()
138    }
139
140    pub fn favorites_count(&self) -> usize {
141        self.favorites.len()
142    }
143
144    pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
145        if self.favorites.is_empty() {
146            return;
147        }
148
149        let Some(models) = &self.models else {
150            return;
151        };
152
153        let all_models: Vec<&AgentModelInfo> = match models {
154            AgentModelList::Flat(list) => list.iter().collect(),
155            AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
156        };
157
158        let favorite_models: Vec<_> = all_models
159            .into_iter()
160            .filter(|model| self.favorites.contains(&model.id))
161            .unique_by(|model| &model.id)
162            .collect();
163
164        if favorite_models.is_empty() {
165            return;
166        }
167
168        let current_id = self.selected_model.as_ref().map(|m| &m.id);
169
170        let current_index_in_favorites = current_id
171            .and_then(|id| favorite_models.iter().position(|m| &m.id == id))
172            .unwrap_or(usize::MAX);
173
174        let next_index = if current_index_in_favorites == usize::MAX {
175            0
176        } else {
177            (current_index_in_favorites + 1) % favorite_models.len()
178        };
179
180        let next_model = favorite_models[next_index].clone();
181
182        self.selector
183            .select_model(next_model.id.clone(), cx)
184            .detach_and_log_err(cx);
185
186        self.selected_model = Some(next_model);
187
188        // Keep the picker selection aligned with the newly-selected model
189        if let Some(new_index) = self.filtered_entries.iter().position(|entry| {
190            matches!(entry, ModelPickerEntry::Model(model_info, _) if self.selected_model.as_ref().is_some_and(|selected| model_info.id == selected.id))
191        }) {
192            self.set_selected_index(new_index, window, cx);
193        } else {
194            cx.notify();
195        }
196    }
197}
198
199impl PickerDelegate for ModelPickerDelegate {
200    type ListItem = AnyElement;
201
202    fn match_count(&self) -> usize {
203        self.filtered_entries.len()
204    }
205
206    fn selected_index(&self) -> usize {
207        self.selected_index
208    }
209
210    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
211        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
212        cx.notify();
213    }
214
215    fn can_select(&self, ix: usize, _window: &mut Window, _cx: &mut Context<Picker<Self>>) -> bool {
216        match self.filtered_entries.get(ix) {
217            Some(ModelPickerEntry::Model(_, _)) => true,
218            Some(ModelPickerEntry::Separator(_)) | None => false,
219        }
220    }
221
222    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
223        "Select a model…".into()
224    }
225
226    fn update_matches(
227        &mut self,
228        query: String,
229        window: &mut Window,
230        cx: &mut Context<Picker<Self>>,
231    ) -> Task<()> {
232        let favorites = self.favorites.clone();
233
234        cx.spawn_in(window, async move |this, cx| {
235            let filtered_models = match this
236                .read_with(cx, |this, cx| {
237                    this.delegate.models.clone().map(move |models| {
238                        fuzzy_search(models, query, cx.background_executor().clone())
239                    })
240                })
241                .ok()
242                .flatten()
243            {
244                Some(task) => task.await,
245                None => AgentModelList::Flat(vec![]),
246            };
247
248            this.update_in(cx, |this, window, cx| {
249                this.delegate.filtered_entries =
250                    info_list_to_picker_entries(filtered_models, &favorites);
251                // Finds the currently selected model in the list
252                let new_index = this
253                    .delegate
254                    .selected_model
255                    .as_ref()
256                    .and_then(|selected| {
257                        this.delegate.filtered_entries.iter().position(|entry| {
258                            if let ModelPickerEntry::Model(model_info, _) = entry {
259                                model_info.id == selected.id
260                            } else {
261                                false
262                            }
263                        })
264                    })
265                    .unwrap_or(0);
266                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
267                cx.notify();
268            })
269            .ok();
270        })
271    }
272
273    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
274        if let Some(ModelPickerEntry::Model(model_info, _)) =
275            self.filtered_entries.get(self.selected_index)
276        {
277            if window.modifiers().secondary() {
278                let default_model = self.agent_server.default_model(cx);
279                let is_default = default_model.as_ref() == Some(&model_info.id);
280
281                self.agent_server.set_default_model(
282                    if is_default {
283                        None
284                    } else {
285                        Some(model_info.id.clone())
286                    },
287                    self.fs.clone(),
288                    cx,
289                );
290            }
291
292            self.selector
293                .select_model(model_info.id.clone(), cx)
294                .detach_and_log_err(cx);
295            self.selected_model = Some(model_info.clone());
296            let current_index = self.selected_index;
297            self.set_selected_index(current_index, window, cx);
298
299            cx.emit(DismissEvent);
300        }
301    }
302
303    fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
304        cx.defer_in(window, |picker, window, cx| {
305            picker.set_query("", window, cx);
306        });
307    }
308
309    fn render_match(
310        &self,
311        ix: usize,
312        selected: bool,
313        _: &mut Window,
314        cx: &mut Context<Picker<Self>>,
315    ) -> Option<Self::ListItem> {
316        match self.filtered_entries.get(ix)? {
317            ModelPickerEntry::Separator(title) => {
318                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
319            }
320            ModelPickerEntry::Model(model_info, is_favorite) => {
321                let is_selected = Some(model_info) == self.selected_model.as_ref();
322                let default_model = self.agent_server.default_model(cx);
323                let is_default = default_model.as_ref() == Some(&model_info.id);
324
325                let is_favorite = *is_favorite;
326                let handle_action_click = {
327                    let model_id = model_info.id.clone();
328                    let fs = self.fs.clone();
329                    let agent_server = self.agent_server.clone();
330
331                    cx.listener(move |_, _, _, cx| {
332                        agent_server.toggle_favorite_model(
333                            model_id.clone(),
334                            !is_favorite,
335                            fs.clone(),
336                            cx,
337                        );
338                    })
339                };
340
341                let model_cost = model_info.cost.clone();
342
343                Some(
344                    div()
345                        .id(("model-picker-menu-child", ix))
346                        .when_some(model_info.description.clone(), |this, description| {
347                            this.on_hover(cx.listener(move |menu, hovered, _, cx| {
348                                if *hovered {
349                                    menu.delegate.selected_description =
350                                        Some((ix, description.clone(), is_default));
351                                } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
352                                    menu.delegate.selected_description = None;
353                                }
354                                cx.notify();
355                            }))
356                        })
357                        .child(
358                            ModelSelectorListItem::new(ix, model_info.name.clone())
359                                .map(|this| match &model_info.icon {
360                                    Some(AgentModelIcon::Path(path)) => this.icon_path(path.clone()),
361                                    Some(AgentModelIcon::Named(icon)) => this.icon(*icon),
362                                    None => this,
363                                })
364                                .is_selected(is_selected)
365                                .is_focused(selected)
366                                .is_latest(model_info.is_latest)
367                                .is_favorite(is_favorite)
368                                .on_toggle_favorite(handle_action_click)
369                                .cost_info(model_cost)
370                        )
371                        .into_any_element(),
372                )
373            }
374        }
375    }
376
377    fn documentation_aside(
378        &self,
379        _window: &mut Window,
380        cx: &mut Context<Picker<Self>>,
381    ) -> Option<ui::DocumentationAside> {
382        self.selected_description
383            .as_ref()
384            .map(|(_, description, is_default)| {
385                let description = description.clone();
386                let is_default = *is_default;
387
388                let settings = AgentSettings::get_global(cx);
389                let side = match settings.dock {
390                    settings::DockPosition::Left => DocumentationSide::Right,
391                    settings::DockPosition::Bottom | settings::DockPosition::Right => {
392                        DocumentationSide::Left
393                    }
394                };
395
396                DocumentationAside::new(
397                    side,
398                    Rc::new(move |_| {
399                        v_flex()
400                            .gap_1()
401                            .child(Label::new(description.clone()))
402                            .child(HoldForDefault::new(is_default))
403                            .into_any_element()
404                    }),
405                )
406            })
407    }
408
409    fn documentation_aside_index(&self) -> Option<usize> {
410        self.selected_description.as_ref().map(|(ix, _, _)| *ix)
411    }
412
413    fn render_footer(
414        &self,
415        _window: &mut Window,
416        _cx: &mut Context<Picker<Self>>,
417    ) -> Option<AnyElement> {
418        let focus_handle = self.focus_handle.clone();
419
420        if !self.selector.should_render_footer() {
421            return None;
422        }
423
424        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
425    }
426}
427
428fn info_list_to_picker_entries(
429    model_list: AgentModelList,
430    favorites: &HashSet<acp::ModelId>,
431) -> Vec<ModelPickerEntry> {
432    let mut entries = Vec::new();
433
434    let all_models: Vec<_> = match &model_list {
435        AgentModelList::Flat(list) => list.iter().collect(),
436        AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
437    };
438
439    let favorite_models: Vec<_> = all_models
440        .iter()
441        .filter(|m| favorites.contains(&m.id))
442        .unique_by(|m| &m.id)
443        .collect();
444
445    let has_favorites = !favorite_models.is_empty();
446    if has_favorites {
447        entries.push(ModelPickerEntry::Separator("Favorite".into()));
448        for model in favorite_models {
449            entries.push(ModelPickerEntry::Model((*model).clone(), true));
450        }
451    }
452
453    match model_list {
454        AgentModelList::Flat(list) => {
455            if has_favorites {
456                entries.push(ModelPickerEntry::Separator("All".into()));
457            }
458            for model in list {
459                let is_favorite = favorites.contains(&model.id);
460                entries.push(ModelPickerEntry::Model(model, is_favorite));
461            }
462        }
463        AgentModelList::Grouped(index_map) => {
464            for (group_name, models) in index_map {
465                entries.push(ModelPickerEntry::Separator(group_name.0));
466                for model in models {
467                    let is_favorite = favorites.contains(&model.id);
468                    entries.push(ModelPickerEntry::Model(model, is_favorite));
469                }
470            }
471        }
472    }
473
474    entries
475}
476
477async fn fuzzy_search(
478    model_list: AgentModelList,
479    query: String,
480    executor: BackgroundExecutor,
481) -> AgentModelList {
482    async fn fuzzy_search_list(
483        model_list: Vec<AgentModelInfo>,
484        query: &str,
485        executor: BackgroundExecutor,
486    ) -> Vec<AgentModelInfo> {
487        let candidates = model_list
488            .iter()
489            .enumerate()
490            .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
491            .collect::<Vec<_>>();
492        let mut matches = match_strings(
493            &candidates,
494            query,
495            false,
496            true,
497            100,
498            &Default::default(),
499            executor,
500        )
501        .await;
502
503        matches.sort_unstable_by_key(|mat| {
504            let candidate = &candidates[mat.candidate_id];
505            (Reverse(OrderedFloat(mat.score)), candidate.id)
506        });
507
508        matches
509            .into_iter()
510            .map(|mat| model_list[mat.candidate_id].clone())
511            .collect()
512    }
513
514    match model_list {
515        AgentModelList::Flat(model_list) => {
516            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
517        }
518        AgentModelList::Grouped(index_map) => {
519            let groups =
520                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
521                    fuzzy_search_list(models, &query, executor.clone())
522                        .map(|results| (group_name, results))
523                }))
524                .await;
525            AgentModelList::Grouped(IndexMap::from_iter(
526                groups
527                    .into_iter()
528                    .filter(|(_, results)| !results.is_empty()),
529            ))
530        }
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use gpui::TestAppContext;
537
538    use super::*;
539
540    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
541        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
542            |(group, models)| {
543                (
544                    acp_thread::AgentModelGroupName(group.to_string().into()),
545                    models
546                        .into_iter()
547                        .map(|model| acp_thread::AgentModelInfo {
548                            id: acp::ModelId::new(model.to_string()),
549                            name: model.to_string().into(),
550                            description: None,
551                            icon: None,
552                            is_latest: false,
553                            cost: None,
554                        })
555                        .collect::<Vec<_>>(),
556                )
557            },
558        )))
559    }
560
561    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
562        let AgentModelList::Grouped(groups) = result else {
563            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
564        };
565
566        assert_eq!(
567            groups.len(),
568            expected.len(),
569            "Number of groups doesn't match"
570        );
571
572        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
573            let (actual_group, actual_models) = groups.get_index(i).unwrap();
574            assert_eq!(
575                actual_group.0.as_ref(),
576                *expected_group,
577                "Group at position {} doesn't match expected group",
578                i
579            );
580            assert_eq!(
581                actual_models.len(),
582                expected_models.len(),
583                "Number of models in group {} doesn't match",
584                expected_group
585            );
586
587            for (j, expected_model_name) in expected_models.iter().enumerate() {
588                assert_eq!(
589                    actual_models[j].name, *expected_model_name,
590                    "Model at position {} in group {} doesn't match expected model",
591                    j, expected_group
592                );
593            }
594        }
595    }
596
597    fn create_favorites(models: Vec<&str>) -> HashSet<acp::ModelId> {
598        models
599            .into_iter()
600            .map(|m| acp::ModelId::new(m.to_string()))
601            .collect()
602    }
603
604    fn get_entry_model_ids(entries: &[ModelPickerEntry]) -> Vec<&str> {
605        entries
606            .iter()
607            .filter_map(|entry| match entry {
608                ModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()),
609                _ => None,
610            })
611            .collect()
612    }
613
614    fn get_entry_labels(entries: &[ModelPickerEntry]) -> Vec<&str> {
615        entries
616            .iter()
617            .map(|entry| match entry {
618                ModelPickerEntry::Model(info, _) => info.id.0.as_ref(),
619                ModelPickerEntry::Separator(s) => &s,
620            })
621            .collect()
622    }
623
624    #[gpui::test]
625    async fn test_fuzzy_match(cx: &mut TestAppContext) {
626        let models = create_model_list(vec![
627            (
628                "zed",
629                vec![
630                    "Claude 3.7 Sonnet",
631                    "Claude 3.7 Sonnet Thinking",
632                    "gpt-5",
633                    "gpt-5-mini",
634                ],
635            ),
636            ("openai", vec!["gpt-3.5-turbo", "gpt-5", "gpt-5-mini"]),
637            ("ollama", vec!["mistral", "deepseek"]),
638        ]);
639
640        // Results should preserve models order whenever possible.
641        // In the case below, `zed/gpt-5-mini` and `openai/gpt-5-mini` have identical
642        // similarity scores, but `zed/gpt-5-mini` was higher in the models list,
643        // so it should appear first in the results.
644        let results = fuzzy_search(models.clone(), "mini".into(), cx.executor()).await;
645        assert_models_eq(
646            results,
647            vec![("zed", vec!["gpt-5-mini"]), ("openai", vec!["gpt-5-mini"])],
648        );
649
650        // Fuzzy search - test with specific model name
651        let results = fuzzy_search(models.clone(), "mistral".into(), cx.executor()).await;
652        assert_models_eq(results, vec![("ollama", vec!["mistral"])]);
653    }
654
655    #[gpui::test]
656    fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
657        let models = create_model_list(vec![
658            ("zed", vec!["zed/claude", "zed/gemini"]),
659            ("openai", vec!["openai/gpt-5"]),
660        ]);
661        let favorites = create_favorites(vec!["zed/gemini"]);
662
663        let entries = info_list_to_picker_entries(models, &favorites);
664
665        assert!(matches!(
666            entries.first(),
667            Some(ModelPickerEntry::Separator(s)) if s == "Favorite"
668        ));
669
670        let model_ids = get_entry_model_ids(&entries);
671        assert_eq!(model_ids[0], "zed/gemini");
672    }
673
674    #[gpui::test]
675    fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
676        let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]);
677        let favorites = create_favorites(vec![]);
678
679        let entries = info_list_to_picker_entries(models, &favorites);
680
681        assert!(matches!(
682            entries.first(),
683            Some(ModelPickerEntry::Separator(s)) if s == "zed"
684        ));
685    }
686
687    #[gpui::test]
688    fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
689        let models = create_model_list(vec![
690            ("zed", vec!["zed/claude", "zed/gemini"]),
691            ("openai", vec!["openai/gpt-5"]),
692        ]);
693        let favorites = create_favorites(vec!["zed/claude"]);
694
695        let entries = info_list_to_picker_entries(models, &favorites);
696
697        for entry in &entries {
698            if let ModelPickerEntry::Model(info, is_favorite) = entry {
699                if info.id.0.as_ref() == "zed/claude" {
700                    assert!(is_favorite, "zed/claude should be a favorite");
701                } else {
702                    assert!(!is_favorite, "{} should not be a favorite", info.id.0);
703                }
704            }
705        }
706    }
707
708    #[gpui::test]
709    fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) {
710        let models = create_model_list(vec![
711            ("zed", vec!["zed/claude", "zed/gemini"]),
712            ("openai", vec!["openai/gpt-5", "openai/gpt-4"]),
713        ]);
714        let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]);
715
716        let entries = info_list_to_picker_entries(models, &favorites);
717        let model_ids = get_entry_model_ids(&entries);
718
719        assert_eq!(model_ids[0], "zed/gemini");
720        assert_eq!(model_ids[1], "openai/gpt-5");
721
722        assert!(model_ids[2..].contains(&"zed/gemini"));
723        assert!(model_ids[2..].contains(&"openai/gpt-5"));
724    }
725
726    #[gpui::test]
727    fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) {
728        let models = create_model_list(vec![
729            ("Recommended", vec!["zed/claude", "anthropic/claude"]),
730            ("Zed", vec!["zed/claude", "zed/gpt-5"]),
731            ("Antropic", vec!["anthropic/claude"]),
732            ("OpenAI", vec!["openai/gpt-5"]),
733        ]);
734
735        let favorites = create_favorites(vec!["zed/claude"]);
736
737        let entries = info_list_to_picker_entries(models, &favorites);
738        let labels = get_entry_labels(&entries);
739
740        assert_eq!(
741            labels,
742            vec![
743                "Favorite",
744                "zed/claude",
745                "Recommended",
746                "zed/claude",
747                "anthropic/claude",
748                "Zed",
749                "zed/claude",
750                "zed/gpt-5",
751                "Antropic",
752                "anthropic/claude",
753                "OpenAI",
754                "openai/gpt-5"
755            ]
756        );
757    }
758
759    #[gpui::test]
760    fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) {
761        let models = AgentModelList::Flat(vec![
762            acp_thread::AgentModelInfo {
763                id: acp::ModelId::new("zed/claude".to_string()),
764                name: "Claude".into(),
765                description: None,
766                icon: None,
767                is_latest: false,
768                cost: None,
769            },
770            acp_thread::AgentModelInfo {
771                id: acp::ModelId::new("zed/gemini".to_string()),
772                name: "Gemini".into(),
773                description: None,
774                icon: None,
775                is_latest: false,
776                cost: None,
777            },
778        ]);
779        let favorites = create_favorites(vec!["zed/gemini"]);
780
781        let entries = info_list_to_picker_entries(models, &favorites);
782
783        assert!(matches!(
784            entries.first(),
785            Some(ModelPickerEntry::Separator(s)) if s == "Favorite"
786        ));
787
788        assert!(entries.iter().any(|e| matches!(
789            e,
790            ModelPickerEntry::Separator(s) if s == "All"
791        )));
792    }
793
794    #[gpui::test]
795    fn test_favorites_count_returns_correct_count(_cx: &mut TestAppContext) {
796        let empty_favorites: HashSet<acp::ModelId> = HashSet::default();
797        assert_eq!(empty_favorites.len(), 0);
798
799        let one_favorite = create_favorites(vec!["model-a"]);
800        assert_eq!(one_favorite.len(), 1);
801
802        let multiple_favorites = create_favorites(vec!["model-a", "model-b", "model-c"]);
803        assert_eq!(multiple_favorites.len(), 3);
804
805        let with_duplicates = create_favorites(vec!["model-a", "model-a", "model-b"]);
806        assert_eq!(with_duplicates.len(), 2);
807    }
808
809    #[gpui::test]
810    fn test_is_favorite_flag_set_correctly_in_entries(_cx: &mut TestAppContext) {
811        let models = AgentModelList::Flat(vec![
812            acp_thread::AgentModelInfo {
813                id: acp::ModelId::new("favorite-model".to_string()),
814                name: "Favorite".into(),
815                description: None,
816                icon: None,
817                is_latest: false,
818                cost: None,
819            },
820            acp_thread::AgentModelInfo {
821                id: acp::ModelId::new("regular-model".to_string()),
822                name: "Regular".into(),
823                description: None,
824                icon: None,
825                is_latest: false,
826                cost: None,
827            },
828        ]);
829        let favorites = create_favorites(vec!["favorite-model"]);
830
831        let entries = info_list_to_picker_entries(models, &favorites);
832
833        for entry in &entries {
834            if let ModelPickerEntry::Model(info, is_favorite) = entry {
835                if info.id.0.as_ref() == "favorite-model" {
836                    assert!(*is_favorite, "favorite-model should have is_favorite=true");
837                } else if info.id.0.as_ref() == "regular-model" {
838                    assert!(!*is_favorite, "regular-model should have is_favorite=false");
839                }
840            }
841        }
842    }
843}