model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_client_protocol as acp;
  5use anyhow::Result;
  6use collections::IndexMap;
  7use futures::FutureExt;
  8use fuzzy::{StringMatchCandidate, match_strings};
  9use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
 10use ordered_float::OrderedFloat;
 11use picker::{Picker, PickerDelegate};
 12use ui::{
 13    AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window,
 14    prelude::*, rems,
 15};
 16use util::ResultExt;
 17
 18pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 19
 20pub fn acp_model_selector(
 21    session_id: acp::SessionId,
 22    selector: Rc<dyn AgentModelSelector>,
 23    window: &mut Window,
 24    cx: &mut Context<AcpModelSelector>,
 25) -> AcpModelSelector {
 26    let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx);
 27    Picker::list(delegate, window, cx)
 28        .show_scrollbar(true)
 29        .width(rems(20.))
 30        .max_height(Some(rems(20.).into()))
 31}
 32
 33enum AcpModelPickerEntry {
 34    Separator(SharedString),
 35    Model(AgentModelInfo),
 36}
 37
 38pub struct AcpModelPickerDelegate {
 39    session_id: acp::SessionId,
 40    selector: Rc<dyn AgentModelSelector>,
 41    filtered_entries: Vec<AcpModelPickerEntry>,
 42    models: Option<AgentModelList>,
 43    selected_index: usize,
 44    selected_model: Option<AgentModelInfo>,
 45    _refresh_models_task: Task<()>,
 46}
 47
 48impl AcpModelPickerDelegate {
 49    fn new(
 50        session_id: acp::SessionId,
 51        selector: Rc<dyn AgentModelSelector>,
 52        window: &mut Window,
 53        cx: &mut Context<AcpModelSelector>,
 54    ) -> Self {
 55        let mut rx = selector.watch(cx);
 56        let refresh_models_task = cx.spawn_in(window, {
 57            let session_id = session_id.clone();
 58            async move |this, cx| {
 59                async fn refresh(
 60                    this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 61                    session_id: &acp::SessionId,
 62                    cx: &mut AsyncWindowContext,
 63                ) -> Result<()> {
 64                    let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 65                        (
 66                            this.delegate.selector.list_models(cx),
 67                            this.delegate.selector.selected_model(session_id, cx),
 68                        )
 69                    })?;
 70
 71                    let (models, selected_model) = futures::join!(models_task, selected_model_task);
 72
 73                    this.update_in(cx, |this, window, cx| {
 74                        this.delegate.models = models.log_err();
 75                        this.delegate.selected_model = selected_model.ok();
 76                        this.delegate.update_matches(this.query(cx), window, cx)
 77                    })?
 78                    .await;
 79
 80                    Ok(())
 81                }
 82
 83                refresh(&this, &session_id, cx).await.log_err();
 84                while let Ok(()) = rx.recv().await {
 85                    refresh(&this, &session_id, cx).await.log_err();
 86                }
 87            }
 88        });
 89
 90        Self {
 91            session_id,
 92            selector,
 93            filtered_entries: Vec::new(),
 94            models: None,
 95            selected_model: None,
 96            selected_index: 0,
 97            _refresh_models_task: refresh_models_task,
 98        }
 99    }
