1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_client_protocol::ModelId;
5use agent_servers::AgentServer;
6use agent_settings::AgentSettings;
7use anyhow::Result;
8use collections::{HashSet, IndexMap};
9use fs::Fs;
10use futures::FutureExt;
11use fuzzy::{StringMatchCandidate, match_strings};
12use gpui::{
13 Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
14};
15use itertools::Itertools;
16use ordered_float::OrderedFloat;
17use picker::{Picker, PickerDelegate};
18use settings::Settings;
19use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, prelude::*};
20use util::ResultExt;
21use zed_actions::agent::OpenSettings;
22
23use crate::ui::{HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
24
25pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
26
27pub fn acp_model_selector(
28 selector: Rc<dyn AgentModelSelector>,
29 agent_server: Rc<dyn AgentServer>,
30 fs: Arc<dyn Fs>,
31 focus_handle: FocusHandle,
32 window: &mut Window,
33 cx: &mut Context<AcpModelSelector>,
34) -> AcpModelSelector {
35 let delegate =
36 AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
37 Picker::list(delegate, window, cx)
38 .show_scrollbar(true)
39 .width(rems(20.))
40 .max_height(Some(rems(20.).into()))
41}
42
43enum AcpModelPickerEntry {
44 Separator(SharedString),
45 Model(AgentModelInfo, bool),
46}
47
48pub struct AcpModelPickerDelegate {
49 selector: Rc<dyn AgentModelSelector>,
50 agent_server: Rc<dyn AgentServer>,
51 fs: Arc<dyn Fs>,
52 filtered_entries: Vec<AcpModelPickerEntry>,
53 models: Option<AgentModelList>,
54 selected_index: usize,
55 selected_description: Option<(usize, SharedString, bool)>,
56 selected_model: Option<AgentModelInfo>,
57 _refresh_models_task: Task<()>,
58 focus_handle: FocusHandle,
59}
60
61impl AcpModelPickerDelegate {
62 fn new(
63 selector: Rc<dyn AgentModelSelector>,
64 agent_server: Rc<dyn AgentServer>,
65 fs: Arc<dyn Fs>,
66 focus_handle: FocusHandle,
67 window: &mut Window,
68 cx: &mut Context<AcpModelSelector>,
69 ) -> Self {
70 let rx = selector.watch(cx);
71 let refresh_models_task = {
72 cx.spawn_in(window, {
73 async move |this, cx| {
74 async fn refresh(
75 this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
76 cx: &mut AsyncWindowContext,
77 ) -> Result<()> {
78 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
79 (
80 this.delegate.selector.list_models(cx),
81 this.delegate.selector.selected_model(cx),
82 )
83 })?;
84
85 let (models, selected_model) =
86 futures::join!(models_task, selected_model_task);
87
88 this.update_in(cx, |this, window, cx| {
89 this.delegate.models = models.ok();
90 this.delegate.selected_model = selected_model.ok();
91 this.refresh(window, cx)
92 })
93 }
94
95 refresh(&this, cx).await.log_err();
96 if let Some(mut rx) = rx {
97 while let Ok(()) = rx.recv().await {
98 refresh(&this, cx).await.log_err();
99 }
100 }
101 }
102 })
103 };
104
105 Self {
106 selector,
107 agent_server,
108 fs,
109 filtered_entries: Vec::new(),
110 models: None,
111 selected_model: None,
112 selected_index: 0,
113 selected_description: None,
114 _refresh_models_task: refresh_models_task,
115 focus_handle,
116 }
117 }
118
119 pub fn active_model(&self) -> Option<&AgentModelInfo> {
120 self.selected_model.as_ref()
121 }
122}
123
124impl PickerDelegate for AcpModelPickerDelegate {
125 type ListItem = AnyElement;
126
127 fn match_count(&self) -> usize {
128 self.filtered_entries.len()
129 }
130
131 fn selected_index(&self) -> usize {
132 self.selected_index
133 }
134
135 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
136 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
137 cx.notify();
138 }
139
140 fn can_select(
141 &mut self,
142 ix: usize,
143 _window: &mut Window,
144 _cx: &mut Context<Picker<Self>>,
145 ) -> bool {
146 match self.filtered_entries.get(ix) {
147 Some(AcpModelPickerEntry::Model(_, _)) => true,
148 Some(AcpModelPickerEntry::Separator(_)) | None => false,
149 }
150 }
151
152 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
153 "Select a model…".into()
154 }
155
156 fn update_matches(
157 &mut self,
158 query: String,
159 window: &mut Window,
160 cx: &mut Context<Picker<Self>>,
161 ) -> Task<()> {
162 let favorites = if self.selector.supports_favorites() {
163 Arc::new(AgentSettings::get_global(cx).favorite_model_ids())
164 } else {
165 Default::default()
166 };
167
168 cx.spawn_in(window, async move |this, cx| {
169 let filtered_models = match this
170 .read_with(cx, |this, cx| {
171 this.delegate.models.clone().map(move |models| {
172 fuzzy_search(models, query, cx.background_executor().clone())
173 })
174 })
175 .ok()
176 .flatten()
177 {
178 Some(task) => task.await,
179 None => AgentModelList::Flat(vec![]),
180 };
181
182 this.update_in(cx, |this, window, cx| {
183 this.delegate.filtered_entries =
184 info_list_to_picker_entries(filtered_models, favorites);
185 // Finds the currently selected model in the list
186 let new_index = this
187 .delegate
188 .selected_model
189 .as_ref()
190 .and_then(|selected| {
191 this.delegate.filtered_entries.iter().position(|entry| {
192 if let AcpModelPickerEntry::Model(model_info, _) = entry {
193 model_info.id == selected.id
194 } else {
195 false
196 }
197 })
198 })
199 .unwrap_or(0);
200 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
201 cx.notify();
202 })
203 .ok();
204 })
205 }
206
207 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
208 if let Some(AcpModelPickerEntry::Model(model_info, _)) =
209 self.filtered_entries.get(self.selected_index)
210 {
211 if window.modifiers().secondary() {
212 let default_model = self.agent_server.default_model(cx);
213 let is_default = default_model.as_ref() == Some(&model_info.id);
214
215 self.agent_server.set_default_model(
216 if is_default {
217 None
218 } else {
219 Some(model_info.id.clone())
220 },
221 self.fs.clone(),
222 cx,
223 );
224 }
225
226 self.selector
227 .select_model(model_info.id.clone(), cx)
228 .detach_and_log_err(cx);
229 self.selected_model = Some(model_info.clone());
230 let current_index = self.selected_index;
231 self.set_selected_index(current_index, window, cx);
232
233 cx.emit(DismissEvent);
234 }
235 }
236
237 fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
238 cx.defer_in(window, |picker, window, cx| {
239 picker.set_query("", window, cx);
240 });
241 }
242
243 fn render_match(
244 &self,
245 ix: usize,
246 selected: bool,
247 _: &mut Window,
248 cx: &mut Context<Picker<Self>>,
249 ) -> Option<Self::ListItem> {
250 match self.filtered_entries.get(ix)? {
251 AcpModelPickerEntry::Separator(title) => {
252 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
253 }
254 AcpModelPickerEntry::Model(model_info, is_favorite) => {
255 let is_selected = Some(model_info) == self.selected_model.as_ref();
256 let default_model = self.agent_server.default_model(cx);
257 let is_default = default_model.as_ref() == Some(&model_info.id);
258
259 let supports_favorites = self.selector.supports_favorites();
260
261 let is_favorite = *is_favorite;
262 let handle_action_click = {
263 let model_id = model_info.id.clone();
264 let fs = self.fs.clone();
265
266 move |cx: &App| {
267 crate::favorite_models::toggle_model_id_in_settings(
268 model_id.clone(),
269 !is_favorite,
270 fs.clone(),
271 cx,
272 );
273 }
274 };
275
276 Some(
277 div()
278 .id(("model-picker-menu-child", ix))
279 .when_some(model_info.description.clone(), |this, description| {
280 this.on_hover(cx.listener(move |menu, hovered, _, cx| {
281 if *hovered {
282 menu.delegate.selected_description =
283 Some((ix, description.clone(), is_default));
284 } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
285 menu.delegate.selected_description = None;
286 }
287 cx.notify();
288 }))
289 })
290 .child(
291 ModelSelectorListItem::new(ix, model_info.name.clone())
292 .when_some(model_info.icon, |this, icon| this.icon(icon))
293 .is_selected(is_selected)
294 .is_focused(selected)
295 .when(supports_favorites, |this| {
296 this.is_favorite(is_favorite)
297 .on_toggle_favorite(handle_action_click)
298 }),
299 )
300 .into_any_element(),
301 )
302 }
303 }
304 }
305
306 fn documentation_aside(
307 &self,
308 _window: &mut Window,
309 _cx: &mut Context<Picker<Self>>,
310 ) -> Option<ui::DocumentationAside> {
311 self.selected_description
312 .as_ref()
313 .map(|(_, description, is_default)| {
314 let description = description.clone();
315 let is_default = *is_default;
316
317 DocumentationAside::new(
318 DocumentationSide::Left,
319 DocumentationEdge::Top,
320 Rc::new(move |_| {
321 v_flex()
322 .gap_1()
323 .child(Label::new(description.clone()))
324 .child(HoldForDefault::new(is_default))
325 .into_any_element()
326 }),
327 )
328 })
329 }
330
331 fn render_footer(
332 &self,
333 _window: &mut Window,
334 _cx: &mut Context<Picker<Self>>,
335 ) -> Option<AnyElement> {
336 let focus_handle = self.focus_handle.clone();
337
338 if !self.selector.should_render_footer() {
339 return None;
340 }
341
342 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
343 }
344}
345
346fn info_list_to_picker_entries(
347 model_list: AgentModelList,
348 favorites: Arc<HashSet<ModelId>>,
349) -> Vec<AcpModelPickerEntry> {
350 let mut entries = Vec::new();
351
352 let all_models: Vec<_> = match &model_list {
353 AgentModelList::Flat(list) => list.iter().collect(),
354 AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
355 };
356
357 let favorite_models: Vec<_> = all_models
358 .iter()
359 .filter(|m| favorites.contains(&m.id))
360 .unique_by(|m| &m.id)
361 .collect();
362
363 let has_favorites = !favorite_models.is_empty();
364 if has_favorites {
365 entries.push(AcpModelPickerEntry::Separator("Favorite".into()));
366 for model in favorite_models {
367 entries.push(AcpModelPickerEntry::Model((*model).clone(), true));
368 }
369 }
370
371 match model_list {
372 AgentModelList::Flat(list) => {
373 if has_favorites {
374 entries.push(AcpModelPickerEntry::Separator("All".into()));
375 }
376 for model in list {
377 let is_favorite = favorites.contains(&model.id);
378 entries.push(AcpModelPickerEntry::Model(model, is_favorite));
379 }
380 }
381 AgentModelList::Grouped(index_map) => {
382 for (group_name, models) in index_map {
383 entries.push(AcpModelPickerEntry::Separator(group_name.0));
384 for model in models {
385 let is_favorite = favorites.contains(&model.id);
386 entries.push(AcpModelPickerEntry::Model(model, is_favorite));
387 }
388 }
389 }
390 }
391
392 entries
393}
394
395async fn fuzzy_search(
396 model_list: AgentModelList,
397 query: String,
398 executor: BackgroundExecutor,
399) -> AgentModelList {
400 async fn fuzzy_search_list(
401 model_list: Vec<AgentModelInfo>,
402 query: &str,
403 executor: BackgroundExecutor,
404 ) -> Vec<AgentModelInfo> {
405 let candidates = model_list
406 .iter()
407 .enumerate()
408 .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
409 .collect::<Vec<_>>();
410 let mut matches = match_strings(
411 &candidates,
412 query,
413 false,
414 true,
415 100,
416 &Default::default(),
417 executor,
418 )
419 .await;
420
421 matches.sort_unstable_by_key(|mat| {
422 let candidate = &candidates[mat.candidate_id];
423 (Reverse(OrderedFloat(mat.score)), candidate.id)
424 });
425
426 matches
427 .into_iter()
428 .map(|mat| model_list[mat.candidate_id].clone())
429 .collect()
430 }
431
432 match model_list {
433 AgentModelList::Flat(model_list) => {
434 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
435 }
436 AgentModelList::Grouped(index_map) => {
437 let groups =
438 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
439 fuzzy_search_list(models, &query, executor.clone())
440 .map(|results| (group_name, results))
441 }))
442 .await;
443 AgentModelList::Grouped(IndexMap::from_iter(
444 groups
445 .into_iter()
446 .filter(|(_, results)| !results.is_empty()),
447 ))
448 }
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use agent_client_protocol as acp;
455 use gpui::TestAppContext;
456
457 use super::*;
458
459 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
460 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
461 |(group, models)| {
462 (
463 acp_thread::AgentModelGroupName(group.to_string().into()),
464 models
465 .into_iter()
466 .map(|model| acp_thread::AgentModelInfo {
467 id: acp::ModelId::new(model.to_string()),
468 name: model.to_string().into(),
469 description: None,
470 icon: None,
471 })
472 .collect::<Vec<_>>(),
473 )
474 },
475 )))
476 }
477
478 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
479 let AgentModelList::Grouped(groups) = result else {
480 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
481 };
482
483 assert_eq!(
484 groups.len(),
485 expected.len(),
486 "Number of groups doesn't match"
487 );
488
489 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
490 let (actual_group, actual_models) = groups.get_index(i).unwrap();
491 assert_eq!(
492 actual_group.0.as_ref(),
493 *expected_group,
494 "Group at position {} doesn't match expected group",
495 i
496 );
497 assert_eq!(
498 actual_models.len(),
499 expected_models.len(),
500 "Number of models in group {} doesn't match",
501 expected_group
502 );
503
504 for (j, expected_model_name) in expected_models.iter().enumerate() {
505 assert_eq!(
506 actual_models[j].name, *expected_model_name,
507 "Model at position {} in group {} doesn't match expected model",
508 j, expected_group
509 );
510 }
511 }
512 }
513
514 fn create_favorites(models: Vec<&str>) -> Arc<HashSet<ModelId>> {
515 Arc::new(
516 models
517 .into_iter()
518 .map(|m| ModelId::new(m.to_string()))
519 .collect(),
520 )
521 }
522
523 fn get_entry_model_ids(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
524 entries
525 .iter()
526 .filter_map(|entry| match entry {
527 AcpModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()),
528 _ => None,
529 })
530 .collect()
531 }
532
533 fn get_entry_labels(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
534 entries
535 .iter()
536 .map(|entry| match entry {
537 AcpModelPickerEntry::Model(info, _) => info.id.0.as_ref(),
538 AcpModelPickerEntry::Separator(s) => &s,
539 })
540 .collect()
541 }
542
543 #[gpui::test]
544 fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
545 let models = create_model_list(vec![
546 ("zed", vec!["zed/claude", "zed/gemini"]),
547 ("openai", vec!["openai/gpt-5"]),
548 ]);
549 let favorites = create_favorites(vec!["zed/gemini"]);
550
551 let entries = info_list_to_picker_entries(models, favorites);
552
553 assert!(matches!(
554 entries.first(),
555 Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
556 ));
557
558 let model_ids = get_entry_model_ids(&entries);
559 assert_eq!(model_ids[0], "zed/gemini");
560 }
561
562 #[gpui::test]
563 fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
564 let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]);
565 let favorites = create_favorites(vec![]);
566
567 let entries = info_list_to_picker_entries(models, favorites);
568
569 assert!(matches!(
570 entries.first(),
571 Some(AcpModelPickerEntry::Separator(s)) if s == "zed"
572 ));
573 }
574
575 #[gpui::test]
576 fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
577 let models = create_model_list(vec![
578 ("zed", vec!["zed/claude", "zed/gemini"]),
579 ("openai", vec!["openai/gpt-5"]),
580 ]);
581 let favorites = create_favorites(vec!["zed/claude"]);
582
583 let entries = info_list_to_picker_entries(models, favorites);
584
585 for entry in &entries {
586 if let AcpModelPickerEntry::Model(info, is_favorite) = entry {
587 if info.id.0.as_ref() == "zed/claude" {
588 assert!(is_favorite, "zed/claude should be a favorite");
589 } else {
590 assert!(!is_favorite, "{} should not be a favorite", info.id.0);
591 }
592 }
593 }
594 }
595
596 #[gpui::test]
597 fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) {
598 let models = create_model_list(vec![
599 ("zed", vec!["zed/claude", "zed/gemini"]),
600 ("openai", vec!["openai/gpt-5", "openai/gpt-4"]),
601 ]);
602 let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]);
603
604 let entries = info_list_to_picker_entries(models, favorites);
605 let model_ids = get_entry_model_ids(&entries);
606
607 assert_eq!(model_ids[0], "zed/gemini");
608 assert_eq!(model_ids[1], "openai/gpt-5");
609
610 assert!(model_ids[2..].contains(&"zed/gemini"));
611 assert!(model_ids[2..].contains(&"openai/gpt-5"));
612 }
613
614 #[gpui::test]
615 fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) {
616 let models = create_model_list(vec![
617 ("Recommended", vec!["zed/claude", "anthropic/claude"]),
618 ("Zed", vec!["zed/claude", "zed/gpt-5"]),
619 ("Antropic", vec!["anthropic/claude"]),
620 ("OpenAI", vec!["openai/gpt-5"]),
621 ]);
622
623 let favorites = create_favorites(vec!["zed/claude"]);
624
625 let entries = info_list_to_picker_entries(models, favorites);
626 let labels = get_entry_labels(&entries);
627
628 assert_eq!(
629 labels,
630 vec![
631 "Favorite",
632 "zed/claude",
633 "Recommended",
634 "zed/claude",
635 "anthropic/claude",
636 "Zed",
637 "zed/claude",
638 "zed/gpt-5",
639 "Antropic",
640 "anthropic/claude",
641 "OpenAI",
642 "openai/gpt-5"
643 ]
644 );
645 }
646
647 #[gpui::test]
648 fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) {
649 let models = AgentModelList::Flat(vec![
650 acp_thread::AgentModelInfo {
651 id: acp::ModelId::new("zed/claude".to_string()),
652 name: "Claude".into(),
653 description: None,
654 icon: None,
655 },
656 acp_thread::AgentModelInfo {
657 id: acp::ModelId::new("zed/gemini".to_string()),
658 name: "Gemini".into(),
659 description: None,
660 icon: None,
661 },
662 ]);
663 let favorites = create_favorites(vec!["zed/gemini"]);
664
665 let entries = info_list_to_picker_entries(models, favorites);
666
667 assert!(matches!(
668 entries.first(),
669 Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
670 ));
671
672 assert!(entries.iter().any(|e| matches!(
673 e,
674 AcpModelPickerEntry::Separator(s) if s == "All"
675 )));
676 }
677
678 #[gpui::test]
679 async fn test_fuzzy_match(cx: &mut TestAppContext) {
680 let models = create_model_list(vec![
681 (
682 "zed",
683 vec![
684 "Claude 3.7 Sonnet",
685 "Claude 3.7 Sonnet Thinking",
686 "gpt-4.1",
687 "gpt-4.1-nano",
688 ],
689 ),
690 ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
691 ("ollama", vec!["mistral", "deepseek"]),
692 ]);
693
694 // Results should preserve models order whenever possible.
695 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
696 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
697 // so it should appear first in the results.
698 let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
699 assert_models_eq(
700 results,
701 vec![
702 ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
703 ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
704 ],
705 );
706
707 // Fuzzy search
708 let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
709 assert_models_eq(
710 results,
711 vec![
712 ("zed", vec!["gpt-4.1-nano"]),
713 ("openai", vec!["gpt-4.1-nano"]),
714 ],
715 );
716 }
717}