1use std::{cmp::Reverse, rc::Rc, sync::Arc};
2
3use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
4use agent_servers::AgentServer;
5use anyhow::Result;
6use collections::IndexMap;
7use fs::Fs;
8use futures::FutureExt;
9use fuzzy::{StringMatchCandidate, match_strings};
10use gpui::{
11 Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
12};
13use ordered_float::OrderedFloat;
14use picker::{Picker, PickerDelegate};
15use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, prelude::*};
16use util::ResultExt;
17use zed_actions::agent::OpenSettings;
18
19use crate::ui::{HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
20
21pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
22
23pub fn acp_model_selector(
24 selector: Rc<dyn AgentModelSelector>,
25 agent_server: Rc<dyn AgentServer>,
26 fs: Arc<dyn Fs>,
27 focus_handle: FocusHandle,
28 window: &mut Window,
29 cx: &mut Context<AcpModelSelector>,
30) -> AcpModelSelector {
31 let delegate =
32 AcpModelPickerDelegate::new(selector, agent_server, fs, focus_handle, window, cx);
33 Picker::list(delegate, window, cx)
34 .show_scrollbar(true)
35 .width(rems(20.))
36 .max_height(Some(rems(20.).into()))
37}
38
39enum AcpModelPickerEntry {
40 Separator(SharedString),
41 Model(AgentModelInfo),
42}
43
44pub struct AcpModelPickerDelegate {
45 selector: Rc<dyn AgentModelSelector>,
46 agent_server: Rc<dyn AgentServer>,
47 fs: Arc<dyn Fs>,
48 filtered_entries: Vec<AcpModelPickerEntry>,
49 models: Option<AgentModelList>,
50 selected_index: usize,
51 selected_description: Option<(usize, SharedString, bool)>,
52 selected_model: Option<AgentModelInfo>,
53 _refresh_models_task: Task<()>,
54 focus_handle: FocusHandle,
55}
56
57impl AcpModelPickerDelegate {
58 fn new(
59 selector: Rc<dyn AgentModelSelector>,
60 agent_server: Rc<dyn AgentServer>,
61 fs: Arc<dyn Fs>,
62 focus_handle: FocusHandle,
63 window: &mut Window,
64 cx: &mut Context<AcpModelSelector>,
65 ) -> Self {
66 let rx = selector.watch(cx);
67 let refresh_models_task = {
68 cx.spawn_in(window, {
69 async move |this, cx| {
70 async fn refresh(
71 this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
72 cx: &mut AsyncWindowContext,
73 ) -> Result<()> {
74 let (models_task, selected_model_task) = this.update(cx, |this, cx| {
75 (
76 this.delegate.selector.list_models(cx),
77 this.delegate.selector.selected_model(cx),
78 )
79 })?;
80
81 let (models, selected_model) =
82 futures::join!(models_task, selected_model_task);
83
84 this.update_in(cx, |this, window, cx| {
85 this.delegate.models = models.ok();
86 this.delegate.selected_model = selected_model.ok();
87 this.refresh(window, cx)
88 })
89 }
90
91 refresh(&this, cx).await.log_err();
92 if let Some(mut rx) = rx {
93 while let Ok(()) = rx.recv().await {
94 refresh(&this, cx).await.log_err();
95 }
96 }
97 }
98 })
99 };
100
101 Self {
102 selector,
103 agent_server,
104 fs,
105 filtered_entries: Vec::new(),
106 models: None,
107 selected_model: None,
108 selected_index: 0,
109 selected_description: None,
110 _refresh_models_task: refresh_models_task,
111 focus_handle,
112 }
113 }
114
115 pub fn active_model(&self) -> Option<&AgentModelInfo> {
116 self.selected_model.as_ref()
117 }
118}
119
120impl PickerDelegate for AcpModelPickerDelegate {
121 type ListItem = AnyElement;
122
123 fn match_count(&self) -> usize {
124 self.filtered_entries.len()
125 }
126
127 fn selected_index(&self) -> usize {
128 self.selected_index
129 }
130
131 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
132 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
133 cx.notify();
134 }
135
136 fn can_select(
137 &mut self,
138 ix: usize,
139 _window: &mut Window,
140 _cx: &mut Context<Picker<Self>>,
141 ) -> bool {
142 match self.filtered_entries.get(ix) {
143 Some(AcpModelPickerEntry::Model(_)) => true,
144 Some(AcpModelPickerEntry::Separator(_)) | None => false,
145 }
146 }
147
148 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
149 "Select a model…".into()
150 }
151
152 fn update_matches(
153 &mut self,
154 query: String,
155 window: &mut Window,
156 cx: &mut Context<Picker<Self>>,
157 ) -> Task<()> {
158 cx.spawn_in(window, async move |this, cx| {
159 let filtered_models = match this
160 .read_with(cx, |this, cx| {
161 this.delegate.models.clone().map(move |models| {
162 fuzzy_search(models, query, cx.background_executor().clone())
163 })
164 })
165 .ok()
166 .flatten()
167 {
168 Some(task) => task.await,
169 None => AgentModelList::Flat(vec![]),
170 };
171
172 this.update_in(cx, |this, window, cx| {
173 this.delegate.filtered_entries =
174 info_list_to_picker_entries(filtered_models).collect();
175 // Finds the currently selected model in the list
176 let new_index = this
177 .delegate
178 .selected_model
179 .as_ref()
180 .and_then(|selected| {
181 this.delegate.filtered_entries.iter().position(|entry| {
182 if let AcpModelPickerEntry::Model(model_info) = entry {
183 model_info.id == selected.id
184 } else {
185 false
186 }
187 })
188 })
189 .unwrap_or(0);
190 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
191 cx.notify();
192 })
193 .ok();
194 })
195 }
196
197 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
198 if let Some(AcpModelPickerEntry::Model(model_info)) =
199 self.filtered_entries.get(self.selected_index)
200 {
201 if window.modifiers().secondary() {
202 let default_model = self.agent_server.default_model(cx);
203 let is_default = default_model.as_ref() == Some(&model_info.id);
204
205 self.agent_server.set_default_model(
206 if is_default {
207 None
208 } else {
209 Some(model_info.id.clone())
210 },
211 self.fs.clone(),
212 cx,
213 );
214 }
215
216 self.selector
217 .select_model(model_info.id.clone(), cx)
218 .detach_and_log_err(cx);
219 self.selected_model = Some(model_info.clone());
220 let current_index = self.selected_index;
221 self.set_selected_index(current_index, window, cx);
222
223 cx.emit(DismissEvent);
224 }
225 }
226
227 fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
228 cx.defer_in(window, |picker, window, cx| {
229 picker.set_query("", window, cx);
230 });
231 }
232
233 fn render_match(
234 &self,
235 ix: usize,
236 is_focused: bool,
237 _: &mut Window,
238 cx: &mut Context<Picker<Self>>,
239 ) -> Option<Self::ListItem> {
240 match self.filtered_entries.get(ix)? {
241 AcpModelPickerEntry::Separator(title) => {
242 Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
243 }
244 AcpModelPickerEntry::Model(model_info) => {
245 let is_selected = Some(model_info) == self.selected_model.as_ref();
246 let default_model = self.agent_server.default_model(cx);
247 let is_default = default_model.as_ref() == Some(&model_info.id);
248
249 Some(
250 div()
251 .id(("model-picker-menu-child", ix))
252 .when_some(model_info.description.clone(), |this, description| {
253 this
254 .on_hover(cx.listener(move |menu, hovered, _, cx| {
255 if *hovered {
256 menu.delegate.selected_description = Some((ix, description.clone(), is_default));
257 } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) {
258 menu.delegate.selected_description = None;
259 }
260 cx.notify();
261 }))
262 })
263 .child(
264 ModelSelectorListItem::new(ix, model_info.name.clone())
265 .is_focused(is_focused)
266 .is_selected(is_selected)
267 .when_some(model_info.icon, |this, icon| this.icon(icon)),
268 )
269 .into_any_element()
270 )
271 }
272 }
273 }
274
275 fn documentation_aside(
276 &self,
277 _window: &mut Window,
278 _cx: &mut Context<Picker<Self>>,
279 ) -> Option<ui::DocumentationAside> {
280 self.selected_description
281 .as_ref()
282 .map(|(_, description, is_default)| {
283 let description = description.clone();
284 let is_default = *is_default;
285
286 DocumentationAside::new(
287 DocumentationSide::Left,
288 DocumentationEdge::Top,
289 Rc::new(move |_| {
290 v_flex()
291 .gap_1()
292 .child(Label::new(description.clone()))
293 .child(HoldForDefault::new(is_default))
294 .into_any_element()
295 }),
296 )
297 })
298 }
299
300 fn render_footer(
301 &self,
302 _window: &mut Window,
303 _cx: &mut Context<Picker<Self>>,
304 ) -> Option<AnyElement> {
305 let focus_handle = self.focus_handle.clone();
306
307 if !self.selector.should_render_footer() {
308 return None;
309 }
310
311 Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
312 }
313}
314
315fn info_list_to_picker_entries(
316 model_list: AgentModelList,
317) -> impl Iterator<Item = AcpModelPickerEntry> {
318 match model_list {
319 AgentModelList::Flat(list) => {
320 itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
321 }
322 AgentModelList::Grouped(index_map) => {
323 itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
324 std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
325 .chain(models.into_iter().map(AcpModelPickerEntry::Model))
326 }))
327 }
328 }
329}
330
331async fn fuzzy_search(
332 model_list: AgentModelList,
333 query: String,
334 executor: BackgroundExecutor,
335) -> AgentModelList {
336 async fn fuzzy_search_list(
337 model_list: Vec<AgentModelInfo>,
338 query: &str,
339 executor: BackgroundExecutor,
340 ) -> Vec<AgentModelInfo> {
341 let candidates = model_list
342 .iter()
343 .enumerate()
344 .map(|(ix, model)| StringMatchCandidate::new(ix, model.name.as_ref()))
345 .collect::<Vec<_>>();
346 let mut matches = match_strings(
347 &candidates,
348 query,
349 false,
350 true,
351 100,
352 &Default::default(),
353 executor,
354 )
355 .await;
356
357 matches.sort_unstable_by_key(|mat| {
358 let candidate = &candidates[mat.candidate_id];
359 (Reverse(OrderedFloat(mat.score)), candidate.id)
360 });
361
362 matches
363 .into_iter()
364 .map(|mat| model_list[mat.candidate_id].clone())
365 .collect()
366 }
367
368 match model_list {
369 AgentModelList::Flat(model_list) => {
370 AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
371 }
372 AgentModelList::Grouped(index_map) => {
373 let groups =
374 futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
375 fuzzy_search_list(models, &query, executor.clone())
376 .map(|results| (group_name, results))
377 }))
378 .await;
379 AgentModelList::Grouped(IndexMap::from_iter(
380 groups
381 .into_iter()
382 .filter(|(_, results)| !results.is_empty()),
383 ))
384 }
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use agent_client_protocol as acp;
391 use gpui::TestAppContext;
392
393 use super::*;
394
395 fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
396 AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
397 |(group, models)| {
398 (
399 acp_thread::AgentModelGroupName(group.to_string().into()),
400 models
401 .into_iter()
402 .map(|model| acp_thread::AgentModelInfo {
403 id: acp::ModelId::new(model.to_string()),
404 name: model.to_string().into(),
405 description: None,
406 icon: None,
407 })
408 .collect::<Vec<_>>(),
409 )
410 },
411 )))
412 }
413
414 fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
415 let AgentModelList::Grouped(groups) = result else {
416 panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
417 };
418
419 assert_eq!(
420 groups.len(),
421 expected.len(),
422 "Number of groups doesn't match"
423 );
424
425 for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
426 let (actual_group, actual_models) = groups.get_index(i).unwrap();
427 assert_eq!(
428 actual_group.0.as_ref(),
429 *expected_group,
430 "Group at position {} doesn't match expected group",
431 i
432 );
433 assert_eq!(
434 actual_models.len(),
435 expected_models.len(),
436 "Number of models in group {} doesn't match",
437 expected_group
438 );
439
440 for (j, expected_model_name) in expected_models.iter().enumerate() {
441 assert_eq!(
442 actual_models[j].name, *expected_model_name,
443 "Model at position {} in group {} doesn't match expected model",
444 j, expected_group
445 );
446 }
447 }
448 }
449
450 #[gpui::test]
451 async fn test_fuzzy_match(cx: &mut TestAppContext) {
452 let models = create_model_list(vec![
453 (
454 "zed",
455 vec![
456 "Claude 3.7 Sonnet",
457 "Claude 3.7 Sonnet Thinking",
458 "gpt-4.1",
459 "gpt-4.1-nano",
460 ],
461 ),
462 ("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
463 ("ollama", vec!["mistral", "deepseek"]),
464 ]);
465
466 // Results should preserve models order whenever possible.
467 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
468 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
469 // so it should appear first in the results.
470 let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
471 assert_models_eq(
472 results,
473 vec![
474 ("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
475 ("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
476 ],
477 );
478
479 // Fuzzy search
480 let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
481 assert_models_eq(
482 results,
483 vec![
484 ("zed", vec!["gpt-4.1-nano"]),
485 ("openai", vec!["gpt-4.1-nano"]),
486 ],
487 );
488 }
489}