model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_client_protocol as acp;
  5use anyhow::Result;
  6use collections::IndexMap;
  7use futures::FutureExt;
  8use fuzzy::{StringMatchCandidate, match_strings};
  9use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
 10use ordered_float::OrderedFloat;
 11use picker::{Picker, PickerDelegate};
 12use ui::{
 13    AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window,
 14    prelude::*, rems,
 15};
 16use util::ResultExt;
 17
 18pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 19
 20pub fn acp_model_selector(
 21    session_id: acp::SessionId,
 22    selector: Rc<dyn AgentModelSelector>,
 23    window: &mut Window,
 24    cx: &mut Context<AcpModelSelector>,
 25) -> AcpModelSelector {
 26    let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx);
 27    Picker::list(delegate, window, cx)
 28        .show_scrollbar(true)
 29        .width(rems(20.))
 30        .max_height(Some(rems(20.).into()))
 31}
 32
 33enum AcpModelPickerEntry {
 34    Separator(SharedString),
 35    Model(AgentModelInfo),
 36}
 37
 38pub struct AcpModelPickerDelegate {
 39    session_id: acp::SessionId,
 40    selector: Rc<dyn AgentModelSelector>,
 41    filtered_entries: Vec<AcpModelPickerEntry>,
 42    models: Option<AgentModelList>,
 43    selected_index: usize,
 44    selected_model: Option<AgentModelInfo>,
 45    _refresh_models_task: Task<()>,
 46}
 47
 48impl AcpModelPickerDelegate {
 49    fn new(
 50        session_id: acp::SessionId,
 51        selector: Rc<dyn AgentModelSelector>,
 52        window: &mut Window,
 53        cx: &mut Context<AcpModelSelector>,
 54    ) -> Self {
 55        let mut rx = selector.watch(cx);
 56        let refresh_models_task = cx.spawn_in(window, {
 57            let session_id = session_id.clone();
 58            async move |this, cx| {
 59                async fn refresh(
 60                    this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 61                    session_id: &acp::SessionId,
 62                    cx: &mut AsyncWindowContext,
 63                ) -> Result<()> {
 64                    let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 65                        (
 66                            this.delegate.selector.list_models(cx),
 67                            this.delegate.selector.selected_model(session_id, cx),
 68                        )
 69                    })?;
 70
 71                    let (models, selected_model) = futures::join!(models_task, selected_model_task);
 72
 73                    this.update_in(cx, |this, window, cx| {
 74                        this.delegate.models = models.ok();
 75                        this.delegate.selected_model = selected_model.ok();
 76                        this.refresh(window, cx)
 77                    })
 78                }
 79
 80                refresh(&this, &session_id, cx).await.log_err();
 81                while let Ok(()) = rx.recv().await {
 82                    refresh(&this, &session_id, cx).await.log_err();
 83                }
 84            }
 85        });
 86
 87        Self {
 88            session_id,
 89            selector,
 90            filtered_entries: Vec::new(),
 91            models: None,
 92            selected_model: None,
 93            selected_index: 0,
 94            _refresh_models_task: refresh_models_task,
 95        }
 96    }
 97
 98    pub fn active_model(&self) -> Option<&AgentModelInfo> {
 99        self.selected_model.as_ref()
100    }
101}
102
103impl PickerDelegate for AcpModelPickerDelegate {
104    type ListItem = AnyElement;
105
106    fn match_count(&self) -> usize {
107        self.filtered_entries.len()
108    }
109
110    fn selected_index(&self) -> usize {
111        self.selected_index
112    }
113
114    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
115        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
116        cx.notify();
117    }
118
119    fn can_select(
120        &mut self,
121        ix: usize,
122        _window: &mut Window,
123        _cx: &mut Context<Picker<Self>>,
124    ) -> bool {
125        match self.filtered_entries.get(ix) {
126            Some(AcpModelPickerEntry::Model(_)) => true,
127            Some(AcpModelPickerEntry::Separator(_)) | None => false,
128        }
129    }
130
131    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
132        "Select a model…".into()
133    }
134
135    fn update_matches(
136        &mut self,
137        query: String,
138        window: &mut Window,
139        cx: &mut Context<Picker<Self>>,
140    ) -> Task<()> {
141        cx.spawn_in(window, async move |this, cx| {
142            let filtered_models = match this
143                .read_with(cx, |this, cx| {
144                    this.delegate.models.clone().map(move |models| {
145                        fuzzy_search(models, query, cx.background_executor().clone())
146                    })
147                })
148                .ok()
149                .flatten()
150            {
151                Some(task) => task.await,
152                None => AgentModelList::Flat(vec![]),
153            };
154
155            this.update_in(cx, |this, window, cx| {
156                this.delegate.filtered_entries =
157                    info_list_to_picker_entries(filtered_models).collect();
158                // Finds the currently selected model in the list
159                let new_index = this
160                    .delegate
161                    .selected_model
162                    .as_ref()
163                    .and_then(|selected| {
164                        this.delegate.filtered_entries.iter().position(|entry| {
165                            if let AcpModelPickerEntry::Model(model_info) = entry {
166                                model_info.id == selected.id
167                            } else {
168                                false
169                            }
170                        })
171                    })
172                    .unwrap_or(0);
173                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
174                cx.notify();
175            })
176            .ok();
177        })
178    }
179
180    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
181        if let Some(AcpModelPickerEntry::Model(model_info)) =
182            self.filtered_entries.get(self.selected_index)
183        {
184            self.selector
185                .select_model(self.session_id.clone(), model_info.id.clone(), cx)
186                .detach_and_log_err(cx);
187            self.selected_model = Some(model_info.clone());
188            let current_index = self.selected_index;
189            self.set_selected_index(current_index, window, cx);
190
191            cx.emit(DismissEvent);
192        }
193    }
194
195    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
196        cx.emit(DismissEvent);
197    }
198
199    fn render_match(
200        &self,
201        ix: usize,
202        selected: bool,
203        _: &mut Window,
204        cx: &mut Context<Picker<Self>>,
205    ) -> Option<Self::ListItem> {
206        match self.filtered_entries.get(ix)? {
207            AcpModelPickerEntry::Separator(title) => Some(
208                div()
209                    .px_2()
210                    .pb_1()
211                    .when(ix > 1, |this| {
212                        this.mt_1()
213                            .pt_2()
214                            .border_t_1()
215                            .border_color(cx.theme().colors().border_variant)
216                    })
217                    .child(
218                        Label::new(title)
219                            .size(LabelSize::XSmall)
220                            .color(Color::Muted),
221                    )
222                    .into_any_element(),
223            ),
224            AcpModelPickerEntry::Model(model_info) => {
225                let is_selected = Some(model_info) == self.selected_model.as_ref();
226
227                let model_icon_color = if is_selected {
228                    Color::Accent
229                } else {
230                    Color::Muted
231                };
232
233                Some(
234                    ListItem::new(ix)
235                        .inset(true)
236                        .spacing(ListItemSpacing::Sparse)
237                        .toggle_state(selected)
238                        .start_slot::<Icon>(model_info.icon.map(|icon| {
239                            Icon::new(icon)
240                                .color(model_icon_color)
241                                .size(IconSize::Small)
242                        }))
243                        .child(
244                            h_flex()
245                                .w_full()
246                                .pl_0p5()
247                                .gap_1p5()
248                                .w(px(240.))
249                                .child(Label::new(model_info.name.clone()).truncate()),
250                        )
251                        .end_slot(div().pr_3().when(is_selected, |this| {
252                            this.child(
253                                Icon::new(IconName::Check)
254                                    .color(Color::Accent)
255                                    .size(IconSize::Small),
256                            )
257                        }))
258                        .into_any_element(),
259                )
260            }
261        }
262    }
263
264    fn render_footer(
265        &self,
266        _: &mut Window,
267        cx: &mut Context<Picker<Self>>,
268    ) -> Option<gpui::AnyElement> {
269        Some(
270            h_flex()
271                .w_full()
272                .border_t_1()
273                .border_color(cx.theme().colors().border_variant)
274                .p_1()
275                .gap_4()
276                .justify_between()
277                .child(
278                    Button::new("configure", "Configure")
279                        .icon(IconName::Settings)
280                        .icon_size(IconSize::Small)
281                        .icon_color(Color::Muted)
282                        .icon_position(IconPosition::Start)
283                        .on_click(|_, window, cx| {
284                            window.dispatch_action(
285                                zed_actions::agent::OpenSettings.boxed_clone(),
286                                cx,
287                            );
288                        }),
289                )
290                .into_any(),
291        )
292    }
293}
294
295fn info_list_to_picker_entries(
296    model_list: AgentModelList,
297) -> impl Iterator<Item = AcpModelPickerEntry> {
298    match model_list {
299        AgentModelList::Flat(list) => {
300            itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
301        }
302        AgentModelList::Grouped(index_map) => {
303            itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
304                std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
305                    .chain(models.into_iter().map(AcpModelPickerEntry::Model))
306            }))
307        }
308    }
309}
310
311async fn fuzzy_search(
312    model_list: AgentModelList,
313    query: String,
314    executor: BackgroundExecutor,
315) -> AgentModelList {
316    async fn fuzzy_search_list(
317        model_list: Vec<AgentModelInfo>,
318        query: &str,
319        executor: BackgroundExecutor,
320    ) -> Vec<AgentModelInfo> {
321        let candidates = model_list
322            .iter()
323            .enumerate()
324            .map(|(ix, model)| {
325                StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
326            })
327            .collect::<Vec<_>>();
328        let mut matches = match_strings(
329            &candidates,
330            query,
331            false,
332            true,
333            100,
334            &Default::default(),
335            executor,
336        )
337        .await;
338
339        matches.sort_unstable_by_key(|mat| {
340            let candidate = &candidates[mat.candidate_id];
341            (Reverse(OrderedFloat(mat.score)), candidate.id)
342        });
343
344        matches
345            .into_iter()
346            .map(|mat| model_list[mat.candidate_id].clone())
347            .collect()
348    }
349
350    match model_list {
351        AgentModelList::Flat(model_list) => {
352            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
353        }
354        AgentModelList::Grouped(index_map) => {
355            let groups =
356                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
357                    fuzzy_search_list(models, &query, executor.clone())
358                        .map(|results| (group_name, results))
359                }))
360                .await;
361            AgentModelList::Grouped(IndexMap::from_iter(
362                groups
363                    .into_iter()
364                    .filter(|(_, results)| !results.is_empty()),
365            ))
366        }
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use gpui::TestAppContext;
373
374    use super::*;
375
376    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
377        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
378            |(group, models)| {
379                (
380                    acp_thread::AgentModelGroupName(group.to_string().into()),
381                    models
382                        .into_iter()
383                        .map(|model| acp_thread::AgentModelInfo {
384                            id: acp_thread::AgentModelId(model.to_string().into()),
385                            name: model.to_string().into(),
386                            icon: None,
387                        })
388                        .collect::<Vec<_>>(),
389                )
390            },
391        )))
392    }
393
394    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
395        let AgentModelList::Grouped(groups) = result else {
396            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
397        };
398
399        assert_eq!(
400            groups.len(),
401            expected.len(),
402            "Number of groups doesn't match"
403        );
404
405        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
406            let (actual_group, actual_models) = groups.get_index(i).unwrap();
407            assert_eq!(
408                actual_group.0.as_ref(),
409                *expected_group,
410                "Group at position {} doesn't match expected group",
411                i
412            );
413            assert_eq!(
414                actual_models.len(),
415                expected_models.len(),
416                "Number of models in group {} doesn't match",
417                expected_group
418            );
419
420            for (j, expected_model_name) in expected_models.iter().enumerate() {
421                assert_eq!(
422                    actual_models[j].name, *expected_model_name,
423                    "Model at position {} in group {} doesn't match expected model",
424                    j, expected_group
425                );
426            }
427        }
428    }
429
430    #[gpui::test]
431    async fn test_fuzzy_match(cx: &mut TestAppContext) {
432        let models = create_model_list(vec![
433            (
434                "zed",
435                vec![
436                    "Claude 3.7 Sonnet",
437                    "Claude 3.7 Sonnet Thinking",
438                    "gpt-4.1",
439                    "gpt-4.1-nano",
440                ],
441            ),
442            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
443            ("ollama", vec!["mistral", "deepseek"]),
444        ]);
445
446        // Results should preserve models order whenever possible.
447        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
448        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
449        // so it should appear first in the results.
450        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
451        assert_models_eq(
452            results,
453            vec![
454                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
455                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
456            ],
457        );
458
459        // Fuzzy search
460        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
461        assert_models_eq(
462            results,
463            vec![
464                ("zed", vec!["gpt-4.1-nano"]),
465                ("openai", vec!["gpt-4.1-nano"]),
466            ],
467        );
468    }
469}