1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_client_protocol as acp;
5use anyhow::Result;
6use collections::IndexMap;
7use futures::FutureExt;
8use fuzzy::{StringMatchCandidate, match_strings};
9use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
10use ordered_float::OrderedFloat;
11use picker::{Picker, PickerDelegate};
12use ui::{
13 AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window,
14 prelude::*, rems,
15};
16use util::ResultExt;
17
18pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
19
20pub fn acp_model_selector(
21 session_id: acp::SessionId,
22 selector: Rc<dyn AgentModelSelector>,
23 window: &mut Window,
24 cx: &mut Context<AcpModelSelector>,
25) -> AcpModelSelector {
26 let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx);
27 Picker::list(delegate, window, cx)
28 .show_scrollbar(true)
29 .width(rems(20.))
30 .max_height(Some(rems(20.).into()))
31}
32
33enum AcpModelPickerEntry {
34 Separator(SharedString),
35 Model(AgentModelInfo),
36}
37
38pub struct AcpModelPickerDelegate {
39 session_id: acp::SessionId,
40 selector: Rc<dyn AgentModelSelector>,
41 filtered_entries: Vec<AcpModelPickerEntry>,
42 models: Option<AgentModelList>,
43 selected_index: usize,
44 selected_model: Option<AgentModelInfo>,
45 _refresh_models_task: Task<()>,
46}
47
48impl AcpModelPickerDelegate {
49 fn new(
50 session_id: acp::SessionId,
51 selector: Rc<dyn AgentModelSelector>,
52 window: &mut Window,
53 cx: &mut Context<AcpModelSelector>,
54 ) -> Self {
55 let mut rx = selector.watch(cx);
56 let refresh_models_task = cx.spawn_in(window, {
57 let session_id = session_id.clone();
58 async move |this, cx| {
59 async fn refresh(
60 this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
61 session_id: &acp::SessionId,
62 cx: &mut AsyncWindowContext,
63 ) -> Result<()> {
64 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
65 (
66 this.delegate.selector.list_models(cx),
67 this.delegate.selector.selected_model(session_id, cx),
68 )
69 })?;
70
71 let (models, selected_model) = futures::join!(models_task, selected_model_task);
72
73 this.update_in(cx, |this, window, cx| {
74 this.delegate.models = models.ok();
75 this.delegate.selected_model = selected_model.ok();
76 this.delegate.update_matches(this.query(cx), window, cx)
77 })?
78 .await;
79
80 Ok(())
81 }
82
83 refresh(&this, &session_id, cx).await.log_err();
84 while let Ok(()) = rx.recv().await {
85 refresh(&this, &session_id, cx).await.log_err();
86 }
87 }
88 });
89
90 Self {
91 session_id,
92 selector,
93 filtered_entries: Vec::new(),
94 models: None,
95 selected_model: None,
96 selected_index: 0,
97 _refresh_models_task: refresh_models_task,
98 }
99 }
100
101 pub fn active_model(&self) -> Option<&AgentModelInfo> {
102 self.selected_model.as_ref()
103 }
104}
105
106impl PickerDelegate for AcpModelPickerDelegate {
107 type ListItem = AnyElement;
108
109 fn match_count(&self) -> usize {
110 self.filtered_entries.len()
111 }
112
113 fn selected_index(&self) -> usize {
114 self.selected_index
115 }
116
117 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
118 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
119 cx.notify();
120 }
121
122 fn can_select(
123 &mut self,
124 ix: usize,
125 _window: &mut Window,
126 _cx: &mut Context<Picker<Self>>,
127 ) -> bool {
128 match self.filtered_entries.get(ix) {
129 Some(AcpModelPickerEntry::Model(_)) => true,
130 Some(AcpModelPickerEntry::Separator(_)) | None => false,
131 }
132 }
133
134 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
135 "Select a model…".into()
136 }
137
138 fn update_matches(
139 &mut self,
140 query: String,
141 window: &mut Window,
142 cx: &mut Context<Picker<Self>>,
143 ) -> Task<()> {
144 cx.spawn_in(window, async move |this, cx| {
145 let filtered_models = match this
146 .read_with(cx, |this, cx| {
147 this.delegate.models.clone().map(move |models| {
148 fuzzy_search(models, query, cx.background_executor().clone())
149 })
150 })
151 .ok()
152 .flatten()
153 {
154 Some(task) => task.await,
155 None => AgentModelList::Flat(vec![]),
156 };
157
158 this.update_in(cx, |this, window, cx| {
159 this.delegate.filtered_entries =
160 info_list_to_picker_entries(filtered_models).collect();
161 // Finds the currently selected model in the list
162 let new_index = this
163 .delegate
164 .selected_model
165 .as_ref()
166 .and_then(|selected| {
167 this.delegate.filtered_entries.iter().position(|entry| {
168 if let AcpModelPickerEntry::Model(model_info) = entry {
169 model_info.id == selected.id
170 } else {
171 false
172 }
173 })
174 })
175 .unwrap_or(0);
176 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
177 cx.notify();
178 })
179 .ok();
180 })
181 }
182
183 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
184 if let Some(AcpModelPickerEntry::Model(model_info)) =
185 self.filtered_entries.get(self.selected_index)
186 {
187 self.selector
188 .select_model(self.session_id.clone(), model_info.id.clone(), cx)
189 .detach_and_log_err(cx);
190 self.selected_model = Some(model_info.clone());
191 let current_index = self.selected_index;
192 self.set_selected_index(current_index, window, cx);
193
194 cx.emit(DismissEvent);
195 }
196 }
197
198 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
199 cx.emit(DismissEvent);
200 }
201
202 fn render_match(
203 &self,
204 ix: usize,
205 selected: bool,
206 _: &mut Window,
207 cx: &mut Context<Picker<Self>>,
208 ) -> Option<Self::ListItem> {
209 match self.filtered_entries.get(ix)? {
210 AcpModelPickerEntry::Separator(title) => Some(
211 div()
212 .px_2()
213 .pb_1()
214 .when(ix > 1, |this| {
215 this.mt_1()
216 .pt_2()
217 .border_t_1()
218 .border_color(cx.theme().colors().border_variant)
219 })
220 .child(
221 Label::new(title)
222 .size(LabelSize::XSmall)
223 .color(Color::Muted),
224 )
225 .into_any_element(),
226 ),
227 AcpModelPickerEntry::Model(model_info) => {
228 let is_selected = Some(model_info) == self.selected_model.as_ref();
229
230 let model_icon_color = if is_selected {
231 Color::Accent
232 } else {
233 Color::Muted
234 };
235
236 Some(
237 ListItem::new(ix)
238 .inset(true)
239 .spacing(ListItemSpacing::Sparse)
240 .toggle_state(selected)
241 .start_slot::<Icon>(model_info.icon.map(|icon| {
242 Icon::new(icon)
243 .color(model_icon_color)
244 .size(IconSize::Small)
245 }))
246 .child(
247 h_flex()
248 .w_full()
249 .pl_0p5()
250 .gap_1p5()
251 .w(px(240.))
252 .child(Label::new(model_info.name.clone()).truncate()),
253 )
254 .end_slot(div().pr_3().when(is_selected, |this| {
255 this.child(
256 Icon::new(IconName::Check)
257 .color(Color::Accent)
258 .size(IconSize::Small),
259 )
260 }))
261 .into_any_element(),
262 )
263 }
264 }
265 }
266
267 fn render_footer(
268 &self,
269 _: &mut Window,
270 cx: &mut Context<Picker<Self>>,
271 ) -> Option<gpui::AnyElement> {
272 Some(
273 h_flex()
274 .w_full()
275 .border_t_1()
276 .border_color(cx.theme().colors().border_variant)
277 .p_1()
278 .gap_4()
279 .justify_between()
280 .child(
281 Button::new("configure", "Configure")
282 .icon(IconName::Settings)
283 .icon_size(IconSize::Small)
284 .icon_color(Color::Muted)
285 .icon_position(IconPosition::Start)
286 .on_click(|_, window, cx| {
287 window.dispatch_action(
288 zed_actions::agent::OpenSettings.boxed_clone(),
289 cx,
290 );
291 }),
292 )
293 .into_any(),
294 )
295 }
296}
297
298fn info_list_to_picker_entries(
299 model_list: AgentModelList,
300) -> impl Iterator<Item = AcpModelPickerEntry> {
301 match model_list {
302 AgentModelList::Flat(list) => {
303 itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
304 }
305 AgentModelList::Grouped(index_map) => {
306 itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
307 std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
308 .chain(models.into_iter().map(AcpModelPickerEntry::Model))
309 }))
310 }
311 }
312}
313
314async fn fuzzy_search(
315 model_list: AgentModelList,
316 query: String,
317 executor: BackgroundExecutor,
318) -> AgentModelList {
319 async fn fuzzy_search_list(
320 model_list: Vec<AgentModelInfo>,
321 query: &str,
322 executor: BackgroundExecutor,
323 ) -> Vec<AgentModelInfo> {
324 let candidates = model_list
325 .iter()
326 .enumerate()
327 .map(|(ix, model)| {
328 StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
329 })
330 .collect::<Vec<_>>();
331 let mut matches = match_strings(
332 &candidates,
333 query,
334 false,
335 true,
336 100,
337 &Default::default(),
338 executor,
339 )
340 .await;
341
342 matches.sort_unstable_by_key(|mat| {
343 let candidate = &candidates[mat.candidate_id];
344 (Reverse(OrderedFloat(mat.score)), candidate.id)
345 });
346
347 matches
348 .into_iter()
349 .map(|mat| model_list[mat.candidate_id].clone())
350 .collect()
351 }
352
353 match model_list {
354 AgentModelList::Flat(model_list) => {
355 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
356 }
357 AgentModelList::Grouped(index_map) => {
358 let groups =
359 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
360 fuzzy_search_list(models, &query, executor.clone())
361 .map(|results| (group_name, results))
362 }))
363 .await;
364 AgentModelList::Grouped(IndexMap::from_iter(
365 groups
366 .into_iter()
367 .filter(|(_, results)| !results.is_empty()),
368 ))
369 }
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use gpui::TestAppContext;
376
377 use super::*;
378
379 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
380 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
381 |(group, models)| {
382 (
383 acp_thread::AgentModelGroupName(group.to_string().into()),
384 models
385 .into_iter()
386 .map(|model| acp_thread::AgentModelInfo {
387 id: acp_thread::AgentModelId(model.to_string().into()),
388 name: model.to_string().into(),
389 icon: None,
390 })
391 .collect::<Vec<_>>(),
392 )
393 },
394 )))
395 }
396
397 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
398 let AgentModelList::Grouped(groups) = result else {
399 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
400 };
401
402 assert_eq!(
403 groups.len(),
404 expected.len(),
405 "Number of groups doesn't match"
406 );
407
408 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
409 let (actual_group, actual_models) = groups.get_index(i).unwrap();
410 assert_eq!(
411 actual_group.0.as_ref(),
412 *expected_group,
413 "Group at position {} doesn't match expected group",
414 i
415 );
416 assert_eq!(
417 actual_models.len(),
418 expected_models.len(),
419 "Number of models in group {} doesn't match",
420 expected_group
421 );
422
423 for (j, expected_model_name) in expected_models.iter().enumerate() {
424 assert_eq!(
425 actual_models[j].name, *expected_model_name,
426 "Model at position {} in group {} doesn't match expected model",
427 j, expected_group
428 );
429 }
430 }
431 }
432
433 #[gpui::test]
434 async fn test_fuzzy_match(cx: &mut TestAppContext) {
435 let models = create_model_list(vec![
436 (
437 "zed",
438 vec![
439 "Claude 3.7 Sonnet",
440 "Claude 3.7 Sonnet Thinking",
441 "gpt-4.1",
442 "gpt-4.1-nano",
443 ],
444 ),
445 ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
446 ("ollama", vec!["mistral", "deepseek"]),
447 ]);
448
449 // Results should preserve models order whenever possible.
450 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
451 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
452 // so it should appear first in the results.
453 let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
454 assert_models_eq(
455 results,
456 vec![
457 ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
458 ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
459 ],
460 );
461
462 // Fuzzy search
463 let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
464 assert_models_eq(
465 results,
466 vec![
467 ("zed", vec!["gpt-4.1-nano"]),
468 ("openai", vec!["gpt-4.1-nano"]),
469 ],
470 );
471 }
472}