model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_client_protocol::ModelId;
  5use agent_servers::AgentServer;
  6use agent_settings::AgentSettings;
  7use anyhow::Result;
  8use collections::{HashSet, IndexMap};
  9use fs::Fs;
 10use futures::FutureExt;
 11use fuzzy::{StringMatchCandidate, match_strings};
 12use gpui::{
 13    Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
 14};
 15use itertools::Itertools;
 16use ordered_float::OrderedFloat;
 17use picker::{Picker, PickerDelegate};
 18use settings::Settings;
 19use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, prelude::*};
 20use util::ResultExt;
 21use zed_actions::agent::OpenSettings;
 22
 23use crate::ui::{HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 24
 25pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 26
 27pub fn acp_model_selector(
 28    selector: Rc<dyn AgentModelSelector>,
 29    agent_server: Rc<dyn AgentServer>,
 30    fs: Arc<dyn Fs>,
 31    focus_handle: FocusHandle,
 32    window: &mut Window,
 33    cx: &mut Context<AcpModelSelector>,
 34) -> AcpModelSelector {
 35    let delegate =
 36        AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
 37    Picker::list(delegate, window, cx)
 38        .show_scrollbar(true)
 39        .width(rems(20.))
 40        .max_height(Some(rems(20.).into()))
 41}
 42
 43enum AcpModelPickerEntry {
 44    Separator(SharedString),
 45    Model(AgentModelInfo, bool),
 46}
 47
 48pub struct AcpModelPickerDelegate {
 49    selector: Rc<dyn AgentModelSelector>,
 50    agent_server: Rc<dyn AgentServer>,
 51    fs: Arc<dyn Fs>,
 52    filtered_entries: Vec<AcpModelPickerEntry>,
 53    models: Option<AgentModelList>,
 54    selected_index: usize,
 55    selected_description: Option<(usize, SharedString, bool)>,
 56    selected_model: Option<AgentModelInfo>,
 57    _refresh_models_task: Task<()>,
 58    focus_handle: FocusHandle,
 59}
 60
 61impl AcpModelPickerDelegate {
 62    fn new(
 63        selector: Rc<dyn AgentModelSelector>,
 64        agent_server: Rc<dyn AgentServer>,
 65        fs: Arc<dyn Fs>,
 66        focus_handle: FocusHandle,
 67        window: &mut Window,
 68        cx: &mut Context<AcpModelSelector>,
 69    ) -> Self {
 70        let rx = selector.watch(cx);
 71        let refresh_models_task = {
 72            cx.spawn_in(window, {
 73                async move |this, cx| {
 74                    async fn refresh(
 75                        this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 76                        cx: &mut AsyncWindowContext,
 77                    ) -> Result<()> {
 78                        let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 79                            (
 80                                this.delegate.selector.list_models(cx),
 81                                this.delegate.selector.selected_model(cx),
 82                            )
 83                        })?;
 84
 85                        let (models, selected_model) =
 86                            futures::join!(models_task, selected_model_task);
 87
 88                        this.update_in(cx, |this, window, cx| {
 89                            this.delegate.models = models.ok();
 90                            this.delegate.selected_model = selected_model.ok();
 91                            this.refresh(window, cx)
 92                        })
 93                    }
 94
 95                    refresh(&this, cx).await.log_err();
 96                    if let Some(mut rx) = rx {
 97                        while let Ok(()) = rx.recv().await {
 98                            refresh(&this, cx).await.log_err();
 99                        }
100                    }
101                }
102            })
103        };
104
105        Self {
106            selector,
107            agent_server,
108            fs,
109            filtered_entries: Vec::new(),
110            models: None,
111            selected_model: None,
112            selected_index: 0,
113            selected_description: None,
114            _refresh_models_task: refresh_models_task,
115            focus_handle,
116        }
117    }
118
119    pub fn active_model(&self) -> Option<&AgentModelInfo> {
120        self.selected_model.as_ref()
121    }
122}
123
124impl PickerDelegate for AcpModelPickerDelegate {
125    type ListItem = AnyElement;
126
127    fn match_count(&self) -> usize {
128        self.filtered_entries.len()
129    }
130
131    fn selected_index(&self) -> usize {
132        self.selected_index
133    }
134
135    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
136        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
137        cx.notify();
138    }
139
140    fn can_select(
141        &mut self,
142        ix: usize,
143        _window: &mut Window,
144        _cx: &mut Context<Picker<Self>>,
145    ) -> bool {
146        match self.filtered_entries.get(ix) {
147            Some(AcpModelPickerEntry::Model(_, _)) => true,
148            Some(AcpModelPickerEntry::Separator(_)) | None => false,
149        }
150    }
151
152    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
153        "Select a model…".into()
154    }
155
156    fn update_matches(
157        &mut self,
158        query: String,
159        window: &mut Window,
160        cx: &mut Context<Picker<Self>>,
161    ) -> Task<()> {
162        let favorites = if self.selector.supports_favorites() {
163            Arc::new(AgentSettings::get_global(cx).favorite_model_ids())
164        } else {
165            Default::default()
166        };
167
168        cx.spawn_in(window, async move |this, cx| {
169            let filtered_models = match this
170                .read_with(cx, |this, cx| {
171                    this.delegate.models.clone().map(move |models| {
172                        fuzzy_search(models, query, cx.background_executor().clone())
173                    })
174                })
175                .ok()
176                .flatten()
177            {
178                Some(task) => task.await,
179                None => AgentModelList::Flat(vec![]),
180            };
181
182            this.update_in(cx, |this, window, cx| {
183                this.delegate.filtered_entries =
184                    info_list_to_picker_entries(filtered_models, favorites);
185                // Finds the currently selected model in the list
186                let new_index = this
187                    .delegate
188                    .selected_model
189                    .as_ref()
190                    .and_then(|selected| {
191                        this.delegate.filtered_entries.iter().position(|entry| {
192                            if let AcpModelPickerEntry::Model(model_info, _) = entry {
193                                model_info.id == selected.id
194                            } else {
195                                false
196                            }
197                        })
198                    })
199                    .unwrap_or(0);
200                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
201                cx.notify();
202            })
203            .ok();
204        })
205    }
206
207    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
208        if let Some(AcpModelPickerEntry::Model(model_info, _)) =
209            self.filtered_entries.get(self.selected_index)
210        {
211            if window.modifiers().secondary() {
212                let default_model = self.agent_server.default_model(cx);
213                let is_default = default_model.as_ref() == Some(&model_info.id);
214
215                self.agent_server.set_default_model(
216                    if is_default {
217                        None
218                    } else {
219                        Some(model_info.id.clone())
220                    },
221                    self.fs.clone(),
222                    cx,
223                );
224            }
225
226            self.selector
227                .select_model(model_info.id.clone(), cx)
228                .detach_and_log_err(cx);
229            self.selected_model = Some(model_info.clone());
230            let current_index = self.selected_index;
231            self.set_selected_index(current_index, window, cx);
232
233            cx.emit(DismissEvent);
234        }
235    }
236
237    fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
238        cx.defer_in(window, |picker, window, cx| {
239            picker.set_query("", window, cx);
240        });
241    }
242
243    fn render_match(
244        &self,
245        ix: usize,
246        selected: bool,
247        _: &mut Window,
248        cx: &mut Context<Picker<Self>>,
249    ) -> Option<Self::ListItem> {
250        match self.filtered_entries.get(ix)? {
251            AcpModelPickerEntry::Separator(title) => {
252                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
253            }
254            AcpModelPickerEntry::Model(model_info, is_favorite) => {
255                let is_selected = Some(model_info) == self.selected_model.as_ref();
256                let default_model = self.agent_server.default_model(cx);
257                let is_default = default_model.as_ref() == Some(&model_info.id);
258
259                let supports_favorites = self.selector.supports_favorites();
260
261                let is_favorite = *is_favorite;
262                let handle_action_click = {
263                    let model_id = model_info.id.clone();
264                    let fs = self.fs.clone();
265
266                    move |cx: &App| {
267                        crate::favorite_models::toggle_model_id_in_settings(
268                            model_id.clone(),
269                            !is_favorite,
270                            fs.clone(),
271                            cx,
272                        );
273                    }
274                };
275
276                Some(
277                    div()
278                        .id(("model-picker-menu-child", ix))
279                        .when_some(model_info.description.clone(), |this, description| {
280                            this.on_hover(cx.listener(move |menu, hovered, _, cx| {
281                                if *hovered {
282                                    menu.delegate.selected_description =
283                                        Some((ix, description.clone(), is_default));
284                                } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
285                                    menu.delegate.selected_description = None;
286                                }
287                                cx.notify();
288                            }))
289                        })
290                        .child(
291                            ModelSelectorListItem::new(ix, model_info.name.clone())
292                                .when_some(model_info.icon, |this, icon| this.icon(icon))
293                                .is_selected(is_selected)
294                                .is_focused(selected)
295                                .when(supports_favorites, |this| {
296                                    this.is_favorite(is_favorite)
297                                        .on_toggle_favorite(handle_action_click)
298                                }),
299                        )
300                        .into_any_element(),
301                )
302            }
303        }
304    }
305
306    fn documentation_aside(
307        &self,
308        _window: &mut Window,
309        _cx: &mut Context<Picker<Self>>,
310    ) -> Option<ui::DocumentationAside> {
311        self.selected_description
312            .as_ref()
313            .map(|(_, description, is_default)| {
314                let description = description.clone();
315                let is_default = *is_default;
316
317                DocumentationAside::new(
318                    DocumentationSide::Left,
319                    DocumentationEdge::Top,
320                    Rc::new(move |_| {
321                        v_flex()
322                            .gap_1()
323                            .child(Label::new(description.clone()))
324                            .child(HoldForDefault::new(is_default))
325                            .into_any_element()
326                    }),
327                )
328            })
329    }
330
331    fn render_footer(
332        &self,
333        _window: &mut Window,
334        _cx: &mut Context<Picker<Self>>,
335    ) -> Option<AnyElement> {
336        let focus_handle = self.focus_handle.clone();
337
338        if !self.selector.should_render_footer() {
339            return None;
340        }
341
342        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
343    }
344}
345
346fn info_list_to_picker_entries(
347    model_list: AgentModelList,
348    favorites: Arc<HashSet<ModelId>>,
349) -> Vec<AcpModelPickerEntry> {
350    let mut entries = Vec::new();
351
352    let all_models: Vec<_> = match &model_list {
353        AgentModelList::Flat(list) => list.iter().collect(),
354        AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
355    };
356
357    let favorite_models: Vec<_> = all_models
358        .iter()
359        .filter(|m| favorites.contains(&m.id))
360        .unique_by(|m| &m.id)
361        .collect();
362
363    let has_favorites = !favorite_models.is_empty();
364    if has_favorites {
365        entries.push(AcpModelPickerEntry::Separator("Favorite".into()));
366        for model in favorite_models {
367            entries.push(AcpModelPickerEntry::Model((*model).clone(), true));
368        }
369    }
370
371    match model_list {
372        AgentModelList::Flat(list) => {
373            if has_favorites {
374                entries.push(AcpModelPickerEntry::Separator("All".into()));
375            }
376            for model in list {
377                let is_favorite = favorites.contains(&model.id);
378                entries.push(AcpModelPickerEntry::Model(model, is_favorite));
379            }
380        }
381        AgentModelList::Grouped(index_map) => {
382            for (group_name, models) in index_map {
383                entries.push(AcpModelPickerEntry::Separator(group_name.0));
384                for model in models {
385                    let is_favorite = favorites.contains(&model.id);
386                    entries.push(AcpModelPickerEntry::Model(model, is_favorite));
387                }
388            }
389        }
390    }
391
392    entries
393}
394
395async fn fuzzy_search(
396    model_list: AgentModelList,
397    query: String,
398    executor: BackgroundExecutor,
399) -> AgentModelList {
400    async fn fuzzy_search_list(
401        model_list: Vec<AgentModelInfo>,
402        query: &str,
403        executor: BackgroundExecutor,
404    ) -> Vec<AgentModelInfo> {
405        let candidates = model_list
406            .iter()
407            .enumerate()
408            .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
409            .collect::<Vec<_>>();
410        let mut matches = match_strings(
411            &candidates,
412            query,
413            false,
414            true,
415            100,
416            &Default::default(),
417            executor,
418        )
419        .await;
420
421        matches.sort_unstable_by_key(|mat| {
422            let candidate = &candidates[mat.candidate_id];
423            (Reverse(OrderedFloat(mat.score)), candidate.id)
424        });
425
426        matches
427            .into_iter()
428            .map(|mat| model_list[mat.candidate_id].clone())
429            .collect()
430    }
431
432    match model_list {
433        AgentModelList::Flat(model_list) => {
434            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
435        }
436        AgentModelList::Grouped(index_map) => {
437            let groups =
438                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
439                    fuzzy_search_list(models, &query, executor.clone())
440                        .map(|results| (group_name, results))
441                }))
442                .await;
443            AgentModelList::Grouped(IndexMap::from_iter(
444                groups
445                    .into_iter()
446                    .filter(|(_, results)| !results.is_empty()),
447            ))
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use agent_client_protocol as acp;
455    use gpui::TestAppContext;
456
457    use super::*;
458
459    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
460        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
461            |(group, models)| {
462                (
463                    acp_thread::AgentModelGroupName(group.to_string().into()),
464                    models
465                        .into_iter()
466                        .map(|model| acp_thread::AgentModelInfo {
467                            id: acp::ModelId::new(model.to_string()),
468                            name: model.to_string().into(),
469                            description: None,
470                            icon: None,
471                        })
472                        .collect::<Vec<_>>(),
473                )
474            },
475        )))
476    }
477
478    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
479        let AgentModelList::Grouped(groups) = result else {
480            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
481        };
482
483        assert_eq!(
484            groups.len(),
485            expected.len(),
486            "Number of groups doesn't match"
487        );
488
489        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
490            let (actual_group, actual_models) = groups.get_index(i).unwrap();
491            assert_eq!(
492                actual_group.0.as_ref(),
493                *expected_group,
494                "Group at position {} doesn't match expected group",
495                i
496            );
497            assert_eq!(
498                actual_models.len(),
499                expected_models.len(),
500                "Number of models in group {} doesn't match",
501                expected_group
502            );
503
504            for (j, expected_model_name) in expected_models.iter().enumerate() {
505                assert_eq!(
506                    actual_models[j].name, *expected_model_name,
507                    "Model at position {} in group {} doesn't match expected model",
508                    j, expected_group
509                );
510            }
511        }
512    }
513
514    fn create_favorites(models: Vec<&str>) -> Arc<HashSet<ModelId>> {
515        Arc::new(
516            models
517                .into_iter()
518                .map(|m| ModelId::new(m.to_string()))
519                .collect(),
520        )
521    }
522
523    fn get_entry_model_ids(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
524        entries
525            .iter()
526            .filter_map(|entry| match entry {
527                AcpModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()),
528                _ => None,
529            })
530            .collect()
531    }
532
533    fn get_entry_labels(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
534        entries
535            .iter()
536            .map(|entry| match entry {
537                AcpModelPickerEntry::Model(info, _) => info.id.0.as_ref(),
538                AcpModelPickerEntry::Separator(s) => &s,
539            })
540            .collect()
541    }
542
543    #[gpui::test]
544    fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
545        let models = create_model_list(vec![
546            ("zed", vec!["zed/claude", "zed/gemini"]),
547            ("openai", vec!["openai/gpt-5"]),
548        ]);
549        let favorites = create_favorites(vec!["zed/gemini"]);
550
551        let entries = info_list_to_picker_entries(models, favorites);
552
553        assert!(matches!(
554            entries.first(),
555            Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
556        ));
557
558        let model_ids = get_entry_model_ids(&entries);
559        assert_eq!(model_ids[0], "zed/gemini");
560    }
561
562    #[gpui::test]
563    fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
564        let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]);
565        let favorites = create_favorites(vec![]);
566
567        let entries = info_list_to_picker_entries(models, favorites);
568
569        assert!(matches!(
570            entries.first(),
571            Some(AcpModelPickerEntry::Separator(s)) if s == "zed"
572        ));
573    }
574
575    #[gpui::test]
576    fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
577        let models = create_model_list(vec![
578            ("zed", vec!["zed/claude", "zed/gemini"]),
579            ("openai", vec!["openai/gpt-5"]),
580        ]);
581        let favorites = create_favorites(vec!["zed/claude"]);
582
583        let entries = info_list_to_picker_entries(models, favorites);
584
585        for entry in &entries {
586            if let AcpModelPickerEntry::Model(info, is_favorite) = entry {
587                if info.id.0.as_ref() == "zed/claude" {
588                    assert!(is_favorite, "zed/claude should be a favorite");
589                } else {
590                    assert!(!is_favorite, "{} should not be a favorite", info.id.0);
591                }
592            }
593        }
594    }
595
596    #[gpui::test]
597    fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) {
598        let models = create_model_list(vec![
599            ("zed", vec!["zed/claude", "zed/gemini"]),
600            ("openai", vec!["openai/gpt-5", "openai/gpt-4"]),
601        ]);
602        let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]);
603
604        let entries = info_list_to_picker_entries(models, favorites);
605        let model_ids = get_entry_model_ids(&entries);
606
607        assert_eq!(model_ids[0], "zed/gemini");
608        assert_eq!(model_ids[1], "openai/gpt-5");
609
610        assert!(model_ids[2..].contains(&"zed/gemini"));
611        assert!(model_ids[2..].contains(&"openai/gpt-5"));
612    }
613
614    #[gpui::test]
615    fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) {
616        let models = create_model_list(vec![
617            ("Recommended", vec!["zed/claude", "anthropic/claude"]),
618            ("Zed", vec!["zed/claude", "zed/gpt-5"]),
619            ("Antropic", vec!["anthropic/claude"]),
620            ("OpenAI", vec!["openai/gpt-5"]),
621        ]);
622
623        let favorites = create_favorites(vec!["zed/claude"]);
624
625        let entries = info_list_to_picker_entries(models, favorites);
626        let labels = get_entry_labels(&entries);
627
628        assert_eq!(
629            labels,
630            vec![
631                "Favorite",
632                "zed/claude",
633                "Recommended",
634                "zed/claude",
635                "anthropic/claude",
636                "Zed",
637                "zed/claude",
638                "zed/gpt-5",
639                "Antropic",
640                "anthropic/claude",
641                "OpenAI",
642                "openai/gpt-5"
643            ]
644        );
645    }
646
647    #[gpui::test]
648    fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) {
649        let models = AgentModelList::Flat(vec![
650            acp_thread::AgentModelInfo {
651                id: acp::ModelId::new("zed/claude".to_string()),
652                name: "Claude".into(),
653                description: None,
654                icon: None,
655            },
656            acp_thread::AgentModelInfo {
657                id: acp::ModelId::new("zed/gemini".to_string()),
658                name: "Gemini".into(),
659                description: None,
660                icon: None,
661            },
662        ]);
663        let favorites = create_favorites(vec!["zed/gemini"]);
664
665        let entries = info_list_to_picker_entries(models, favorites);
666
667        assert!(matches!(
668            entries.first(),
669            Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
670        ));
671
672        assert!(entries.iter().any(|e| matches!(
673            e,
674            AcpModelPickerEntry::Separator(s) if s == "All"
675        )));
676    }
677
678    #[gpui::test]
679    async fn test_fuzzy_match(cx: &mut TestAppContext) {
680        let models = create_model_list(vec![
681            (
682                "zed",
683                vec![
684                    "Claude 3.7 Sonnet",
685                    "Claude 3.7 Sonnet Thinking",
686                    "gpt-4.1",
687                    "gpt-4.1-nano",
688                ],
689            ),
690            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
691            ("ollama", vec!["mistral", "deepseek"]),
692        ]);
693
694        // Results should preserve models order whenever possible.
695        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
696        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
697        // so it should appear first in the results.
698        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
699        assert_models_eq(
700            results,
701            vec![
702                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
703                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
704            ],
705        );
706
707        // Fuzzy search
708        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
709        assert_models_eq(
710            results,
711            vec![
712                ("zed", vec!["gpt-4.1-nano"]),
713                ("openai", vec!["gpt-4.1-nano"]),
714            ],
715        );
716    }
717}