model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
  4use anyhow::Result;
  5use collections::IndexMap;
  6use futures::FutureExt;
  7use fuzzy::{StringMatchCandidate, match_strings};
  8use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
  9use ordered_float::OrderedFloat;
 10use picker::{Picker, PickerDelegate};
 11use ui::{
 12    AnyElement, App, Context, DocumentationAside, DocumentationEdge, DocumentationSide,
 13    IntoElement, ListItem, ListItemSpacing, SharedString, Window, prelude::*, rems,
 14};
 15use util::ResultExt;
 16
 17pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 18
 19pub fn acp_model_selector(
 20    selector: Rc<dyn AgentModelSelector>,
 21    window: &mut Window,
 22    cx: &mut Context<AcpModelSelector>,
 23) -> AcpModelSelector {
 24    let delegate = AcpModelPickerDelegate::new(selector, window, cx);
 25    Picker::list(delegate, window, cx)
 26        .show_scrollbar(true)
 27        .width(rems(20.))
 28        .max_height(Some(rems(20.).into()))
 29}
 30
 31enum AcpModelPickerEntry {
 32    Separator(SharedString),
 33    Model(AgentModelInfo),
 34}
 35
 36pub struct AcpModelPickerDelegate {
 37    selector: Rc<dyn AgentModelSelector>,
 38    filtered_entries: Vec<AcpModelPickerEntry>,
 39    models: Option<AgentModelList>,
 40    selected_index: usize,
 41    selected_description: Option<(usize, SharedString)>,
 42    selected_model: Option<AgentModelInfo>,
 43    _refresh_models_task: Task<()>,
 44}
 45
 46impl AcpModelPickerDelegate {
 47    fn new(
 48        selector: Rc<dyn AgentModelSelector>,
 49        window: &mut Window,
 50        cx: &mut Context<AcpModelSelector>,
 51    ) -> Self {
 52        let rx = selector.watch(cx);
 53        let refresh_models_task = {
 54            cx.spawn_in(window, {
 55                async move |this, cx| {
 56                    async fn refresh(
 57                        this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 58                        cx: &mut AsyncWindowContext,
 59                    ) -> Result<()> {
 60                        let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 61                            (
 62                                this.delegate.selector.list_models(cx),
 63                                this.delegate.selector.selected_model(cx),
 64                            )
 65                        })?;
 66
 67                        let (models, selected_model) =
 68                            futures::join!(models_task, selected_model_task);
 69
 70                        this.update_in(cx, |this, window, cx| {
 71                            this.delegate.models = models.ok();
 72                            this.delegate.selected_model = selected_model.ok();
 73                            this.refresh(window, cx)
 74                        })
 75                    }
 76
 77                    refresh(&this, cx).await.log_err();
 78                    if let Some(mut rx) = rx {
 79                        while let Ok(()) = rx.recv().await {
 80                            refresh(&this, cx).await.log_err();
 81                        }
 82                    }
 83                }
 84            })
 85        };
 86
 87        Self {
 88            selector,
 89            filtered_entries: Vec::new(),
 90            models: None,
 91            selected_model: None,
 92            selected_index: 0,
 93            selected_description: None,
 94            _refresh_models_task: refresh_models_task,
 95        }
 96    }
 97
 98    pub fn active_model(&self) -> Option<&AgentModelInfo> {
 99        self.selected_model.as_ref()
100    }
101}
102
103impl PickerDelegate for AcpModelPickerDelegate {
104    type ListItem = AnyElement;
105
106    fn match_count(&self) -> usize {
107        self.filtered_entries.len()
108    }
109
110    fn selected_index(&self) -> usize {
111        self.selected_index
112    }
113
114    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
115        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
116        cx.notify();
117    }
118
119    fn can_select(
120        &mut self,
121        ix: usize,
122        _window: &mut Window,
123        _cx: &mut Context<Picker<Self>>,
124    ) -> bool {
125        match self.filtered_entries.get(ix) {
126            Some(AcpModelPickerEntry::Model(_)) => true,
127            Some(AcpModelPickerEntry::Separator(_)) | None => false,
128        }
129    }
130
131    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
132        "Select a model…".into()
133    }
134
135    fn update_matches(
136        &mut self,
137        query: String,
138        window: &mut Window,
139        cx: &mut Context<Picker<Self>>,
140    ) -> Task<()> {
141        cx.spawn_in(window, async move |this, cx| {
142            let filtered_models = match this
143                .read_with(cx, |this, cx| {
144                    this.delegate.models.clone().map(move |models| {
145                        fuzzy_search(models, query, cx.background_executor().clone())
146                    })
147                })
148                .ok()
149                .flatten()
150            {
151                Some(task) => task.await,
152                None => AgentModelList::Flat(vec![]),
153            };
154
155            this.update_in(cx, |this, window, cx| {
156                this.delegate.filtered_entries =
157                    info_list_to_picker_entries(filtered_models).collect();
158                // Finds the currently selected model in the list
159                let new_index = this
160                    .delegate
161                    .selected_model
162                    .as_ref()
163                    .and_then(|selected| {
164                        this.delegate.filtered_entries.iter().position(|entry| {
165                            if let AcpModelPickerEntry::Model(model_info) = entry {
166                                model_info.id == selected.id
167                            } else {
168                                false
169                            }
170                        })
171                    })
172                    .unwrap_or(0);
173                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
174                cx.notify();
175            })
176            .ok();
177        })
178    }
179
180    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
181        if let Some(AcpModelPickerEntry::Model(model_info)) =
182            self.filtered_entries.get(self.selected_index)
183        {
184            self.selector
185                .select_model(model_info.id.clone(), cx)
186                .detach_and_log_err(cx);
187            self.selected_model = Some(model_info.clone());
188            let current_index = self.selected_index;
189            self.set_selected_index(current_index, window, cx);
190
191            cx.emit(DismissEvent);
192        }
193    }
194
195    fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
196        cx.defer_in(window, |picker, window, cx| {
197            picker.set_query("", window, cx);
198        });
199    }
200
201    fn render_match(
202        &self,
203        ix: usize,
204        selected: bool,
205        _: &mut Window,
206        cx: &mut Context<Picker<Self>>,
207    ) -> Option<Self::ListItem> {
208        match self.filtered_entries.get(ix)? {
209            AcpModelPickerEntry::Separator(title) => Some(
210                div()
211                    .px_2()
212                    .pb_1()
213                    .when(ix > 1, |this| {
214                        this.mt_1()
215                            .pt_2()
216                            .border_t_1()
217                            .border_color(cx.theme().colors().border_variant)
218                    })
219                    .child(
220                        Label::new(title)
221                            .size(LabelSize::XSmall)
222                            .color(Color::Muted),
223                    )
224                    .into_any_element(),
225            ),
226            AcpModelPickerEntry::Model(model_info) => {
227                let is_selected = Some(model_info) == self.selected_model.as_ref();
228
229                let model_icon_color = if is_selected {
230                    Color::Accent
231                } else {
232                    Color::Muted
233                };
234
235                Some(
236                    div()
237                        .id(("model-picker-menu-child", ix))
238                        .when_some(model_info.description.clone(), |this, description| {
239                            this
240                                .on_hover(cx.listener(move |menu, hovered, _, cx| {
241                                    if *hovered {
242                                        menu.delegate.selected_description = Some((ix, description.clone()));
243                                    } else if matches!(menu.delegate.selected_description, Some((id, _)) if id == ix) {
244                                        menu.delegate.selected_description = None;
245                                    }
246                                    cx.notify();
247                                }))
248                        })
249                        .child(
250                            ListItem::new(ix)
251                                .inset(true)
252                                .spacing(ListItemSpacing::Sparse)
253                                .toggle_state(selected)
254                                .start_slot::<Icon>(model_info.icon.map(|icon| {
255                                    Icon::new(icon)
256                                        .color(model_icon_color)
257                                        .size(IconSize::Small)
258                                }))
259                                .child(
260                                    h_flex()
261                                        .w_full()
262                                        .pl_0p5()
263                                        .gap_1p5()
264                                        .w(px(240.))
265                                        .child(Label::new(model_info.name.clone()).truncate()),
266                                )
267                                .end_slot(div().pr_3().when(is_selected, |this| {
268                                    this.child(
269                                        Icon::new(IconName::Check)
270                                            .color(Color::Accent)
271                                            .size(IconSize::Small),
272                                    )
273                                })),
274                        )
275                        .into_any_element()
276                )
277            }
278        }
279    }
280
281    fn render_footer(
282        &self,
283        _: &mut Window,
284        cx: &mut Context<Picker<Self>>,
285    ) -> Option<gpui::AnyElement> {
286        Some(
287            h_flex()
288                .w_full()
289                .border_t_1()
290                .border_color(cx.theme().colors().border_variant)
291                .p_1()
292                .gap_4()
293                .justify_between()
294                .child(
295                    Button::new("configure", "Configure")
296                        .icon(IconName::Settings)
297                        .icon_size(IconSize::Small)
298                        .icon_color(Color::Muted)
299                        .icon_position(IconPosition::Start)
300                        .on_click(|_, window, cx| {
301                            window.dispatch_action(
302                                zed_actions::agent::OpenSettings.boxed_clone(),
303                                cx,
304                            );
305                        }),
306                )
307                .into_any(),
308        )
309    }
310
311    fn documentation_aside(
312        &self,
313        _window: &mut Window,
314        _cx: &mut Context<Picker<Self>>,
315    ) -> Option<ui::DocumentationAside> {
316        self.selected_description.as_ref().map(|(_, description)| {
317            let description = description.clone();
318            DocumentationAside::new(
319                DocumentationSide::Left,
320                DocumentationEdge::Bottom,
321                Rc::new(move |_| Label::new(description.clone()).into_any_element()),
322            )
323        })
324    }
325}
326
327fn info_list_to_picker_entries(
328    model_list: AgentModelList,
329) -> impl Iterator<Item = AcpModelPickerEntry> {
330    match model_list {
331        AgentModelList::Flat(list) => {
332            itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
333        }
334        AgentModelList::Grouped(index_map) => {
335            itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
336                std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
337                    .chain(models.into_iter().map(AcpModelPickerEntry::Model))
338            }))
339        }
340    }
341}
342
343async fn fuzzy_search(
344    model_list: AgentModelList,
345    query: String,
346    executor: BackgroundExecutor,
347) -> AgentModelList {
348    async fn fuzzy_search_list(
349        model_list: Vec<AgentModelInfo>,
350        query: &str,
351        executor: BackgroundExecutor,
352    ) -> Vec<AgentModelInfo> {
353        let candidates = model_list
354            .iter()
355            .enumerate()
356            .map(|(ix, model)| {
357                StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
358            })
359            .collect::<Vec<_>>();
360        let mut matches = match_strings(
361            &candidates,
362            query,
363            false,
364            true,
365            100,
366            &Default::default(),
367            executor,
368        )
369        .await;
370
371        matches.sort_unstable_by_key(|mat| {
372            let candidate = &candidates[mat.candidate_id];
373            (Reverse(OrderedFloat(mat.score)), candidate.id)
374        });
375
376        matches
377            .into_iter()
378            .map(|mat| model_list[mat.candidate_id].clone())
379            .collect()
380    }
381
382    match model_list {
383        AgentModelList::Flat(model_list) => {
384            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
385        }
386        AgentModelList::Grouped(index_map) => {
387            let groups =
388                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
389                    fuzzy_search_list(models, &query, executor.clone())
390                        .map(|results| (group_name, results))
391                }))
392                .await;
393            AgentModelList::Grouped(IndexMap::from_iter(
394                groups
395                    .into_iter()
396                    .filter(|(_, results)| !results.is_empty()),
397            ))
398        }
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use agent_client_protocol as acp;
405    use gpui::TestAppContext;
406
407    use super::*;
408
409    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
410        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
411            |(group, models)| {
412                (
413                    acp_thread::AgentModelGroupName(group.to_string().into()),
414                    models
415                        .into_iter()
416                        .map(|model| acp_thread::AgentModelInfo {
417                            id: acp::ModelId(model.to_string().into()),
418                            name: model.to_string().into(),
419                            description: None,
420                            icon: None,
421                        })
422                        .collect::<Vec<_>>(),
423                )
424            },
425        )))
426    }
427
428    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
429        let AgentModelList::Grouped(groups) = result else {
430            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
431        };
432
433        assert_eq!(
434            groups.len(),
435            expected.len(),
436            "Number of groups doesn't match"
437        );
438
439        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
440            let (actual_group, actual_models) = groups.get_index(i).unwrap();
441            assert_eq!(
442                actual_group.0.as_ref(),
443                *expected_group,
444                "Group at position {} doesn't match expected group",
445                i
446            );
447            assert_eq!(
448                actual_models.len(),
449                expected_models.len(),
450                "Number of models in group {} doesn't match",
451                expected_group
452            );
453
454            for (j, expected_model_name) in expected_models.iter().enumerate() {
455                assert_eq!(
456                    actual_models[j].name, *expected_model_name,
457                    "Model at position {} in group {} doesn't match expected model",
458                    j, expected_group
459                );
460            }
461        }
462    }
463
464    #[gpui::test]
465    async fn test_fuzzy_match(cx: &mut TestAppContext) {
466        let models = create_model_list(vec![
467            (
468                "zed",
469                vec![
470                    "Claude 3.7 Sonnet",
471                    "Claude 3.7 Sonnet Thinking",
472                    "gpt-4.1",
473                    "gpt-4.1-nano",
474                ],
475            ),
476            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
477            ("ollama", vec!["mistral", "deepseek"]),
478        ]);
479
480        // Results should preserve models order whenever possible.
481        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
482        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
483        // so it should appear first in the results.
484        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
485        assert_models_eq(
486            results,
487            vec![
488                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
489                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
490            ],
491        );
492
493        // Fuzzy search
494        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
495        assert_models_eq(
496            results,
497            vec![
498                ("zed", vec!["gpt-4.1-nano"]),
499                ("openai", vec!["gpt-4.1-nano"]),
500            ],
501        );
502    }
503}