1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
4use anyhow::Result;
5use collections::IndexMap;
6use futures::FutureExt;
7use fuzzy::{StringMatchCandidate, match_strings};
8use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
9use ordered_float::OrderedFloat;
10use picker::{Picker, PickerDelegate};
11use ui::{
12 AnyElement, App, Context, DocumentationAside, DocumentationEdge, DocumentationSide,
13 IntoElement, ListItem, ListItemSpacing, SharedString, Window, prelude::*, rems,
14};
15use util::ResultExt;
16
17pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
18
19pub fn acp_model_selector(
20 selector: Rc<dyn AgentModelSelector>,
21 window: &mut Window,
22 cx: &mut Context<AcpModelSelector>,
23) -> AcpModelSelector {
24 let delegate = AcpModelPickerDelegate::new(selector, window, cx);
25 Picker::list(delegate, window, cx)
26 .show_scrollbar(true)
27 .width(rems(20.))
28 .max_height(Some(rems(20.).into()))
29}
30
31enum AcpModelPickerEntry {
32 Separator(SharedString),
33 Model(AgentModelInfo),
34}
35
36pub struct AcpModelPickerDelegate {
37 selector: Rc<dyn AgentModelSelector>,
38 filtered_entries: Vec<AcpModelPickerEntry>,
39 models: Option<AgentModelList>,
40 selected_index: usize,
41 selected_description: Option<(usize, SharedString)>,
42 selected_model: Option<AgentModelInfo>,
43 _refresh_models_task: Task<()>,
44}
45
46impl AcpModelPickerDelegate {
47 fn new(
48 selector: Rc<dyn AgentModelSelector>,
49 window: &mut Window,
50 cx: &mut Context<AcpModelSelector>,
51 ) -> Self {
52 let rx = selector.watch(cx);
53 let refresh_models_task = {
54 cx.spawn_in(window, {
55 async move |this, cx| {
56 async fn refresh(
57 this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
58 cx: &mut AsyncWindowContext,
59 ) -> Result<()> {
60 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
61 (
62 this.delegate.selector.list_models(cx),
63 this.delegate.selector.selected_model(cx),
64 )
65 })?;
66
67 let (models, selected_model) =
68 futures::join!(models_task, selected_model_task);
69
70 this.update_in(cx, |this, window, cx| {
71 this.delegate.models = models.ok();
72 this.delegate.selected_model = selected_model.ok();
73 this.refresh(window, cx)
74 })
75 }
76
77 refresh(&this, cx).await.log_err();
78 if let Some(mut rx) = rx {
79 while let Ok(()) = rx.recv().await {
80 refresh(&this, cx).await.log_err();
81 }
82 }
83 }
84 })
85 };
86
87 Self {
88 selector,
89 filtered_entries: Vec::new(),
90 models: None,
91 selected_model: None,
92 selected_index: 0,
93 selected_description: None,
94 _refresh_models_task: refresh_models_task,
95 }
96 }
97
98 pub fn active_model(&self) -> Option<&AgentModelInfo> {
99 self.selected_model.as_ref()
100 }
101}
102
103impl PickerDelegate for AcpModelPickerDelegate {
104 type ListItem = AnyElement;
105
106 fn match_count(&self) -> usize {
107 self.filtered_entries.len()
108 }
109
110 fn selected_index(&self) -> usize {
111 self.selected_index
112 }
113
114 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
115 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
116 cx.notify();
117 }
118
119 fn can_select(
120 &mut self,
121 ix: usize,
122 _window: &mut Window,
123 _cx: &mut Context<Picker<Self>>,
124 ) -> bool {
125 match self.filtered_entries.get(ix) {
126 Some(AcpModelPickerEntry::Model(_)) => true,
127 Some(AcpModelPickerEntry::Separator(_)) | None => false,
128 }
129 }
130
131 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
132 "Select a model…".into()
133 }
134
135 fn update_matches(
136 &mut self,
137 query: String,
138 window: &mut Window,
139 cx: &mut Context<Picker<Self>>,
140 ) -> Task<()> {
141 cx.spawn_in(window, async move |this, cx| {
142 let filtered_models = match this
143 .read_with(cx, |this, cx| {
144 this.delegate.models.clone().map(move |models| {
145 fuzzy_search(models, query, cx.background_executor().clone())
146 })
147 })
148 .ok()
149 .flatten()
150 {
151 Some(task) => task.await,
152 None => AgentModelList::Flat(vec![]),
153 };
154
155 this.update_in(cx, |this, window, cx| {
156 this.delegate.filtered_entries =
157 info_list_to_picker_entries(filtered_models).collect();
158 // Finds the currently selected model in the list
159 let new_index = this
160 .delegate
161 .selected_model
162 .as_ref()
163 .and_then(|selected| {
164 this.delegate.filtered_entries.iter().position(|entry| {
165 if let AcpModelPickerEntry::Model(model_info) = entry {
166 model_info.id == selected.id
167 } else {
168 false
169 }
170 })
171 })
172 .unwrap_or(0);
173 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
174 cx.notify();
175 })
176 .ok();
177 })
178 }
179
180 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
181 if let Some(AcpModelPickerEntry::Model(model_info)) =
182 self.filtered_entries.get(self.selected_index)
183 {
184 self.selector
185 .select_model(model_info.id.clone(), cx)
186 .detach_and_log_err(cx);
187 self.selected_model = Some(model_info.clone());
188 let current_index = self.selected_index;
189 self.set_selected_index(current_index, window, cx);
190
191 cx.emit(DismissEvent);
192 }
193 }
194
195 fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
196 cx.defer_in(window, |picker, window, cx| {
197 picker.set_query("", window, cx);
198 });
199 }
200
201 fn render_match(
202 &self,
203 ix: usize,
204 selected: bool,
205 _: &mut Window,
206 cx: &mut Context<Picker<Self>>,
207 ) -> Option<Self::ListItem> {
208 match self.filtered_entries.get(ix)? {
209 AcpModelPickerEntry::Separator(title) => Some(
210 div()
211 .px_2()
212 .pb_1()
213 .when(ix > 1, |this| {
214 this.mt_1()
215 .pt_2()
216 .border_t_1()
217 .border_color(cx.theme().colors().border_variant)
218 })
219 .child(
220 Label::new(title)
221 .size(LabelSize::XSmall)
222 .color(Color::Muted),
223 )
224 .into_any_element(),
225 ),
226 AcpModelPickerEntry::Model(model_info) => {
227 let is_selected = Some(model_info) == self.selected_model.as_ref();
228
229 let model_icon_color = if is_selected {
230 Color::Accent
231 } else {
232 Color::Muted
233 };
234
235 Some(
236 div()
237 .id(("model-picker-menu-child", ix))
238 .when_some(model_info.description.clone(), |this, description| {
239 this
240 .on_hover(cx.listener(move |menu, hovered, _, cx| {
241 if *hovered {
242 menu.delegate.selected_description = Some((ix, description.clone()));
243 } else if matches!(menu.delegate.selected_description, Some((id, _)) if id == ix) {
244 menu.delegate.selected_description = None;
245 }
246 cx.notify();
247 }))
248 })
249 .child(
250 ListItem::new(ix)
251 .inset(true)
252 .spacing(ListItemSpacing::Sparse)
253 .toggle_state(selected)
254 .start_slot::<Icon>(model_info.icon.map(|icon| {
255 Icon::new(icon)
256 .color(model_icon_color)
257 .size(IconSize::Small)
258 }))
259 .child(
260 h_flex()
261 .w_full()
262 .pl_0p5()
263 .gap_1p5()
264 .w(px(240.))
265 .child(Label::new(model_info.name.clone()).truncate()),
266 )
267 .end_slot(div().pr_3().when(is_selected, |this| {
268 this.child(
269 Icon::new(IconName::Check)
270 .color(Color::Accent)
271 .size(IconSize::Small),
272 )
273 })),
274 )
275 .into_any_element()
276 )
277 }
278 }
279 }
280
281 fn render_footer(
282 &self,
283 _: &mut Window,
284 cx: &mut Context<Picker<Self>>,
285 ) -> Option<gpui::AnyElement> {
286 Some(
287 h_flex()
288 .w_full()
289 .border_t_1()
290 .border_color(cx.theme().colors().border_variant)
291 .p_1()
292 .gap_4()
293 .justify_between()
294 .child(
295 Button::new("configure", "Configure")
296 .icon(IconName::Settings)
297 .icon_size(IconSize::Small)
298 .icon_color(Color::Muted)
299 .icon_position(IconPosition::Start)
300 .on_click(|_, window, cx| {
301 window.dispatch_action(
302 zed_actions::agent::OpenSettings.boxed_clone(),
303 cx,
304 );
305 }),
306 )
307 .into_any(),
308 )
309 }
310
311 fn documentation_aside(
312 &self,
313 _window: &mut Window,
314 _cx: &mut Context<Picker<Self>>,
315 ) -> Option<ui::DocumentationAside> {
316 self.selected_description.as_ref().map(|(_, description)| {
317 let description = description.clone();
318 DocumentationAside::new(
319 DocumentationSide::Left,
320 DocumentationEdge::Bottom,
321 Rc::new(move |_| Label::new(description.clone()).into_any_element()),
322 )
323 })
324 }
325}
326
327fn info_list_to_picker_entries(
328 model_list: AgentModelList,
329) -> impl Iterator<Item = AcpModelPickerEntry> {
330 match model_list {
331 AgentModelList::Flat(list) => {
332 itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
333 }
334 AgentModelList::Grouped(index_map) => {
335 itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
336 std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
337 .chain(models.into_iter().map(AcpModelPickerEntry::Model))
338 }))
339 }
340 }
341}
342
343async fn fuzzy_search(
344 model_list: AgentModelList,
345 query: String,
346 executor: BackgroundExecutor,
347) -> AgentModelList {
348 async fn fuzzy_search_list(
349 model_list: Vec<AgentModelInfo>,
350 query: &str,
351 executor: BackgroundExecutor,
352 ) -> Vec<AgentModelInfo> {
353 let candidates = model_list
354 .iter()
355 .enumerate()
356 .map(|(ix, model)| {
357 StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
358 })
359 .collect::<Vec<_>>();
360 let mut matches = match_strings(
361 &candidates,
362 query,
363 false,
364 true,
365 100,
366 &Default::default(),
367 executor,
368 )
369 .await;
370
371 matches.sort_unstable_by_key(|mat| {
372 let candidate = &candidates[mat.candidate_id];
373 (Reverse(OrderedFloat(mat.score)), candidate.id)
374 });
375
376 matches
377 .into_iter()
378 .map(|mat| model_list[mat.candidate_id].clone())
379 .collect()
380 }
381
382 match model_list {
383 AgentModelList::Flat(model_list) => {
384 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
385 }
386 AgentModelList::Grouped(index_map) => {
387 let groups =
388 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
389 fuzzy_search_list(models, &query, executor.clone())
390 .map(|results| (group_name, results))
391 }))
392 .await;
393 AgentModelList::Grouped(IndexMap::from_iter(
394 groups
395 .into_iter()
396 .filter(|(_, results)| !results.is_empty()),
397 ))
398 }
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use agent_client_protocol as acp;
405 use gpui::TestAppContext;
406
407 use super::*;
408
409 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
410 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
411 |(group, models)| {
412 (
413 acp_thread::AgentModelGroupName(group.to_string().into()),
414 models
415 .into_iter()
416 .map(|model| acp_thread::AgentModelInfo {
417 id: acp::ModelId(model.to_string().into()),
418 name: model.to_string().into(),
419 description: None,
420 icon: None,
421 })
422 .collect::<Vec<_>>(),
423 )
424 },
425 )))
426 }
427
428 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
429 let AgentModelList::Grouped(groups) = result else {
430 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
431 };
432
433 assert_eq!(
434 groups.len(),
435 expected.len(),
436 "Number of groups doesn't match"
437 );
438
439 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
440 let (actual_group, actual_models) = groups.get_index(i).unwrap();
441 assert_eq!(
442 actual_group.0.as_ref(),
443 *expected_group,
444 "Group at position {} doesn't match expected group",
445 i
446 );
447 assert_eq!(
448 actual_models.len(),
449 expected_models.len(),
450 "Number of models in group {} doesn't match",
451 expected_group
452 );
453
454 for (j, expected_model_name) in expected_models.iter().enumerate() {
455 assert_eq!(
456 actual_models[j].name, *expected_model_name,
457 "Model at position {} in group {} doesn't match expected model",
458 j, expected_group
459 );
460 }
461 }
462 }
463
464 #[gpui::test]
465 async fn test_fuzzy_match(cx: &mut TestAppContext) {
466 let models = create_model_list(vec![
467 (
468 "zed",
469 vec![
470 "Claude 3.7 Sonnet",
471 "Claude 3.7 Sonnet Thinking",
472 "gpt-4.1",
473 "gpt-4.1-nano",
474 ],
475 ),
476 ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
477 ("ollama", vec!["mistral", "deepseek"]),
478 ]);
479
480 // Results should preserve models order whenever possible.
481 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
482 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
483 // so it should appear first in the results.
484 let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
485 assert_models_eq(
486 results,
487 vec![
488 ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
489 ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
490 ],
491 );
492
493 // Fuzzy search
494 let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
495 assert_models_eq(
496 results,
497 vec![
498 ("zed", vec!["gpt-4.1-nano"]),
499 ("openai", vec!["gpt-4.1-nano"]),
500 ],
501 );
502 }
503}