1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_client_protocol as acp;
5use anyhow::Result;
6use collections::IndexMap;
7use futures::FutureExt;
8use fuzzy::{StringMatchCandidate, match_strings};
9use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
10use ordered_float::OrderedFloat;
11use picker::{Picker, PickerDelegate};
12use ui::{
13 AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window,
14 prelude::*, rems,
15};
16use util::ResultExt;
17
18pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
19
20pub fn acp_model_selector(
21 session_id: acp::SessionId,
22 selector: Rc<dyn AgentModelSelector>,
23 window: &mut Window,
24 cx: &mut Context<AcpModelSelector>,
25) -> AcpModelSelector {
26 let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx);
27 Picker::list(delegate, window, cx)
28 .show_scrollbar(true)
29 .width(rems(20.))
30 .max_height(Some(rems(20.).into()))
31}
32
33enum AcpModelPickerEntry {
34 Separator(SharedString),
35 Model(AgentModelInfo),
36}
37
38pub struct AcpModelPickerDelegate {
39 session_id: acp::SessionId,
40 selector: Rc<dyn AgentModelSelector>,
41 filtered_entries: Vec<AcpModelPickerEntry>,
42 models: Option<AgentModelList>,
43 selected_index: usize,
44 selected_model: Option<AgentModelInfo>,
45 _refresh_models_task: Task<()>,
46}
47
48impl AcpModelPickerDelegate {
49 fn new(
50 session_id: acp::SessionId,
51 selector: Rc<dyn AgentModelSelector>,
52 window: &mut Window,
53 cx: &mut Context<AcpModelSelector>,
54 ) -> Self {
55 let mut rx = selector.watch(cx);
56 let refresh_models_task = cx.spawn_in(window, {
57 let session_id = session_id.clone();
58 async move |this, cx| {
59 async fn refresh(
60 this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
61 session_id: &acp::SessionId,
62 cx: &mut AsyncWindowContext,
63 ) -> Result<()> {
64 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
65 (
66 this.delegate.selector.list_models(cx),
67 this.delegate.selector.selected_model(session_id, cx),
68 )
69 })?;
70
71 let (models, selected_model) = futures::join!(models_task, selected_model_task);
72
73 this.update_in(cx, |this, window, cx| {
74 this.delegate.models = models.log_err();
75 this.delegate.selected_model = selected_model.ok();
76 this.refresh(window, cx)
77 })
78 }
79
80 refresh(&this, &session_id, cx).await.log_err();
81 while let Ok(()) = rx.recv().await {
82 refresh(&this, &session_id, cx).await.log_err();
83 }
84 }
85 });
86
87 Self {
88 session_id,
89 selector,
90 filtered_entries: Vec::new(),
91 models: None,
92 selected_model: None,
93 selected_index: 0,
94 _refresh_models_task: refresh_models_task,
95 }
96 }
97
98 pub fn active_model(&self) -> Option<&AgentModelInfo> {
99 self.selected_model.as_ref()
100 }
101}
102
103impl PickerDelegate for AcpModelPickerDelegate {
104 type ListItem = AnyElement;
105
106 fn match_count(&self) -> usize {
107 self.filtered_entries.len()
108 }
109
110 fn selected_index(&self) -> usize {
111 self.selected_index
112 }
113
114 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
115 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
116 cx.notify();
117 }
118
119 fn can_select(
120 &mut self,
121 ix: usize,
122 _window: &mut Window,
123 _cx: &mut Context<Picker<Self>>,
124 ) -> bool {
125 match self.filtered_entries.get(ix) {
126 Some(AcpModelPickerEntry::Model(_)) => true,
127 Some(AcpModelPickerEntry::Separator(_)) | None => false,
128 }
129 }
130
131 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
132 "Select a model…".into()
133 }
134
135 fn update_matches(
136 &mut self,
137 query: String,
138 window: &mut Window,
139 cx: &mut Context<Picker<Self>>,
140 ) -> Task<()> {
141 cx.spawn_in(window, async move |this, cx| {
142 let filtered_models = match this
143 .read_with(cx, |this, cx| {
144 if let Some(models) = this.delegate.models.as_ref() {
145 log::debug!("Filtering {} models.", models.len());
146 } else {
147 log::debug!("No models available.");
148 }
149 this.delegate.models.clone().map(move |models| {
150 fuzzy_search(models, query, cx.background_executor().clone())
151 })
152 })
153 .ok()
154 .flatten()
155 {
156 Some(task) => task.await,
157 None => AgentModelList::Flat(vec![]),
158 };
159
160 log::debug!("Filtered models. {} available.", filtered_models.len());
161
162 this.update_in(cx, |this, window, cx| {
163 this.delegate.filtered_entries =
164 info_list_to_picker_entries(filtered_models).collect();
165 // Finds the currently selected model in the list
166 let new_index = this
167 .delegate
168 .selected_model
169 .as_ref()
170 .and_then(|selected| {
171 this.delegate.filtered_entries.iter().position(|entry| {
172 if let AcpModelPickerEntry::Model(model_info) = entry {
173 model_info.id == selected.id
174 } else {
175 false
176 }
177 })
178 })
179 .unwrap_or(0);
180 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
181 cx.notify();
182 })
183 .ok();
184 })
185 }
186
187 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
188 if let Some(AcpModelPickerEntry::Model(model_info)) =
189 self.filtered_entries.get(self.selected_index)
190 {
191 self.selector
192 .select_model(self.session_id.clone(), model_info.id.clone(), cx)
193 .detach_and_log_err(cx);
194 self.selected_model = Some(model_info.clone());
195 let current_index = self.selected_index;
196 self.set_selected_index(current_index, window, cx);
197
198 cx.emit(DismissEvent);
199 }
200 }
201
202 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
203 cx.emit(DismissEvent);
204 }
205
206 fn render_match(
207 &self,
208 ix: usize,
209 selected: bool,
210 _: &mut Window,
211 cx: &mut Context<Picker<Self>>,
212 ) -> Option<Self::ListItem> {
213 match self.filtered_entries.get(ix)? {
214 AcpModelPickerEntry::Separator(title) => Some(
215 div()
216 .px_2()
217 .pb_1()
218 .when(ix > 1, |this| {
219 this.mt_1()
220 .pt_2()
221 .border_t_1()
222 .border_color(cx.theme().colors().border_variant)
223 })
224 .child(
225 Label::new(title)
226 .size(LabelSize::XSmall)
227 .color(Color::Muted),
228 )
229 .into_any_element(),
230 ),
231 AcpModelPickerEntry::Model(model_info) => {
232 let is_selected = Some(model_info) == self.selected_model.as_ref();
233
234 let model_icon_color = if is_selected {
235 Color::Accent
236 } else {
237 Color::Muted
238 };
239
240 Some(
241 ListItem::new(ix)
242 .inset(true)
243 .spacing(ListItemSpacing::Sparse)
244 .toggle_state(selected)
245 .start_slot::<Icon>(model_info.icon.map(|icon| {
246 Icon::new(icon)
247 .color(model_icon_color)
248 .size(IconSize::Small)
249 }))
250 .child(
251 h_flex()
252 .w_full()
253 .pl_0p5()
254 .gap_1p5()
255 .w(px(240.))
256 .child(Label::new(model_info.name.clone()).truncate()),
257 )
258 .end_slot(div().pr_3().when(is_selected, |this| {
259 this.child(
260 Icon::new(IconName::Check)
261 .color(Color::Accent)
262 .size(IconSize::Small),
263 )
264 }))
265 .into_any_element(),
266 )
267 }
268 }
269 }
270
271 fn render_footer(
272 &self,
273 _: &mut Window,
274 cx: &mut Context<Picker<Self>>,
275 ) -> Option<gpui::AnyElement> {
276 Some(
277 h_flex()
278 .w_full()
279 .border_t_1()
280 .border_color(cx.theme().colors().border_variant)
281 .p_1()
282 .gap_4()
283 .justify_between()
284 .child(
285 Button::new("configure", "Configure")
286 .icon(IconName::Settings)
287 .icon_size(IconSize::Small)
288 .icon_color(Color::Muted)
289 .icon_position(IconPosition::Start)
290 .on_click(|_, window, cx| {
291 window.dispatch_action(
292 zed_actions::agent::OpenSettings.boxed_clone(),
293 cx,
294 );
295 }),
296 )
297 .into_any(),
298 )
299 }
300}
301
302fn info_list_to_picker_entries(
303 model_list: AgentModelList,
304) -> impl Iterator<Item = AcpModelPickerEntry> {
305 match model_list {
306 AgentModelList::Flat(list) => {
307 itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
308 }
309 AgentModelList::Grouped(index_map) => {
310 itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
311 std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
312 .chain(models.into_iter().map(AcpModelPickerEntry::Model))
313 }))
314 }
315 }
316}
317
318async fn fuzzy_search(
319 model_list: AgentModelList,
320 query: String,
321 executor: BackgroundExecutor,
322) -> AgentModelList {
323 async fn fuzzy_search_list(
324 model_list: Vec<AgentModelInfo>,
325 query: &str,
326 executor: BackgroundExecutor,
327 ) -> Vec<AgentModelInfo> {
328 let candidates = model_list
329 .iter()
330 .enumerate()
331 .map(|(ix, model)| {
332 StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
333 })
334 .collect::<Vec<_>>();
335 let mut matches = match_strings(
336 &candidates,
337 query,
338 false,
339 true,
340 100,
341 &Default::default(),
342 executor,
343 )
344 .await;
345
346 matches.sort_unstable_by_key(|mat| {
347 let candidate = &candidates[mat.candidate_id];
348 (Reverse(OrderedFloat(mat.score)), candidate.id)
349 });
350
351 matches
352 .into_iter()
353 .map(|mat| model_list[mat.candidate_id].clone())
354 .collect()
355 }
356
357 match model_list {
358 AgentModelList::Flat(model_list) => {
359 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
360 }
361 AgentModelList::Grouped(index_map) => {
362 let groups =
363 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
364 fuzzy_search_list(models, &query, executor.clone())
365 .map(|results| (group_name, results))
366 }))
367 .await;
368 AgentModelList::Grouped(IndexMap::from_iter(
369 groups
370 .into_iter()
371 .filter(|(_, results)| !results.is_empty()),
372 ))
373 }
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use gpui::TestAppContext;
380
381 use super::*;
382
383 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
384 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
385 |(group, models)| {
386 (
387 acp_thread::AgentModelGroupName(group.to_string().into()),
388 models
389 .into_iter()
390 .map(|model| acp_thread::AgentModelInfo {
391 id: acp_thread::AgentModelId(model.to_string().into()),
392 name: model.to_string().into(),
393 icon: None,
394 })
395 .collect::<Vec<_>>(),
396 )
397 },
398 )))
399 }
400
401 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
402 let AgentModelList::Grouped(groups) = result else {
403 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
404 };
405
406 assert_eq!(
407 groups.len(),
408 expected.len(),
409 "Number of groups doesn't match"
410 );
411
412 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
413 let (actual_group, actual_models) = groups.get_index(i).unwrap();
414 assert_eq!(
415 actual_group.0.as_ref(),
416 *expected_group,
417 "Group at position {} doesn't match expected group",
418 i
419 );
420 assert_eq!(
421 actual_models.len(),
422 expected_models.len(),
423 "Number of models in group {} doesn't match",
424 expected_group
425 );
426
427 for (j, expected_model_name) in expected_models.iter().enumerate() {
428 assert_eq!(
429 actual_models[j].name, *expected_model_name,
430 "Model at position {} in group {} doesn't match expected model",
431 j, expected_group
432 );
433 }
434 }
435 }
436
437 #[gpui::test]
438 async fn test_fuzzy_match(cx: &mut TestAppContext) {
439 let models = create_model_list(vec![
440 (
441 "zed",
442 vec![
443 "Claude 3.7 Sonnet",
444 "Claude 3.7 Sonnet Thinking",
445 "gpt-4.1",
446 "gpt-4.1-nano",
447 ],
448 ),
449 ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
450 ("ollama", vec!["mistral", "deepseek"]),
451 ]);
452
453 // Results should preserve models order whenever possible.
454 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
455 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
456 // so it should appear first in the results.
457 let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
458 assert_models_eq(
459 results,
460 vec![
461 ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
462 ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
463 ],
464 );
465
466 // Fuzzy search
467 let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
468 assert_models_eq(
469 results,
470 vec![
471 ("zed", vec!["gpt-4.1-nano"]),
472 ("openai", vec!["gpt-4.1-nano"]),
473 ],
474 );
475 }
476}