1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_servers::AgentServer;
5use anyhow::Result;
6use collections::IndexMap;
7use fs::Fs;
8use futures::FutureExt;
9use fuzzy::{StringMatchCandidate, match_strings};
10use gpui::{
11 Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
12};
13use ordered_float::OrderedFloat;
14use picker::{Picker, PickerDelegate};
15use ui::{
16 DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, KeyBinding, ListItem,
17 ListItemSpacing, prelude::*,
18};
19use util::ResultExt;
20use zed_actions::agent::OpenSettings;
21
22use crate::ui::HoldForDefault;
23
24pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
25
26pub fn acp_model_selector(
27 selector: Rc<dyn AgentModelSelector>,
28 agent_server: Rc<dyn AgentServer>,
29 fs: Arc<dyn Fs>,
30 focus_handle: FocusHandle,
31 window: &mut Window,
32 cx: &mut Context<AcpModelSelector>,
33) -> AcpModelSelector {
34 let delegate =
35 AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
36 Picker::list(delegate, window, cx)
37 .show_scrollbar(true)
38 .width(rems(20.))
39 .max_height(Some(rems(20.).into()))
40}
41
42enum AcpModelPickerEntry {
43 Separator(SharedString),
44 Model(AgentModelInfo),
45}
46
47pub struct AcpModelPickerDelegate {
48 selector: Rc<dyn AgentModelSelector>,
49 agent_server: Rc<dyn AgentServer>,
50 fs: Arc<dyn Fs>,
51 filtered_entries: Vec<AcpModelPickerEntry>,
52 models: Option<AgentModelList>,
53 selected_index: usize,
54 selected_description: Option<(usize, SharedString, bool)>,
55 selected_model: Option<AgentModelInfo>,
56 _refresh_models_task: Task<()>,
57 focus_handle: FocusHandle,
58}
59
60impl AcpModelPickerDelegate {
61 fn new(
62 selector: Rc<dyn AgentModelSelector>,
63 agent_server: Rc<dyn AgentServer>,
64 fs: Arc<dyn Fs>,
65 focus_handle: FocusHandle,
66 window: &mut Window,
67 cx: &mut Context<AcpModelSelector>,
68 ) -> Self {
69 let rx = selector.watch(cx);
70 let refresh_models_task = {
71 cx.spawn_in(window, {
72 async move |this, cx| {
73 async fn refresh(
74 this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
75 cx: &mut AsyncWindowContext,
76 ) -> Result<()> {
77 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
78 (
79 this.delegate.selector.list_models(cx),
80 this.delegate.selector.selected_model(cx),
81 )
82 })?;
83
84 let (models, selected_model) =
85 futures::join!(models_task, selected_model_task);
86
87 this.update_in(cx, |this, window, cx| {
88 this.delegate.models = models.ok();
89 this.delegate.selected_model = selected_model.ok();
90 this.refresh(window, cx)
91 })
92 }
93
94 refresh(&this, cx).await.log_err();
95 if let Some(mut rx) = rx {
96 while let Ok(()) = rx.recv().await {
97 refresh(&this, cx).await.log_err();
98 }
99 }
100 }
101 })
102 };
103
104 Self {
105 selector,
106 agent_server,
107 fs,
108 filtered_entries: Vec::new(),
109 models: None,
110 selected_model: None,
111 selected_index: 0,
112 selected_description: None,
113 _refresh_models_task: refresh_models_task,
114 focus_handle,
115 }
116 }
117
118 pub fn active_model(&self) -> Option<&AgentModelInfo> {
119 self.selected_model.as_ref()
120 }
121}
122
123impl PickerDelegate for AcpModelPickerDelegate {
124 type ListItem = AnyElement;
125
126 fn match_count(&self) -> usize {
127 self.filtered_entries.len()
128 }
129
130 fn selected_index(&self) -> usize {
131 self.selected_index
132 }
133
134 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
135 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
136 cx.notify();
137 }
138
139 fn can_select(
140 &mut self,
141 ix: usize,
142 _window: &mut Window,
143 _cx: &mut Context<Picker<Self>>,
144 ) -> bool {
145 match self.filtered_entries.get(ix) {
146 Some(AcpModelPickerEntry::Model(_)) => true,
147 Some(AcpModelPickerEntry::Separator(_)) | None => false,
148 }
149 }
150
151 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
152 "Select a model…".into()
153 }
154
155 fn update_matches(
156 &mut self,
157 query: String,
158 window: &mut Window,
159 cx: &mut Context<Picker<Self>>,
160 ) -> Task<()> {
161 cx.spawn_in(window, async move |this, cx| {
162 let filtered_models = match this
163 .read_with(cx, |this, cx| {
164 this.delegate.models.clone().map(move |models| {
165 fuzzy_search(models, query, cx.background_executor().clone())
166 })
167 })
168 .ok()
169 .flatten()
170 {
171 Some(task) => task.await,
172 None => AgentModelList::Flat(vec![]),
173 };
174
175 this.update_in(cx, |this, window, cx| {
176 this.delegate.filtered_entries =
177 info_list_to_picker_entries(filtered_models).collect();
178 // Finds the currently selected model in the list
179 let new_index = this
180 .delegate
181 .selected_model
182 .as_ref()
183 .and_then(|selected| {
184 this.delegate.filtered_entries.iter().position(|entry| {
185 if let AcpModelPickerEntry::Model(model_info) = entry {
186 model_info.id == selected.id
187 } else {
188 false
189 }
190 })
191 })
192 .unwrap_or(0);
193 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
194 cx.notify();
195 })
196 .ok();
197 })
198 }
199
200 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
201 if let Some(AcpModelPickerEntry::Model(model_info)) =
202 self.filtered_entries.get(self.selected_index)
203 {
204 if window.modifiers().secondary() {
205 let default_model = self.agent_server.default_model(cx);
206 let is_default = default_model.as_ref() == Some(&model_info.id);
207
208 self.agent_server.set_default_model(
209 if is_default {
210 None
211 } else {
212 Some(model_info.id.clone())
213 },
214 self.fs.clone(),
215 cx,
216 );
217 }
218
219 self.selector
220 .select_model(model_info.id.clone(), cx)
221 .detach_and_log_err(cx);
222 self.selected_model = Some(model_info.clone());
223 let current_index = self.selected_index;
224 self.set_selected_index(current_index, window, cx);
225
226 cx.emit(DismissEvent);
227 }
228 }
229
230 fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
231 cx.defer_in(window, |picker, window, cx| {
232 picker.set_query("", window, cx);
233 });
234 }
235
236 fn render_match(
237 &self,
238 ix: usize,
239 selected: bool,
240 _: &mut Window,
241 cx: &mut Context<Picker<Self>>,
242 ) -> Option<Self::ListItem> {
243 match self.filtered_entries.get(ix)? {
244 AcpModelPickerEntry::Separator(title) => Some(
245 div()
246 .px_2()
247 .pb_1()
248 .when(ix > 1, |this| {
249 this.mt_1()
250 .pt_2()
251 .border_t_1()
252 .border_color(cx.theme().colors().border_variant)
253 })
254 .child(
255 Label::new(title)
256 .size(LabelSize::XSmall)
257 .color(Color::Muted),
258 )
259 .into_any_element(),
260 ),
261 AcpModelPickerEntry::Model(model_info) => {
262 let is_selected = Some(model_info) == self.selected_model.as_ref();
263 let default_model = self.agent_server.default_model(cx);
264 let is_default = default_model.as_ref() == Some(&model_info.id);
265
266 let model_icon_color = if is_selected {
267 Color::Accent
268 } else {
269 Color::Muted
270 };
271
272 Some(
273 div()
274 .id(("model-picker-menu-child", ix))
275 .when_some(model_info.description.clone(), |this, description| {
276 this
277 .on_hover(cx.listener(move |menu, hovered, _, cx| {
278 if *hovered {
279 menu.delegate.selected_description = Some((ix, description.clone(), is_default));
280 } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
281 menu.delegate.selected_description = None;
282 }
283 cx.notify();
284 }))
285 })
286 .child(
287 ListItem::new(ix)
288 .inset(true)
289 .spacing(ListItemSpacing::Sparse)
290 .toggle_state(selected)
291 .child(
292 h_flex()
293 .w_full()
294 .gap_1p5()
295 .map(|this| match &model_info.icon {
296 Some(AgentModelIcon::Path(path)) => this.child(
297 Icon::from_path(path.clone())
298 .color(model_icon_color)
299 .size(IconSize::Small),
300 ),
301 Some(AgentModelIcon::Named(icon)) => this.child(
302 Icon::new(*icon)
303 .color(model_icon_color)
304 .size(IconSize::Small),
305 ),
306 None => this,
307 })
308 .child(Label::new(model_info.name.clone()).truncate()),
309 )
310 .end_slot(div().pr_3().when(is_selected, |this| {
311 this.child(
312 Icon::new(IconName::Check)
313 .color(Color::Accent)
314 .size(IconSize::Small),
315 )
316 })),
317 )
318 .into_any_element()
319 )
320 }
321 }
322 }
323
324 fn documentation_aside(
325 &self,
326 _window: &mut Window,
327 _cx: &mut Context<Picker<Self>>,
328 ) -> Option<ui::DocumentationAside> {
329 self.selected_description
330 .as_ref()
331 .map(|(_, description, is_default)| {
332 let description = description.clone();
333 let is_default = *is_default;
334
335 DocumentationAside::new(
336 DocumentationSide::Left,
337 DocumentationEdge::Top,
338 Rc::new(move |_| {
339 v_flex()
340 .gap_1()
341 .child(Label::new(description.clone()))
342 .child(HoldForDefault::new(is_default))
343 .into_any_element()
344 }),
345 )
346 })
347 }
348
349 fn render_footer(
350 &self,
351 _window: &mut Window,
352 cx: &mut Context<Picker<Self>>,
353 ) -> Option<AnyElement> {
354 let focus_handle = self.focus_handle.clone();
355
356 if !self.selector.should_render_footer() {
357 return None;
358 }
359
360 Some(
361 h_flex()
362 .w_full()
363 .p_1p5()
364 .border_t_1()
365 .border_color(cx.theme().colors().border_variant)
366 .child(
367 Button::new("configure", "Configure")
368 .full_width()
369 .style(ButtonStyle::Outlined)
370 .key_binding(
371 KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
372 .map(|kb| kb.size(rems_from_px(12.))),
373 )
374 .on_click(|_, window, cx| {
375 window.dispatch_action(OpenSettings.boxed_clone(), cx);
376 }),
377 )
378 .into_any(),
379 )
380 }
381}
382
383fn info_list_to_picker_entries(
384 model_list: AgentModelList,
385) -> impl Iterator<Item = AcpModelPickerEntry> {
386 match model_list {
387 AgentModelList::Flat(list) => {
388 itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
389 }
390 AgentModelList::Grouped(index_map) => {
391 itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
392 std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
393 .chain(models.into_iter().map(AcpModelPickerEntry::Model))
394 }))
395 }
396 }
397}
398
399async fn fuzzy_search(
400 model_list: AgentModelList,
401 query: String,
402 executor: BackgroundExecutor,
403) -> AgentModelList {
404 async fn fuzzy_search_list(
405 model_list: Vec<AgentModelInfo>,
406 query: &str,
407 executor: BackgroundExecutor,
408 ) -> Vec<AgentModelInfo> {
409 let candidates = model_list
410 .iter()
411 .enumerate()
412 .map(|(ix, model)| {
413 StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
414 })
415 .collect::<Vec<_>>();
416 let mut matches = match_strings(
417 &candidates,
418 query,
419 false,
420 true,
421 100,
422 &Default::default(),
423 executor,
424 )
425 .await;
426
427 matches.sort_unstable_by_key(|mat| {
428 let candidate = &candidates[mat.candidate_id];
429 (Reverse(OrderedFloat(mat.score)), candidate.id)
430 });
431
432 matches
433 .into_iter()
434 .map(|mat| model_list[mat.candidate_id].clone())
435 .collect()
436 }
437
438 match model_list {
439 AgentModelList::Flat(model_list) => {
440 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
441 }
442 AgentModelList::Grouped(index_map) => {
443 let groups =
444 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
445 fuzzy_search_list(models, &query, executor.clone())
446 .map(|results| (group_name, results))
447 }))
448 .await;
449 AgentModelList::Grouped(IndexMap::from_iter(
450 groups
451 .into_iter()
452 .filter(|(_, results)| !results.is_empty()),
453 ))
454 }
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use agent_client_protocol as acp;
461 use gpui::TestAppContext;
462
463 use super::*;
464
465 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
466 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
467 |(group, models)| {
468 (
469 acp_thread::AgentModelGroupName(group.to_string().into()),
470 models
471 .into_iter()
472 .map(|model| acp_thread::AgentModelInfo {
473 id: acp::ModelId::new(model.to_string()),
474 name: model.to_string().into(),
475 description: None,
476 icon: None,
477 })
478 .collect::<Vec<_>>(),
479 )
480 },
481 )))
482 }
483
484 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
485 let AgentModelList::Grouped(groups) = result else {
486 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
487 };
488
489 assert_eq!(
490 groups.len(),
491 expected.len(),
492 "Number of groups doesn't match"
493 );
494
495 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
496 let (actual_group, actual_models) = groups.get_index(i).unwrap();
497 assert_eq!(
498 actual_group.0.as_ref(),
499 *expected_group,
500 "Group at position {} doesn't match expected group",
501 i
502 );
503 assert_eq!(
504 actual_models.len(),
505 expected_models.len(),
506 "Number of models in group {} doesn't match",
507 expected_group
508 );
509
510 for (j, expected_model_name) in expected_models.iter().enumerate() {
511 assert_eq!(
512 actual_models[j].name, *expected_model_name,
513 "Model at position {} in group {} doesn't match expected model",
514 j, expected_group
515 );
516 }
517 }
518 }
519
520 #[gpui::test]
521 async fn test_fuzzy_match(cx: &mut TestAppContext) {
522 let models = create_model_list(vec![
523 (
524 "zed",
525 vec![
526 "Claude 3.7 Sonnet",
527 "Claude 3.7 Sonnet Thinking",
528 "gpt-4.1",
529 "gpt-4.1-nano",
530 ],
531 ),
532 ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
533 ("ollama", vec!["mistral", "deepseek"]),
534 ]);
535
536 // Results should preserve models order whenever possible.
537 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
538 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
539 // so it should appear first in the results.
540 let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
541 assert_models_eq(
542 results,
543 vec![
544 ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
545 ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
546 ],
547 );
548
549 // Fuzzy search
550 let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
551 assert_models_eq(
552 results,
553 vec![
554 ("zed", vec!["gpt-4.1-nano"]),
555 ("openai", vec!["gpt-4.1-nano"]),
556 ],
557 );
558 }
559}