1use std::sync::Arc;
2
3use assistant_settings::{
4 AgentProfile, AgentProfileContent, AssistantSettings, AssistantSettingsContent,
5 ContextServerPresetContent, VersionedAssistantSettingsContent,
6};
7use assistant_tool::{ToolSource, ToolWorkingSet};
8use fs::Fs;
9use fuzzy::{match_strings, StringMatch, StringMatchCandidate};
10use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
11use picker::{Picker, PickerDelegate};
12use settings::update_settings_file;
13use ui::{prelude::*, HighlightedLabel, ListItem, ListItemSpacing};
14use util::ResultExt as _;
15
16pub struct ToolPicker {
17 picker: Entity<Picker<ToolPickerDelegate>>,
18}
19
20impl ToolPicker {
21 pub fn new(delegate: ToolPickerDelegate, window: &mut Window, cx: &mut Context<Self>) -> Self {
22 let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
23 Self { picker }
24 }
25}
26
27impl EventEmitter<DismissEvent> for ToolPicker {}
28
29impl Focusable for ToolPicker {
30 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
31 self.picker.focus_handle(cx)
32 }
33}
34
35impl Render for ToolPicker {
36 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
37 v_flex().w(rems(34.)).child(self.picker.clone())
38 }
39}
40
41#[derive(Debug, Clone)]
42pub struct ToolEntry {
43 pub name: Arc<str>,
44 pub source: ToolSource,
45}
46
47pub struct ToolPickerDelegate {
48 tool_picker: WeakEntity<ToolPicker>,
49 fs: Arc<dyn Fs>,
50 tools: Vec<ToolEntry>,
51 profile_id: Arc<str>,
52 profile: AgentProfile,
53 matches: Vec<StringMatch>,
54 selected_index: usize,
55}
56
57impl ToolPickerDelegate {
58 pub fn new(
59 fs: Arc<dyn Fs>,
60 tool_set: Arc<ToolWorkingSet>,
61 profile_id: Arc<str>,
62 profile: AgentProfile,
63 cx: &mut Context<ToolPicker>,
64 ) -> Self {
65 let mut tool_entries = Vec::new();
66
67 for (source, tools) in tool_set.tools_by_source(cx) {
68 tool_entries.extend(tools.into_iter().map(|tool| ToolEntry {
69 name: tool.name().into(),
70 source: source.clone(),
71 }));
72 }
73
74 Self {
75 tool_picker: cx.entity().downgrade(),
76 fs,
77 tools: tool_entries,
78 profile_id,
79 profile,
80 matches: Vec::new(),
81 selected_index: 0,
82 }
83 }
84}
85
86impl PickerDelegate for ToolPickerDelegate {
87 type ListItem = ListItem;
88
89 fn match_count(&self) -> usize {
90 self.matches.len()
91 }
92
93 fn selected_index(&self) -> usize {
94 self.selected_index
95 }
96
97 fn set_selected_index(
98 &mut self,
99 ix: usize,
100 _window: &mut Window,
101 _cx: &mut Context<Picker<Self>>,
102 ) {
103 self.selected_index = ix;
104 }
105
106 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
107 "Search tools…".into()
108 }
109
110 fn update_matches(
111 &mut self,
112 query: String,
113 window: &mut Window,
114 cx: &mut Context<Picker<Self>>,
115 ) -> Task<()> {
116 let background = cx.background_executor().clone();
117 let candidates = self
118 .tools
119 .iter()
120 .enumerate()
121 .map(|(id, profile)| StringMatchCandidate::new(id, profile.name.as_ref()))
122 .collect::<Vec<_>>();
123
124 cx.spawn_in(window, async move |this, cx| {
125 let matches = if query.is_empty() {
126 candidates
127 .into_iter()
128 .enumerate()
129 .map(|(index, candidate)| StringMatch {
130 candidate_id: index,
131 string: candidate.string,
132 positions: Vec::new(),
133 score: 0.,
134 })
135 .collect()
136 } else {
137 match_strings(
138 &candidates,
139 &query,
140 false,
141 100,
142 &Default::default(),
143 background,
144 )
145 .await
146 };
147
148 this.update(cx, |this, _cx| {
149 this.delegate.matches = matches;
150 this.delegate.selected_index = this
151 .delegate
152 .selected_index
153 .min(this.delegate.matches.len().saturating_sub(1));
154 })
155 .log_err();
156 })
157 }
158
159 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
160 if self.matches.is_empty() {
161 self.dismissed(window, cx);
162 return;
163 }
164
165 let candidate_id = self.matches[self.selected_index].candidate_id;
166 let tool = &self.tools[candidate_id];
167
168 let is_enabled = match &tool.source {
169 ToolSource::Native => {
170 let is_enabled = self.profile.tools.entry(tool.name.clone()).or_default();
171 *is_enabled = !*is_enabled;
172 *is_enabled
173 }
174 ToolSource::ContextServer { id } => {
175 let preset = self
176 .profile
177 .context_servers
178 .entry(id.clone().into())
179 .or_default();
180 let is_enabled = preset.tools.entry(tool.name.clone()).or_default();
181 *is_enabled = !*is_enabled;
182 *is_enabled
183 }
184 };
185
186 update_settings_file::<AssistantSettings>(self.fs.clone(), cx, {
187 let profile_id = self.profile_id.clone();
188 let default_profile = self.profile.clone();
189 let tool = tool.clone();
190 move |settings, _cx| match settings {
191 AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
192 settings,
193 )) => {
194 let profiles = settings.profiles.get_or_insert_default();
195 let profile =
196 profiles
197 .entry(profile_id)
198 .or_insert_with(|| AgentProfileContent {
199 name: default_profile.name.into(),
200 tools: default_profile.tools,
201 context_servers: default_profile
202 .context_servers
203 .into_iter()
204 .map(|(server_id, preset)| {
205 (
206 server_id,
207 ContextServerPresetContent {
208 tools: preset.tools,
209 },
210 )
211 })
212 .collect(),
213 });
214
215 match tool.source {
216 ToolSource::Native => {
217 *profile.tools.entry(tool.name).or_default() = is_enabled;
218 }
219 ToolSource::ContextServer { id } => {
220 let preset = profile
221 .context_servers
222 .entry(id.clone().into())
223 .or_default();
224 *preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
225 }
226 }
227 }
228 _ => {}
229 }
230 });
231 }
232
233 fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
234 self.tool_picker
235 .update(cx, |_this, cx| cx.emit(DismissEvent))
236 .log_err();
237 }
238
239 fn render_match(
240 &self,
241 ix: usize,
242 selected: bool,
243 _window: &mut Window,
244 _cx: &mut Context<Picker<Self>>,
245 ) -> Option<Self::ListItem> {
246 let tool_match = &self.matches[ix];
247 let tool = &self.tools[tool_match.candidate_id];
248
249 let is_enabled = match &tool.source {
250 ToolSource::Native => self.profile.tools.get(&tool.name).copied().unwrap_or(false),
251 ToolSource::ContextServer { id } => self
252 .profile
253 .context_servers
254 .get(id.as_ref())
255 .and_then(|preset| preset.tools.get(&tool.name))
256 .copied()
257 .unwrap_or(false),
258 };
259
260 Some(
261 ListItem::new(ix)
262 .inset(true)
263 .spacing(ListItemSpacing::Sparse)
264 .toggle_state(selected)
265 .child(
266 h_flex()
267 .gap_2()
268 .child(HighlightedLabel::new(
269 tool_match.string.clone(),
270 tool_match.positions.clone(),
271 ))
272 .map(|parent| match &tool.source {
273 ToolSource::Native => parent,
274 ToolSource::ContextServer { id } => parent
275 .child(Label::new(id).size(LabelSize::XSmall).color(Color::Muted)),
276 }),
277 )
278 .end_slot::<Icon>(is_enabled.then(|| {
279 Icon::new(IconName::Check)
280 .size(IconSize::Small)
281 .color(Color::Success)
282 })),
283 )
284 }
285}