1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_client_protocol as acp;
5use anyhow::Result;
6use collections::IndexMap;
7use futures::FutureExt;
8use fuzzy::{StringMatchCandidate, match_strings};
9use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
10use ordered_float::OrderedFloat;
11use picker::{Picker, PickerDelegate};
12use ui::{
13 AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window,
14 prelude::*, rems,
15};
16use util::ResultExt;
17
18pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
19
20pub fn acp_model_selector(
21 session_id: acp::SessionId,
22 selector: Rc<dyn AgentModelSelector>,
23 window: &mut Window,
24 cx: &mut Context<AcpModelSelector>,
25) -> AcpModelSelector {
26 let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx);
27 Picker::list(delegate, window, cx)
28 .show_scrollbar(true)
29 .width(rems(20.))
30 .max_height(Some(rems(20.).into()))
31}
32
33enum AcpModelPickerEntry {
34 Separator(SharedString),
35 Model(AgentModelInfo),
36}
37
38pub struct AcpModelPickerDelegate {
39 session_id: acp::SessionId,
40 selector: Rc<dyn AgentModelSelector>,
41 filtered_entries: Vec<AcpModelPickerEntry>,
42 models: Option<AgentModelList>,
43 selected_index: usize,
44 selected_model: Option<AgentModelInfo>,
45 _refresh_models_task: Task<()>,
46}
47
48impl AcpModelPickerDelegate {
49 fn new(
50 session_id: acp::SessionId,
51 selector: Rc<dyn AgentModelSelector>,
52 window: &mut Window,
53 cx: &mut Context<AcpModelSelector>,
54 ) -> Self {
55 let mut rx = selector.watch(cx);
56 let refresh_models_task = cx.spawn_in(window, {
57 let session_id = session_id.clone();
58 async move |this, cx| {
59 async fn refresh(
60 this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
61 session_id: &acp::SessionId,
62 cx: &mut AsyncWindowContext,
63 ) -> Result<()> {
64 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
65 (
66 this.delegate.selector.list_models(cx),
67 this.delegate.selector.selected_model(session_id, cx),
68 )
69 })?;
70
71 let (models, selected_model) = futures::join!(models_task, selected_model_task);
72
73 this.update_in(cx, |this, window, cx| {
74 this.delegate.models = models.log_err();
75 this.delegate.selected_model = selected_model.ok();
76 this.delegate.update_matches(this.query(cx), window, cx)
77 })?
78 .await;
79
80 Ok(())
81 }
82
83 refresh(&this, &session_id, cx).await.log_err();
84 while let Ok(()) = rx.recv().await {
85 refresh(&this, &session_id, cx).await.log_err();
86 }
87 }
88 });
89
90 Self {
91 session_id,
92 selector,
93 filtered_entries: Vec::new(),
94 models: None,
95 selected_model: None,
96 selected_index: 0,
97 _refresh_models_task: refresh_models_task,
98 }
99 }
100
101 pub fn active_model(&self) -> Option<&AgentModelInfo> {
102 self.selected_model.as_ref()
103 }
104}
105
106impl PickerDelegate for AcpModelPickerDelegate {
107 type ListItem = AnyElement;
108
109 fn match_count(&self) -> usize {
110 self.filtered_entries.len()
111 }
112
113 fn selected_index(&self) -> usize {
114 self.selected_index
115 }
116
117 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
118 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
119 cx.notify();
120 }
121
122 fn can_select(
123 &mut self,
124 ix: usize,
125 _window: &mut Window,
126 _cx: &mut Context<Picker<Self>>,
127 ) -> bool {
128 match self.filtered_entries.get(ix) {
129 Some(AcpModelPickerEntry::Model(_)) => true,
130 Some(AcpModelPickerEntry::Separator(_)) | None => false,
131 }
132 }
133
134 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
135 "Select a model…".into()
136 }
137
138 fn update_matches(
139 &mut self,
140 query: String,
141 window: &mut Window,
142 cx: &mut Context<Picker<Self>>,
143 ) -> Task<()> {
144 cx.spawn_in(window, async move |this, cx| {
145 let filtered_models = match this
146 .read_with(cx, |this, cx| {
147 if let Some(models) = this.delegate.models.as_ref() {
148 log::debug!("Filtering {} models.", models.len());
149 } else {
150 log::debug!("No models available.");
151 }
152 this.delegate.models.clone().map(move |models| {
153 fuzzy_search(models, query, cx.background_executor().clone())
154 })
155 })
156 .ok()
157 .flatten()
158 {
159 Some(task) => task.await,
160 None => AgentModelList::Flat(vec![]),
161 };
162
163 log::debug!("Filtered models. {} available.", filtered_models.len());
164
165 this.update_in(cx, |this, window, cx| {
166 this.delegate.filtered_entries =
167 info_list_to_picker_entries(filtered_models).collect();
168 // Finds the currently selected model in the list
169 let new_index = this
170 .delegate
171 .selected_model
172 .as_ref()
173 .and_then(|selected| {
174 this.delegate.filtered_entries.iter().position(|entry| {
175 if let AcpModelPickerEntry::Model(model_info) = entry {
176 model_info.id == selected.id
177 } else {
178 false
179 }
180 })
181 })
182 .unwrap_or(0);
183 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
184 cx.notify();
185 })
186 .ok();
187 })
188 }
189
190 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
191 if let Some(AcpModelPickerEntry::Model(model_info)) =
192 self.filtered_entries.get(self.selected_index)
193 {
194 self.selector
195 .select_model(self.session_id.clone(), model_info.id.clone(), cx)
196 .detach_and_log_err(cx);
197 self.selected_model = Some(model_info.clone());
198 let current_index = self.selected_index;
199 self.set_selected_index(current_index, window, cx);
200
201 cx.emit(DismissEvent);
202 }
203 }
204
205 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
206 cx.emit(DismissEvent);
207 }
208
209 fn render_match(
210 &self,
211 ix: usize,
212 selected: bool,
213 _: &mut Window,
214 cx: &mut Context<Picker<Self>>,
215 ) -> Option<Self::ListItem> {
216 match self.filtered_entries.get(ix)? {
217 AcpModelPickerEntry::Separator(title) => Some(
218 div()
219 .px_2()
220 .pb_1()
221 .when(ix > 1, |this| {
222 this.mt_1()
223 .pt_2()
224 .border_t_1()
225 .border_color(cx.theme().colors().border_variant)
226 })
227 .child(
228 Label::new(title)
229 .size(LabelSize::XSmall)
230 .color(Color::Muted),
231 )
232 .into_any_element(),
233 ),
234 AcpModelPickerEntry::Model(model_info) => {
235 let is_selected = Some(model_info) == self.selected_model.as_ref();
236
237 let model_icon_color = if is_selected {
238 Color::Accent
239 } else {
240 Color::Muted
241 };
242
243 Some(
244 ListItem::new(ix)
245 .inset(true)
246 .spacing(ListItemSpacing::Sparse)
247 .toggle_state(selected)
248 .start_slot::<Icon>(model_info.icon.map(|icon| {
249 Icon::new(icon)
250 .color(model_icon_color)
251 .size(IconSize::Small)
252 }))
253 .child(
254 h_flex()
255 .w_full()
256 .pl_0p5()
257 .gap_1p5()
258 .w(px(240.))
259 .child(Label::new(model_info.name.clone()).truncate()),
260 )
261 .end_slot(div().pr_3().when(is_selected, |this| {
262 this.child(
263 Icon::new(IconName::Check)
264 .color(Color::Accent)
265 .size(IconSize::Small),
266 )
267 }))
268 .into_any_element(),
269 )
270 }
271 }
272 }
273
274 fn render_footer(
275 &self,
276 _: &mut Window,
277 cx: &mut Context<Picker<Self>>,
278 ) -> Option<gpui::AnyElement> {
279 Some(
280 h_flex()
281 .w_full()
282 .border_t_1()
283 .border_color(cx.theme().colors().border_variant)
284 .p_1()
285 .gap_4()
286 .justify_between()
287 .child(
288 Button::new("configure", "Configure")
289 .icon(IconName::Settings)
290 .icon_size(IconSize::Small)
291 .icon_color(Color::Muted)
292 .icon_position(IconPosition::Start)
293 .on_click(|_, window, cx| {
294 window.dispatch_action(
295 zed_actions::agent::OpenSettings.boxed_clone(),
296 cx,
297 );
298 }),
299 )
300 .into_any(),
301 )
302 }
303}
304
305fn info_list_to_picker_entries(
306 model_list: AgentModelList,
307) -> impl Iterator<Item = AcpModelPickerEntry> {
308 match model_list {
309 AgentModelList::Flat(list) => {
310 itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
311 }
312 AgentModelList::Grouped(index_map) => {
313 itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
314 std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
315 .chain(models.into_iter().map(AcpModelPickerEntry::Model))
316 }))
317 }
318 }
319}
320
321async fn fuzzy_search(
322 model_list: AgentModelList,
323 query: String,
324 executor: BackgroundExecutor,
325) -> AgentModelList {
326 async fn fuzzy_search_list(
327 model_list: Vec<AgentModelInfo>,
328 query: &str,
329 executor: BackgroundExecutor,
330 ) -> Vec<AgentModelInfo> {
331 let candidates = model_list
332 .iter()
333 .enumerate()
334 .map(|(ix, model)| {
335 StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
336 })
337 .collect::<Vec<_>>();
338 let mut matches = match_strings(
339 &candidates,
340 query,
341 false,
342 true,
343 100,
344 &Default::default(),
345 executor,
346 )
347 .await;
348
349 matches.sort_unstable_by_key(|mat| {
350 let candidate = &candidates[mat.candidate_id];
351 (Reverse(OrderedFloat(mat.score)), candidate.id)
352 });
353
354 matches
355 .into_iter()
356 .map(|mat| model_list[mat.candidate_id].clone())
357 .collect()
358 }
359
360 match model_list {
361 AgentModelList::Flat(model_list) => {
362 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
363 }
364 AgentModelList::Grouped(index_map) => {
365 let groups =
366 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
367 fuzzy_search_list(models, &query, executor.clone())
368 .map(|results| (group_name, results))
369 }))
370 .await;
371 AgentModelList::Grouped(IndexMap::from_iter(
372 groups
373 .into_iter()
374 .filter(|(_, results)| !results.is_empty()),
375 ))
376 }
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use gpui::TestAppContext;
383
384 use super::*;
385
386 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
387 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
388 |(group, models)| {
389 (
390 acp_thread::AgentModelGroupName(group.to_string().into()),
391 models
392 .into_iter()
393 .map(|model| acp_thread::AgentModelInfo {
394 id: acp_thread::AgentModelId(model.to_string().into()),
395 name: model.to_string().into(),
396 icon: None,
397 })
398 .collect::<Vec<_>>(),
399 )
400 },
401 )))
402 }
403
404 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
405 let AgentModelList::Grouped(groups) = result else {
406 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
407 };
408
409 assert_eq!(
410 groups.len(),
411 expected.len(),
412 "Number of groups doesn't match"
413 );
414
415 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
416 let (actual_group, actual_models) = groups.get_index(i).unwrap();
417 assert_eq!(
418 actual_group.0.as_ref(),
419 *expected_group,
420 "Group at position {} doesn't match expected group",
421 i
422 );
423 assert_eq!(
424 actual_models.len(),
425 expected_models.len(),
426 "Number of models in group {} doesn't match",
427 expected_group
428 );
429
430 for (j, expected_model_name) in expected_models.iter().enumerate() {
431 assert_eq!(
432 actual_models[j].name, *expected_model_name,
433 "Model at position {} in group {} doesn't match expected model",
434 j, expected_group
435 );
436 }
437 }
438 }
439
440 #[gpui::test]
441 async fn test_fuzzy_match(cx: &mut TestAppContext) {
442 let models = create_model_list(vec![
443 (
444 "zed",
445 vec![
446 "Claude 3.7 Sonnet",
447 "Claude 3.7 Sonnet Thinking",
448 "gpt-4.1",
449 "gpt-4.1-nano",
450 ],
451 ),
452 ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
453 ("ollama", vec!["mistral", "deepseek"]),
454 ]);
455
456 // Results should preserve models order whenever possible.
457 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
458 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
459 // so it should appear first in the results.
460 let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
461 assert_models_eq(
462 results,
463 vec![
464 ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
465 ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
466 ],
467 );
468
469 // Fuzzy search
470 let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
471 assert_models_eq(
472 results,
473 vec![
474 ("zed", vec!["gpt-4.1-nano"]),
475 ("openai", vec!["gpt-4.1-nano"]),
476 ],
477 );
478 }
479}