model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_servers::AgentServer;
  5use anyhow::Result;
  6use collections::IndexMap;
  7use fs::Fs;
  8use futures::FutureExt;
  9use fuzzy::{StringMatchCandidate, match_strings};
 10use gpui::{AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
 11use ordered_float::OrderedFloat;
 12use picker::{Picker, PickerDelegate};
 13use ui::{
 14    DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, ListItem,
 15    ListItemSpacing, prelude::*,
 16};
 17use util::ResultExt;
 18
 19use crate::ui::HoldForDefault;
 20
 21pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 22
 23pub fn acp_model_selector(
 24    selector: Rc<dyn AgentModelSelector>,
 25    agent_server: Rc<dyn AgentServer>,
 26    fs: Arc<dyn Fs>,
 27    window: &mut Window,
 28    cx: &mut Context<AcpModelSelector>,
 29) -> AcpModelSelector {
 30    let delegate = AcpModelPickerDelegate::new(selector, agent_server, fs, window, cx);
 31    Picker::list(delegate, window, cx)
 32        .show_scrollbar(true)
 33        .width(rems(20.))
 34        .max_height(Some(rems(20.).into()))
 35}
 36
 37enum AcpModelPickerEntry {
 38    Separator(SharedString),
 39    Model(AgentModelInfo),
 40}
 41
 42pub struct AcpModelPickerDelegate {
 43    selector: Rc<dyn AgentModelSelector>,
 44    agent_server: Rc<dyn AgentServer>,
 45    fs: Arc<dyn Fs>,
 46    filtered_entries: Vec<AcpModelPickerEntry>,
 47    models: Option<AgentModelList>,
 48    selected_index: usize,
 49    selected_description: Option<(usize, SharedString, bool)>,
 50    selected_model: Option<AgentModelInfo>,
 51    _refresh_models_task: Task<()>,
 52}
 53
 54impl AcpModelPickerDelegate {
 55    fn new(
 56        selector: Rc<dyn AgentModelSelector>,
 57        agent_server: Rc<dyn AgentServer>,
 58        fs: Arc<dyn Fs>,
 59        window: &mut Window,
 60        cx: &mut Context<AcpModelSelector>,
 61    ) -> Self {
 62        let rx = selector.watch(cx);
 63        let refresh_models_task = {
 64            cx.spawn_in(window, {
 65                async move |this, cx| {
 66                    async fn refresh(
 67                        this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 68                        cx: &mut AsyncWindowContext,
 69                    ) -> Result<()> {
 70                        let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 71                            (
 72                                this.delegate.selector.list_models(cx),
 73                                this.delegate.selector.selected_model(cx),
 74                            )
 75                        })?;
 76
 77                        let (models, selected_model) =
 78                            futures::join!(models_task, selected_model_task);
 79
 80                        this.update_in(cx, |this, window, cx| {
 81                            this.delegate.models = models.ok();
 82                            this.delegate.selected_model = selected_model.ok();
 83                            this.refresh(window, cx)
 84                        })
 85                    }
 86
 87                    refresh(&this, cx).await.log_err();
 88                    if let Some(mut rx) = rx {
 89                        while let Ok(()) = rx.recv().await {
 90                            refresh(&this, cx).await.log_err();
 91                        }
 92                    }
 93                }
 94            })
 95        };
 96
 97        Self {
 98            selector,
 99            agent_server,
100            fs,
101            filtered_entries: Vec::new(),
102            models: None,
103            selected_model: None,
104            selected_index: 0,
105            selected_description: None,
106            _refresh_models_task: refresh_models_task,
107        }
108    }
109
110    pub fn active_model(&self) -> Option<&AgentModelInfo> {
111        self.selected_model.as_ref()
112    }
113}
114
115impl PickerDelegate for AcpModelPickerDelegate {
116    type ListItem = AnyElement;
117
118    fn match_count(&self) -> usize {
119        self.filtered_entries.len()
120    }
121
122    fn selected_index(&self) -> usize {
123        self.selected_index
124    }
125
126    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
127        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
128        cx.notify();
129    }
130
131    fn can_select(
132        &mut self,
133        ix: usize,
134        _window: &mut Window,
135        _cx: &mut Context<Picker<Self>>,
136    ) -> bool {
137        match self.filtered_entries.get(ix) {
138            Some(AcpModelPickerEntry::Model(_)) => true,
139            Some(AcpModelPickerEntry::Separator(_)) | None => false,
140        }
141    }
142
143    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
144        "Select a model…".into()
145    }
146
147    fn update_matches(
148        &mut self,
149        query: String,
150        window: &mut Window,
151        cx: &mut Context<Picker<Self>>,
152    ) -> Task<()> {
153        cx.spawn_in(window, async move |this, cx| {
154            let filtered_models = match this
155                .read_with(cx, |this, cx| {
156                    this.delegate.models.clone().map(move |models| {
157                        fuzzy_search(models, query, cx.background_executor().clone())
158                    })
159                })
160                .ok()
161                .flatten()
162            {
163                Some(task) => task.await,
164                None => AgentModelList::Flat(vec![]),
165            };
166
167            this.update_in(cx, |this, window, cx| {
168                this.delegate.filtered_entries =
169                    info_list_to_picker_entries(filtered_models).collect();
170                // Finds the currently selected model in the list
171                let new_index = this
172                    .delegate
173                    .selected_model
174                    .as_ref()
175                    .and_then(|selected| {
176                        this.delegate.filtered_entries.iter().position(|entry| {
177                            if let AcpModelPickerEntry::Model(model_info) = entry {
178                                model_info.id == selected.id
179                            } else {
180                                false
181                            }
182                        })
183                    })
184                    .unwrap_or(0);
185                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
186                cx.notify();
187            })
188            .ok();
189        })
190    }
191
192    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
193        if let Some(AcpModelPickerEntry::Model(model_info)) =
194            self.filtered_entries.get(self.selected_index)
195        {
196            if window.modifiers().secondary() {
197                let default_model = self.agent_server.default_model(cx);
198                let is_default = default_model.as_ref() == Some(&model_info.id);
199
200                self.agent_server.set_default_model(
201                    if is_default {
202                        None
203                    } else {
204                        Some(model_info.id.clone())
205                    },
206                    self.fs.clone(),
207                    cx,
208                );
209            }
210
211            self.selector
212                .select_model(model_info.id.clone(), cx)
213                .detach_and_log_err(cx);
214            self.selected_model = Some(model_info.clone());
215            let current_index = self.selected_index;
216            self.set_selected_index(current_index, window, cx);
217
218            cx.emit(DismissEvent);
219        }
220    }
221
222    fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
223        cx.defer_in(window, |picker, window, cx| {
224            picker.set_query("", window, cx);
225        });
226    }
227
228    fn render_match(
229        &self,
230        ix: usize,
231        selected: bool,
232        _: &mut Window,
233        cx: &mut Context<Picker<Self>>,
234    ) -> Option<Self::ListItem> {
235        match self.filtered_entries.get(ix)? {
236            AcpModelPickerEntry::Separator(title) => Some(
237                div()
238                    .px_2()
239                    .pb_1()
240                    .when(ix > 1, |this| {
241                        this.mt_1()
242                            .pt_2()
243                            .border_t_1()
244                            .border_color(cx.theme().colors().border_variant)
245                    })
246                    .child(
247                        Label::new(title)
248                            .size(LabelSize::XSmall)
249                            .color(Color::Muted),
250                    )
251                    .into_any_element(),
252            ),
253            AcpModelPickerEntry::Model(model_info) => {
254                let is_selected = Some(model_info) == self.selected_model.as_ref();
255                let default_model = self.agent_server.default_model(cx);
256                let is_default = default_model.as_ref() == Some(&model_info.id);
257
258                let model_icon_color = if is_selected {
259                    Color::Accent
260                } else {
261                    Color::Muted
262                };
263
264                Some(
265                    div()
266                        .id(("model-picker-menu-child", ix))
267                        .when_some(model_info.description.clone(), |this, description| {
268                            this
269                                .on_hover(cx.listener(move |menu, hovered, _, cx| {
270                                    if *hovered {
271                                        menu.delegate.selected_description = Some((ix, description.clone(), is_default));
272                                    } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
273                                        menu.delegate.selected_description = None;
274                                    }
275                                    cx.notify();
276                                }))
277                        })
278                        .child(
279                            ListItem::new(ix)
280                                .inset(true)
281                                .spacing(ListItemSpacing::Sparse)
282                                .toggle_state(selected)
283                                .child(
284                                    h_flex()
285                                        .w_full()
286                                        .gap_1p5()
287                                        .when_some(model_info.icon, |this, icon| {
288                                            this.child(
289                                                Icon::new(icon)
290                                                    .color(model_icon_color)
291                                                    .size(IconSize::Small)
292                                            )
293                                        })
294                                        .child(Label::new(model_info.name.clone()).truncate()),
295                                )
296                                .end_slot(div().pr_3().when(is_selected, |this| {
297                                    this.child(
298                                        Icon::new(IconName::Check)
299                                            .color(Color::Accent)
300                                            .size(IconSize::Small),
301                                    )
302                                })),
303                        )
304                        .into_any_element()
305                )
306            }
307        }
308    }
309
310    fn documentation_aside(
311        &self,
312        _window: &mut Window,
313        _cx: &mut Context<Picker<Self>>,
314    ) -> Option<ui::DocumentationAside> {
315        self.selected_description
316            .as_ref()
317            .map(|(_, description, is_default)| {
318                let description = description.clone();
319                let is_default = *is_default;
320
321                DocumentationAside::new(
322                    DocumentationSide::Left,
323                    DocumentationEdge::Top,
324                    Rc::new(move |_| {
325                        v_flex()
326                            .gap_1()
327                            .child(Label::new(description.clone()))
328                            .child(HoldForDefault::new(is_default))
329                            .into_any_element()
330                    }),
331                )
332            })
333    }
334}
335
336fn info_list_to_picker_entries(
337    model_list: AgentModelList,
338) -> impl Iterator<Item = AcpModelPickerEntry> {
339    match model_list {
340        AgentModelList::Flat(list) => {
341            itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
342        }
343        AgentModelList::Grouped(index_map) => {
344            itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
345                std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
346                    .chain(models.into_iter().map(AcpModelPickerEntry::Model))
347            }))
348        }
349    }
350}
351
352async fn fuzzy_search(
353    model_list: AgentModelList,
354    query: String,
355    executor: BackgroundExecutor,
356) -> AgentModelList {
357    async fn fuzzy_search_list(
358        model_list: Vec<AgentModelInfo>,
359        query: &str,
360        executor: BackgroundExecutor,
361    ) -> Vec<AgentModelInfo> {
362        let candidates = model_list
363            .iter()
364            .enumerate()
365            .map(|(ix, model)| {
366                StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
367            })
368            .collect::<Vec<_>>();
369        let mut matches = match_strings(
370            &candidates,
371            query,
372            false,
373            true,
374            100,
375            &Default::default(),
376            executor,
377        )
378        .await;
379
380        matches.sort_unstable_by_key(|mat| {
381            let candidate = &candidates[mat.candidate_id];
382            (Reverse(OrderedFloat(mat.score)), candidate.id)
383        });
384
385        matches
386            .into_iter()
387            .map(|mat| model_list[mat.candidate_id].clone())
388            .collect()
389    }
390
391    match model_list {
392        AgentModelList::Flat(model_list) => {
393            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
394        }
395        AgentModelList::Grouped(index_map) => {
396            let groups =
397                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
398                    fuzzy_search_list(models, &query, executor.clone())
399                        .map(|results| (group_name, results))
400                }))
401                .await;
402            AgentModelList::Grouped(IndexMap::from_iter(
403                groups
404                    .into_iter()
405                    .filter(|(_, results)| !results.is_empty()),
406            ))
407        }
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use agent_client_protocol as acp;
414    use gpui::TestAppContext;
415
416    use super::*;
417
418    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
419        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
420            |(group, models)| {
421                (
422                    acp_thread::AgentModelGroupName(group.to_string().into()),
423                    models
424                        .into_iter()
425                        .map(|model| acp_thread::AgentModelInfo {
426                            id: acp::ModelId(model.to_string().into()),
427                            name: model.to_string().into(),
428                            description: None,
429                            icon: None,
430                        })
431                        .collect::<Vec<_>>(),
432                )
433            },
434        )))
435    }
436
437    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
438        let AgentModelList::Grouped(groups) = result else {
439            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
440        };
441
442        assert_eq!(
443            groups.len(),
444            expected.len(),
445            "Number of groups doesn't match"
446        );
447
448        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
449            let (actual_group, actual_models) = groups.get_index(i).unwrap();
450            assert_eq!(
451                actual_group.0.as_ref(),
452                *expected_group,
453                "Group at position {} doesn't match expected group",
454                i
455            );
456            assert_eq!(
457                actual_models.len(),
458                expected_models.len(),
459                "Number of models in group {} doesn't match",
460                expected_group
461            );
462
463            for (j, expected_model_name) in expected_models.iter().enumerate() {
464                assert_eq!(
465                    actual_models[j].name, *expected_model_name,
466                    "Model at position {} in group {} doesn't match expected model",
467                    j, expected_group
468                );
469            }
470        }
471    }
472
473    #[gpui::test]
474    async fn test_fuzzy_match(cx: &mut TestAppContext) {
475        let models = create_model_list(vec![
476            (
477                "zed",
478                vec![
479                    "Claude 3.7 Sonnet",
480                    "Claude 3.7 Sonnet Thinking",
481                    "gpt-4.1",
482                    "gpt-4.1-nano",
483                ],
484            ),
485            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
486            ("ollama", vec!["mistral", "deepseek"]),
487        ]);
488
489        // Results should preserve models order whenever possible.
490        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
491        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
492        // so it should appear first in the results.
493        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
494        assert_models_eq(
495            results,
496            vec![
497                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
498                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
499            ],
500        );
501
502        // Fuzzy search
503        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
504        assert_models_eq(
505            results,
506            vec![
507                ("zed", vec!["gpt-4.1-nano"]),
508                ("openai", vec!["gpt-4.1-nano"]),
509            ],
510        );
511    }
512}