1use std::{collections::BTreeMap, sync::Arc};
2
3use agent_settings::{
4 AgentProfile, AgentProfileContent, AgentProfileId, AgentSettings, AgentSettingsContent,
5 ContextServerPresetContent,
6};
7use assistant_tool::{ToolSource, ToolWorkingSet};
8use fs::Fs;
9use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
10use picker::{Picker, PickerDelegate};
11use settings::{Settings as _, update_settings_file};
12use ui::{ListItem, ListItemSpacing, prelude::*};
13use util::ResultExt as _;
14
15use crate::ThreadStore;
16
17pub struct ToolPicker {
18 picker: Entity<Picker<ToolPickerDelegate>>,
19}
20
21#[derive(Clone, Copy, Debug, PartialEq)]
22pub enum ToolPickerMode {
23 BuiltinTools,
24 McpTools,
25}
26
27impl ToolPicker {
28 pub fn builtin_tools(
29 delegate: ToolPickerDelegate,
30 window: &mut Window,
31 cx: &mut Context<Self>,
32 ) -> Self {
33 let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
34 Self { picker }
35 }
36
37 pub fn mcp_tools(
38 delegate: ToolPickerDelegate,
39 window: &mut Window,
40 cx: &mut Context<Self>,
41 ) -> Self {
42 let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
43 Self { picker }
44 }
45}
46
47impl EventEmitter<DismissEvent> for ToolPicker {}
48
49impl Focusable for ToolPicker {
50 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
51 self.picker.focus_handle(cx)
52 }
53}
54
55impl Render for ToolPicker {
56 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
57 v_flex().w(rems(34.)).child(self.picker.clone())
58 }
59}
60
61#[derive(Debug, Clone)]
62pub enum PickerItem {
63 Tool {
64 server_id: Option<Arc<str>>,
65 name: Arc<str>,
66 },
67 ContextServer {
68 server_id: Arc<str>,
69 },
70}
71
72pub struct ToolPickerDelegate {
73 tool_picker: WeakEntity<ToolPicker>,
74 thread_store: WeakEntity<ThreadStore>,
75 fs: Arc<dyn Fs>,
76 items: Arc<Vec<PickerItem>>,
77 profile_id: AgentProfileId,
78 profile: AgentProfile,
79 filtered_items: Vec<PickerItem>,
80 selected_index: usize,
81 mode: ToolPickerMode,
82}
83
84impl ToolPickerDelegate {
85 pub fn new(
86 mode: ToolPickerMode,
87 fs: Arc<dyn Fs>,
88 tool_set: Entity<ToolWorkingSet>,
89 thread_store: WeakEntity<ThreadStore>,
90 profile_id: AgentProfileId,
91 profile: AgentProfile,
92 cx: &mut Context<ToolPicker>,
93 ) -> Self {
94 let items = Arc::new(Self::resolve_items(mode, &tool_set, cx));
95
96 Self {
97 tool_picker: cx.entity().downgrade(),
98 thread_store,
99 fs,
100 items,
101 profile_id,
102 profile,
103 filtered_items: Vec::new(),
104 selected_index: 0,
105 mode,
106 }
107 }
108
109 fn resolve_items(
110 mode: ToolPickerMode,
111 tool_set: &Entity<ToolWorkingSet>,
112 cx: &mut App,
113 ) -> Vec<PickerItem> {
114 let mut items = Vec::new();
115 for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
116 match source {
117 ToolSource::Native => {
118 if mode == ToolPickerMode::BuiltinTools {
119 items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
120 name: tool.name().into(),
121 server_id: None,
122 }));
123 }
124 }
125 ToolSource::ContextServer { id } => {
126 if mode == ToolPickerMode::McpTools && !tools.is_empty() {
127 let server_id: Arc<str> = id.clone().into();
128 items.push(PickerItem::ContextServer {
129 server_id: server_id.clone(),
130 });
131 items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
132 name: tool.name().into(),
133 server_id: Some(server_id.clone()),
134 }));
135 }
136 }
137 }
138 }
139 items
140 }
141}
142
143impl PickerDelegate for ToolPickerDelegate {
144 type ListItem = AnyElement;
145
146 fn match_count(&self) -> usize {
147 self.filtered_items.len()
148 }
149
150 fn selected_index(&self) -> usize {
151 self.selected_index
152 }
153
154 fn set_selected_index(
155 &mut self,
156 ix: usize,
157 _window: &mut Window,
158 _cx: &mut Context<Picker<Self>>,
159 ) {
160 self.selected_index = ix;
161 }
162
163 fn can_select(
164 &mut self,
165 ix: usize,
166 _window: &mut Window,
167 _cx: &mut Context<Picker<Self>>,
168 ) -> bool {
169 let item = &self.filtered_items[ix];
170 match item {
171 PickerItem::Tool { .. } => true,
172 PickerItem::ContextServer { .. } => false,
173 }
174 }
175
176 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
177 match self.mode {
178 ToolPickerMode::BuiltinTools => "Search built-in tools…",
179 ToolPickerMode::McpTools => "Search MCP tools…",
180 }
181 .into()
182 }
183
184 fn update_matches(
185 &mut self,
186 query: String,
187 window: &mut Window,
188 cx: &mut Context<Picker<Self>>,
189 ) -> Task<()> {
190 let all_items = self.items.clone();
191
192 cx.spawn_in(window, async move |this, cx| {
193 let filtered_items = cx
194 .background_spawn(async move {
195 let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
196 BTreeMap::default();
197
198 for item in all_items.iter() {
199 if let PickerItem::Tool { server_id, name } = item.clone() {
200 if name.contains(&query) {
201 tools_by_provider.entry(server_id).or_default().push(name);
202 }
203 }
204 }
205
206 let mut items = Vec::new();
207
208 for (server_id, names) in tools_by_provider {
209 if let Some(server_id) = server_id.clone() {
210 items.push(PickerItem::ContextServer { server_id });
211 }
212 for name in names {
213 items.push(PickerItem::Tool {
214 server_id: server_id.clone(),
215 name,
216 });
217 }
218 }
219
220 items
221 })
222 .await;
223
224 this.update(cx, |this, _cx| {
225 this.delegate.filtered_items = filtered_items;
226 this.delegate.selected_index = this
227 .delegate
228 .selected_index
229 .min(this.delegate.filtered_items.len().saturating_sub(1));
230 })
231 .log_err();
232 })
233 }
234
235 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
236 if self.filtered_items.is_empty() {
237 self.dismissed(window, cx);
238 return;
239 }
240
241 let item = &self.filtered_items[self.selected_index];
242
243 let PickerItem::Tool {
244 name: tool_name,
245 server_id,
246 } = item
247 else {
248 return;
249 };
250
251 let is_currently_enabled = if let Some(server_id) = server_id.clone() {
252 let preset = self.profile.context_servers.entry(server_id).or_default();
253 let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
254 *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
255 is_enabled
256 } else {
257 let is_enabled = *self.profile.tools.entry(tool_name.clone()).or_default();
258 *self.profile.tools.entry(tool_name.clone()).or_default() = !is_enabled;
259 is_enabled
260 };
261
262 let active_profile_id = &AgentSettings::get_global(cx).default_profile;
263 if active_profile_id == &self.profile_id {
264 self.thread_store
265 .update(cx, |this, cx| {
266 this.load_profile(self.profile.clone(), cx);
267 })
268 .log_err();
269 }
270
271 update_settings_file::<AgentSettings>(self.fs.clone(), cx, {
272 let profile_id = self.profile_id.clone();
273 let default_profile = self.profile.clone();
274 let server_id = server_id.clone();
275 let tool_name = tool_name.clone();
276 move |settings: &mut AgentSettingsContent, _cx| {
277 settings
278 .v2_setting(|v2_settings| {
279 let profiles = v2_settings.profiles.get_or_insert_default();
280 let profile =
281 profiles
282 .entry(profile_id)
283 .or_insert_with(|| AgentProfileContent {
284 name: default_profile.name.into(),
285 tools: default_profile.tools,
286 enable_all_context_servers: Some(
287 default_profile.enable_all_context_servers,
288 ),
289 context_servers: default_profile
290 .context_servers
291 .into_iter()
292 .map(|(server_id, preset)| {
293 (
294 server_id,
295 ContextServerPresetContent {
296 tools: preset.tools,
297 },
298 )
299 })
300 .collect(),
301 });
302
303 if let Some(server_id) = server_id {
304 let preset = profile.context_servers.entry(server_id).or_default();
305 *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
306 } else {
307 *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
308 }
309
310 Ok(())
311 })
312 .ok();
313 }
314 });
315 }
316
317 fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
318 self.tool_picker
319 .update(cx, |_this, cx| cx.emit(DismissEvent))
320 .log_err();
321 }
322
323 fn render_match(
324 &self,
325 ix: usize,
326 selected: bool,
327 _window: &mut Window,
328 cx: &mut Context<Picker<Self>>,
329 ) -> Option<Self::ListItem> {
330 let item = &self.filtered_items[ix];
331 match item {
332 PickerItem::ContextServer { server_id, .. } => Some(
333 div()
334 .px_2()
335 .pb_1()
336 .when(ix > 1, |this| {
337 this.mt_1()
338 .pt_2()
339 .border_t_1()
340 .border_color(cx.theme().colors().border_variant)
341 })
342 .child(
343 Label::new(server_id)
344 .size(LabelSize::XSmall)
345 .color(Color::Muted),
346 )
347 .into_any_element(),
348 ),
349 PickerItem::Tool { name, server_id } => {
350 let is_enabled = if let Some(server_id) = server_id {
351 self.profile
352 .context_servers
353 .get(server_id.as_ref())
354 .and_then(|preset| preset.tools.get(name))
355 .copied()
356 .unwrap_or(self.profile.enable_all_context_servers)
357 } else {
358 self.profile.tools.get(name).copied().unwrap_or(false)
359 };
360
361 Some(
362 ListItem::new(ix)
363 .inset(true)
364 .spacing(ListItemSpacing::Sparse)
365 .toggle_state(selected)
366 .child(Label::new(name.clone()))
367 .end_slot::<Icon>(is_enabled.then(|| {
368 Icon::new(IconName::Check)
369 .size(IconSize::Small)
370 .color(Color::Success)
371 }))
372 .into_any_element(),
373 )
374 }
375 }
376 }
377}