model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_client_protocol as acp;
  5use anyhow::Result;
  6use collections::IndexMap;
  7use futures::FutureExt;
  8use fuzzy::{StringMatchCandidate, match_strings};
  9use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
 10use ordered_float::OrderedFloat;
 11use picker::{Picker, PickerDelegate};
 12use ui::{
 13    AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window,
 14    prelude::*, rems,
 15};
 16use util::ResultExt;
 17
 18pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 19
 20pub fn acp_model_selector(
 21    session_id: acp::SessionId,
 22    selector: Rc<dyn AgentModelSelector>,
 23    window: &mut Window,
 24    cx: &mut Context<AcpModelSelector>,
 25) -> AcpModelSelector {
 26    let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx);
 27    Picker::list(delegate, window, cx)
 28        .show_scrollbar(true)
 29        .width(rems(20.))
 30        .max_height(Some(rems(20.).into()))
 31}
 32
 33enum AcpModelPickerEntry {
 34    Separator(SharedString),
 35    Model(AgentModelInfo),
 36}
 37
 38pub struct AcpModelPickerDelegate {
 39    session_id: acp::SessionId,
 40    selector: Rc<dyn AgentModelSelector>,
 41    filtered_entries: Vec<AcpModelPickerEntry>,
 42    models: Option<AgentModelList>,
 43    selected_index: usize,
 44    selected_model: Option<AgentModelInfo>,
 45    _refresh_models_task: Task<()>,
 46}
 47
 48impl AcpModelPickerDelegate {
 49    fn new(
 50        session_id: acp::SessionId,
 51        selector: Rc<dyn AgentModelSelector>,
 52        window: &mut Window,
 53        cx: &mut Context<AcpModelSelector>,
 54    ) -> Self {
 55        let mut rx = selector.watch(cx);
 56        let refresh_models_task = cx.spawn_in(window, {
 57            let session_id = session_id.clone();
 58            async move |this, cx| {
 59                async fn refresh(
 60                    this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 61                    session_id: &acp::SessionId,
 62                    cx: &mut AsyncWindowContext,
 63                ) -> Result<()> {
 64                    let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 65                        (
 66                            this.delegate.selector.list_models(cx),
 67                            this.delegate.selector.selected_model(session_id, cx),
 68                        )
 69                    })?;
 70
 71                    let (models, selected_model) = futures::join!(models_task, selected_model_task);
 72
 73                    this.update_in(cx, |this, window, cx| {
 74                        this.delegate.models = models.log_err();
 75                        this.delegate.selected_model = selected_model.ok();
 76                        this.refresh(window, cx)
 77                    })
 78                }
 79
 80                refresh(&this, &session_id, cx).await.log_err();
 81                while let Ok(()) = rx.recv().await {
 82                    refresh(&this, &session_id, cx).await.log_err();
 83                }
 84            }
 85        });
 86
 87        Self {
 88            session_id,
 89            selector,
 90            filtered_entries: Vec::new(),
 91            models: None,
 92            selected_model: None,
 93            selected_index: 0,
 94            _refresh_models_task: refresh_models_task,
 95        }
 96    }
 97
 98    pub fn active_model(&self) -> Option<&AgentModelInfo> {
 99        self.selected_model.as_ref()
100    }
101}
102
103impl PickerDelegate for AcpModelPickerDelegate {
104    type ListItem = AnyElement;
105
106    fn match_count(&self) -> usize {
107        self.filtered_entries.len()
108    }
109
110    fn selected_index(&self) -> usize {
111        self.selected_index
112    }
113
114    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
115        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
116        cx.notify();
117    }
118
119    fn can_select(
120        &mut self,
121        ix: usize,
122        _window: &mut Window,
123        _cx: &mut Context<Picker<Self>>,
124    ) -> bool {
125        match self.filtered_entries.get(ix) {
126            Some(AcpModelPickerEntry::Model(_)) => true,
127            Some(AcpModelPickerEntry::Separator(_)) | None => false,
128        }
129    }
130
131    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
132        "Select a model…".into()
133    }
134
135    fn update_matches(
136        &mut self,
137        query: String,
138        window: &mut Window,
139        cx: &mut Context<Picker<Self>>,
140    ) -> Task<()> {
141        cx.spawn_in(window, async move |this, cx| {
142            let filtered_models = match this
143                .read_with(cx, |this, cx| {
144                    if let Some(models) = this.delegate.models.as_ref() {
145                        log::debug!("Filtering {} models.", models.len());
146                    } else {
147                        log::debug!("No models available.");
148                    }
149                    this.delegate.models.clone().map(move |models| {
150                        fuzzy_search(models, query, cx.background_executor().clone())
151                    })
152                })
153                .ok()
154                .flatten()
155            {
156                Some(task) => task.await,
157                None => AgentModelList::Flat(vec![]),
158            };
159
160            log::debug!("Filtered models. {} available.", filtered_models.len());
161
162            this.update_in(cx, |this, window, cx| {
163                this.delegate.filtered_entries =
164                    info_list_to_picker_entries(filtered_models).collect();
165                // Finds the currently selected model in the list
166                let new_index = this
167                    .delegate
168                    .selected_model
169                    .as_ref()
170                    .and_then(|selected| {
171                        this.delegate.filtered_entries.iter().position(|entry| {
172                            if let AcpModelPickerEntry::Model(model_info) = entry {
173                                model_info.id == selected.id
174                            } else {
175                                false
176                            }
177                        })
178                    })
179                    .unwrap_or(0);
180                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
181                cx.notify();
182            })
183            .ok();
184        })
185    }
186
187    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
188        if let Some(AcpModelPickerEntry::Model(model_info)) =
189            self.filtered_entries.get(self.selected_index)
190        {
191            self.selector
192                .select_model(self.session_id.clone(), model_info.id.clone(), cx)
193                .detach_and_log_err(cx);
194            self.selected_model = Some(model_info.clone());
195            let current_index = self.selected_index;
196            self.set_selected_index(current_index, window, cx);
197
198            cx.emit(DismissEvent);
199        }
200    }
201
202    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
203        cx.emit(DismissEvent);
204    }
205
206    fn render_match(
207        &self,
208        ix: usize,
209        selected: bool,
210        _: &mut Window,
211        cx: &mut Context<Picker<Self>>,
212    ) -> Option<Self::ListItem> {
213        match self.filtered_entries.get(ix)? {
214            AcpModelPickerEntry::Separator(title) => Some(
215                div()
216                    .px_2()
217                    .pb_1()
218                    .when(ix > 1, |this| {
219                        this.mt_1()
220                            .pt_2()
221                            .border_t_1()
222                            .border_color(cx.theme().colors().border_variant)
223                    })
224                    .child(
225                        Label::new(title)
226                            .size(LabelSize::XSmall)
227                            .color(Color::Muted),
228                    )
229                    .into_any_element(),
230            ),
231            AcpModelPickerEntry::Model(model_info) => {
232                let is_selected = Some(model_info) == self.selected_model.as_ref();
233
234                let model_icon_color = if is_selected {
235                    Color::Accent
236                } else {
237                    Color::Muted
238                };
239
240                Some(
241                    ListItem::new(ix)
242                        .inset(true)
243                        .spacing(ListItemSpacing::Sparse)
244                        .toggle_state(selected)
245                        .start_slot::<Icon>(model_info.icon.map(|icon| {
246                            Icon::new(icon)
247                                .color(model_icon_color)
248                                .size(IconSize::Small)
249                        }))
250                        .child(
251                            h_flex()
252                                .w_full()
253                                .pl_0p5()
254                                .gap_1p5()
255                                .w(px(240.))
256                                .child(Label::new(model_info.name.clone()).truncate()),
257                        )
258                        .end_slot(div().pr_3().when(is_selected, |this| {
259                            this.child(
260                                Icon::new(IconName::Check)
261                                    .color(Color::Accent)
262                                    .size(IconSize::Small),
263                            )
264                        }))
265                        .into_any_element(),
266                )
267            }
268        }
269    }
270
271    fn render_footer(
272        &self,
273        _: &mut Window,
274        cx: &mut Context<Picker<Self>>,
275    ) -> Option<gpui::AnyElement> {
276        Some(
277            h_flex()
278                .w_full()
279                .border_t_1()
280                .border_color(cx.theme().colors().border_variant)
281                .p_1()
282                .gap_4()
283                .justify_between()
284                .child(
285                    Button::new("configure", "Configure")
286                        .icon(IconName::Settings)
287                        .icon_size(IconSize::Small)
288                        .icon_color(Color::Muted)
289                        .icon_position(IconPosition::Start)
290                        .on_click(|_, window, cx| {
291                            window.dispatch_action(
292                                zed_actions::agent::OpenSettings.boxed_clone(),
293                                cx,
294                            );
295                        }),
296                )
297                .into_any(),
298        )
299    }
300}
301
302fn info_list_to_picker_entries(
303    model_list: AgentModelList,
304) -> impl Iterator<Item = AcpModelPickerEntry> {
305    match model_list {
306        AgentModelList::Flat(list) => {
307            itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
308        }
309        AgentModelList::Grouped(index_map) => {
310            itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
311                std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
312                    .chain(models.into_iter().map(AcpModelPickerEntry::Model))
313            }))
314        }
315    }
316}
317
318async fn fuzzy_search(
319    model_list: AgentModelList,
320    query: String,
321    executor: BackgroundExecutor,
322) -> AgentModelList {
323    async fn fuzzy_search_list(
324        model_list: Vec<AgentModelInfo>,
325        query: &str,
326        executor: BackgroundExecutor,
327    ) -> Vec<AgentModelInfo> {
328        let candidates = model_list
329            .iter()
330            .enumerate()
331            .map(|(ix, model)| {
332                StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
333            })
334            .collect::<Vec<_>>();
335        let mut matches = match_strings(
336            &candidates,
337            query,
338            false,
339            true,
340            100,
341            &Default::default(),
342            executor,
343        )
344        .await;
345
346        matches.sort_unstable_by_key(|mat| {
347            let candidate = &candidates[mat.candidate_id];
348            (Reverse(OrderedFloat(mat.score)), candidate.id)
349        });
350
351        matches
352            .into_iter()
353            .map(|mat| model_list[mat.candidate_id].clone())
354            .collect()
355    }
356
357    match model_list {
358        AgentModelList::Flat(model_list) => {
359            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
360        }
361        AgentModelList::Grouped(index_map) => {
362            let groups =
363                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
364                    fuzzy_search_list(models, &query, executor.clone())
365                        .map(|results| (group_name, results))
366                }))
367                .await;
368            AgentModelList::Grouped(IndexMap::from_iter(
369                groups
370                    .into_iter()
371                    .filter(|(_, results)| !results.is_empty()),
372            ))
373        }
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use gpui::TestAppContext;
380
381    use super::*;
382
383    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
384        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
385            |(group, models)| {
386                (
387                    acp_thread::AgentModelGroupName(group.to_string().into()),
388                    models
389                        .into_iter()
390                        .map(|model| acp_thread::AgentModelInfo {
391                            id: acp_thread::AgentModelId(model.to_string().into()),
392                            name: model.to_string().into(),
393                            icon: None,
394                        })
395                        .collect::<Vec<_>>(),
396                )
397            },
398        )))
399    }
400
401    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
402        let AgentModelList::Grouped(groups) = result else {
403            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
404        };
405
406        assert_eq!(
407            groups.len(),
408            expected.len(),
409            "Number of groups doesn't match"
410        );
411
412        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
413            let (actual_group, actual_models) = groups.get_index(i).unwrap();
414            assert_eq!(
415                actual_group.0.as_ref(),
416                *expected_group,
417                "Group at position {} doesn't match expected group",
418                i
419            );
420            assert_eq!(
421                actual_models.len(),
422                expected_models.len(),
423                "Number of models in group {} doesn't match",
424                expected_group
425            );
426
427            for (j, expected_model_name) in expected_models.iter().enumerate() {
428                assert_eq!(
429                    actual_models[j].name, *expected_model_name,
430                    "Model at position {} in group {} doesn't match expected model",
431                    j, expected_group
432                );
433            }
434        }
435    }
436
437    #[gpui::test]
438    async fn test_fuzzy_match(cx: &mut TestAppContext) {
439        let models = create_model_list(vec![
440            (
441                "zed",
442                vec![
443                    "Claude 3.7 Sonnet",
444                    "Claude 3.7 Sonnet Thinking",
445                    "gpt-4.1",
446                    "gpt-4.1-nano",
447                ],
448            ),
449            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
450            ("ollama", vec!["mistral", "deepseek"]),
451        ]);
452
453        // Results should preserve models order whenever possible.
454        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
455        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
456        // so it should appear first in the results.
457        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
458        assert_models_eq(
459            results,
460            vec![
461                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
462                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
463            ],
464        );
465
466        // Fuzzy search
467        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
468        assert_models_eq(
469            results,
470            vec![
471                ("zed", vec!["gpt-4.1-nano"]),
472                ("openai", vec!["gpt-4.1-nano"]),
473            ],
474        );
475    }
476}