1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_servers::AgentServer;
5use anyhow::Result;
6use collections::IndexMap;
7use fs::Fs;
8use futures::FutureExt;
9use fuzzy::{StringMatchCandidate, match_strings};
10use gpui::{AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
11use ordered_float::OrderedFloat;
12use picker::{Picker, PickerDelegate};
13use ui::{
14 DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, ListItem,
15 ListItemSpacing, prelude::*,
16};
17use util::ResultExt;
18
19use crate::ui::HoldForDefault;
20
21pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
22
23pub fn acp_model_selector(
24 selector: Rc<dyn AgentModelSelector>,
25 agent_server: Rc<dyn AgentServer>,
26 fs: Arc<dyn Fs>,
27 window: &mut Window,
28 cx: &mut Context<AcpModelSelector>,
29) -> AcpModelSelector {
30 let delegate = AcpModelPickerDelegate::new(selector, agent_server, fs, window, cx);
31 Picker::list(delegate, window, cx)
32 .show_scrollbar(true)
33 .width(rems(20.))
34 .max_height(Some(rems(20.).into()))
35}
36
37enum AcpModelPickerEntry {
38 Separator(SharedString),
39 Model(AgentModelInfo),
40}
41
42pub struct AcpModelPickerDelegate {
43 selector: Rc<dyn AgentModelSelector>,
44 agent_server: Rc<dyn AgentServer>,
45 fs: Arc<dyn Fs>,
46 filtered_entries: Vec<AcpModelPickerEntry>,
47 models: Option<AgentModelList>,
48 selected_index: usize,
49 selected_description: Option<(usize, SharedString, bool)>,
50 selected_model: Option<AgentModelInfo>,
51 _refresh_models_task: Task<()>,
52}
53
54impl AcpModelPickerDelegate {
55 fn new(
56 selector: Rc<dyn AgentModelSelector>,
57 agent_server: Rc<dyn AgentServer>,
58 fs: Arc<dyn Fs>,
59 window: &mut Window,
60 cx: &mut Context<AcpModelSelector>,
61 ) -> Self {
62 let rx = selector.watch(cx);
63 let refresh_models_task = {
64 cx.spawn_in(window, {
65 async move |this, cx| {
66 async fn refresh(
67 this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
68 cx: &mut AsyncWindowContext,
69 ) -> Result<()> {
70 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
71 (
72 this.delegate.selector.list_models(cx),
73 this.delegate.selector.selected_model(cx),
74 )
75 })?;
76
77 let (models, selected_model) =
78 futures::join!(models_task, selected_model_task);
79
80 this.update_in(cx, |this, window, cx| {
81 this.delegate.models = models.ok();
82 this.delegate.selected_model = selected_model.ok();
83 this.refresh(window, cx)
84 })
85 }
86
87 refresh(&this, cx).await.log_err();
88 if let Some(mut rx) = rx {
89 while let Ok(()) = rx.recv().await {
90 refresh(&this, cx).await.log_err();
91 }
92 }
93 }
94 })
95 };
96
97 Self {
98 selector,
99 agent_server,
100 fs,
101 filtered_entries: Vec::new(),
102 models: None,
103 selected_model: None,
104 selected_index: 0,
105 selected_description: None,
106 _refresh_models_task: refresh_models_task,
107 }
108 }
109
110 pub fn active_model(&self) -> Option<&AgentModelInfo> {
111 self.selected_model.as_ref()
112 }
113}
114
115impl PickerDelegate for AcpModelPickerDelegate {
116 type ListItem = AnyElement;
117
118 fn match_count(&self) -> usize {
119 self.filtered_entries.len()
120 }
121
122 fn selected_index(&self) -> usize {
123 self.selected_index
124 }
125
126 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
127 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
128 cx.notify();
129 }
130
131 fn can_select(
132 &mut self,
133 ix: usize,
134 _window: &mut Window,
135 _cx: &mut Context<Picker<Self>>,
136 ) -> bool {
137 match self.filtered_entries.get(ix) {
138 Some(AcpModelPickerEntry::Model(_)) => true,
139 Some(AcpModelPickerEntry::Separator(_)) | None => false,
140 }
141 }
142
143 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
144 "Select a model…".into()
145 }
146
147 fn update_matches(
148 &mut self,
149 query: String,
150 window: &mut Window,
151 cx: &mut Context<Picker<Self>>,
152 ) -> Task<()> {
153 cx.spawn_in(window, async move |this, cx| {
154 let filtered_models = match this
155 .read_with(cx, |this, cx| {
156 this.delegate.models.clone().map(move |models| {
157 fuzzy_search(models, query, cx.background_executor().clone())
158 })
159 })
160 .ok()
161 .flatten()
162 {
163 Some(task) => task.await,
164 None => AgentModelList::Flat(vec![]),
165 };
166
167 this.update_in(cx, |this, window, cx| {
168 this.delegate.filtered_entries =
169 info_list_to_picker_entries(filtered_models).collect();
170 // Finds the currently selected model in the list
171 let new_index = this
172 .delegate
173 .selected_model
174 .as_ref()
175 .and_then(|selected| {
176 this.delegate.filtered_entries.iter().position(|entry| {
177 if let AcpModelPickerEntry::Model(model_info) = entry {
178 model_info.id == selected.id
179 } else {
180 false
181 }
182 })
183 })
184 .unwrap_or(0);
185 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
186 cx.notify();
187 })
188 .ok();
189 })
190 }
191
192 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
193 if let Some(AcpModelPickerEntry::Model(model_info)) =
194 self.filtered_entries.get(self.selected_index)
195 {
196 if window.modifiers().secondary() {
197 let default_model = self.agent_server.default_model(cx);
198 let is_default = default_model.as_ref() == Some(&model_info.id);
199
200 self.agent_server.set_default_model(
201 if is_default {
202 None
203 } else {
204 Some(model_info.id.clone())
205 },
206 self.fs.clone(),
207 cx,
208 );
209 }
210
211 self.selector
212 .select_model(model_info.id.clone(), cx)
213 .detach_and_log_err(cx);
214 self.selected_model = Some(model_info.clone());
215 let current_index = self.selected_index;
216 self.set_selected_index(current_index, window, cx);
217
218 cx.emit(DismissEvent);
219 }
220 }
221
222 fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
223 cx.defer_in(window, |picker, window, cx| {
224 picker.set_query("", window, cx);
225 });
226 }
227
228 fn render_match(
229 &self,
230 ix: usize,
231 selected: bool,
232 _: &mut Window,
233 cx: &mut Context<Picker<Self>>,
234 ) -> Option<Self::ListItem> {
235 match self.filtered_entries.get(ix)? {
236 AcpModelPickerEntry::Separator(title) => Some(
237 div()
238 .px_2()
239 .pb_1()
240 .when(ix > 1, |this| {
241 this.mt_1()
242 .pt_2()
243 .border_t_1()
244 .border_color(cx.theme().colors().border_variant)
245 })
246 .child(
247 Label::new(title)
248 .size(LabelSize::XSmall)
249 .color(Color::Muted),
250 )
251 .into_any_element(),
252 ),
253 AcpModelPickerEntry::Model(model_info) => {
254 let is_selected = Some(model_info) == self.selected_model.as_ref();
255 let default_model = self.agent_server.default_model(cx);
256 let is_default = default_model.as_ref() == Some(&model_info.id);
257
258 let model_icon_color = if is_selected {
259 Color::Accent
260 } else {
261 Color::Muted
262 };
263
264 Some(
265 div()
266 .id(("model-picker-menu-child", ix))
267 .when_some(model_info.description.clone(), |this, description| {
268 this
269 .on_hover(cx.listener(move |menu, hovered, _, cx| {
270 if *hovered {
271 menu.delegate.selected_description = Some((ix, description.clone(), is_default));
272 } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
273 menu.delegate.selected_description = None;
274 }
275 cx.notify();
276 }))
277 })
278 .child(
279 ListItem::new(ix)
280 .inset(true)
281 .spacing(ListItemSpacing::Sparse)
282 .toggle_state(selected)
283 .child(
284 h_flex()
285 .w_full()
286 .gap_1p5()
287 .when_some(model_info.icon, |this, icon| {
288 this.child(
289 Icon::new(icon)
290 .color(model_icon_color)
291 .size(IconSize::Small)
292 )
293 })
294 .child(Label::new(model_info.name.clone()).truncate()),
295 )
296 .end_slot(div().pr_3().when(is_selected, |this| {
297 this.child(
298 Icon::new(IconName::Check)
299 .color(Color::Accent)
300 .size(IconSize::Small),
301 )
302 })),
303 )
304 .into_any_element()
305 )
306 }
307 }
308 }
309
310 fn documentation_aside(
311 &self,
312 _window: &mut Window,
313 _cx: &mut Context<Picker<Self>>,
314 ) -> Option<ui::DocumentationAside> {
315 self.selected_description
316 .as_ref()
317 .map(|(_, description, is_default)| {
318 let description = description.clone();
319 let is_default = *is_default;
320
321 DocumentationAside::new(
322 DocumentationSide::Left,
323 DocumentationEdge::Top,
324 Rc::new(move |_| {
325 v_flex()
326 .gap_1()
327 .child(Label::new(description.clone()))
328 .child(HoldForDefault::new(is_default))
329 .into_any_element()
330 }),
331 )
332 })
333 }
334}
335
336fn info_list_to_picker_entries(
337 model_list: AgentModelList,
338) -> impl Iterator<Item = AcpModelPickerEntry> {
339 match model_list {
340 AgentModelList::Flat(list) => {
341 itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
342 }
343 AgentModelList::Grouped(index_map) => {
344 itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
345 std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
346 .chain(models.into_iter().map(AcpModelPickerEntry::Model))
347 }))
348 }
349 }
350}
351
352async fn fuzzy_search(
353 model_list: AgentModelList,
354 query: String,
355 executor: BackgroundExecutor,
356) -> AgentModelList {
357 async fn fuzzy_search_list(
358 model_list: Vec<AgentModelInfo>,
359 query: &str,
360 executor: BackgroundExecutor,
361 ) -> Vec<AgentModelInfo> {
362 let candidates = model_list
363 .iter()
364 .enumerate()
365 .map(|(ix, model)| {
366 StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
367 })
368 .collect::<Vec<_>>();
369 let mut matches = match_strings(
370 &candidates,
371 query,
372 false,
373 true,
374 100,
375 &Default::default(),
376 executor,
377 )
378 .await;
379
380 matches.sort_unstable_by_key(|mat| {
381 let candidate = &candidates[mat.candidate_id];
382 (Reverse(OrderedFloat(mat.score)), candidate.id)
383 });
384
385 matches
386 .into_iter()
387 .map(|mat| model_list[mat.candidate_id].clone())
388 .collect()
389 }
390
391 match model_list {
392 AgentModelList::Flat(model_list) => {
393 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
394 }
395 AgentModelList::Grouped(index_map) => {
396 let groups =
397 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
398 fuzzy_search_list(models, &query, executor.clone())
399 .map(|results| (group_name, results))
400 }))
401 .await;
402 AgentModelList::Grouped(IndexMap::from_iter(
403 groups
404 .into_iter()
405 .filter(|(_, results)| !results.is_empty()),
406 ))
407 }
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use agent_client_protocol as acp;
414 use gpui::TestAppContext;
415
416 use super::*;
417
418 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
419 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
420 |(group, models)| {
421 (
422 acp_thread::AgentModelGroupName(group.to_string().into()),
423 models
424 .into_iter()
425 .map(|model| acp_thread::AgentModelInfo {
426 id: acp::ModelId(model.to_string().into()),
427 name: model.to_string().into(),
428 description: None,
429 icon: None,
430 })
431 .collect::<Vec<_>>(),
432 )
433 },
434 )))
435 }
436
437 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
438 let AgentModelList::Grouped(groups) = result else {
439 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
440 };
441
442 assert_eq!(
443 groups.len(),
444 expected.len(),
445 "Number of groups doesn't match"
446 );
447
448 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
449 let (actual_group, actual_models) = groups.get_index(i).unwrap();
450 assert_eq!(
451 actual_group.0.as_ref(),
452 *expected_group,
453 "Group at position {} doesn't match expected group",
454 i
455 );
456 assert_eq!(
457 actual_models.len(),
458 expected_models.len(),
459 "Number of models in group {} doesn't match",
460 expected_group
461 );
462
463 for (j, expected_model_name) in expected_models.iter().enumerate() {
464 assert_eq!(
465 actual_models[j].name, *expected_model_name,
466 "Model at position {} in group {} doesn't match expected model",
467 j, expected_group
468 );
469 }
470 }
471 }
472
473 #[gpui::test]
474 async fn test_fuzzy_match(cx: &mut TestAppContext) {
475 let models = create_model_list(vec![
476 (
477 "zed",
478 vec![
479 "Claude 3.7 Sonnet",
480 "Claude 3.7 Sonnet Thinking",
481 "gpt-4.1",
482 "gpt-4.1-nano",
483 ],
484 ),
485 ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
486 ("ollama", vec!["mistral", "deepseek"]),
487 ]);
488
489 // Results should preserve models order whenever possible.
490 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
491 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
492 // so it should appear first in the results.
493 let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
494 assert_models_eq(
495 results,
496 vec![
497 ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
498 ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
499 ],
500 );
501
502 // Fuzzy search
503 let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
504 assert_models_eq(
505 results,
506 vec![
507 ("zed", vec!["gpt-4.1-nano"]),
508 ("openai", vec!["gpt-4.1-nano"]),
509 ],
510 );
511 }
512}