1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_servers::AgentServer;
5use anyhow::Result;
6use collections::IndexMap;
7use fs::Fs;
8use futures::FutureExt;
9use fuzzy::{StringMatchCandidate, match_strings};
10use gpui::{
11 Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
12};
13use ordered_float::OrderedFloat;
14use picker::{Picker, PickerDelegate};
15use ui::{
16 DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, KeyBinding, ListItem,
17 ListItemSpacing, prelude::*,
18};
19use util::ResultExt;
20use zed_actions::agent::OpenSettings;
21
22use crate::ui::HoldForDefault;
23
24pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
25
26pub fn acp_model_selector(
27 selector: Rc<dyn AgentModelSelector>,
28 agent_server: Rc<dyn AgentServer>,
29 fs: Arc<dyn Fs>,
30 focus_handle: FocusHandle,
31 window: &mut Window,
32 cx: &mut Context<AcpModelSelector>,
33) -> AcpModelSelector {
34 let delegate =
35 AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
36 Picker::list(delegate, window, cx)
37 .show_scrollbar(true)
38 .width(rems(20.))
39 .max_height(Some(rems(20.).into()))
40}
41
42enum AcpModelPickerEntry {
43 Separator(SharedString),
44 Model(AgentModelInfo),
45}
46
47pub struct AcpModelPickerDelegate {
48 selector: Rc<dyn AgentModelSelector>,
49 agent_server: Rc<dyn AgentServer>,
50 fs: Arc<dyn Fs>,
51 filtered_entries: Vec<AcpModelPickerEntry>,
52 models: Option<AgentModelList>,
53 selected_index: usize,
54 selected_description: Option<(usize, SharedString, bool)>,
55 selected_model: Option<AgentModelInfo>,
56 _refresh_models_task: Task<()>,
57 focus_handle: FocusHandle,
58}
59
60impl AcpModelPickerDelegate {
61 fn new(
62 selector: Rc<dyn AgentModelSelector>,
63 agent_server: Rc<dyn AgentServer>,
64 fs: Arc<dyn Fs>,
65 focus_handle: FocusHandle,
66 window: &mut Window,
67 cx: &mut Context<AcpModelSelector>,
68 ) -> Self {
69 let rx = selector.watch(cx);
70 let refresh_models_task = {
71 cx.spawn_in(window, {
72 async move |this, cx| {
73 async fn refresh(
74 this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
75 cx: &mut AsyncWindowContext,
76 ) -> Result<()> {
77 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
78 (
79 this.delegate.selector.list_models(cx),
80 this.delegate.selector.selected_model(cx),
81 )
82 })?;
83
84 let (models, selected_model) =
85 futures::join!(models_task, selected_model_task);
86
87 this.update_in(cx, |this, window, cx| {
88 this.delegate.models = models.ok();
89 this.delegate.selected_model = selected_model.ok();
90 this.refresh(window, cx)
91 })
92 }
93
94 refresh(&this, cx).await.log_err();
95 if let Some(mut rx) = rx {
96 while let Ok(()) = rx.recv().await {
97 refresh(&this, cx).await.log_err();
98 }
99 }
100 }
101 })
102 };
103
104 Self {
105 selector,
106 agent_server,
107 fs,
108 filtered_entries: Vec::new(),
109 models: None,
110 selected_model: None,
111 selected_index: 0,
112 selected_description: None,
113 _refresh_models_task: refresh_models_task,
114 focus_handle,
115 }
116 }
117
118 pub fn active_model(&self) -> Option<&AgentModelInfo> {
119 self.selected_model.as_ref()
120 }
121}
122
123impl PickerDelegate for AcpModelPickerDelegate {
124 type ListItem = AnyElement;
125
126 fn match_count(&self) -> usize {
127 self.filtered_entries.len()
128 }
129
130 fn selected_index(&self) -> usize {
131 self.selected_index
132 }
133
134 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
135 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
136 cx.notify();
137 }
138
139 fn can_select(
140 &mut self,
141 ix: usize,
142 _window: &mut Window,
143 _cx: &mut Context<Picker<Self>>,
144 ) -> bool {
145 match self.filtered_entries.get(ix) {
146 Some(AcpModelPickerEntry::Model(_)) => true,
147 Some(AcpModelPickerEntry::Separator(_)) | None => false,
148 }
149 }
150
151 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
152 "Select a model…".into()
153 }
154
155 fn update_matches(
156 &mut self,
157 query: String,
158 window: &mut Window,
159 cx: &mut Context<Picker<Self>>,
160 ) -> Task<()> {
161 cx.spawn_in(window, async move |this, cx| {
162 let filtered_models = match this
163 .read_with(cx, |this, cx| {
164 this.delegate.models.clone().map(move |models| {
165 fuzzy_search(models, query, cx.background_executor().clone())
166 })
167 })
168 .ok()
169 .flatten()
170 {
171 Some(task) => task.await,
172 None => AgentModelList::Flat(vec![]),
173 };
174
175 this.update_in(cx, |this, window, cx| {
176 this.delegate.filtered_entries =
177 info_list_to_picker_entries(filtered_models).collect();
178 // Finds the currently selected model in the list
179 let new_index = this
180 .delegate
181 .selected_model
182 .as_ref()
183 .and_then(|selected| {
184 this.delegate.filtered_entries.iter().position(|entry| {
185 if let AcpModelPickerEntry::Model(model_info) = entry {
186 model_info.id == selected.id
187 } else {
188 false
189 }
190 })
191 })
192 .unwrap_or(0);
193 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
194 cx.notify();
195 })
196 .ok();
197 })
198 }
199
200 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
201 if let Some(AcpModelPickerEntry::Model(model_info)) =
202 self.filtered_entries.get(self.selected_index)
203 {
204 if window.modifiers().secondary() {
205 let default_model = self.agent_server.default_model(cx);
206 let is_default = default_model.as_ref() == Some(&model_info.id);
207
208 self.agent_server.set_default_model(
209 if is_default {
210 None
211 } else {
212 Some(model_info.id.clone())
213 },
214 self.fs.clone(),
215 cx,
216 );
217 }
218
219 self.selector
220 .select_model(model_info.id.clone(), cx)
221 .detach_and_log_err(cx);
222 self.selected_model = Some(model_info.clone());
223 let current_index = self.selected_index;
224 self.set_selected_index(current_index, window, cx);
225
226 cx.emit(DismissEvent);
227 }
228 }
229
230 fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
231 cx.defer_in(window, |picker, window, cx| {
232 picker.set_query("", window, cx);
233 });
234 }
235
236 fn render_match(
237 &self,
238 ix: usize,
239 selected: bool,
240 _: &mut Window,
241 cx: &mut Context<Picker<Self>>,
242 ) -> Option<Self::ListItem> {
243 match self.filtered_entries.get(ix)? {
244 AcpModelPickerEntry::Separator(title) => Some(
245 div()
246 .px_2()
247 .pb_1()
248 .when(ix > 1, |this| {
249 this.mt_1()
250 .pt_2()
251 .border_t_1()
252 .border_color(cx.theme().colors().border_variant)
253 })
254 .child(
255 Label::new(title)
256 .size(LabelSize::XSmall)
257 .color(Color::Muted),
258 )
259 .into_any_element(),
260 ),
261 AcpModelPickerEntry::Model(model_info) => {
262 let is_selected = Some(model_info) == self.selected_model.as_ref();
263 let default_model = self.agent_server.default_model(cx);
264 let is_default = default_model.as_ref() == Some(&model_info.id);
265
266 let model_icon_color = if is_selected {
267 Color::Accent
268 } else {
269 Color::Muted
270 };
271
272 Some(
273 div()
274 .id(("model-picker-menu-child", ix))
275 .when_some(model_info.description.clone(), |this, description| {
276 this
277 .on_hover(cx.listener(move |menu, hovered, _, cx| {
278 if *hovered {
279 menu.delegate.selected_description = Some((ix, description.clone(), is_default));
280 } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
281 menu.delegate.selected_description = None;
282 }
283 cx.notify();
284 }))
285 })
286 .child(
287 ListItem::new(ix)
288 .inset(true)
289 .spacing(ListItemSpacing::Sparse)
290 .toggle_state(selected)
291 .child(
292 h_flex()
293 .w_full()
294 .gap_1p5()
295 .when_some(model_info.icon, |this, icon| {
296 this.child(
297 Icon::new(icon)
298 .color(model_icon_color)
299 .size(IconSize::Small)
300 )
301 })
302 .child(Label::new(model_info.name.clone()).truncate()),
303 )
304 .end_slot(div().pr_3().when(is_selected, |this| {
305 this.child(
306 Icon::new(IconName::Check)
307 .color(Color::Accent)
308 .size(IconSize::Small),
309 )
310 })),
311 )
312 .into_any_element()
313 )
314 }
315 }
316 }
317
318 fn documentation_aside(
319 &self,
320 _window: &mut Window,
321 _cx: &mut Context<Picker<Self>>,
322 ) -> Option<ui::DocumentationAside> {
323 self.selected_description
324 .as_ref()
325 .map(|(_, description, is_default)| {
326 let description = description.clone();
327 let is_default = *is_default;
328
329 DocumentationAside::new(
330 DocumentationSide::Left,
331 DocumentationEdge::Top,
332 Rc::new(move |_| {
333 v_flex()
334 .gap_1()
335 .child(Label::new(description.clone()))
336 .child(HoldForDefault::new(is_default))
337 .into_any_element()
338 }),
339 )
340 })
341 }
342
343 fn render_footer(
344 &self,
345 _window: &mut Window,
346 cx: &mut Context<Picker<Self>>,
347 ) -> Option<AnyElement> {
348 let focus_handle = self.focus_handle.clone();
349
350 if !self.selector.should_render_footer() {
351 return None;
352 }
353
354 Some(
355 h_flex()
356 .w_full()
357 .p_1p5()
358 .border_t_1()
359 .border_color(cx.theme().colors().border_variant)
360 .child(
361 Button::new("configure", "Configure")
362 .full_width()
363 .style(ButtonStyle::Outlined)
364 .key_binding(
365 KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
366 .map(|kb| kb.size(rems_from_px(12.))),
367 )
368 .on_click(|_, window, cx| {
369 window.dispatch_action(OpenSettings.boxed_clone(), cx);
370 }),
371 )
372 .into_any(),
373 )
374 }
375}
376
377fn info_list_to_picker_entries(
378 model_list: AgentModelList,
379) -> impl Iterator<Item = AcpModelPickerEntry> {
380 match model_list {
381 AgentModelList::Flat(list) => {
382 itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
383 }
384 AgentModelList::Grouped(index_map) => {
385 itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
386 std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
387 .chain(models.into_iter().map(AcpModelPickerEntry::Model))
388 }))
389 }
390 }
391}
392
393async fn fuzzy_search(
394 model_list: AgentModelList,
395 query: String,
396 executor: BackgroundExecutor,
397) -> AgentModelList {
398 async fn fuzzy_search_list(
399 model_list: Vec<AgentModelInfo>,
400 query: &str,
401 executor: BackgroundExecutor,
402 ) -> Vec<AgentModelInfo> {
403 let candidates = model_list
404 .iter()
405 .enumerate()
406 .map(|(ix, model)| {
407 StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
408 })
409 .collect::<Vec<_>>();
410 let mut matches = match_strings(
411 &candidates,
412 query,
413 false,
414 true,
415 100,
416 &Default::default(),
417 executor,
418 )
419 .await;
420
421 matches.sort_unstable_by_key(|mat| {
422 let candidate = &candidates[mat.candidate_id];
423 (Reverse(OrderedFloat(mat.score)), candidate.id)
424 });
425
426 matches
427 .into_iter()
428 .map(|mat| model_list[mat.candidate_id].clone())
429 .collect()
430 }
431
432 match model_list {
433 AgentModelList::Flat(model_list) => {
434 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
435 }
436 AgentModelList::Grouped(index_map) => {
437 let groups =
438 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
439 fuzzy_search_list(models, &query, executor.clone())
440 .map(|results| (group_name, results))
441 }))
442 .await;
443 AgentModelList::Grouped(IndexMap::from_iter(
444 groups
445 .into_iter()
446 .filter(|(_, results)| !results.is_empty()),
447 ))
448 }
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use agent_client_protocol as acp;
455 use gpui::TestAppContext;
456
457 use super::*;
458
459 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
460 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
461 |(group, models)| {
462 (
463 acp_thread::AgentModelGroupName(group.to_string().into()),
464 models
465 .into_iter()
466 .map(|model| acp_thread::AgentModelInfo {
467 id: acp::ModelId(model.to_string().into()),
468 name: model.to_string().into(),
469 description: None,
470 icon: None,
471 })
472 .collect::<Vec<_>>(),
473 )
474 },
475 )))
476 }
477
478 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
479 let AgentModelList::Grouped(groups) = result else {
480 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
481 };
482
483 assert_eq!(
484 groups.len(),
485 expected.len(),
486 "Number of groups doesn't match"
487 );
488
489 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
490 let (actual_group, actual_models) = groups.get_index(i).unwrap();
491 assert_eq!(
492 actual_group.0.as_ref(),
493 *expected_group,
494 "Group at position {} doesn't match expected group",
495 i
496 );
497 assert_eq!(
498 actual_models.len(),
499 expected_models.len(),
500 "Number of models in group {} doesn't match",
501 expected_group
502 );
503
504 for (j, expected_model_name) in expected_models.iter().enumerate() {
505 assert_eq!(
506 actual_models[j].name, *expected_model_name,
507 "Model at position {} in group {} doesn't match expected model",
508 j, expected_group
509 );
510 }
511 }
512 }
513
514 #[gpui::test]
515 async fn test_fuzzy_match(cx: &mut TestAppContext) {
516 let models = create_model_list(vec![
517 (
518 "zed",
519 vec![
520 "Claude 3.7 Sonnet",
521 "Claude 3.7 Sonnet Thinking",
522 "gpt-4.1",
523 "gpt-4.1-nano",
524 ],
525 ),
526 ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
527 ("ollama", vec!["mistral", "deepseek"]),
528 ]);
529
530 // Results should preserve models order whenever possible.
531 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
532 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
533 // so it should appear first in the results.
534 let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
535 assert_models_eq(
536 results,
537 vec![
538 ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
539 ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
540 ],
541 );
542
543 // Fuzzy search
544 let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
545 assert_models_eq(
546 results,
547 vec![
548 ("zed", vec!["gpt-4.1-nano"]),
549 ("openai", vec!["gpt-4.1-nano"]),
550 ],
551 );
552 }
553}