model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_client_protocol as acp;
  5use anyhow::Result;
  6use collections::IndexMap;
  7use futures::FutureExt;
  8use fuzzy::{StringMatchCandidate, match_strings};
  9use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
 10use ordered_float::OrderedFloat;
 11use picker::{Picker, PickerDelegate};
 12use ui::{
 13    AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window,
 14    prelude::*, rems,
 15};
 16use util::ResultExt;
 17
 18pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 19
 20pub fn acp_model_selector(
 21    session_id: acp::SessionId,
 22    selector: Rc<dyn AgentModelSelector>,
 23    window: &mut Window,
 24    cx: &mut Context<AcpModelSelector>,
 25) -> AcpModelSelector {
 26    let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx);
 27    Picker::list(delegate, window, cx)
 28        .show_scrollbar(true)
 29        .width(rems(20.))
 30        .max_height(Some(rems(20.).into()))
 31}
 32
 33enum AcpModelPickerEntry {
 34    Separator(SharedString),
 35    Model(AgentModelInfo),
 36}
 37
 38pub struct AcpModelPickerDelegate {
 39    session_id: acp::SessionId,
 40    selector: Rc<dyn AgentModelSelector>,
 41    filtered_entries: Vec<AcpModelPickerEntry>,
 42    models: Option<AgentModelList>,
 43    selected_index: usize,
 44    selected_model: Option<AgentModelInfo>,
 45    _refresh_models_task: Task<()>,
 46}
 47
 48impl AcpModelPickerDelegate {
 49    fn new(
 50        session_id: acp::SessionId,
 51        selector: Rc<dyn AgentModelSelector>,
 52        window: &mut Window,
 53        cx: &mut Context<AcpModelSelector>,
 54    ) -> Self {
 55        let mut rx = selector.watch(cx);
 56        let refresh_models_task = cx.spawn_in(window, {
 57            let session_id = session_id.clone();
 58            async move |this, cx| {
 59                async fn refresh(
 60                    this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 61                    session_id: &acp::SessionId,
 62                    cx: &mut AsyncWindowContext,
 63                ) -> Result<()> {
 64                    let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 65                        (
 66                            this.delegate.selector.list_models(cx),
 67                            this.delegate.selector.selected_model(session_id, cx),
 68                        )
 69                    })?;
 70
 71                    let (models, selected_model) = futures::join!(models_task, selected_model_task);
 72
 73                    this.update_in(cx, |this, window, cx| {
 74                        this.delegate.models = models.ok();
 75                        this.delegate.selected_model = selected_model.ok();
 76                        this.delegate.update_matches(this.query(cx), window, cx)
 77                    })?
 78                    .await;
 79
 80                    Ok(())
 81                }
 82
 83                refresh(&this, &session_id, cx).await.log_err();
 84                while let Ok(()) = rx.recv().await {
 85                    refresh(&this, &session_id, cx).await.log_err();
 86                }
 87            }
 88        });
 89
 90        Self {
 91            session_id,
 92            selector,
 93            filtered_entries: Vec::new(),
 94            models: None,
 95            selected_model: None,
 96            selected_index: 0,
 97            _refresh_models_task: refresh_models_task,
 98        }
 99    }
