model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_servers::AgentServer;
  5use anyhow::Result;
  6use collections::IndexMap;
  7use fs::Fs;
  8use futures::FutureExt;
  9use fuzzy::{StringMatchCandidate, match_strings};
 10use gpui::{
 11    Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
 12};
 13use ordered_float::OrderedFloat;
 14use picker::{Picker, PickerDelegate};
 15use ui::{
 16    DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, KeyBinding, ListItem,
 17    ListItemSpacing, prelude::*,
 18};
 19use util::ResultExt;
 20use zed_actions::agent::OpenSettings;
 21
 22use crate::ui::HoldForDefault;
 23
 24pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 25
 26pub fn acp_model_selector(
 27    selector: Rc<dyn AgentModelSelector>,
 28    agent_server: Rc<dyn AgentServer>,
 29    fs: Arc<dyn Fs>,
 30    focus_handle: FocusHandle,
 31    window: &mut Window,
 32    cx: &mut Context<AcpModelSelector>,
 33) -> AcpModelSelector {
 34    let delegate =
 35        AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
 36    Picker::list(delegate, window, cx)
 37        .show_scrollbar(true)
 38        .width(rems(20.))
 39        .max_height(Some(rems(20.).into()))
 40}
 41
 42enum AcpModelPickerEntry {
 43    Separator(SharedString),
 44    Model(AgentModelInfo),
 45}
 46
 47pub struct AcpModelPickerDelegate {
 48    selector: Rc<dyn AgentModelSelector>,
 49    agent_server: Rc<dyn AgentServer>,
 50    fs: Arc<dyn Fs>,
 51    filtered_entries: Vec<AcpModelPickerEntry>,
 52    models: Option<AgentModelList>,
 53    selected_index: usize,
 54    selected_description: Option<(usize, SharedString, bool)>,
 55    selected_model: Option<AgentModelInfo>,
 56    _refresh_models_task: Task<()>,
 57    focus_handle: FocusHandle,
 58}
 59
 60impl AcpModelPickerDelegate {
 61    fn new(
 62        selector: Rc<dyn AgentModelSelector>,
 63        agent_server: Rc<dyn AgentServer>,
 64        fs: Arc<dyn Fs>,
 65        focus_handle: FocusHandle,
 66        window: &mut Window,
 67        cx: &mut Context<AcpModelSelector>,
 68    ) -> Self {
 69        let rx = selector.watch(cx);
 70        let refresh_models_task = {
 71            cx.spawn_in(window, {
 72                async move |this, cx| {
 73                    async fn refresh(
 74                        this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 75                        cx: &mut AsyncWindowContext,
 76                    ) -> Result<()> {
 77                        let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 78                            (
 79                                this.delegate.selector.list_models(cx),
 80                                this.delegate.selector.selected_model(cx),
 81                            )
 82                        })?;
 83
 84                        let (models, selected_model) =
 85                            futures::join!(models_task, selected_model_task);
 86
 87                        this.update_in(cx, |this, window, cx| {
 88                            this.delegate.models = models.ok();
 89                            this.delegate.selected_model = selected_model.ok();
 90                            this.refresh(window, cx)
 91                        })
 92                    }
 93
 94                    refresh(&this, cx).await.log_err();
 95                    if let Some(mut rx) = rx {
 96                        while let Ok(()) = rx.recv().await {
 97                            refresh(&this, cx).await.log_err();
 98                        }
 99                    }
100                }
101            })
102        };
103
104        Self {
105            selector,
106            agent_server,
107            fs,
108            filtered_entries: Vec::new(),
109            models: None,
110            selected_model: None,
111            selected_index: 0,
112            selected_description: None,
113            _refresh_models_task: refresh_models_task,
114            focus_handle,
115        }
116    }
117
118    pub fn active_model(&self) -> Option<&AgentModelInfo> {
119        self.selected_model.as_ref()
120    }
121}
122
123impl PickerDelegate for AcpModelPickerDelegate {
124    type ListItem = AnyElement;
125
126    fn match_count(&self) -> usize {
127        self.filtered_entries.len()
128    }
129
130    fn selected_index(&self) -> usize {
131        self.selected_index
132    }
133
134    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
135        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
136        cx.notify();
137    }
138
139    fn can_select(
140        &mut self,
141        ix: usize,
142        _window: &mut Window,
143        _cx: &mut Context<Picker<Self>>,
144    ) -> bool {
145        match self.filtered_entries.get(ix) {
146            Some(AcpModelPickerEntry::Model(_)) => true,
147            Some(AcpModelPickerEntry::Separator(_)) | None => false,
148        }
149    }
150
151    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
152        "Select a model…".into()
153    }
154
155    fn update_matches(
156        &mut self,
157        query: String,
158        window: &mut Window,
159        cx: &mut Context<Picker<Self>>,
160    ) -> Task<()> {
161        cx.spawn_in(window, async move |this, cx| {
162            let filtered_models = match this
163                .read_with(cx, |this, cx| {
164                    this.delegate.models.clone().map(move |models| {
165                        fuzzy_search(models, query, cx.background_executor().clone())
166                    })
167                })
168                .ok()
169                .flatten()
170            {
171                Some(task) => task.await,
172                None => AgentModelList::Flat(vec![]),
173            };
174
175            this.update_in(cx, |this, window, cx| {
176                this.delegate.filtered_entries =
177                    info_list_to_picker_entries(filtered_models).collect();
178                // Finds the currently selected model in the list
179                let new_index = this
180                    .delegate
181                    .selected_model
182                    .as_ref()
183                    .and_then(|selected| {
184                        this.delegate.filtered_entries.iter().position(|entry| {
185                            if let AcpModelPickerEntry::Model(model_info) = entry {
186                                model_info.id == selected.id
187                            } else {
188                                false
189                            }
190                        })
191                    })
192                    .unwrap_or(0);
193                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
194                cx.notify();
195            })
196            .ok();
197        })
198    }
199
200    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
201        if let Some(AcpModelPickerEntry::Model(model_info)) =
202            self.filtered_entries.get(self.selected_index)
203        {
204            if window.modifiers().secondary() {
205                let default_model = self.agent_server.default_model(cx);
206                let is_default = default_model.as_ref() == Some(&model_info.id);
207
208                self.agent_server.set_default_model(
209                    if is_default {
210                        None
211                    } else {
212                        Some(model_info.id.clone())
213                    },
214                    self.fs.clone(),
215                    cx,
216                );
217            }
218
219            self.selector
220                .select_model(model_info.id.clone(), cx)
221                .detach_and_log_err(cx);
222            self.selected_model = Some(model_info.clone());
223            let current_index = self.selected_index;
224            self.set_selected_index(current_index, window, cx);
225
226            cx.emit(DismissEvent);
227        }
228    }
229
230    fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
231        cx.defer_in(window, |picker, window, cx| {
232            picker.set_query("", window, cx);
233        });
234    }
235
236    fn render_match(
237        &self,
238        ix: usize,
239        selected: bool,
240        _: &mut Window,
241        cx: &mut Context<Picker<Self>>,
242    ) -> Option<Self::ListItem> {
243        match self.filtered_entries.get(ix)? {
244            AcpModelPickerEntry::Separator(title) => Some(
245                div()
246                    .px_2()
247                    .pb_1()
248                    .when(ix > 1, |this| {
249                        this.mt_1()
250                            .pt_2()
251                            .border_t_1()
252                            .border_color(cx.theme().colors().border_variant)
253                    })
254                    .child(
255                        Label::new(title)
256                            .size(LabelSize::XSmall)
257                            .color(Color::Muted),
258                    )
259                    .into_any_element(),
260            ),
261            AcpModelPickerEntry::Model(model_info) => {
262                let is_selected = Some(model_info) == self.selected_model.as_ref();
263                let default_model = self.agent_server.default_model(cx);
264                let is_default = default_model.as_ref() == Some(&model_info.id);
265
266                let model_icon_color = if is_selected {
267                    Color::Accent
268                } else {
269                    Color::Muted
270                };
271
272                Some(
273                    div()
274                        .id(("model-picker-menu-child", ix))
275                        .when_some(model_info.description.clone(), |this, description| {
276                            this
277                                .on_hover(cx.listener(move |menu, hovered, _, cx| {
278                                    if *hovered {
279                                        menu.delegate.selected_description = Some((ix, description.clone(), is_default));
280                                    } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
281                                        menu.delegate.selected_description = None;
282                                    }
283                                    cx.notify();
284                                }))
285                        })
286                        .child(
287                            ListItem::new(ix)
288                                .inset(true)
289                                .spacing(ListItemSpacing::Sparse)
290                                .toggle_state(selected)
291                                .child(
292                                    h_flex()
293                                        .w_full()
294                                        .gap_1p5()
295                                        .when_some(model_info.icon, |this, icon| {
296                                            this.child(
297                                                Icon::new(icon)
298                                                    .color(model_icon_color)
299                                                    .size(IconSize::Small)
300                                            )
301                                        })
302                                        .child(Label::new(model_info.name.clone()).truncate()),
303                                )
304                                .end_slot(div().pr_3().when(is_selected, |this| {
305                                    this.child(
306                                        Icon::new(IconName::Check)
307                                            .color(Color::Accent)
308                                            .size(IconSize::Small),
309                                    )
310                                })),
311                        )
312                        .into_any_element()
313                )
314            }
315        }
316    }
317
318    fn documentation_aside(
319        &self,
320        _window: &mut Window,
321        _cx: &mut Context<Picker<Self>>,
322    ) -> Option<ui::DocumentationAside> {
323        self.selected_description
324            .as_ref()
325            .map(|(_, description, is_default)| {
326                let description = description.clone();
327                let is_default = *is_default;
328
329                DocumentationAside::new(
330                    DocumentationSide::Left,
331                    DocumentationEdge::Top,
332                    Rc::new(move |_| {
333                        v_flex()
334                            .gap_1()
335                            .child(Label::new(description.clone()))
336                            .child(HoldForDefault::new(is_default))
337                            .into_any_element()
338                    }),
339                )
340            })
341    }
342
343    fn render_footer(
344        &self,
345        _window: &mut Window,
346        cx: &mut Context<Picker<Self>>,
347    ) -> Option<AnyElement> {
348        let focus_handle = self.focus_handle.clone();
349
350        if !self.selector.should_render_footer() {
351            return None;
352        }
353
354        Some(
355            h_flex()
356                .w_full()
357                .p_1p5()
358                .border_t_1()
359                .border_color(cx.theme().colors().border_variant)
360                .child(
361                    Button::new("configure", "Configure")
362                        .full_width()
363                        .style(ButtonStyle::Outlined)
364                        .key_binding(
365                            KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
366                                .map(|kb| kb.size(rems_from_px(12.))),
367                        )
368                        .on_click(|_, window, cx| {
369                            window.dispatch_action(OpenSettings.boxed_clone(), cx);
370                        }),
371                )
372                .into_any(),
373        )
374    }
375}
376
377fn info_list_to_picker_entries(
378    model_list: AgentModelList,
379) -> impl Iterator<Item = AcpModelPickerEntry> {
380    match model_list {
381        AgentModelList::Flat(list) => {
382            itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
383        }
384        AgentModelList::Grouped(index_map) => {
385            itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
386                std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
387                    .chain(models.into_iter().map(AcpModelPickerEntry::Model))
388            }))
389        }
390    }
391}
392
393async fn fuzzy_search(
394    model_list: AgentModelList,
395    query: String,
396    executor: BackgroundExecutor,
397) -> AgentModelList {
398    async fn fuzzy_search_list(
399        model_list: Vec<AgentModelInfo>,
400        query: &str,
401        executor: BackgroundExecutor,
402    ) -> Vec<AgentModelInfo> {
403        let candidates = model_list
404            .iter()
405            .enumerate()
406            .map(|(ix, model)| {
407                StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
408            })
409            .collect::<Vec<_>>();
410        let mut matches = match_strings(
411            &candidates,
412            query,
413            false,
414            true,
415            100,
416            &Default::default(),
417            executor,
418        )
419        .await;
420
421        matches.sort_unstable_by_key(|mat| {
422            let candidate = &candidates[mat.candidate_id];
423            (Reverse(OrderedFloat(mat.score)), candidate.id)
424        });
425
426        matches
427            .into_iter()
428            .map(|mat| model_list[mat.candidate_id].clone())
429            .collect()
430    }
431
432    match model_list {
433        AgentModelList::Flat(model_list) => {
434            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
435        }
436        AgentModelList::Grouped(index_map) => {
437            let groups =
438                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
439                    fuzzy_search_list(models, &query, executor.clone())
440                        .map(|results| (group_name, results))
441                }))
442                .await;
443            AgentModelList::Grouped(IndexMap::from_iter(
444                groups
445                    .into_iter()
446                    .filter(|(_, results)| !results.is_empty()),
447            ))
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use agent_client_protocol as acp;
455    use gpui::TestAppContext;
456
457    use super::*;
458
459    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
460        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
461            |(group, models)| {
462                (
463                    acp_thread::AgentModelGroupName(group.to_string().into()),
464                    models
465                        .into_iter()
466                        .map(|model| acp_thread::AgentModelInfo {
467                            id: acp::ModelId(model.to_string().into()),
468                            name: model.to_string().into(),
469                            description: None,
470                            icon: None,
471                        })
472                        .collect::<Vec<_>>(),
473                )
474            },
475        )))
476    }
477
478    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
479        let AgentModelList::Grouped(groups) = result else {
480            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
481        };
482
483        assert_eq!(
484            groups.len(),
485            expected.len(),
486            "Number of groups doesn't match"
487        );
488
489        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
490            let (actual_group, actual_models) = groups.get_index(i).unwrap();
491            assert_eq!(
492                actual_group.0.as_ref(),
493                *expected_group,
494                "Group at position {} doesn't match expected group",
495                i
496            );
497            assert_eq!(
498                actual_models.len(),
499                expected_models.len(),
500                "Number of models in group {} doesn't match",
501                expected_group
502            );
503
504            for (j, expected_model_name) in expected_models.iter().enumerate() {
505                assert_eq!(
506                    actual_models[j].name, *expected_model_name,
507                    "Model at position {} in group {} doesn't match expected model",
508                    j, expected_group
509                );
510            }
511        }
512    }
513
514    #[gpui::test]
515    async fn test_fuzzy_match(cx: &mut TestAppContext) {
516        let models = create_model_list(vec![
517            (
518                "zed",
519                vec![
520                    "Claude 3.7 Sonnet",
521                    "Claude 3.7 Sonnet Thinking",
522                    "gpt-4.1",
523                    "gpt-4.1-nano",
524                ],
525            ),
526            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
527            ("ollama", vec!["mistral", "deepseek"]),
528        ]);
529
530        // Results should preserve models order whenever possible.
531        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
532        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
533        // so it should appear first in the results.
534        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
535        assert_models_eq(
536            results,
537            vec![
538                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
539                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
540            ],
541        );
542
543        // Fuzzy search
544        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
545        assert_models_eq(
546            results,
547            vec![
548                ("zed", vec!["gpt-4.1-nano"]),
549                ("openai", vec!["gpt-4.1-nano"]),
550            ],
551        );
552    }
553}