model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_servers::AgentServer;
  5use anyhow::Result;
  6use collections::IndexMap;
  7use fs::Fs;
  8use futures::FutureExt;
  9use fuzzy::{StringMatchCandidate, match_strings};
 10use gpui::{
 11    Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
 12};
 13use ordered_float::OrderedFloat;
 14use picker::{Picker, PickerDelegate};
 15use ui::{
 16    DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, KeyBinding, ListItem,
 17    ListItemSpacing, prelude::*,
 18};
 19use util::ResultExt;
 20use zed_actions::agent::OpenSettings;
 21
 22use crate::ui::HoldForDefault;
 23
 24pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 25
 26pub fn acp_model_selector(
 27    selector: Rc<dyn AgentModelSelector>,
 28    agent_server: Rc<dyn AgentServer>,
 29    fs: Arc<dyn Fs>,
 30    focus_handle: FocusHandle,
 31    window: &mut Window,
 32    cx: &mut Context<AcpModelSelector>,
 33) -> AcpModelSelector {
 34    let delegate =
 35        AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
 36    Picker::list(delegate, window, cx)
 37        .show_scrollbar(true)
 38        .width(rems(20.))
 39        .max_height(Some(rems(20.).into()))
 40}
 41
 42enum AcpModelPickerEntry {
 43    Separator(SharedString),
 44    Model(AgentModelInfo),
 45}
 46
 47pub struct AcpModelPickerDelegate {
 48    selector: Rc<dyn AgentModelSelector>,
 49    agent_server: Rc<dyn AgentServer>,
 50    fs: Arc<dyn Fs>,
 51    filtered_entries: Vec<AcpModelPickerEntry>,
 52    models: Option<AgentModelList>,
 53    selected_index: usize,
 54    selected_description: Option<(usize, SharedString, bool)>,
 55    selected_model: Option<AgentModelInfo>,
 56    _refresh_models_task: Task<()>,
 57    focus_handle: FocusHandle,
 58}
 59
 60impl AcpModelPickerDelegate {
 61    fn new(
 62        selector: Rc<dyn AgentModelSelector>,
 63        agent_server: Rc<dyn AgentServer>,
 64        fs: Arc<dyn Fs>,
 65        focus_handle: FocusHandle,
 66        window: &mut Window,
 67        cx: &mut Context<AcpModelSelector>,
 68    ) -> Self {
 69        let rx = selector.watch(cx);
 70        let refresh_models_task = {
 71            cx.spawn_in(window, {
 72                async move |this, cx| {
 73                    async fn refresh(
 74                        this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 75                        cx: &mut AsyncWindowContext,
 76                    ) -> Result<()> {
 77                        let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 78                            (
 79                                this.delegate.selector.list_models(cx),
 80                                this.delegate.selector.selected_model(cx),
 81                            )
 82                        })?;
 83
 84                        let (models, selected_model) =
 85                            futures::join!(models_task, selected_model_task);
 86
 87                        this.update_in(cx, |this, window, cx| {
 88                            this.delegate.models = models.ok();
 89                            this.delegate.selected_model = selected_model.ok();
 90                            this.refresh(window, cx)
 91                        })
 92                    }
 93
 94                    refresh(&this, cx).await.log_err();
 95                    if let Some(mut rx) = rx {
 96                        while let Ok(()) = rx.recv().await {
 97                            refresh(&this, cx).await.log_err();
 98                        }
 99                    }
100                }
101            })
102        };
103
104        Self {
105            selector,
106            agent_server,
107            fs,
108            filtered_entries: Vec::new(),
109            models: None,
110            selected_model: None,
111            selected_index: 0,
112            selected_description: None,
113            _refresh_models_task: refresh_models_task,
114            focus_handle,
115        }
116    }
117
118    pub fn active_model(&self) -> Option<&AgentModelInfo> {
119        self.selected_model.as_ref()
120    }
121}
122
123impl PickerDelegate for AcpModelPickerDelegate {
124    type ListItem = AnyElement;
125
126    fn match_count(&self) -> usize {
127        self.filtered_entries.len()
128    }
129
130    fn selected_index(&self) -> usize {
131        self.selected_index
132    }
133
134    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
135        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
136        cx.notify();
137    }
138
139    fn can_select(
140        &mut self,
141        ix: usize,
142        _window: &mut Window,
143        _cx: &mut Context<Picker<Self>>,
144    ) -> bool {
145        match self.filtered_entries.get(ix) {
146            Some(AcpModelPickerEntry::Model(_)) => true,
147            Some(AcpModelPickerEntry::Separator(_)) | None => false,
148        }
149    }
150
151    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
152        "Select a model…".into()
153    }
154
155    fn update_matches(
156        &mut self,
157        query: String,
158        window: &mut Window,
159        cx: &mut Context<Picker<Self>>,
160    ) -> Task<()> {
161        cx.spawn_in(window, async move |this, cx| {
162            let filtered_models = match this
163                .read_with(cx, |this, cx| {
164                    this.delegate.models.clone().map(move |models| {
165                        fuzzy_search(models, query, cx.background_executor().clone())
166                    })
167                })
168                .ok()
169                .flatten()
170            {
171                Some(task) => task.await,
172                None => AgentModelList::Flat(vec![]),
173            };
174
175            this.update_in(cx, |this, window, cx| {
176                this.delegate.filtered_entries =
177                    info_list_to_picker_entries(filtered_models).collect();
178                // Finds the currently selected model in the list
179                let new_index = this
180                    .delegate
181                    .selected_model
182                    .as_ref()
183                    .and_then(|selected| {
184                        this.delegate.filtered_entries.iter().position(|entry| {
185                            if let AcpModelPickerEntry::Model(model_info) = entry {
186                                model_info.id == selected.id
187                            } else {
188                                false
189                            }
190                        })
191                    })
192                    .unwrap_or(0);
193                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
194                cx.notify();
195            })
196            .ok();
197        })
198    }
199
200    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
201        if let Some(AcpModelPickerEntry::Model(model_info)) =
202            self.filtered_entries.get(self.selected_index)
203        {
204            if window.modifiers().secondary() {
205                let default_model = self.agent_server.default_model(cx);
206                let is_default = default_model.as_ref() == Some(&model_info.id);
207
208                self.agent_server.set_default_model(
209                    if is_default {
210                        None
211                    } else {
212                        Some(model_info.id.clone())
213                    },
214                    self.fs.clone(),
215                    cx,
216                );
217            }
218
219            self.selector
220                .select_model(model_info.id.clone(), cx)
221                .detach_and_log_err(cx);
222            self.selected_model = Some(model_info.clone());
223            let current_index = self.selected_index;
224            self.set_selected_index(current_index, window, cx);
225
226            cx.emit(DismissEvent);
227        }
228    }
229
230    fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
231        cx.defer_in(window, |picker, window, cx| {
232            picker.set_query("", window, cx);
233        });
234    }
235
236    fn render_match(
237        &self,
238        ix: usize,
239        selected: bool,
240        _: &mut Window,
241        cx: &mut Context<Picker<Self>>,
242    ) -> Option<Self::ListItem> {
243        match self.filtered_entries.get(ix)? {
244            AcpModelPickerEntry::Separator(title) => Some(
245                div()
246                    .px_2()
247                    .pb_1()
248                    .when(ix > 1, |this| {
249                        this.mt_1()
250                            .pt_2()
251                            .border_t_1()
252                            .border_color(cx.theme().colors().border_variant)
253                    })
254                    .child(
255                        Label::new(title)
256                            .size(LabelSize::XSmall)
257                            .color(Color::Muted),
258                    )
259                    .into_any_element(),
260            ),
261            AcpModelPickerEntry::Model(model_info) => {
262                let is_selected = Some(model_info) == self.selected_model.as_ref();
263                let default_model = self.agent_server.default_model(cx);
264                let is_default = default_model.as_ref() == Some(&model_info.id);
265
266                let model_icon_color = if is_selected {
267                    Color::Accent
268                } else {
269                    Color::Muted
270                };
271
272                Some(
273                    div()
274                        .id(("model-picker-menu-child", ix))
275                        .when_some(model_info.description.clone(), |this, description| {
276                            this
277                                .on_hover(cx.listener(move |menu, hovered, _, cx| {
278                                    if *hovered {
279                                        menu.delegate.selected_description = Some((ix, description.clone(), is_default));
280                                    } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
281                                        menu.delegate.selected_description = None;
282                                    }
283                                    cx.notify();
284                                }))
285                        })
286                        .child(
287                            ListItem::new(ix)
288                                .inset(true)
289                                .spacing(ListItemSpacing::Sparse)
290                                .toggle_state(selected)
291                                .child(
292                                    h_flex()
293                                        .w_full()
294                                        .gap_1p5()
295                                        .map(|this| match &model_info.icon {
296                                            Some(icon) => this.child(match icon {
297                                                    AgentModelIcon::Path(path) => Icon::from_path(path.clone()),
298                                                    AgentModelIcon::Named(name) => Icon::new(*name),
299                                                }
300                                                .color(model_icon_color)
301                                                .size(IconSize::Small)
302                                            ),
303                                            None => this,
304                                        })
305                                        .child(Label::new(model_info.name.clone()).truncate()),
306                                )
307                                .end_slot(div().pr_3().when(is_selected, |this| {
308                                    this.child(
309                                        Icon::new(IconName::Check)
310                                            .color(Color::Accent)
311                                            .size(IconSize::Small),
312                                    )
313                                })),
314                        )
315                        .into_any_element()
316                )
317            }
318        }
319    }
320
321    fn documentation_aside(
322        &self,
323        _window: &mut Window,
324        _cx: &mut Context<Picker<Self>>,
325    ) -> Option<ui::DocumentationAside> {
326        self.selected_description
327            .as_ref()
328            .map(|(_, description, is_default)| {
329                let description = description.clone();
330                let is_default = *is_default;
331
332                DocumentationAside::new(
333                    DocumentationSide::Left,
334                    DocumentationEdge::Top,
335                    Rc::new(move |_| {
336                        v_flex()
337                            .gap_1()
338                            .child(Label::new(description.clone()))
339                            .child(HoldForDefault::new(is_default))
340                            .into_any_element()
341                    }),
342                )
343            })
344    }
345
346    fn render_footer(
347        &self,
348        _window: &mut Window,
349        cx: &mut Context<Picker<Self>>,
350    ) -> Option<AnyElement> {
351        let focus_handle = self.focus_handle.clone();
352
353        if !self.selector.should_render_footer() {
354            return None;
355        }
356
357        Some(
358            h_flex()
359                .w_full()
360                .p_1p5()
361                .border_t_1()
362                .border_color(cx.theme().colors().border_variant)
363                .child(
364                    Button::new("configure", "Configure")
365                        .full_width()
366                        .style(ButtonStyle::Outlined)
367                        .key_binding(
368                            KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
369                                .map(|kb| kb.size(rems_from_px(12.))),
370                        )
371                        .on_click(|_, window, cx| {
372                            window.dispatch_action(OpenSettings.boxed_clone(), cx);
373                        }),
374                )
375                .into_any(),
376        )
377    }
378}
379
380fn info_list_to_picker_entries(
381    model_list: AgentModelList,
382) -> impl Iterator<Item = AcpModelPickerEntry> {
383    match model_list {
384        AgentModelList::Flat(list) => {
385            itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
386        }
387        AgentModelList::Grouped(index_map) => {
388            itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
389                std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
390                    .chain(models.into_iter().map(AcpModelPickerEntry::Model))
391            }))
392        }
393    }
394}
395
396async fn fuzzy_search(
397    model_list: AgentModelList,
398    query: String,
399    executor: BackgroundExecutor,
400) -> AgentModelList {
401    async fn fuzzy_search_list(
402        model_list: Vec<AgentModelInfo>,
403        query: &str,
404        executor: BackgroundExecutor,
405    ) -> Vec<AgentModelInfo> {
406        let candidates = model_list
407            .iter()
408            .enumerate()
409            .map(|(ix, model)| {
410                StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
411            })
412            .collect::<Vec<_>>();
413        let mut matches = match_strings(
414            &candidates,
415            query,
416            false,
417            true,
418            100,
419            &Default::default(),
420            executor,
421        )
422        .await;
423
424        matches.sort_unstable_by_key(|mat| {
425            let candidate = &candidates[mat.candidate_id];
426            (Reverse(OrderedFloat(mat.score)), candidate.id)
427        });
428
429        matches
430            .into_iter()
431            .map(|mat| model_list[mat.candidate_id].clone())
432            .collect()
433    }
434
435    match model_list {
436        AgentModelList::Flat(model_list) => {
437            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
438        }
439        AgentModelList::Grouped(index_map) => {
440            let groups =
441                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
442                    fuzzy_search_list(models, &query, executor.clone())
443                        .map(|results| (group_name, results))
444                }))
445                .await;
446            AgentModelList::Grouped(IndexMap::from_iter(
447                groups
448                    .into_iter()
449                    .filter(|(_, results)| !results.is_empty()),
450            ))
451        }
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use agent_client_protocol as acp;
458    use gpui::TestAppContext;
459
460    use super::*;
461
462    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
463        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
464            |(group, models)| {
465                (
466                    acp_thread::AgentModelGroupName(group.to_string().into()),
467                    models
468                        .into_iter()
469                        .map(|model| acp_thread::AgentModelInfo {
470                            id: acp::ModelId::new(model.to_string()),
471                            name: model.to_string().into(),
472                            description: None,
473                            icon: None,
474                        })
475                        .collect::<Vec<_>>(),
476                )
477            },
478        )))
479    }
480
481    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
482        let AgentModelList::Grouped(groups) = result else {
483            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
484        };
485
486        assert_eq!(
487            groups.len(),
488            expected.len(),
489            "Number of groups doesn't match"
490        );
491
492        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
493            let (actual_group, actual_models) = groups.get_index(i).unwrap();
494            assert_eq!(
495                actual_group.0.as_ref(),
496                *expected_group,
497                "Group at position {} doesn't match expected group",
498                i
499            );
500            assert_eq!(
501                actual_models.len(),
502                expected_models.len(),
503                "Number of models in group {} doesn't match",
504                expected_group
505            );
506
507            for (j, expected_model_name) in expected_models.iter().enumerate() {
508                assert_eq!(
509                    actual_models[j].name, *expected_model_name,
510                    "Model at position {} in group {} doesn't match expected model",
511                    j, expected_group
512                );
513            }
514        }
515    }
516
517    #[gpui::test]
518    async fn test_fuzzy_match(cx: &mut TestAppContext) {
519        let models = create_model_list(vec![
520            (
521                "zed",
522                vec![
523                    "Claude 3.7 Sonnet",
524                    "Claude 3.7 Sonnet Thinking",
525                    "gpt-4.1",
526                    "gpt-4.1-nano",
527                ],
528            ),
529            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
530            ("ollama", vec!["mistral", "deepseek"]),
531        ]);
532
533        // Results should preserve models order whenever possible.
534        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
535        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
536        // so it should appear first in the results.
537        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
538        assert_models_eq(
539            results,
540            vec![
541                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
542                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
543            ],
544        );
545
546        // Fuzzy search
547        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
548        assert_models_eq(
549            results,
550            vec![
551                ("zed", vec!["gpt-4.1-nano"]),
552                ("openai", vec!["gpt-4.1-nano"]),
553            ],
554        );
555    }
556}