100
101    pub fn active_model(&self) -> Option<&AgentModelInfo> {
102        self.selected_model.as_ref()
103    }
104}
105
106impl PickerDelegate for AcpModelPickerDelegate {
107    type ListItem = AnyElement;
108
109    fn match_count(&self) -> usize {
110        self.filtered_entries.len()
111    }
112
113    fn selected_index(&self) -> usize {
114        self.selected_index
115    }
116
117    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
118        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
119        cx.notify();
120    }
121
122    fn can_select(
123        &mut self,
124        ix: usize,
125        _window: &mut Window,
126        _cx: &mut Context<Picker<Self>>,
127    ) -> bool {
128        match self.filtered_entries.get(ix) {
129            Some(AcpModelPickerEntry::Model(_)) => true,
130            Some(AcpModelPickerEntry::Separator(_)) | None => false,
131        }
132    }
133
134    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
135        "Select a model…".into()
136    }
137
138    fn update_matches(
139        &mut self,
140        query: String,
141        window: &mut Window,
142        cx: &mut Context<Picker<Self>>,
143    ) -> Task<()> {
144        cx.spawn_in(window, async move |this, cx| {
145            let filtered_models = match this
146                .read_with(cx, |this, cx| {
147                    this.delegate.models.clone().map(move |models| {
148                        fuzzy_search(models, query, cx.background_executor().clone())
149                    })
150                })
151                .ok()
152                .flatten()
153            {
154                Some(task) => task.await,
155                None => AgentModelList::Flat(vec![]),
156            };
157
158            this.update_in(cx, |this, window, cx| {
159                this.delegate.filtered_entries =
160                    info_list_to_picker_entries(filtered_models).collect();
161                // Finds the currently selected model in the list
162                let new_index = this
163                    .delegate
164                    .selected_model
165                    .as_ref()
166                    .and_then(|selected| {
167                        this.delegate.filtered_entries.iter().position(|entry| {
168                            if let AcpModelPickerEntry::Model(model_info) = entry {
169                                model_info.id == selected.id
170                            } else {
171                                false
172                            }
173                        })
174                    })
175                    .unwrap_or(0);
176                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
177                cx.notify();
178            })
179            .ok();
180        })
181    }
182
183    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
184        if let Some(AcpModelPickerEntry::Model(model_info)) =
185            self.filtered_entries.get(self.selected_index)
186        {
187            self.selector
188                .select_model(self.session_id.clone(), model_info.id.clone(), cx)
189                .detach_and_log_err(cx);
190            self.selected_model = Some(model_info.clone());
191            let current_index = self.selected_index;
192            self.set_selected_index(current_index, window, cx);
193
194            cx.emit(DismissEvent);
195        }
196    }
197
198    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
199        cx.emit(DismissEvent);
200    }
201
202    fn render_match(
203        &self,
204        ix: usize,
205        selected: bool,
206        _: &mut Window,
207        cx: &mut Context<Picker<Self>>,
208    ) -> Option<Self::ListItem> {
209        match self.filtered_entries.get(ix)? {
210            AcpModelPickerEntry::Separator(title) => Some(
211                div()
212                    .px_2()
213                    .pb_1()
214                    .when(ix > 1, |this| {
215                        this.mt_1()
216                            .pt_2()
217                            .border_t_1()
218                            .border_color(cx.theme().colors().border_variant)
219                    })
220                    .child(
221                        Label::new(title)
222                            .size(LabelSize::XSmall)
223                            .color(Color::Muted),
224                    )
225                    .into_any_element(),
226            ),
227            AcpModelPickerEntry::Model(model_info) => {
228                let is_selected = Some(model_info) == self.selected_model.as_ref();
229
230                let model_icon_color = if is_selected {
231                    Color::Accent
232                } else {
233                    Color::Muted
234                };
235
236                Some(
237                    ListItem::new(ix)
238                        .inset(true)
239                        .spacing(ListItemSpacing::Sparse)
240                        .toggle_state(selected)
241                        .start_slot::<Icon>(model_info.icon.map(|icon| {
242                            Icon::new(icon)
243                                .color(model_icon_color)
244                                .size(IconSize::Small)
245                        }))
246                        .child(
247                            h_flex()
248                                .w_full()
249                                .pl_0p5()
250                                .gap_1p5()
251                                .w(px(240.))
252                                .child(Label::new(model_info.name.clone()).truncate()),
253                        )
254                        .end_slot(div().pr_3().when(is_selected, |this| {
255                            this.child(
256                                Icon::new(IconName::Check)
257                                    .color(Color::Accent)
258                                    .size(IconSize::Small),
259                            )
260                        }))
261                        .into_any_element(),
262                )
263            }
264        }
265    }
266
267    fn render_footer(
268        &self,
269        _: &mut Window,
270        cx: &mut Context<Picker<Self>>,
271    ) -> Option<gpui::AnyElement> {
272        Some(
273            h_flex()
274                .w_full()
275                .border_t_1()
276                .border_color(cx.theme().colors().border_variant)
277                .p_1()
278                .gap_4()
279                .justify_between()
280                .child(
281                    Button::new("configure", "Configure")
282                        .icon(IconName::Settings)
283                        .icon_size(IconSize::Small)
284                        .icon_color(Color::Muted)
285                        .icon_position(IconPosition::Start)
286                        .on_click(|_, window, cx| {
287                            window.dispatch_action(
288                                zed_actions::agent::OpenSettings.boxed_clone(),
289                                cx,
290                            );
291                        }),
292                )
293                .into_any(),
294        )
295    }
296}
297
298fn info_list_to_picker_entries(
299    model_list: AgentModelList,
300) -> impl Iterator<Item = AcpModelPickerEntry> {
301    match model_list {
302        AgentModelList::Flat(list) => {
303            itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
304        }
305        AgentModelList::Grouped(index_map) => {
306            itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
307                std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
308                    .chain(models.into_iter().map(AcpModelPickerEntry::Model))
309            }))
310        }
311    }
312}
313
314async fn fuzzy_search(
315    model_list: AgentModelList,
316    query: String,
317    executor: BackgroundExecutor,
318) -> AgentModelList {
319    async fn fuzzy_search_list(
320        model_list: Vec<AgentModelInfo>,
321        query: &str,
322        executor: BackgroundExecutor,
323    ) -> Vec<AgentModelInfo> {
324        let candidates = model_list
325            .iter()
326            .enumerate()
327            .map(|(ix, model)| {
328                StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
329            })
330            .collect::<Vec<_>>();
331        let mut matches = match_strings(
332            &candidates,
333            query,
334            false,
335            true,
336            100,
337            &Default::default(),
338            executor,
339        )
340        .await;
341
342        matches.sort_unstable_by_key(|mat| {
343            let candidate = &candidates[mat.candidate_id];
344            (Reverse(OrderedFloat(mat.score)), candidate.id)
345        });
346
347        matches
348            .into_iter()
349            .map(|mat| model_list[mat.candidate_id].clone())
350            .collect()
351    }
352
353    match model_list {
354        AgentModelList::Flat(model_list) => {
355            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
356        }
357        AgentModelList::Grouped(index_map) => {
358            let groups =
359                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
360                    fuzzy_search_list(models, &query, executor.clone())
361                        .map(|results| (group_name, results))
362                }))
363                .await;
364            AgentModelList::Grouped(IndexMap::from_iter(
365                groups
366                    .into_iter()
367                    .filter(|(_, results)| !results.is_empty()),
368            ))
369        }
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use gpui::TestAppContext;
376
377    use super::*;
378
379    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
380        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
381            |(group, models)| {
382                (
383                    acp_thread::AgentModelGroupName(group.to_string().into()),
384                    models
385                        .into_iter()
386                        .map(|model| acp_thread::AgentModelInfo {
387                            id: acp_thread::AgentModelId(model.to_string().into()),
388                            name: model.to_string().into(),
389                            icon: None,
390                        })
391                        .collect::<Vec<_>>(),
392                )
393            },
394        )))
395    }
396
397    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
398        let AgentModelList::Grouped(groups) = result else {
399            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
400        };
401
402        assert_eq!(
403            groups.len(),
404            expected.len(),
405            "Number of groups doesn't match"
406        );
407
408        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
409            let (actual_group, actual_models) = groups.get_index(i).unwrap();
410            assert_eq!(
411                actual_group.0.as_ref(),
412                *expected_group,
413                "Group at position {} doesn't match expected group",
414                i
415            );
416            assert_eq!(
417                actual_models.len(),
418                expected_models.len(),
419                "Number of models in group {} doesn't match",
420                expected_group
421            );
422
423            for (j, expected_model_name) in expected_models.iter().enumerate() {
424                assert_eq!(
425                    actual_models[j].name, *expected_model_name,
426                    "Model at position {} in group {} doesn't match expected model",
427                    j, expected_group
428                );
429            }
430        }
431    }
432
433    #[gpui::test]
434    async fn test_fuzzy_match(cx: &mut TestAppContext) {
435        let models = create_model_list(vec![
436            (
437                "zed",
438                vec![
439                    "Claude 3.7 Sonnet",
440                    "Claude 3.7 Sonnet Thinking",
441                    "gpt-4.1",
442                    "gpt-4.1-nano",
443                ],
444            ),
445            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
446            ("ollama", vec!["mistral", "deepseek"]),
447        ]);
448
449        // Results should preserve models order whenever possible.
450        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
451        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
452        // so it should appear first in the results.
453        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
454        assert_models_eq(
455            results,
456            vec![
457                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
458                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
459            ],
460        );
461
462        // Fuzzy search
463        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
464        assert_models_eq(
465            results,
466            vec![
467                ("zed", vec!["gpt-4.1-nano"]),
468                ("openai", vec!["gpt-4.1-nano"]),
469            ],
470        );
471    }
472}