tool_picker.rs

  1use std::sync::Arc;
  2
  3use assistant_settings::{
  4    AgentProfile, AgentProfileContent, AssistantSettings, AssistantSettingsContent,
  5    ContextServerPresetContent, VersionedAssistantSettingsContent,
  6};
  7use assistant_tool::{ToolSource, ToolWorkingSet};
  8use fs::Fs;
  9use fuzzy::{match_strings, StringMatch, StringMatchCandidate};
 10use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
 11use picker::{Picker, PickerDelegate};
 12use settings::update_settings_file;
 13use ui::{prelude::*, HighlightedLabel, ListItem, ListItemSpacing};
 14use util::ResultExt as _;
 15
 16pub struct ToolPicker {
 17    picker: Entity<Picker<ToolPickerDelegate>>,
 18}
 19
 20impl ToolPicker {
 21    pub fn new(delegate: ToolPickerDelegate, window: &mut Window, cx: &mut Context<Self>) -> Self {
 22        let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
 23        Self { picker }
 24    }
 25}
 26
 27impl EventEmitter<DismissEvent> for ToolPicker {}
 28
 29impl Focusable for ToolPicker {
 30    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 31        self.picker.focus_handle(cx)
 32    }
 33}
 34
 35impl Render for ToolPicker {
 36    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
 37        v_flex().w(rems(34.)).child(self.picker.clone())
 38    }
 39}
 40
 41#[derive(Debug, Clone)]
 42pub struct ToolEntry {
 43    pub name: Arc<str>,
 44    pub source: ToolSource,
 45}
 46
 47pub struct ToolPickerDelegate {
 48    tool_picker: WeakEntity<ToolPicker>,
 49    fs: Arc<dyn Fs>,
 50    tools: Vec<ToolEntry>,
 51    profile_id: Arc<str>,
 52    profile: AgentProfile,
 53    matches: Vec<StringMatch>,
 54    selected_index: usize,
 55}
 56
 57impl ToolPickerDelegate {
 58    pub fn new(
 59        fs: Arc<dyn Fs>,
 60        tool_set: Arc<ToolWorkingSet>,
 61        profile_id: Arc<str>,
 62        profile: AgentProfile,
 63        cx: &mut Context<ToolPicker>,
 64    ) -> Self {
 65        let mut tool_entries = Vec::new();
 66
 67        for (source, tools) in tool_set.tools_by_source(cx) {
 68            tool_entries.extend(tools.into_iter().map(|tool| ToolEntry {
 69                name: tool.name().into(),
 70                source: source.clone(),
 71            }));
 72        }
 73
 74        Self {
 75            tool_picker: cx.entity().downgrade(),
 76            fs,
 77            tools: tool_entries,
 78            profile_id,
 79            profile,
 80            matches: Vec::new(),
 81            selected_index: 0,
 82        }
 83    }
 84}
 85
 86impl PickerDelegate for ToolPickerDelegate {
 87    type ListItem = ListItem;
 88
 89    fn match_count(&self) -> usize {
 90        self.matches.len()
 91    }
 92
 93    fn selected_index(&self) -> usize {
 94        self.selected_index
 95    }
 96
 97    fn set_selected_index(
 98        &mut self,
 99        ix: usize,
100        _window: &mut Window,
101        _cx: &mut Context<Picker<Self>>,
102    ) {
103        self.selected_index = ix;
104    }
105
106    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
107        "Search tools…".into()
108    }
109
110    fn update_matches(
111        &mut self,
112        query: String,
113        window: &mut Window,
114        cx: &mut Context<Picker<Self>>,
115    ) -> Task<()> {
116        let background = cx.background_executor().clone();
117        let candidates = self
118            .tools
119            .iter()
120            .enumerate()
121            .map(|(id, profile)| StringMatchCandidate::new(id, profile.name.as_ref()))
122            .collect::<Vec<_>>();
123
124        cx.spawn_in(window, async move |this, cx| {
125            let matches = if query.is_empty() {
126                candidates
127                    .into_iter()
128                    .enumerate()
129                    .map(|(index, candidate)| StringMatch {
130                        candidate_id: index,
131                        string: candidate.string,
132                        positions: Vec::new(),
133                        score: 0.,
134                    })
135                    .collect()
136            } else {
137                match_strings(
138                    &candidates,
139                    &query,
140                    false,
141                    100,
142                    &Default::default(),
143                    background,
144                )
145                .await
146            };
147
148            this.update(cx, |this, _cx| {
149                this.delegate.matches = matches;
150                this.delegate.selected_index = this
151                    .delegate
152                    .selected_index
153                    .min(this.delegate.matches.len().saturating_sub(1));
154            })
155            .log_err();
156        })
157    }
158
159    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
160        if self.matches.is_empty() {
161            self.dismissed(window, cx);
162            return;
163        }
164
165        let candidate_id = self.matches[self.selected_index].candidate_id;
166        let tool = &self.tools[candidate_id];
167
168        let is_enabled = match &tool.source {
169            ToolSource::Native => {
170                let is_enabled = self.profile.tools.entry(tool.name.clone()).or_default();
171                *is_enabled = !*is_enabled;
172                *is_enabled
173            }
174            ToolSource::ContextServer { id } => {
175                let preset = self
176                    .profile
177                    .context_servers
178                    .entry(id.clone().into())
179                    .or_default();
180                let is_enabled = preset.tools.entry(tool.name.clone()).or_default();
181                *is_enabled = !*is_enabled;
182                *is_enabled
183            }
184        };
185
186        update_settings_file::<AssistantSettings>(self.fs.clone(), cx, {
187            let profile_id = self.profile_id.clone();
188            let default_profile = self.profile.clone();
189            let tool = tool.clone();
190            move |settings, _cx| match settings {
191                AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
192                    settings,
193                )) => {
194                    let profiles = settings.profiles.get_or_insert_default();
195                    let profile =
196                        profiles
197                            .entry(profile_id)
198                            .or_insert_with(|| AgentProfileContent {
199                                name: default_profile.name.into(),
200                                tools: default_profile.tools,
201                                context_servers: default_profile
202                                    .context_servers
203                                    .into_iter()
204                                    .map(|(server_id, preset)| {
205                                        (
206                                            server_id,
207                                            ContextServerPresetContent {
208                                                tools: preset.tools,
209                                            },
210                                        )
211                                    })
212                                    .collect(),
213                            });
214
215                    match tool.source {
216                        ToolSource::Native => {
217                            *profile.tools.entry(tool.name).or_default() = is_enabled;
218                        }
219                        ToolSource::ContextServer { id } => {
220                            let preset = profile
221                                .context_servers
222                                .entry(id.clone().into())
223                                .or_default();
224                            *preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
225                        }
226                    }
227                }
228                _ => {}
229            }
230        });
231    }
232
233    fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
234        self.tool_picker
235            .update(cx, |_this, cx| cx.emit(DismissEvent))
236            .log_err();
237    }
238
239    fn render_match(
240        &self,
241        ix: usize,
242        selected: bool,
243        _window: &mut Window,
244        _cx: &mut Context<Picker<Self>>,
245    ) -> Option<Self::ListItem> {
246        let tool_match = &self.matches[ix];
247        let tool = &self.tools[tool_match.candidate_id];
248
249        let is_enabled = match &tool.source {
250            ToolSource::Native => self.profile.tools.get(&tool.name).copied().unwrap_or(false),
251            ToolSource::ContextServer { id } => self
252                .profile
253                .context_servers
254                .get(id.as_ref())
255                .and_then(|preset| preset.tools.get(&tool.name))
256                .copied()
257                .unwrap_or(false),
258        };
259
260        Some(
261            ListItem::new(ix)
262                .inset(true)
263                .spacing(ListItemSpacing::Sparse)
264                .toggle_state(selected)
265                .child(
266                    h_flex()
267                        .gap_2()
268                        .child(HighlightedLabel::new(
269                            tool_match.string.clone(),
270                            tool_match.positions.clone(),
271                        ))
272                        .map(|parent| match &tool.source {
273                            ToolSource::Native => parent,
274                            ToolSource::ContextServer { id } => parent
275                                .child(Label::new(id).size(LabelSize::XSmall).color(Color::Muted)),
276                        }),
277                )
278                .end_slot::<Icon>(is_enabled.then(|| {
279                    Icon::new(IconName::Check)
280                        .size(IconSize::Small)
281                        .color(Color::Success)
282                })),
283        )
284    }
285}