model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_servers::AgentServer;
  5use anyhow::Result;
  6use collections::IndexMap;
  7use fs::Fs;
  8use futures::FutureExt;
  9use fuzzy::{StringMatchCandidate, match_strings};
 10use gpui::{
 11    Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
 12};
 13use ordered_float::OrderedFloat;
 14use picker::{Picker, PickerDelegate};
 15use ui::{
 16    DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, KeyBinding, ListItem,
 17    ListItemSpacing, prelude::*,
 18};
 19use util::ResultExt;
 20use zed_actions::agent::OpenSettings;
 21
 22use crate::ui::HoldForDefault;
 23
 24pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 25
 26pub fn acp_model_selector(
 27    selector: Rc<dyn AgentModelSelector>,
 28    agent_server: Rc<dyn AgentServer>,
 29    fs: Arc<dyn Fs>,
 30    focus_handle: FocusHandle,
 31    window: &mut Window,
 32    cx: &mut Context<AcpModelSelector>,
 33) -> AcpModelSelector {
 34    let delegate =
 35        AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
 36    Picker::list(delegate, window, cx)
 37        .show_scrollbar(true)
 38        .width(rems(20.))
 39        .max_height(Some(rems(20.).into()))
 40}
 41
 42enum AcpModelPickerEntry {
 43    Separator(SharedString),
 44    Model(AgentModelInfo),
 45}
 46
 47pub struct AcpModelPickerDelegate {
 48    selector: Rc<dyn AgentModelSelector>,
 49    agent_server: Rc<dyn AgentServer>,
 50    fs: Arc<dyn Fs>,
 51    filtered_entries: Vec<AcpModelPickerEntry>,
 52    models: Option<AgentModelList>,
 53    selected_index: usize,
 54    selected_description: Option<(usize, SharedString, bool)>,
 55    selected_model: Option<AgentModelInfo>,
 56    _refresh_models_task: Task<()>,
 57    focus_handle: FocusHandle,
 58}
 59
 60impl AcpModelPickerDelegate {
 61    fn new(
 62        selector: Rc<dyn AgentModelSelector>,
 63        agent_server: Rc<dyn AgentServer>,
 64        fs: Arc<dyn Fs>,
 65        focus_handle: FocusHandle,
 66        window: &mut Window,
 67        cx: &mut Context<AcpModelSelector>,
 68    ) -> Self {
 69        let rx = selector.watch(cx);
 70        let refresh_models_task = {
 71            cx.spawn_in(window, {
 72                async move |this, cx| {
 73                    async fn refresh(
 74                        this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 75                        cx: &mut AsyncWindowContext,
 76                    ) -> Result<()> {
 77                        let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 78                            (
 79                                this.delegate.selector.list_models(cx),
 80                                this.delegate.selector.selected_model(cx),
 81                            )
 82                        })?;
 83
 84                        let (models, selected_model) =
 85                            futures::join!(models_task, selected_model_task);
 86
 87                        this.update_in(cx, |this, window, cx| {
 88                            this.delegate.models = models.ok();
 89                            this.delegate.selected_model = selected_model.ok();
 90                            this.refresh(window, cx)
 91                        })
 92                    }
 93
 94                    refresh(&this, cx).await.log_err();
 95                    if let Some(mut rx) = rx {
 96                        while let Ok(()) = rx.recv().await {
 97                            refresh(&this, cx).await.log_err();
 98                        }
 99                    }
100                }
101            })
102        };
103
104        Self {
105            selector,
106            agent_server,
107            fs,
108            filtered_entries: Vec::new(),
109            models: None,
110            selected_model: None,
111            selected_index: 0,
112            selected_description: None,
113            _refresh_models_task: refresh_models_task,
114            focus_handle,
115        }
116    }
117
118    pub fn active_model(&self) -> Option<&AgentModelInfo> {
119        self.selected_model.as_ref()
120    }
121}
122
123impl PickerDelegate for AcpModelPickerDelegate {
124    type ListItem = AnyElement;
125
126    fn match_count(&self) -> usize {
127        self.filtered_entries.len()
128    }
129
130    fn selected_index(&self) -> usize {
131        self.selected_index
132    }
133
134    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
135        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
136        cx.notify();
137    }
138
139    fn can_select(
140        &mut self,
141        ix: usize,
142        _window: &mut Window,
143        _cx: &mut Context<Picker<Self>>,
144    ) -> bool {
145        match self.filtered_entries.get(ix) {
146            Some(AcpModelPickerEntry::Model(_)) => true,
147            Some(AcpModelPickerEntry::Separator(_)) | None => false,
148        }
149    }
150
151    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
152        "Select a model…".into()
153    }
154
155    fn update_matches(
156        &mut self,
157        query: String,
158        window: &mut Window,
159        cx: &mut Context<Picker<Self>>,
160    ) -> Task<()> {
161        cx.spawn_in(window, async move |this, cx| {
162            let filtered_models = match this
163                .read_with(cx, |this, cx| {
164                    this.delegate.models.clone().map(move |models| {
165                        fuzzy_search(models, query, cx.background_executor().clone())
166                    })
167                })
168                .ok()
169                .flatten()
170            {
171                Some(task) => task.await,
172                None => AgentModelList::Flat(vec![]),
173            };
174
175            this.update_in(cx, |this, window, cx| {
176                this.delegate.filtered_entries =
177                    info_list_to_picker_entries(filtered_models).collect();
178                // Finds the currently selected model in the list
179                let new_index = this
180                    .delegate
181                    .selected_model
182                    .as_ref()
183                    .and_then(|selected| {
184                        this.delegate.filtered_entries.iter().position(|entry| {
185                            if let AcpModelPickerEntry::Model(model_info) = entry {
186                                model_info.id == selected.id
187                            } else {
188                                false
189                            }
190                        })
191                    })
192                    .unwrap_or(0);
193                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
194                cx.notify();
195            })
196            .ok();
197        })
198    }
199
200    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
201        if let Some(AcpModelPickerEntry::Model(model_info)) =
202            self.filtered_entries.get(self.selected_index)
203        {
204            if window.modifiers().secondary() {
205                let default_model = self.agent_server.default_model(cx);
206                let is_default = default_model.as_ref() == Some(&model_info.id);
207
208                self.agent_server.set_default_model(
209                    if is_default {
210                        None
211                    } else {
212                        Some(model_info.id.clone())
213                    },
214                    self.fs.clone(),
215                    cx,
216                );
217            }
218
219            self.selector
220                .select_model(model_info.id.clone(), cx)
221                .detach_and_log_err(cx);
222            self.selected_model = Some(model_info.clone());
223            let current_index = self.selected_index;
224            self.set_selected_index(current_index, window, cx);
225
226            cx.emit(DismissEvent);
227        }
228    }
229
230    fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
231        cx.defer_in(window, |picker, window, cx| {
232            picker.set_query("", window, cx);
233        });
234    }
235
236    fn render_match(
237        &self,
238        ix: usize,
239        selected: bool,
240        _: &mut Window,
241        cx: &mut Context<Picker<Self>>,
242    ) -> Option<Self::ListItem> {
243        match self.filtered_entries.get(ix)? {
244            AcpModelPickerEntry::Separator(title) => Some(
245                div()
246                    .px_2()
247                    .pb_1()
248                    .when(ix > 1, |this| {
249                        this.mt_1()
250                            .pt_2()
251                            .border_t_1()
252                            .border_color(cx.theme().colors().border_variant)
253                    })
254                    .child(
255                        Label::new(title)
256                            .size(LabelSize::XSmall)
257                            .color(Color::Muted),
258                    )
259                    .into_any_element(),
260            ),
261            AcpModelPickerEntry::Model(model_info) => {
262                let is_selected = Some(model_info) == self.selected_model.as_ref();
263                let default_model = self.agent_server.default_model(cx);
264                let is_default = default_model.as_ref() == Some(&model_info.id);
265
266                let model_icon_color = if is_selected {
267                    Color::Accent
268                } else {
269                    Color::Muted
270                };
271
272                Some(
273                    div()
274                        .id(("model-picker-menu-child", ix))
275                        .when_some(model_info.description.clone(), |this, description| {
276                            this
277                                .on_hover(cx.listener(move |menu, hovered, _, cx| {
278                                    if *hovered {
279                                        menu.delegate.selected_description = Some((ix, description.clone(), is_default));
280                                    } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
281                                        menu.delegate.selected_description = None;
282                                    }
283                                    cx.notify();
284                                }))
285                        })
286                        .child(
287                            ListItem::new(ix)
288                                .inset(true)
289                                .spacing(ListItemSpacing::Sparse)
290                                .toggle_state(selected)
291                                .child(
292                                    h_flex()
293                                        .w_full()
294                                        .gap_1p5()
295                                        .map(|this| match &model_info.icon {
296                                            Some(icon) => this.child(
297                                                match icon {
298                                                    AgentModelIcon::Path(path) => Icon::from_external_svg(path.clone()),
299                                                    AgentModelIcon::Named(icon) => Icon::new(*icon)
300                                                }
301                                                .color(model_icon_color)
302                                                .size(IconSize::Small)
303                                            ),
304                                            None => this,
305                                        })
306                                        .child(Label::new(model_info.name.clone()).truncate()),
307                                )
308                                .end_slot(div().pr_3().when(is_selected, |this| {
309                                    this.child(
310                                        Icon::new(IconName::Check)
311                                            .color(Color::Accent)
312                                            .size(IconSize::Small),
313                                    )
314                                })),
315                        )
316                        .into_any_element()
317                )
318            }
319        }
320    }
321
322    fn documentation_aside(
323        &self,
324        _window: &mut Window,
325        _cx: &mut Context<Picker<Self>>,
326    ) -> Option<ui::DocumentationAside> {
327        self.selected_description
328            .as_ref()
329            .map(|(_, description, is_default)| {
330                let description = description.clone();
331                let is_default = *is_default;
332
333                DocumentationAside::new(
334                    DocumentationSide::Left,
335                    DocumentationEdge::Top,
336                    Rc::new(move |_| {
337                        v_flex()
338                            .gap_1()
339                            .child(Label::new(description.clone()))
340                            .child(HoldForDefault::new(is_default))
341                            .into_any_element()
342                    }),
343                )
344            })
345    }
346
347    fn render_footer(
348        &self,
349        _window: &mut Window,
350        cx: &mut Context<Picker<Self>>,
351    ) -> Option<AnyElement> {
352        let focus_handle = self.focus_handle.clone();
353
354        if !self.selector.should_render_footer() {
355            return None;
356        }
357
358        Some(
359            h_flex()
360                .w_full()
361                .p_1p5()
362                .border_t_1()
363                .border_color(cx.theme().colors().border_variant)
364                .child(
365                    Button::new("configure", "Configure")
366                        .full_width()
367                        .style(ButtonStyle::Outlined)
368                        .key_binding(
369                            KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
370                                .map(|kb| kb.size(rems_from_px(12.))),
371                        )
372                        .on_click(|_, window, cx| {
373                            window.dispatch_action(OpenSettings.boxed_clone(), cx);
374                        }),
375                )
376                .into_any(),
377        )
378    }
379}
380
381fn info_list_to_picker_entries(
382    model_list: AgentModelList,
383) -> impl Iterator<Item = AcpModelPickerEntry> {
384    match model_list {
385        AgentModelList::Flat(list) => {
386            itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
387        }
388        AgentModelList::Grouped(index_map) => {
389            itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
390                std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
391                    .chain(models.into_iter().map(AcpModelPickerEntry::Model))
392            }))
393        }
394    }
395}
396
397async fn fuzzy_search(
398    model_list: AgentModelList,
399    query: String,
400    executor: BackgroundExecutor,
401) -> AgentModelList {
402    async fn fuzzy_search_list(
403        model_list: Vec<AgentModelInfo>,
404        query: &str,
405        executor: BackgroundExecutor,
406    ) -> Vec<AgentModelInfo> {
407        let candidates = model_list
408            .iter()
409            .enumerate()
410            .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
411            .collect::<Vec<_>>();
412        let mut matches = match_strings(
413            &candidates,
414            query,
415            false,
416            true,
417            100,
418            &Default::default(),
419            executor,
420        )
421        .await;
422
423        matches.sort_unstable_by_key(|mat| {
424            let candidate = &candidates[mat.candidate_id];
425            (Reverse(OrderedFloat(mat.score)), candidate.id)
426        });
427
428        matches
429            .into_iter()
430            .map(|mat| model_list[mat.candidate_id].clone())
431            .collect()
432    }
433
434    match model_list {
435        AgentModelList::Flat(model_list) => {
436            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
437        }
438        AgentModelList::Grouped(index_map) => {
439            let groups =
440                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
441                    fuzzy_search_list(models, &query, executor.clone())
442                        .map(|results| (group_name, results))
443                }))
444                .await;
445            AgentModelList::Grouped(IndexMap::from_iter(
446                groups
447                    .into_iter()
448                    .filter(|(_, results)| !results.is_empty()),
449            ))
450        }
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use agent_client_protocol as acp;
457    use gpui::TestAppContext;
458
459    use super::*;
460
461    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
462        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
463            |(group, models)| {
464                (
465                    acp_thread::AgentModelGroupName(group.to_string().into()),
466                    models
467                        .into_iter()
468                        .map(|model| acp_thread::AgentModelInfo {
469                            id: acp::ModelId::new(model.to_string()),
470                            name: model.to_string().into(),
471                            description: None,
472                            icon: None,
473                        })
474                        .collect::<Vec<_>>(),
475                )
476            },
477        )))
478    }
479
480    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
481        let AgentModelList::Grouped(groups) = result else {
482            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
483        };
484
485        assert_eq!(
486            groups.len(),
487            expected.len(),
488            "Number of groups doesn't match"
489        );
490
491        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
492            let (actual_group, actual_models) = groups.get_index(i).unwrap();
493            assert_eq!(
494                actual_group.0.as_ref(),
495                *expected_group,
496                "Group at position {} doesn't match expected group",
497                i
498            );
499            assert_eq!(
500                actual_models.len(),
501                expected_models.len(),
502                "Number of models in group {} doesn't match",
503                expected_group
504            );
505
506            for (j, expected_model_name) in expected_models.iter().enumerate() {
507                assert_eq!(
508                    actual_models[j].name, *expected_model_name,
509                    "Model at position {} in group {} doesn't match expected model",
510                    j, expected_group
511                );
512            }
513        }
514    }
515
516    #[gpui::test]
517    async fn test_fuzzy_match(cx: &mut TestAppContext) {
518        let models = create_model_list(vec![
519            (
520                "zed",
521                vec![
522                    "Claude 3.7 Sonnet",
523                    "Claude 3.7 Sonnet Thinking",
524                    "gpt-4.1",
525                    "gpt-4.1-nano",
526                ],
527            ),
528            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
529            ("ollama", vec!["mistral", "deepseek"]),
530        ]);
531
532        // Results should preserve models order whenever possible.
533        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
534        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
535        // so it should appear first in the results.
536        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
537        assert_models_eq(
538            results,
539            vec![
540                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
541                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
542            ],
543        );
544
545        // Fuzzy search
546        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
547        assert_models_eq(
548            results,
549            vec![
550                ("zed", vec!["gpt-4.1-nano"]),
551                ("openai", vec!["gpt-4.1-nano"]),
552            ],
553        );
554    }
555}