100
101    pub fn active_model(&self) -> Option<&AgentModelInfo> {
102        self.selected_model.as_ref()
103    }
104}
105
106impl PickerDelegate for AcpModelPickerDelegate {
107    type ListItem = AnyElement;
108
109    fn match_count(&self) -> usize {
110        self.filtered_entries.len()
111    }
112
113    fn selected_index(&self) -> usize {
114        self.selected_index
115    }
116
117    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
118        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
119        cx.notify();
120    }
121
122    fn can_select(
123        &mut self,
124        ix: usize,
125        _window: &mut Window,
126        _cx: &mut Context<Picker<Self>>,
127    ) -> bool {
128        match self.filtered_entries.get(ix) {
129            Some(AcpModelPickerEntry::Model(_)) => true,
130            Some(AcpModelPickerEntry::Separator(_)) | None => false,
131        }
132    }
133
134    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
135        "Select a model…".into()
136    }
137
138    fn update_matches(
139        &mut self,
140        query: String,
141        window: &mut Window,
142        cx: &mut Context<Picker<Self>>,
143    ) -> Task<()> {
144        cx.spawn_in(window, async move |this, cx| {
145            let filtered_models = match this
146                .read_with(cx, |this, cx| {
147                    if let Some(models) = this.delegate.models.as_ref() {
148                        log::debug!("Filtering {} models.", models.len());
149                    } else {
150                        log::debug!("No models available.");
151                    }
152                    this.delegate.models.clone().map(move |models| {
153                        fuzzy_search(models, query, cx.background_executor().clone())
154                    })
155                })
156                .ok()
157                .flatten()
158            {
159                Some(task) => task.await,
160                None => AgentModelList::Flat(vec![]),
161            };
162
163            log::debug!("Filtered models. {} available.", filtered_models.len());
164
165            this.update_in(cx, |this, window, cx| {
166                this.delegate.filtered_entries =
167                    info_list_to_picker_entries(filtered_models).collect();
168                // Finds the currently selected model in the list
169                let new_index = this
170                    .delegate
171                    .selected_model
172                    .as_ref()
173                    .and_then(|selected| {
174                        this.delegate.filtered_entries.iter().position(|entry| {
175                            if let AcpModelPickerEntry::Model(model_info) = entry {
176                                model_info.id == selected.id
177                            } else {
178                                false
179                            }
180                        })
181                    })
182                    .unwrap_or(0);
183                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
184                cx.notify();
185            })
186            .ok();
187        })
188    }
189
190    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
191        if let Some(AcpModelPickerEntry::Model(model_info)) =
192            self.filtered_entries.get(self.selected_index)
193        {
194            self.selector
195                .select_model(self.session_id.clone(), model_info.id.clone(), cx)
196                .detach_and_log_err(cx);
197            self.selected_model = Some(model_info.clone());
198            let current_index = self.selected_index;
199            self.set_selected_index(current_index, window, cx);
200
201            cx.emit(DismissEvent);
202        }
203    }
204
205    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
206        cx.emit(DismissEvent);
207    }
208
209    fn render_match(
210        &self,
211        ix: usize,
212        selected: bool,
213        _: &mut Window,
214        cx: &mut Context<Picker<Self>>,
215    ) -> Option<Self::ListItem> {
216        match self.filtered_entries.get(ix)? {
217            AcpModelPickerEntry::Separator(title) => Some(
218                div()
219                    .px_2()
220                    .pb_1()
221                    .when(ix > 1, |this| {
222                        this.mt_1()
223                            .pt_2()
224                            .border_t_1()
225                            .border_color(cx.theme().colors().border_variant)
226                    })
227                    .child(
228                        Label::new(title)
229                            .size(LabelSize::XSmall)
230                            .color(Color::Muted),
231                    )
232                    .into_any_element(),
233            ),
234            AcpModelPickerEntry::Model(model_info) => {
235                let is_selected = Some(model_info) == self.selected_model.as_ref();
236
237                let model_icon_color = if is_selected {
238                    Color::Accent
239                } else {
240                    Color::Muted
241                };
242
243                Some(
244                    ListItem::new(ix)
245                        .inset(true)
246                        .spacing(ListItemSpacing::Sparse)
247                        .toggle_state(selected)
248                        .start_slot::<Icon>(model_info.icon.map(|icon| {
249                            Icon::new(icon)
250                                .color(model_icon_color)
251                                .size(IconSize::Small)
252                        }))
253                        .child(
254                            h_flex()
255                                .w_full()
256                                .pl_0p5()
257                                .gap_1p5()
258                                .w(px(240.))
259                                .child(Label::new(model_info.name.clone()).truncate()),
260                        )
261                        .end_slot(div().pr_3().when(is_selected, |this| {
262                            this.child(
263                                Icon::new(IconName::Check)
264                                    .color(Color::Accent)
265                                    .size(IconSize::Small),
266                            )
267                        }))
268                        .into_any_element(),
269                )
270            }
271        }
272    }
273
274    fn render_footer(
275        &self,
276        _: &mut Window,
277        cx: &mut Context<Picker<Self>>,
278    ) -> Option<gpui::AnyElement> {
279        Some(
280            h_flex()
281                .w_full()
282                .border_t_1()
283                .border_color(cx.theme().colors().border_variant)
284                .p_1()
285                .gap_4()
286                .justify_between()
287                .child(
288                    Button::new("configure", "Configure")
289                        .icon(IconName::Settings)
290                        .icon_size(IconSize::Small)
291                        .icon_color(Color::Muted)
292                        .icon_position(IconPosition::Start)
293                        .on_click(|_, window, cx| {
294                            window.dispatch_action(
295                                zed_actions::agent::OpenSettings.boxed_clone(),
296                                cx,
297                            );
298                        }),
299                )
300                .into_any(),
301        )
302    }
303}
304
305fn info_list_to_picker_entries(
306    model_list: AgentModelList,
307) -> impl Iterator<Item = AcpModelPickerEntry> {
308    match model_list {
309        AgentModelList::Flat(list) => {
310            itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
311        }
312        AgentModelList::Grouped(index_map) => {
313            itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
314                std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
315                    .chain(models.into_iter().map(AcpModelPickerEntry::Model))
316            }))
317        }
318    }
319}
320
321async fn fuzzy_search(
322    model_list: AgentModelList,
323    query: String,
324    executor: BackgroundExecutor,
325) -> AgentModelList {
326    async fn fuzzy_search_list(
327        model_list: Vec<AgentModelInfo>,
328        query: &str,
329        executor: BackgroundExecutor,
330    ) -> Vec<AgentModelInfo> {
331        let candidates = model_list
332            .iter()
333            .enumerate()
334            .map(|(ix, model)| {
335                StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
336            })
337            .collect::<Vec<_>>();
338        let mut matches = match_strings(
339            &candidates,
340            query,
341            false,
342            true,
343            100,
344            &Default::default(),
345            executor,
346        )
347        .await;
348
349        matches.sort_unstable_by_key(|mat| {
350            let candidate = &candidates[mat.candidate_id];
351            (Reverse(OrderedFloat(mat.score)), candidate.id)
352        });
353
354        matches
355            .into_iter()
356            .map(|mat| model_list[mat.candidate_id].clone())
357            .collect()
358    }
359
360    match model_list {
361        AgentModelList::Flat(model_list) => {
362            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
363        }
364        AgentModelList::Grouped(index_map) => {
365            let groups =
366                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
367                    fuzzy_search_list(models, &query, executor.clone())
368                        .map(|results| (group_name, results))
369                }))
370                .await;
371            AgentModelList::Grouped(IndexMap::from_iter(
372                groups
373                    .into_iter()
374                    .filter(|(_, results)| !results.is_empty()),
375            ))
376        }
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use gpui::TestAppContext;
383
384    use super::*;
385
386    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
387        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
388            |(group, models)| {
389                (
390                    acp_thread::AgentModelGroupName(group.to_string().into()),
391                    models
392                        .into_iter()
393                        .map(|model| acp_thread::AgentModelInfo {
394                            id: acp_thread::AgentModelId(model.to_string().into()),
395                            name: model.to_string().into(),
396                            icon: None,
397                        })
398                        .collect::<Vec<_>>(),
399                )
400            },
401        )))
402    }
403
404    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
405        let AgentModelList::Grouped(groups) = result else {
406            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
407        };
408
409        assert_eq!(
410            groups.len(),
411            expected.len(),
412            "Number of groups doesn't match"
413        );
414
415        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
416            let (actual_group, actual_models) = groups.get_index(i).unwrap();
417            assert_eq!(
418                actual_group.0.as_ref(),
419                *expected_group,
420                "Group at position {} doesn't match expected group",
421                i
422            );
423            assert_eq!(
424                actual_models.len(),
425                expected_models.len(),
426                "Number of models in group {} doesn't match",
427                expected_group
428            );
429
430            for (j, expected_model_name) in expected_models.iter().enumerate() {
431                assert_eq!(
432                    actual_models[j].name, *expected_model_name,
433                    "Model at position {} in group {} doesn't match expected model",
434                    j, expected_group
435                );
436            }
437        }
438    }
439
440    #[gpui::test]
441    async fn test_fuzzy_match(cx: &mut TestAppContext) {
442        let models = create_model_list(vec![
443            (
444                "zed",
445                vec![
446                    "Claude 3.7 Sonnet",
447                    "Claude 3.7 Sonnet Thinking",
448                    "gpt-4.1",
449                    "gpt-4.1-nano",
450                ],
451            ),
452            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
453            ("ollama", vec!["mistral", "deepseek"]),
454        ]);
455
456        // Results should preserve models order whenever possible.
457        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
458        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
459        // so it should appear first in the results.
460        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
461        assert_models_eq(
462            results,
463            vec![
464                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
465                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
466            ],
467        );
468
469        // Fuzzy search
470        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
471        assert_models_eq(
472            results,
473            vec![
474                ("zed", vec!["gpt-4.1-nano"]),
475                ("openai", vec!["gpt-4.1-nano"]),
476            ],
477        );
478    }
479}