model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_servers::AgentServer;
  5use anyhow::Result;
  6use collections::IndexMap;
  7use fs::Fs;
  8use futures::FutureExt;
  9use fuzzy::{StringMatchCandidate, match_strings};
 10use gpui::{
 11    Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
 12};
 13use ordered_float::OrderedFloat;
 14use picker::{Picker, PickerDelegate};
 15use ui::{
 16    DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, KeyBinding, ListItem,
 17    ListItemSpacing, prelude::*,
 18};
 19use util::ResultExt;
 20use zed_actions::agent::OpenSettings;
 21
 22use crate::ui::HoldForDefault;
 23
 24pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 25
 26pub fn acp_model_selector(
 27    selector: Rc<dyn AgentModelSelector>,
 28    agent_server: Rc<dyn AgentServer>,
 29    fs: Arc<dyn Fs>,
 30    focus_handle: FocusHandle,
 31    window: &mut Window,
 32    cx: &mut Context<AcpModelSelector>,
 33) -> AcpModelSelector {
 34    let delegate =
 35        AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
 36    Picker::list(delegate, window, cx)
 37        .show_scrollbar(true)
 38        .width(rems(20.))
 39        .max_height(Some(rems(20.).into()))
 40}
 41
 42enum AcpModelPickerEntry {
 43    Separator(SharedString),
 44    Model(AgentModelInfo),
 45}
 46
 47pub struct AcpModelPickerDelegate {
 48    selector: Rc<dyn AgentModelSelector>,
 49    agent_server: Rc<dyn AgentServer>,
 50    fs: Arc<dyn Fs>,
 51    filtered_entries: Vec<AcpModelPickerEntry>,
 52    models: Option<AgentModelList>,
 53    selected_index: usize,
 54    selected_description: Option<(usize, SharedString, bool)>,
 55    selected_model: Option<AgentModelInfo>,
 56    _refresh_models_task: Task<()>,
 57    focus_handle: FocusHandle,
 58}
 59
 60impl AcpModelPickerDelegate {
 61    fn new(
 62        selector: Rc<dyn AgentModelSelector>,
 63        agent_server: Rc<dyn AgentServer>,
 64        fs: Arc<dyn Fs>,
 65        focus_handle: FocusHandle,
 66        window: &mut Window,
 67        cx: &mut Context<AcpModelSelector>,
 68    ) -> Self {
 69        let rx = selector.watch(cx);
 70        let refresh_models_task = {
 71            cx.spawn_in(window, {
 72                async move |this, cx| {
 73                    async fn refresh(
 74                        this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 75                        cx: &mut AsyncWindowContext,
 76                    ) -> Result<()> {
 77                        let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 78                            (
 79                                this.delegate.selector.list_models(cx),
 80                                this.delegate.selector.selected_model(cx),
 81                            )
 82                        })?;
 83
 84                        let (models, selected_model) =
 85                            futures::join!(models_task, selected_model_task);
 86
 87                        this.update_in(cx, |this, window, cx| {
 88                            this.delegate.models = models.ok();
 89                            this.delegate.selected_model = selected_model.ok();
 90                            this.refresh(window, cx)
 91                        })
 92                    }
 93
 94                    refresh(&this, cx).await.log_err();
 95                    if let Some(mut rx) = rx {
 96                        while let Ok(()) = rx.recv().await {
 97                            refresh(&this, cx).await.log_err();
 98                        }
 99                    }
100                }
101            })
102        };
103
104        Self {
105            selector,
106            agent_server,
107            fs,
108            filtered_entries: Vec::new(),
109            models: None,
110            selected_model: None,
111            selected_index: 0,
112            selected_description: None,
113            _refresh_models_task: refresh_models_task,
114            focus_handle,
115        }
116    }
117
118    pub fn active_model(&self) -> Option<&AgentModelInfo> {
119        self.selected_model.as_ref()
120    }
121}
122
123impl PickerDelegate for AcpModelPickerDelegate {
124    type ListItem = AnyElement;
125
126    fn match_count(&self) -> usize {
127        self.filtered_entries.len()
128    }
129
130    fn selected_index(&self) -> usize {
131        self.selected_index
132    }
133
134    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
135        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
136        cx.notify();
137    }
138
139    fn can_select(
140        &mut self,
141        ix: usize,
142        _window: &mut Window,
143        _cx: &mut Context<Picker<Self>>,
144    ) -> bool {
145        match self.filtered_entries.get(ix) {
146            Some(AcpModelPickerEntry::Model(_)) => true,
147            Some(AcpModelPickerEntry::Separator(_)) | None => false,
148        }
149    }
150
151    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
152        "Select a model…".into()
153    }
154
155    fn update_matches(
156        &mut self,
157        query: String,
158        window: &mut Window,
159        cx: &mut Context<Picker<Self>>,
160    ) -> Task<()> {
161        cx.spawn_in(window, async move |this, cx| {
162            let filtered_models = match this
163                .read_with(cx, |this, cx| {
164                    this.delegate.models.clone().map(move |models| {
165                        fuzzy_search(models, query, cx.background_executor().clone())
166                    })
167                })
168                .ok()
169                .flatten()
170            {
171                Some(task) => task.await,
172                None => AgentModelList::Flat(vec![]),
173            };
174
175            this.update_in(cx, |this, window, cx| {
176                this.delegate.filtered_entries =
177                    info_list_to_picker_entries(filtered_models).collect();
178                // Finds the currently selected model in the list
179                let new_index = this
180                    .delegate
181                    .selected_model
182                    .as_ref()
183                    .and_then(|selected| {
184                        this.delegate.filtered_entries.iter().position(|entry| {
185                            if let AcpModelPickerEntry::Model(model_info) = entry {
186                                model_info.id == selected.id
187                            } else {
188                                false
189                            }
190                        })
191                    })
192                    .unwrap_or(0);
193                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
194                cx.notify();
195            })
196            .ok();
197        })
198    }
199
200    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
201        if let Some(AcpModelPickerEntry::Model(model_info)) =
202            self.filtered_entries.get(self.selected_index)
203        {
204            if window.modifiers().secondary() {
205                let default_model = self.agent_server.default_model(cx);
206                let is_default = default_model.as_ref() == Some(&model_info.id);
207
208                self.agent_server.set_default_model(
209                    if is_default {
210                        None
211                    } else {
212                        Some(model_info.id.clone())
213                    },
214                    self.fs.clone(),
215                    cx,
216                );
217            }
218
219            self.selector
220                .select_model(model_info.id.clone(), cx)
221                .detach_and_log_err(cx);
222            self.selected_model = Some(model_info.clone());
223            let current_index = self.selected_index;
224            self.set_selected_index(current_index, window, cx);
225
226            cx.emit(DismissEvent);
227        }
228    }
229
230    fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
231        cx.defer_in(window, |picker, window, cx| {
232            picker.set_query("", window, cx);
233        });
234    }
235
236    fn render_match(
237        &self,
238        ix: usize,
239        selected: bool,
240        _: &mut Window,
241        cx: &mut Context<Picker<Self>>,
242    ) -> Option<Self::ListItem> {
243        match self.filtered_entries.get(ix)? {
244            AcpModelPickerEntry::Separator(title) => Some(
245                div()
246                    .px_2()
247                    .pb_1()
248                    .when(ix > 1, |this| {
249                        this.mt_1()
250                            .pt_2()
251                            .border_t_1()
252                            .border_color(cx.theme().colors().border_variant)
253                    })
254                    .child(
255                        Label::new(title)
256                            .size(LabelSize::XSmall)
257                            .color(Color::Muted),
258                    )
259                    .into_any_element(),
260            ),
261            AcpModelPickerEntry::Model(model_info) => {
262                let is_selected = Some(model_info) == self.selected_model.as_ref();
263                let default_model = self.agent_server.default_model(cx);
264                let is_default = default_model.as_ref() == Some(&model_info.id);
265
266                let model_icon_color = if is_selected {
267                    Color::Accent
268                } else {
269                    Color::Muted
270                };
271
272                Some(
273                    div()
274                        .id(("model-picker-menu-child", ix))
275                        .when_some(model_info.description.clone(), |this, description| {
276                            this
277                                .on_hover(cx.listener(move |menu, hovered, _, cx| {
278                                    if *hovered {
279                                        menu.delegate.selected_description = Some((ix, description.clone(), is_default));
280                                    } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
281                                        menu.delegate.selected_description = None;
282                                    }
283                                    cx.notify();
284                                }))
285                        })
286                        .child(
287                            ListItem::new(ix)
288                                .inset(true)
289                                .spacing(ListItemSpacing::Sparse)
290                                .toggle_state(selected)
291                                .child(
292                                    h_flex()
293                                        .w_full()
294                                        .gap_1p5()
295                                        .map(|this| match &model_info.icon {
296                                            Some(AgentModelIcon::Path(path)) => this.child(
297                                                Icon::from_path(path.clone())
298                                                    .color(model_icon_color)
299                                                    .size(IconSize::Small),
300                                            ),
301                                            Some(AgentModelIcon::Named(icon)) => this.child(
302                                                Icon::new(*icon)
303                                                    .color(model_icon_color)
304                                                    .size(IconSize::Small),
305                                            ),
306                                            None => this,
307                                        })
308                                        .child(Label::new(model_info.name.clone()).truncate()),
309                                )
310                                .end_slot(div().pr_3().when(is_selected, |this| {
311                                    this.child(
312                                        Icon::new(IconName::Check)
313                                            .color(Color::Accent)
314                                            .size(IconSize::Small),
315                                    )
316                                })),
317                        )
318                        .into_any_element()
319                )
320            }
321        }
322    }
323
324    fn documentation_aside(
325        &self,
326        _window: &mut Window,
327        _cx: &mut Context<Picker<Self>>,
328    ) -> Option<ui::DocumentationAside> {
329        self.selected_description
330            .as_ref()
331            .map(|(_, description, is_default)| {
332                let description = description.clone();
333                let is_default = *is_default;
334
335                DocumentationAside::new(
336                    DocumentationSide::Left,
337                    DocumentationEdge::Top,
338                    Rc::new(move |_| {
339                        v_flex()
340                            .gap_1()
341                            .child(Label::new(description.clone()))
342                            .child(HoldForDefault::new(is_default))
343                            .into_any_element()
344                    }),
345                )
346            })
347    }
348
349    fn render_footer(
350        &self,
351        _window: &mut Window,
352        cx: &mut Context<Picker<Self>>,
353    ) -> Option<AnyElement> {
354        let focus_handle = self.focus_handle.clone();
355
356        if !self.selector.should_render_footer() {
357            return None;
358        }
359
360        Some(
361            h_flex()
362                .w_full()
363                .p_1p5()
364                .border_t_1()
365                .border_color(cx.theme().colors().border_variant)
366                .child(
367                    Button::new("configure", "Configure")
368                        .full_width()
369                        .style(ButtonStyle::Outlined)
370                        .key_binding(
371                            KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
372                                .map(|kb| kb.size(rems_from_px(12.))),
373                        )
374                        .on_click(|_, window, cx| {
375                            window.dispatch_action(OpenSettings.boxed_clone(), cx);
376                        }),
377                )
378                .into_any(),
379        )
380    }
381}
382
383fn info_list_to_picker_entries(
384    model_list: AgentModelList,
385) -> impl Iterator<Item = AcpModelPickerEntry> {
386    match model_list {
387        AgentModelList::Flat(list) => {
388            itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
389        }
390        AgentModelList::Grouped(index_map) => {
391            itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
392                std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
393                    .chain(models.into_iter().map(AcpModelPickerEntry::Model))
394            }))
395        }
396    }
397}
398
399async fn fuzzy_search(
400    model_list: AgentModelList,
401    query: String,
402    executor: BackgroundExecutor,
403) -> AgentModelList {
404    async fn fuzzy_search_list(
405        model_list: Vec<AgentModelInfo>,
406        query: &str,
407        executor: BackgroundExecutor,
408    ) -> Vec<AgentModelInfo> {
409        let candidates = model_list
410            .iter()
411            .enumerate()
412            .map(|(ix, model)| {
413                StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
414            })
415            .collect::<Vec<_>>();
416        let mut matches = match_strings(
417            &candidates,
418            query,
419            false,
420            true,
421            100,
422            &Default::default(),
423            executor,
424        )
425        .await;
426
427        matches.sort_unstable_by_key(|mat| {
428            let candidate = &candidates[mat.candidate_id];
429            (Reverse(OrderedFloat(mat.score)), candidate.id)
430        });
431
432        matches
433            .into_iter()
434            .map(|mat| model_list[mat.candidate_id].clone())
435            .collect()
436    }
437
438    match model_list {
439        AgentModelList::Flat(model_list) => {
440            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
441        }
442        AgentModelList::Grouped(index_map) => {
443            let groups =
444                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
445                    fuzzy_search_list(models, &query, executor.clone())
446                        .map(|results| (group_name, results))
447                }))
448                .await;
449            AgentModelList::Grouped(IndexMap::from_iter(
450                groups
451                    .into_iter()
452                    .filter(|(_, results)| !results.is_empty()),
453            ))
454        }
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use agent_client_protocol as acp;
461    use gpui::TestAppContext;
462
463    use super::*;
464
465    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
466        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
467            |(group, models)| {
468                (
469                    acp_thread::AgentModelGroupName(group.to_string().into()),
470                    models
471                        .into_iter()
472                        .map(|model| acp_thread::AgentModelInfo {
473                            id: acp::ModelId::new(model.to_string()),
474                            name: model.to_string().into(),
475                            description: None,
476                            icon: None,
477                        })
478                        .collect::<Vec<_>>(),
479                )
480            },
481        )))
482    }
483
484    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
485        let AgentModelList::Grouped(groups) = result else {
486            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
487        };
488
489        assert_eq!(
490            groups.len(),
491            expected.len(),
492            "Number of groups doesn't match"
493        );
494
495        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
496            let (actual_group, actual_models) = groups.get_index(i).unwrap();
497            assert_eq!(
498                actual_group.0.as_ref(),
499                *expected_group,
500                "Group at position {} doesn't match expected group",
501                i
502            );
503            assert_eq!(
504                actual_models.len(),
505                expected_models.len(),
506                "Number of models in group {} doesn't match",
507                expected_group
508            );
509
510            for (j, expected_model_name) in expected_models.iter().enumerate() {
511                assert_eq!(
512                    actual_models[j].name, *expected_model_name,
513                    "Model at position {} in group {} doesn't match expected model",
514                    j, expected_group
515                );
516            }
517        }
518    }
519
520    #[gpui::test]
521    async fn test_fuzzy_match(cx: &mut TestAppContext) {
522        let models = create_model_list(vec![
523            (
524                "zed",
525                vec![
526                    "Claude 3.7 Sonnet",
527                    "Claude 3.7 Sonnet Thinking",
528                    "gpt-4.1",
529                    "gpt-4.1-nano",
530                ],
531            ),
532            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
533            ("ollama", vec!["mistral", "deepseek"]),
534        ]);
535
536        // Results should preserve models order whenever possible.
537        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
538        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
539        // so it should appear first in the results.
540        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
541        assert_models_eq(
542            results,
543            vec![
544                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
545                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
546            ],
547        );
548
549        // Fuzzy search
550        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
551        assert_models_eq(
552            results,
553            vec![
554                ("zed", vec!["gpt-4.1-nano"]),
555                ("openai", vec!["gpt-4.1-nano"]),
556            ],
557        );
558    }
559}