model_selector.rs

  1use std::{cmp::Reverse, rc::Rc, sync::Arc};
  2
  3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
  4use agent_servers::AgentServer;
  5use anyhow::Result;
  6use collections::IndexMap;
  7use fs::Fs;
  8use futures::FutureExt;
  9use fuzzy::{StringMatchCandidate, match_strings};
 10use gpui::{
 11    Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
 12};
 13use ordered_float::OrderedFloat;
 14use picker::{Picker, PickerDelegate};
 15use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, prelude::*};
 16use util::ResultExt;
 17use zed_actions::agent::OpenSettings;
 18
 19use crate::ui::{HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 20
 21pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
 22
 23pub fn acp_model_selector(
 24    selector: Rc<dyn AgentModelSelector>,
 25    agent_server: Rc<dyn AgentServer>,
 26    fs: Arc<dyn Fs>,
 27    focus_handle: FocusHandle,
 28    window: &mut Window,
 29    cx: &mut Context<AcpModelSelector>,
 30) -> AcpModelSelector {
 31    let delegate =
 32        AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
 33    Picker::list(delegate, window, cx)
 34        .show_scrollbar(true)
 35        .width(rems(20.))
 36        .max_height(Some(rems(20.).into()))
 37}
 38
 39enum AcpModelPickerEntry {
 40    Separator(SharedString),
 41    Model(AgentModelInfo),
 42}
 43
 44pub struct AcpModelPickerDelegate {
 45    selector: Rc<dyn AgentModelSelector>,
 46    agent_server: Rc<dyn AgentServer>,
 47    fs: Arc<dyn Fs>,
 48    filtered_entries: Vec<AcpModelPickerEntry>,
 49    models: Option<AgentModelList>,
 50    selected_index: usize,
 51    selected_description: Option<(usize, SharedString, bool)>,
 52    selected_model: Option<AgentModelInfo>,
 53    _refresh_models_task: Task<()>,
 54    focus_handle: FocusHandle,
 55}
 56
 57impl AcpModelPickerDelegate {
 58    fn new(
 59        selector: Rc<dyn AgentModelSelector>,
 60        agent_server: Rc<dyn AgentServer>,
 61        fs: Arc<dyn Fs>,
 62        focus_handle: FocusHandle,
 63        window: &mut Window,
 64        cx: &mut Context<AcpModelSelector>,
 65    ) -> Self {
 66        let rx = selector.watch(cx);
 67        let refresh_models_task = {
 68            cx.spawn_in(window, {
 69                async move |this, cx| {
 70                    async fn refresh(
 71                        this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
 72                        cx: &mut AsyncWindowContext,
 73                    ) -> Result<()> {
 74                        let (models_task, selected_model_task) = this.update(cx, |this, cx| {
 75                            (
 76                                this.delegate.selector.list_models(cx),
 77                                this.delegate.selector.selected_model(cx),
 78                            )
 79                        })?;
 80
 81                        let (models, selected_model) =
 82                            futures::join!(models_task, selected_model_task);
 83
 84                        this.update_in(cx, |this, window, cx| {
 85                            this.delegate.models = models.ok();
 86                            this.delegate.selected_model = selected_model.ok();
 87                            this.refresh(window, cx)
 88                        })
 89                    }
 90
 91                    refresh(&this, cx).await.log_err();
 92                    if let Some(mut rx) = rx {
 93                        while let Ok(()) = rx.recv().await {
 94                            refresh(&this, cx).await.log_err();
 95                        }
 96                    }
 97                }
 98            })
 99        };
100
101        Self {
102            selector,
103            agent_server,
104            fs,
105            filtered_entries: Vec::new(),
106            models: None,
107            selected_model: None,
108            selected_index: 0,
109            selected_description: None,
110            _refresh_models_task: refresh_models_task,
111            focus_handle,
112        }
113    }
114
115    pub fn active_model(&self) -> Option<&AgentModelInfo> {
116        self.selected_model.as_ref()
117    }
118}
119
120impl PickerDelegate for AcpModelPickerDelegate {
121    type ListItem = AnyElement;
122
123    fn match_count(&self) -> usize {
124        self.filtered_entries.len()
125    }
126
127    fn selected_index(&self) -> usize {
128        self.selected_index
129    }
130
131    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
132        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
133        cx.notify();
134    }
135
136    fn can_select(
137        &mut self,
138        ix: usize,
139        _window: &mut Window,
140        _cx: &mut Context<Picker<Self>>,
141    ) -> bool {
142        match self.filtered_entries.get(ix) {
143            Some(AcpModelPickerEntry::Model(_)) => true,
144            Some(AcpModelPickerEntry::Separator(_)) | None => false,
145        }
146    }
147
148    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
149        "Select a model…".into()
150    }
151
152    fn update_matches(
153        &mut self,
154        query: String,
155        window: &mut Window,
156        cx: &mut Context<Picker<Self>>,
157    ) -> Task<()> {
158        cx.spawn_in(window, async move |this, cx| {
159            let filtered_models = match this
160                .read_with(cx, |this, cx| {
161                    this.delegate.models.clone().map(move |models| {
162                        fuzzy_search(models, query, cx.background_executor().clone())
163                    })
164                })
165                .ok()
166                .flatten()
167            {
168                Some(task) => task.await,
169                None => AgentModelList::Flat(vec![]),
170            };
171
172            this.update_in(cx, |this, window, cx| {
173                this.delegate.filtered_entries =
174                    info_list_to_picker_entries(filtered_models).collect();
175                // Finds the currently selected model in the list
176                let new_index = this
177                    .delegate
178                    .selected_model
179                    .as_ref()
180                    .and_then(|selected| {
181                        this.delegate.filtered_entries.iter().position(|entry| {
182                            if let AcpModelPickerEntry::Model(model_info) = entry {
183                                model_info.id == selected.id
184                            } else {
185                                false
186                            }
187                        })
188                    })
189                    .unwrap_or(0);
190                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
191                cx.notify();
192            })
193            .ok();
194        })
195    }
196
197    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
198        if let Some(AcpModelPickerEntry::Model(model_info)) =
199            self.filtered_entries.get(self.selected_index)
200        {
201            if window.modifiers().secondary() {
202                let default_model = self.agent_server.default_model(cx);
203                let is_default = default_model.as_ref() == Some(&model_info.id);
204
205                self.agent_server.set_default_model(
206                    if is_default {
207                        None
208                    } else {
209                        Some(model_info.id.clone())
210                    },
211                    self.fs.clone(),
212                    cx,
213                );
214            }
215
216            self.selector
217                .select_model(model_info.id.clone(), cx)
218                .detach_and_log_err(cx);
219            self.selected_model = Some(model_info.clone());
220            let current_index = self.selected_index;
221            self.set_selected_index(current_index, window, cx);
222
223            cx.emit(DismissEvent);
224        }
225    }
226
227    fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
228        cx.defer_in(window, |picker, window, cx| {
229            picker.set_query("", window, cx);
230        });
231    }
232
233    fn render_match(
234        &self,
235        ix: usize,
236        is_focused: bool,
237        _: &mut Window,
238        cx: &mut Context<Picker<Self>>,
239    ) -> Option<Self::ListItem> {
240        match self.filtered_entries.get(ix)? {
241            AcpModelPickerEntry::Separator(title) => {
242                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
243            }
244            AcpModelPickerEntry::Model(model_info) => {
245                let is_selected = Some(model_info) == self.selected_model.as_ref();
246                let default_model = self.agent_server.default_model(cx);
247                let is_default = default_model.as_ref() == Some(&model_info.id);
248
249                Some(
250                    div()
251                        .id(("model-picker-menu-child", ix))
252                        .when_some(model_info.description.clone(), |this, description| {
253                            this
254                                .on_hover(cx.listener(move |menu, hovered, _, cx| {
255                                    if *hovered {
256                                        menu.delegate.selected_description = Some((ix, description.clone(), is_default));
257                                    } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
258                                        menu.delegate.selected_description = None;
259                                    }
260                                    cx.notify();
261                                }))
262                        })
263                        .child(
264                            ModelSelectorListItem::new(ix, model_info.name.clone())
265                                .is_focused(is_focused)
266                                .is_selected(is_selected)
267                                .when_some(model_info.icon, |this, icon| this.icon(icon)),
268                        )
269                        .into_any_element()
270                )
271            }
272        }
273    }
274
275    fn documentation_aside(
276        &self,
277        _window: &mut Window,
278        _cx: &mut Context<Picker<Self>>,
279    ) -> Option<ui::DocumentationAside> {
280        self.selected_description
281            .as_ref()
282            .map(|(_, description, is_default)| {
283                let description = description.clone();
284                let is_default = *is_default;
285
286                DocumentationAside::new(
287                    DocumentationSide::Left,
288                    DocumentationEdge::Top,
289                    Rc::new(move |_| {
290                        v_flex()
291                            .gap_1()
292                            .child(Label::new(description.clone()))
293                            .child(HoldForDefault::new(is_default))
294                            .into_any_element()
295                    }),
296                )
297            })
298    }
299
300    fn render_footer(
301        &self,
302        _window: &mut Window,
303        _cx: &mut Context<Picker<Self>>,
304    ) -> Option<AnyElement> {
305        let focus_handle = self.focus_handle.clone();
306
307        if !self.selector.should_render_footer() {
308            return None;
309        }
310
311        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
312    }
313}
314
315fn info_list_to_picker_entries(
316    model_list: AgentModelList,
317) -> impl Iterator<Item = AcpModelPickerEntry> {
318    match model_list {
319        AgentModelList::Flat(list) => {
320            itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
321        }
322        AgentModelList::Grouped(index_map) => {
323            itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
324                std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
325                    .chain(models.into_iter().map(AcpModelPickerEntry::Model))
326            }))
327        }
328    }
329}
330
331async fn fuzzy_search(
332    model_list: AgentModelList,
333    query: String,
334    executor: BackgroundExecutor,
335) -> AgentModelList {
336    async fn fuzzy_search_list(
337        model_list: Vec<AgentModelInfo>,
338        query: &str,
339        executor: BackgroundExecutor,
340    ) -> Vec<AgentModelInfo> {
341        let candidates = model_list
342            .iter()
343            .enumerate()
344            .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
345            .collect::<Vec<_>>();
346        let mut matches = match_strings(
347            &candidates,
348            query,
349            false,
350            true,
351            100,
352            &Default::default(),
353            executor,
354        )
355        .await;
356
357        matches.sort_unstable_by_key(|mat| {
358            let candidate = &candidates[mat.candidate_id];
359            (Reverse(OrderedFloat(mat.score)), candidate.id)
360        });
361
362        matches
363            .into_iter()
364            .map(|mat| model_list[mat.candidate_id].clone())
365            .collect()
366    }
367
368    match model_list {
369        AgentModelList::Flat(model_list) => {
370            AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
371        }
372        AgentModelList::Grouped(index_map) => {
373            let groups =
374                futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
375                    fuzzy_search_list(models, &query, executor.clone())
376                        .map(|results| (group_name, results))
377                }))
378                .await;
379            AgentModelList::Grouped(IndexMap::from_iter(
380                groups
381                    .into_iter()
382                    .filter(|(_, results)| !results.is_empty()),
383            ))
384        }
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use agent_client_protocol as acp;
391    use gpui::TestAppContext;
392
393    use super::*;
394
395    fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
396        AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
397            |(group, models)| {
398                (
399                    acp_thread::AgentModelGroupName(group.to_string().into()),
400                    models
401                        .into_iter()
402                        .map(|model| acp_thread::AgentModelInfo {
403                            id: acp::ModelId::new(model.to_string()),
404                            name: model.to_string().into(),
405                            description: None,
406                            icon: None,
407                        })
408                        .collect::<Vec<_>>(),
409                )
410            },
411        )))
412    }
413
414    fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
415        let AgentModelList::Grouped(groups) = result else {
416            panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
417        };
418
419        assert_eq!(
420            groups.len(),
421            expected.len(),
422            "Number of groups doesn't match"
423        );
424
425        for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
426            let (actual_group, actual_models) = groups.get_index(i).unwrap();
427            assert_eq!(
428                actual_group.0.as_ref(),
429                *expected_group,
430                "Group at position {} doesn't match expected group",
431                i
432            );
433            assert_eq!(
434                actual_models.len(),
435                expected_models.len(),
436                "Number of models in group {} doesn't match",
437                expected_group
438            );
439
440            for (j, expected_model_name) in expected_models.iter().enumerate() {
441                assert_eq!(
442                    actual_models[j].name, *expected_model_name,
443                    "Model at position {} in group {} doesn't match expected model",
444                    j, expected_group
445                );
446            }
447        }
448    }
449
450    #[gpui::test]
451    async fn test_fuzzy_match(cx: &mut TestAppContext) {
452        let models = create_model_list(vec![
453            (
454                "zed",
455                vec![
456                    "Claude 3.7 Sonnet",
457                    "Claude 3.7 Sonnet Thinking",
458                    "gpt-4.1",
459                    "gpt-4.1-nano",
460                ],
461            ),
462            ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
463            ("ollama", vec!["mistral", "deepseek"]),
464        ]);
465
466        // Results should preserve models order whenever possible.
467        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
468        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
469        // so it should appear first in the results.
470        let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
471        assert_models_eq(
472            results,
473            vec![
474                ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
475                ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
476            ],
477        );
478
479        // Fuzzy search
480        let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
481        assert_models_eq(
482            results,
483            vec![
484                ("zed", vec!["gpt-4.1-nano"]),
485                ("openai", vec!["gpt-4.1-nano"]),
486            ],
487        );
488    }
489}