1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_servers::AgentServer;
5use anyhow::Result;
6use collections::IndexMap;
7use fs::Fs;
8use futures::FutureExt;
9use fuzzy::{StringMatchCandidate, match_strings};
10use gpui::{
11 Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
12};
13use ordered_float::OrderedFloat;
14use picker::{Picker, PickerDelegate};
15use ui::{
16 DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, KeyBinding, ListItem,
17 ListItemSpacing, prelude::*,
18};
19use util::ResultExt;
20use zed_actions::agent::OpenSettings;
21
22use crate::ui::HoldForDefault;
23
24pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
25
26pub fn acp_model_selector(
27 selector: Rc<dyn AgentModelSelector>,
28 agent_server: Rc<dyn AgentServer>,
29 fs: Arc<dyn Fs>,
30 focus_handle: FocusHandle,
31 window: &mut Window,
32 cx: &mut Context<AcpModelSelector>,
33) -> AcpModelSelector {
34 let delegate =
35 AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
36 Picker::list(delegate, window, cx)
37 .show_scrollbar(true)
38 .width(rems(20.))
39 .max_height(Some(rems(20.).into()))
40}
41
42enum AcpModelPickerEntry {
43 Separator(SharedString),
44 Model(AgentModelInfo),
45}
46
47pub struct AcpModelPickerDelegate {
48 selector: Rc<dyn AgentModelSelector>,
49 agent_server: Rc<dyn AgentServer>,
50 fs: Arc<dyn Fs>,
51 filtered_entries: Vec<AcpModelPickerEntry>,
52 models: Option<AgentModelList>,
53 selected_index: usize,
54 selected_description: Option<(usize, SharedString, bool)>,
55 selected_model: Option<AgentModelInfo>,
56 _refresh_models_task: Task<()>,
57 focus_handle: FocusHandle,
58}
59
60impl AcpModelPickerDelegate {
61 fn new(
62 selector: Rc<dyn AgentModelSelector>,
63 agent_server: Rc<dyn AgentServer>,
64 fs: Arc<dyn Fs>,
65 focus_handle: FocusHandle,
66 window: &mut Window,
67 cx: &mut Context<AcpModelSelector>,
68 ) -> Self {
69 let rx = selector.watch(cx);
70 let refresh_models_task = {
71 cx.spawn_in(window, {
72 async move |this, cx| {
73 async fn refresh(
74 this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
75 cx: &mut AsyncWindowContext,
76 ) -> Result<()> {
77 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
78 (
79 this.delegate.selector.list_models(cx),
80 this.delegate.selector.selected_model(cx),
81 )
82 })?;
83
84 let (models, selected_model) =
85 futures::join!(models_task, selected_model_task);
86
87 this.update_in(cx, |this, window, cx| {
88 this.delegate.models = models.ok();
89 this.delegate.selected_model = selected_model.ok();
90 this.refresh(window, cx)
91 })
92 }
93
94 refresh(&this, cx).await.log_err();
95 if let Some(mut rx) = rx {
96 while let Ok(()) = rx.recv().await {
97 refresh(&this, cx).await.log_err();
98 }
99 }
100 }
101 })
102 };
103
104 Self {
105 selector,
106 agent_server,
107 fs,
108 filtered_entries: Vec::new(),
109 models: None,
110 selected_model: None,
111 selected_index: 0,
112 selected_description: None,
113 _refresh_models_task: refresh_models_task,
114 focus_handle,
115 }
116 }
117
118 pub fn active_model(&self) -> Option<&AgentModelInfo> {
119 self.selected_model.as_ref()
120 }
121}
122
123impl PickerDelegate for AcpModelPickerDelegate {
124 type ListItem = AnyElement;
125
126 fn match_count(&self) -> usize {
127 self.filtered_entries.len()
128 }
129
130 fn selected_index(&self) -> usize {
131 self.selected_index
132 }
133
134 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
135 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
136 cx.notify();
137 }
138
139 fn can_select(
140 &mut self,
141 ix: usize,
142 _window: &mut Window,
143 _cx: &mut Context<Picker<Self>>,
144 ) -> bool {
145 match self.filtered_entries.get(ix) {
146 Some(AcpModelPickerEntry::Model(_)) => true,
147 Some(AcpModelPickerEntry::Separator(_)) | None => false,
148 }
149 }
150
151 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
152 "Select a model…".into()
153 }
154
155 fn update_matches(
156 &mut self,
157 query: String,
158 window: &mut Window,
159 cx: &mut Context<Picker<Self>>,
160 ) -> Task<()> {
161 cx.spawn_in(window, async move |this, cx| {
162 let filtered_models = match this
163 .read_with(cx, |this, cx| {
164 this.delegate.models.clone().map(move |models| {
165 fuzzy_search(models, query, cx.background_executor().clone())
166 })
167 })
168 .ok()
169 .flatten()
170 {
171 Some(task) => task.await,
172 None => AgentModelList::Flat(vec![]),
173 };
174
175 this.update_in(cx, |this, window, cx| {
176 this.delegate.filtered_entries =
177 info_list_to_picker_entries(filtered_models).collect();
178 // Finds the currently selected model in the list
179 let new_index = this
180 .delegate
181 .selected_model
182 .as_ref()
183 .and_then(|selected| {
184 this.delegate.filtered_entries.iter().position(|entry| {
185 if let AcpModelPickerEntry::Model(model_info) = entry {
186 model_info.id == selected.id
187 } else {
188 false
189 }
190 })
191 })
192 .unwrap_or(0);
193 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
194 cx.notify();
195 })
196 .ok();
197 })
198 }
199
200 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
201 if let Some(AcpModelPickerEntry::Model(model_info)) =
202 self.filtered_entries.get(self.selected_index)
203 {
204 if window.modifiers().secondary() {
205 let default_model = self.agent_server.default_model(cx);
206 let is_default = default_model.as_ref() == Some(&model_info.id);
207
208 self.agent_server.set_default_model(
209 if is_default {
210 None
211 } else {
212 Some(model_info.id.clone())
213 },
214 self.fs.clone(),
215 cx,
216 );
217 }
218
219 self.selector
220 .select_model(model_info.id.clone(), cx)
221 .detach_and_log_err(cx);
222 self.selected_model = Some(model_info.clone());
223 let current_index = self.selected_index;
224 self.set_selected_index(current_index, window, cx);
225
226 cx.emit(DismissEvent);
227 }
228 }
229
230 fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
231 cx.defer_in(window, |picker, window, cx| {
232 picker.set_query("", window, cx);
233 });
234 }
235
236 fn render_match(
237 &self,
238 ix: usize,
239 selected: bool,
240 _: &mut Window,
241 cx: &mut Context<Picker<Self>>,
242 ) -> Option<Self::ListItem> {
243 match self.filtered_entries.get(ix)? {
244 AcpModelPickerEntry::Separator(title) => Some(
245 div()
246 .px_2()
247 .pb_1()
248 .when(ix > 1, |this| {
249 this.mt_1()
250 .pt_2()
251 .border_t_1()
252 .border_color(cx.theme().colors().border_variant)
253 })
254 .child(
255 Label::new(title)
256 .size(LabelSize::XSmall)
257 .color(Color::Muted),
258 )
259 .into_any_element(),
260 ),
261 AcpModelPickerEntry::Model(model_info) => {
262 let is_selected = Some(model_info) == self.selected_model.as_ref();
263 let default_model = self.agent_server.default_model(cx);
264 let is_default = default_model.as_ref() == Some(&model_info.id);
265
266 let model_icon_color = if is_selected {
267 Color::Accent
268 } else {
269 Color::Muted
270 };
271
272 Some(
273 div()
274 .id(("model-picker-menu-child", ix))
275 .when_some(model_info.description.clone(), |this, description| {
276 this
277 .on_hover(cx.listener(move |menu, hovered, _, cx| {
278 if *hovered {
279 menu.delegate.selected_description = Some((ix, description.clone(), is_default));
280 } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
281 menu.delegate.selected_description = None;
282 }
283 cx.notify();
284 }))
285 })
286 .child(
287 ListItem::new(ix)
288 .inset(true)
289 .spacing(ListItemSpacing::Sparse)
290 .toggle_state(selected)
291 .child(
292 h_flex()
293 .w_full()
294 .gap_1p5()
295 .map(|this| match &model_info.icon {
296 Some(icon) => this.child(
297 match icon {
298 AgentModelIcon::Path(path) => Icon::from_external_svg(path.clone()),
299 AgentModelIcon::Named(icon) => Icon::new(*icon)
300 }
301 .color(model_icon_color)
302 .size(IconSize::Small)
303 ),
304 None => this,
305 })
306 .child(Label::new(model_info.name.clone()).truncate()),
307 )
308 .end_slot(div().pr_3().when(is_selected, |this| {
309 this.child(
310 Icon::new(IconName::Check)
311 .color(Color::Accent)
312 .size(IconSize::Small),
313 )
314 })),
315 )
316 .into_any_element()
317 )
318 }
319 }
320 }
321
322 fn documentation_aside(
323 &self,
324 _window: &mut Window,
325 _cx: &mut Context<Picker<Self>>,
326 ) -> Option<ui::DocumentationAside> {
327 self.selected_description
328 .as_ref()
329 .map(|(_, description, is_default)| {
330 let description = description.clone();
331 let is_default = *is_default;
332
333 DocumentationAside::new(
334 DocumentationSide::Left,
335 DocumentationEdge::Top,
336 Rc::new(move |_| {
337 v_flex()
338 .gap_1()
339 .child(Label::new(description.clone()))
340 .child(HoldForDefault::new(is_default))
341 .into_any_element()
342 }),
343 )
344 })
345 }
346
347 fn render_footer(
348 &self,
349 _window: &mut Window,
350 cx: &mut Context<Picker<Self>>,
351 ) -> Option<AnyElement> {
352 let focus_handle = self.focus_handle.clone();
353
354 if !self.selector.should_render_footer() {
355 return None;
356 }
357
358 Some(
359 h_flex()
360 .w_full()
361 .p_1p5()
362 .border_t_1()
363 .border_color(cx.theme().colors().border_variant)
364 .child(
365 Button::new("configure", "Configure")
366 .full_width()
367 .style(ButtonStyle::Outlined)
368 .key_binding(
369 KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
370 .map(|kb| kb.size(rems_from_px(12.))),
371 )
372 .on_click(|_, window, cx| {
373 window.dispatch_action(OpenSettings.boxed_clone(), cx);
374 }),
375 )
376 .into_any(),
377 )
378 }
379}
380
381fn info_list_to_picker_entries(
382 model_list: AgentModelList,
383) -> impl Iterator<Item = AcpModelPickerEntry> {
384 match model_list {
385 AgentModelList::Flat(list) => {
386 itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
387 }
388 AgentModelList::Grouped(index_map) => {
389 itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
390 std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
391 .chain(models.into_iter().map(AcpModelPickerEntry::Model))
392 }))
393 }
394 }
395}
396
397async fn fuzzy_search(
398 model_list: AgentModelList,
399 query: String,
400 executor: BackgroundExecutor,
401) -> AgentModelList {
402 async fn fuzzy_search_list(
403 model_list: Vec<AgentModelInfo>,
404 query: &str,
405 executor: BackgroundExecutor,
406 ) -> Vec<AgentModelInfo> {
407 let candidates = model_list
408 .iter()
409 .enumerate()
410 .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
411 .collect::<Vec<_>>();
412 let mut matches = match_strings(
413 &candidates,
414 query,
415 false,
416 true,
417 100,
418 &Default::default(),
419 executor,
420 )
421 .await;
422
423 matches.sort_unstable_by_key(|mat| {
424 let candidate = &candidates[mat.candidate_id];
425 (Reverse(OrderedFloat(mat.score)), candidate.id)
426 });
427
428 matches
429 .into_iter()
430 .map(|mat| model_list[mat.candidate_id].clone())
431 .collect()
432 }
433
434 match model_list {
435 AgentModelList::Flat(model_list) => {
436 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
437 }
438 AgentModelList::Grouped(index_map) => {
439 let groups =
440 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
441 fuzzy_search_list(models, &query, executor.clone())
442 .map(|results| (group_name, results))
443 }))
444 .await;
445 AgentModelList::Grouped(IndexMap::from_iter(
446 groups
447 .into_iter()
448 .filter(|(_, results)| !results.is_empty()),
449 ))
450 }
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use agent_client_protocol as acp;
457 use gpui::TestAppContext;
458
459 use super::*;
460
461 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
462 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
463 |(group, models)| {
464 (
465 acp_thread::AgentModelGroupName(group.to_string().into()),
466 models
467 .into_iter()
468 .map(|model| acp_thread::AgentModelInfo {
469 id: acp::ModelId::new(model.to_string()),
470 name: model.to_string().into(),
471 description: None,
472 icon: None,
473 })
474 .collect::<Vec<_>>(),
475 )
476 },
477 )))
478 }
479
480 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
481 let AgentModelList::Grouped(groups) = result else {
482 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
483 };
484
485 assert_eq!(
486 groups.len(),
487 expected.len(),
488 "Number of groups doesn't match"
489 );
490
491 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
492 let (actual_group, actual_models) = groups.get_index(i).unwrap();
493 assert_eq!(
494 actual_group.0.as_ref(),
495 *expected_group,
496 "Group at position {} doesn't match expected group",
497 i
498 );
499 assert_eq!(
500 actual_models.len(),
501 expected_models.len(),
502 "Number of models in group {} doesn't match",
503 expected_group
504 );
505
506 for (j, expected_model_name) in expected_models.iter().enumerate() {
507 assert_eq!(
508 actual_models[j].name, *expected_model_name,
509 "Model at position {} in group {} doesn't match expected model",
510 j, expected_group
511 );
512 }
513 }
514 }
515
516 #[gpui::test]
517 async fn test_fuzzy_match(cx: &mut TestAppContext) {
518 let models = create_model_list(vec![
519 (
520 "zed",
521 vec![
522 "Claude 3.7 Sonnet",
523 "Claude 3.7 Sonnet Thinking",
524 "gpt-4.1",
525 "gpt-4.1-nano",
526 ],
527 ),
528 ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
529 ("ollama", vec!["mistral", "deepseek"]),
530 ]);
531
532 // Results should preserve models order whenever possible.
533 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
534 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
535 // so it should appear first in the results.
536 let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
537 assert_models_eq(
538 results,
539 vec![
540 ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
541 ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
542 ],
543 );
544
545 // Fuzzy search
546 let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
547 assert_models_eq(
548 results,
549 vec![
550 ("zed", vec!["gpt-4.1-nano"]),
551 ("openai", vec!["gpt-4.1-nano"]),
552 ],
553 );
554 }
555}