model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_client_protocol::schema as acp;
  5use agent_servers::AgentServer;
  6
  7use anyhow::Result;
  8use collections::{HashSet, IndexMap};
  9use fs::Fs;
 10use futures::FutureExt;
 11use fuzzy::{StringMatchCandidate, match_strings};
 12use gpui::{
 13    Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
 14    WeakEntity,
 15};
 16use itertools::Itertools;
 17use ordered_float::OrderedFloat;
 18use picker::{Picker, PickerDelegate};
 19use settings::SettingsStore;
 20use ui::{DocumentationAside, IntoElement, prelude::*};
 21use util::ResultExt;
 22use zed_actions::agent::OpenSettings;
 23
 24use crate::ui::{
 25    HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem,
 26    documentation_aside_side,
 27};
 28
 29pub type ModelSelector = Picker<ModelPickerDelegate>;
 30
 31pub fn acp_model_selector(
 32    selector: Rc<dyn AgentModelSelector>,
 33    agent_server: Rc<dyn AgentServer>,
 34    fs: Arc<dyn Fs>,
 35    focus_handle: FocusHandle,
 36    window: &mut Window,
 37    cx: &mut Context<ModelSelector>,
 38) -> ModelSelector {
 39    let delegate = ModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
 40    Picker::list(delegate, window, cx)
 41        .show_scrollbar(true)
 42        .width(rems(20.))
 43        .max_height(Some(rems(20.).into()))
 44}
 45
 46enum ModelPickerEntry {
 47    Separator(SharedString),
 48    Model(AgentModelInfo, bool),
 49}
 50
 51pub struct ModelPickerDelegate {
 52    selector: Rc<dyn AgentModelSelector>,
 53    agent_server: Rc<dyn AgentServer>,
 54    fs: Arc<dyn Fs>,
 55    filtered_entries: Vec<ModelPickerEntry>,
 56    models: Option<AgentModelList>,
 57    selected_index: usize,
 58    selected_description: Option<(usize, SharedString, bool)>,
 59    selected_model: Option<AgentModelInfo>,
 60    favorites: HashSet<acp::ModelId>,
 61    _refresh_models_task: Task<()>,
 62    _settings_subscription: Subscription,
 63    focus_handle: FocusHandle,
 64}
 65
 66impl ModelPickerDelegate {
 67    fn new(
 68        selector: Rc<dyn AgentModelSelector>,
 69        agent_server: Rc<dyn AgentServer>,
 70        fs: Arc<dyn Fs>,
 71        focus_handle: FocusHandle,
 72        window: &mut Window,
 73        cx: &mut Context<ModelSelector>,
 74    ) -> Self {
 75        let rx = selector.watch(cx);
 76        let refresh_models_task = {
 77            cx.spawn_in(window, {
 78                async move |this, cx| {
 79                    async fn refresh(
 80                        this: &WeakEntity<Picker<ModelPickerDelegate>>,
 81                        cx: &mut AsyncWindowContext,
 82                    ) -> Result<()> {
 83                        let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 84                            (
 85                                this.delegate.selector.list_models(cx),
 86                                this.delegate.selector.selected_model(cx),
 87                            )
 88                        })?;
 89
 90                        let (models, selected_model) =
 91                            futures::join!(models_task, selected_model_task);
 92
 93                        this.update_in(cx, |this, window, cx| {
 94                            this.delegate.models = models.ok();
 95                            this.delegate.selected_model = selected_model.ok();
 96                            this.refresh(window, cx)
 97                        })
 98                    }
 99
100                    refresh(&this, cx).await.log_err();
101                    if let Some(mut rx) = rx {
102                        while let Ok(()) = rx.recv().await {
103                            refresh(&this, cx).await.log_err();
104                        }
105                    }
106                }
107            })
108        };
109
110        let agent_server_for_subscription = agent_server.clone();
111        let settings_subscription =
112            cx.observe_global_in::<SettingsStore>(window, move |picker, window, cx| {
113                // Only refresh if the favorites actually changed to avoid redundant work
114                // when other settings are modified (e.g., user editing settings.json)
115                let new_favorites = agent_server_for_subscription.favorite_model_ids(cx);
116                if new_favorites != picker.delegate.favorites {
117                    picker.delegate.favorites = new_favorites;
118                    picker.refresh(window, cx);
119                }
120            });
121        let favorites = agent_server.favorite_model_ids(cx);
122
123        Self {
124            selector,
125            agent_server,
126            fs,
127            filtered_entries: Vec::new(),
128            models: None,
129            selected_model: None,
130            selected_index: 0,
131            selected_description: None,
132            favorites,
133            _refresh_models_task: refresh_models_task,
134            _settings_subscription: settings_subscription,
135            focus_handle,
136        }
137    }
138
139    pub fn active_model(&self) -> Option<&AgentModelInfo> {
140        self.selected_model.as_ref()
141    }
142
143    pub fn favorites_count(&self) -> usize {
144        self.favorites.len()
145    }
146
147    pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
148        if self.favorites.is_empty() {
149            return;
150        }
151
152        let Some(models) = &self.models else {
153            return;
154        };
155
156        let all_models: Vec<&AgentModelInfo> = match models {
157            AgentModelList::Flat(list) => list.iter().collect(),
158            AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
159        };
160
161        let favorite_models: Vec<_> = all_models
162            .into_iter()
163            .filter(|model| self.favorites.contains(&model.id))
164            .unique_by(|model| &model.id)
165            .collect();
166
167        if favorite_models.is_empty() {
168            return;
169        }
170
171        let current_id = self.selected_model.as_ref().map(|m| &m.id);
172
173        let current_index_in_favorites = current_id
174            .and_then(|id| favorite_models.iter().position(|m| &m.id == id))
175            .unwrap_or(usize::MAX);
176
177        let next_index = if current_index_in_favorites == usize::MAX {
178            0
179        } else {
180            (current_index_in_favorites + 1) % favorite_models.len()
181        };
182
183        let next_model = favorite_models[next_index].clone();
184
185        self.selector
186            .select_model(next_model.id.clone(), cx)
187            .detach_and_log_err(cx);
188
189        self.selected_model = Some(next_model);
190
191        // Keep the picker selection aligned with the newly-selected model
192        if let Some(new_index) = self.filtered_entries.iter().position(|entry| {
193            matches!(entry, ModelPickerEntry::Model(model_info, _) if self.selected_model.as_ref().is_some_and(|selected| model_info.id == selected.id))
194        }) {
195            self.set_selected_index(new_index, window, cx);
196        } else {
197            cx.notify();
198        }
199    }
200}
201
202impl PickerDelegate for ModelPickerDelegate {
203    type ListItem = AnyElement;
204
205    fn match_count(&self) -> usize {
206        self.filtered_entries.len()
207    }
208
209    fn selected_index(&self) -> usize {
210        self.selected_index
211    }
212
213    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
214        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
215        cx.notify();
216    }
217
218    fn can_select(&self, ix: usize, _window: &mut Window, _cx: &mut Context<Picker<Self>>) -> bool {
219        match self.filtered_entries.get(ix) {
220            Some(ModelPickerEntry::Model(_, _)) => true,
221            Some(ModelPickerEntry::Separator(_)) | None => false,
222        }
223    }
224
225    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
226        "Select a model…".into()
227    }
228
229    fn update_matches(
230        &mut self,
231        query: String,
232        window: &mut Window,
233        cx: &mut Context<Picker<Self>>,
234    ) -> Task<()> {
235        let favorites = self.favorites.clone();
236
237        cx.spawn_in(window, async move |this, cx| {
238            let filtered_models = match this
239                .read_with(cx, |this, cx| {
240                    this.delegate.models.clone().map(move |models| {
241                        fuzzy_search(models, query, cx.background_executor().clone())
242                    })
243                })
244                .ok()
245                .flatten()
246            {
247                Some(task) => task.await,
248                None => AgentModelList::Flat(vec![]),
249            };
250
251            this.update_in(cx, |this, window, cx| {
252                this.delegate.filtered_entries =
253                    info_list_to_picker_entries(filtered_models, &favorites);
254                // Finds the currently selected model in the list
255                let new_index = this
256                    .delegate
257                    .selected_model
258                    .as_ref()
259                    .and_then(|selected| {
260                        this.delegate.filtered_entries.iter().position(|entry| {
261                            if let ModelPickerEntry::Model(model_info, _) = entry {
262                                model_info.id == selected.id
263                            } else {
264                                false
265                            }
266                        })
267                    })
268                    .unwrap_or(0);
269                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
270                cx.notify();
271            })
272            .ok();
273        })
274    }
275
276    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
277        if let Some(ModelPickerEntry::Model(model_info, _)) =
278            self.filtered_entries.get(self.selected_index)
279        {
280            if window.modifiers().secondary() {
281                let default_model = self.agent_server.default_model(cx);
282                let is_default = default_model.as_ref() == Some(&model_info.id);
283
284                self.agent_server.set_default_model(
285                    if is_default {
286                        None
287                    } else {
288                        Some(model_info.id.clone())
289                    },
290                    self.fs.clone(),
291                    cx,
292                );
293            }
294
295            self.selector
296                .select_model(model_info.id.clone(), cx)
297                .detach_and_log_err(cx);
298            self.selected_model = Some(model_info.clone());
299            let current_index = self.selected_index;
300            self.set_selected_index(current_index, window, cx);
301
302            cx.emit(DismissEvent);
303        }
304    }
305
306    fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
307        cx.defer_in(window, |picker, window, cx| {
308            picker.set_query("", window, cx);
309        });
310    }
311
312    fn render_match(
313        &self,
314        ix: usize,
315        selected: bool,
316        _: &mut Window,
317        cx: &mut Context<Picker<Self>>,
318    ) -> Option<Self::ListItem> {
319        match self.filtered_entries.get(ix)? {
320            ModelPickerEntry::Separator(title) => {
321                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
322            }
323            ModelPickerEntry::Model(model_info, is_favorite) => {
324                let is_selected = Some(model_info) == self.selected_model.as_ref();
325                let default_model = self.agent_server.default_model(cx);
326                let is_default = default_model.as_ref() == Some(&model_info.id);
327
328                let is_favorite = *is_favorite;
329                let handle_action_click = {
330                    let model_id = model_info.id.clone();
331                    let fs = self.fs.clone();
332                    let agent_server = self.agent_server.clone();
333
334                    cx.listener(move |_, _, _, cx| {
335                        agent_server.toggle_favorite_model(
336                            model_id.clone(),
337                            !is_favorite,
338                            fs.clone(),
339                            cx,
340                        );
341                    })
342                };
343
344                let model_cost = model_info.cost.clone();
345
346                Some(
347                    div()
348                        .id(("model-picker-menu-child", ix))
349                        .when_some(model_info.description.clone(), |this, description| {
350                            this.on_hover(cx.listener(move |menu, hovered, _, cx| {
351                                if *hovered {
352                                    menu.delegate.selected_description =
353                                        Some((ix, description.clone(), is_default));
354                                } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
355                                    menu.delegate.selected_description = None;
356                                }
357                                cx.notify();
358                            }))
359                        })
360                        .child(
361                            ModelSelectorListItem::new(ix, model_info.name.clone())
362                                .map(|this| match &model_info.icon {
363                                    Some(AgentModelIcon::Path(path)) => this.icon_path(path.clone()),
364                                    Some(AgentModelIcon::Named(icon)) => this.icon(*icon),
365                                    None => this,
366                                })
367                                .is_selected(is_selected)
368                                .is_focused(selected)
369                                .is_latest(model_info.is_latest)
370                                .is_favorite(is_favorite)
371                                .on_toggle_favorite(handle_action_click)
372                                .cost_info(model_cost)
373                        )
374                        .into_any_element(),
375                )
376            }
377        }
378    }
379
380    fn documentation_aside(
381        &self,
382        _window: &mut Window,
383        cx: &mut Context<Picker<Self>>,
384    ) -> Option<ui::DocumentationAside> {
385        self.selected_description
386            .as_ref()
387            .map(|(_, description, is_default)| {
388                let description = description.clone();
389                let is_default = *is_default;
390
391                let side = documentation_aside_side(cx);
392
393                DocumentationAside::new(
394                    side,
395                    Rc::new(move |_| {
396                        v_flex()
397                            .gap_1()
398                            .child(Label::new(description.clone()))
399                            .child(HoldForDefault::new(is_default))
400                            .into_any_element()
401                    }),
402                )
403            })
404    }
405
406    fn documentation_aside_index(&self) -> Option<usize> {
407        self.selected_description.as_ref().map(|(ix, _, _)| *ix)
408    }
409
410    fn render_footer(
411        &self,
412        _window: &mut Window,
413        _cx: &mut Context<Picker<Self>>,
414    ) -> Option<AnyElement> {
415        let focus_handle = self.focus_handle.clone();
416
417        if !self.selector.should_render_footer() {
418            return None;
419        }
420
421        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
422    }
423}
424
425fn info_list_to_picker_entries(
426    model_list: AgentModelList,
427    favorites: &HashSet<acp::ModelId>,
428) -> Vec<ModelPickerEntry> {
429    let mut entries = Vec::new();
430
431    let all_models: Vec<_> = match &model_list {
432        AgentModelList::Flat(list) => list.iter().collect(),
433        AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
434    };
435
436    let favorite_models: Vec<_> = all_models
437        .iter()
438        .filter(|m| favorites.contains(&m.id))
439        .unique_by(|m| &m.id)
440        .collect();
441
442    let has_favorites = !favorite_models.is_empty();
443    if has_favorites {
444        entries.push(ModelPickerEntry::Separator("Favorite".into()));
445        for model in favorite_models {
446            entries.push(ModelPickerEntry::Model((*model).clone(), true));
447        }
448    }
449
450    match model_list {
451        AgentModelList::Flat(list) => {
452            if has_favorites {
453                entries.push(ModelPickerEntry::Separator("All".into()));
454            }
455            for model in list {
456                let is_favorite = favorites.contains(&model.id);
457                entries.push(ModelPickerEntry::Model(model, is_favorite));
458            }
459        }
460        AgentModelList::Grouped(index_map) => {
461            for (group_name, models) in index_map {
462                entries.push(ModelPickerEntry::Separator(group_name.0));
463                for model in models {
464                    let is_favorite = favorites.contains(&model.id);
465                    entries.push(ModelPickerEntry::Model(model, is_favorite));
466                }
467            }
468        }
469    }
470
471    entries
472}
473
474async fn fuzzy_search(
475    model_list: AgentModelList,
476    query: String,
477    executor: BackgroundExecutor,
478) -> AgentModelList {
479    async fn fuzzy_search_list(
480        model_list: Vec<AgentModelInfo>,
481        query: &str,
482        executor: BackgroundExecutor,
483    ) -> Vec<AgentModelInfo> {
484        let candidates = model_list
485            .iter()
486            .enumerate()
487            .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
488            .collect::<Vec<_>>();
489        let mut matches = match_strings(
490            &candidates,
491            query,
492            false,
493            true,
494            100,
495            &Default::default(),
496            executor,
497        )
498        .await;
499
500        matches.sort_unstable_by_key(|mat| {
501            let candidate = &candidates[mat.candidate_id];
502            (Reverse(OrderedFloat(mat.score)), candidate.id)
503        });
504
505        matches
506            .into_iter()
507            .map(|mat| model_list[mat.candidate_id].clone())
508            .collect()
509    }
510
511    match model_list {
512        AgentModelList::Flat(model_list) => {
513            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
514        }
515        AgentModelList::Grouped(index_map) => {
516            let groups =
517                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
518                    fuzzy_search_list(models, &query, executor.clone())
519                        .map(|results| (group_name, results))
520                }))
521                .await;
522            AgentModelList::Grouped(IndexMap::from_iter(
523                groups
524                    .into_iter()
525                    .filter(|(_, results)| !results.is_empty()),
526            ))
527        }
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use gpui::TestAppContext;
534
535    use super::*;
536
537    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
538        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
539            |(group, models)| {
540                (
541                    acp_thread::AgentModelGroupName(group.to_string().into()),
542                    models
543                        .into_iter()
544                        .map(|model| acp_thread::AgentModelInfo {
545                            id: acp::ModelId::new(model.to_string()),
546                            name: model.to_string().into(),
547                            description: None,
548                            icon: None,
549                            is_latest: false,
550                            cost: None,
551                        })
552                        .collect::<Vec<_>>(),
553                )
554            },
555        )))
556    }
557
558    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
559        let AgentModelList::Grouped(groups) = result else {
560            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
561        };
562
563        assert_eq!(
564            groups.len(),
565            expected.len(),
566            "Number of groups doesn't match"
567        );
568
569        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
570            let (actual_group, actual_models) = groups.get_index(i).unwrap();
571            assert_eq!(
572                actual_group.0.as_ref(),
573                *expected_group,
574                "Group at position {} doesn't match expected group",
575                i
576            );
577            assert_eq!(
578                actual_models.len(),
579                expected_models.len(),
580                "Number of models in group {} doesn't match",
581                expected_group
582            );
583
584            for (j, expected_model_name) in expected_models.iter().enumerate() {
585                assert_eq!(
586                    actual_models[j].name, *expected_model_name,
587                    "Model at position {} in group {} doesn't match expected model",
588                    j, expected_group
589                );
590            }
591        }
592    }
593
594    fn create_favorites(models: Vec<&str>) -> HashSet<acp::ModelId> {
595        models
596            .into_iter()
597            .map(|m| acp::ModelId::new(m.to_string()))
598            .collect()
599    }
600
601    fn get_entry_model_ids(entries: &[ModelPickerEntry]) -> Vec<&str> {
602        entries
603            .iter()
604            .filter_map(|entry| match entry {
605                ModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()),
606                _ => None,
607            })
608            .collect()
609    }
610
611    fn get_entry_labels(entries: &[ModelPickerEntry]) -> Vec<&str> {
612        entries
613            .iter()
614            .map(|entry| match entry {
615                ModelPickerEntry::Model(info, _) => info.id.0.as_ref(),
616                ModelPickerEntry::Separator(s) => &s,
617            })
618            .collect()
619    }
620
621    #[gpui::test]
622    async fn test_fuzzy_match(cx: &mut TestAppContext) {
623        let models = create_model_list(vec![
624            (
625                "zed",
626                vec![
627                    "Claude 3.7 Sonnet",
628                    "Claude 3.7 Sonnet Thinking",
629                    "gpt-5",
630                    "gpt-5-mini",
631                ],
632            ),
633            ("openai", vec!["gpt-3.5-turbo", "gpt-5", "gpt-5-mini"]),
634            ("ollama", vec!["mistral", "deepseek"]),
635        ]);
636
637        // Results should preserve models order whenever possible.
638        // In the case below, `zed/gpt-5-mini` and `openai/gpt-5-mini` have identical
639        // similarity scores, but `zed/gpt-5-mini` was higher in the models list,
640        // so it should appear first in the results.
641        let results = fuzzy_search(models.clone(), "mini".into(), cx.executor()).await;
642        assert_models_eq(
643            results,
644            vec![("zed", vec!["gpt-5-mini"]), ("openai", vec!["gpt-5-mini"])],
645        );
646
647        // Fuzzy search - test with specific model name
648        let results = fuzzy_search(models.clone(), "mistral".into(), cx.executor()).await;
649        assert_models_eq(results, vec![("ollama", vec!["mistral"])]);
650    }
651
652    #[gpui::test]
653    fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
654        let models = create_model_list(vec![
655            ("zed", vec!["zed/claude", "zed/gemini"]),
656            ("openai", vec!["openai/gpt-5"]),
657        ]);
658        let favorites = create_favorites(vec!["zed/gemini"]);
659
660        let entries = info_list_to_picker_entries(models, &favorites);
661
662        assert!(matches!(
663            entries.first(),
664            Some(ModelPickerEntry::Separator(s)) if s == "Favorite"
665        ));
666
667        let model_ids = get_entry_model_ids(&entries);
668        assert_eq!(model_ids[0], "zed/gemini");
669    }
670
671    #[gpui::test]
672    fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
673        let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]);
674        let favorites = create_favorites(vec![]);
675
676        let entries = info_list_to_picker_entries(models, &favorites);
677
678        assert!(matches!(
679            entries.first(),
680            Some(ModelPickerEntry::Separator(s)) if s == "zed"
681        ));
682    }
683
684    #[gpui::test]
685    fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
686        let models = create_model_list(vec![
687            ("zed", vec!["zed/claude", "zed/gemini"]),
688            ("openai", vec!["openai/gpt-5"]),
689        ]);
690        let favorites = create_favorites(vec!["zed/claude"]);
691
692        let entries = info_list_to_picker_entries(models, &favorites);
693
694        for entry in &entries {
695            if let ModelPickerEntry::Model(info, is_favorite) = entry {
696                if info.id.0.as_ref() == "zed/claude" {
697                    assert!(is_favorite, "zed/claude should be a favorite");
698                } else {
699                    assert!(!is_favorite, "{} should not be a favorite", info.id.0);
700                }
701            }
702        }
703    }
704
705    #[gpui::test]
706    fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) {
707        let models = create_model_list(vec![
708            ("zed", vec!["zed/claude", "zed/gemini"]),
709            ("openai", vec!["openai/gpt-5", "openai/gpt-4"]),
710        ]);
711        let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]);
712
713        let entries = info_list_to_picker_entries(models, &favorites);
714        let model_ids = get_entry_model_ids(&entries);
715
716        assert_eq!(model_ids[0], "zed/gemini");
717        assert_eq!(model_ids[1], "openai/gpt-5");
718
719        assert!(model_ids[2..].contains(&"zed/gemini"));
720        assert!(model_ids[2..].contains(&"openai/gpt-5"));
721    }
722
723    #[gpui::test]
724    fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) {
725        let models = create_model_list(vec![
726            ("Recommended", vec!["zed/claude", "anthropic/claude"]),
727            ("Zed", vec!["zed/claude", "zed/gpt-5"]),
728            ("Antropic", vec!["anthropic/claude"]),
729            ("OpenAI", vec!["openai/gpt-5"]),
730        ]);
731
732        let favorites = create_favorites(vec!["zed/claude"]);
733
734        let entries = info_list_to_picker_entries(models, &favorites);
735        let labels = get_entry_labels(&entries);
736
737        assert_eq!(
738            labels,
739            vec![
740                "Favorite",
741                "zed/claude",
742                "Recommended",
743                "zed/claude",
744                "anthropic/claude",
745                "Zed",
746                "zed/claude",
747                "zed/gpt-5",
748                "Antropic",
749                "anthropic/claude",
750                "OpenAI",
751                "openai/gpt-5"
752            ]
753        );
754    }
755
756    #[gpui::test]
757    fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) {
758        let models = AgentModelList::Flat(vec![
759            acp_thread::AgentModelInfo {
760                id: acp::ModelId::new("zed/claude".to_string()),
761                name: "Claude".into(),
762                description: None,
763                icon: None,
764                is_latest: false,
765                cost: None,
766            },
767            acp_thread::AgentModelInfo {
768                id: acp::ModelId::new("zed/gemini".to_string()),
769                name: "Gemini".into(),
770                description: None,
771                icon: None,
772                is_latest: false,
773                cost: None,
774            },
775        ]);
776        let favorites = create_favorites(vec!["zed/gemini"]);
777
778        let entries = info_list_to_picker_entries(models, &favorites);
779
780        assert!(matches!(
781            entries.first(),
782            Some(ModelPickerEntry::Separator(s)) if s == "Favorite"
783        ));
784
785        assert!(entries.iter().any(|e| matches!(
786            e,
787            ModelPickerEntry::Separator(s) if s == "All"
788        )));
789    }
790
791    #[gpui::test]
792    fn test_favorites_count_returns_correct_count(_cx: &mut TestAppContext) {
793        let empty_favorites: HashSet<acp::ModelId> = HashSet::default();
794        assert_eq!(empty_favorites.len(), 0);
795
796        let one_favorite = create_favorites(vec!["model-a"]);
797        assert_eq!(one_favorite.len(), 1);
798
799        let multiple_favorites = create_favorites(vec!["model-a", "model-b", "model-c"]);
800        assert_eq!(multiple_favorites.len(), 3);
801
802        let with_duplicates = create_favorites(vec!["model-a", "model-a", "model-b"]);
803        assert_eq!(with_duplicates.len(), 2);
804    }
805
806    #[gpui::test]
807    fn test_is_favorite_flag_set_correctly_in_entries(_cx: &mut TestAppContext) {
808        let models = AgentModelList::Flat(vec![
809            acp_thread::AgentModelInfo {
810                id: acp::ModelId::new("favorite-model".to_string()),
811                name: "Favorite".into(),
812                description: None,
813                icon: None,
814                is_latest: false,
815                cost: None,
816            },
817            acp_thread::AgentModelInfo {
818                id: acp::ModelId::new("regular-model".to_string()),
819                name: "Regular".into(),
820                description: None,
821                icon: None,
822                is_latest: false,
823                cost: None,
824            },
825        ]);
826        let favorites = create_favorites(vec!["favorite-model"]);
827
828        let entries = info_list_to_picker_entries(models, &favorites);
829
830        for entry in &entries {
831            if let ModelPickerEntry::Model(info, is_favorite) = entry {
832                if info.id.0.as_ref() == "favorite-model" {
833                    assert!(*is_favorite, "favorite-model should have is_favorite=true");
834                } else if info.id.0.as_ref() == "regular-model" {
835                    assert!(!*is_favorite, "regular-model should have is_favorite=false");
836                }
837            }
838        }
839    }
840}