tool_picker.rs

  1use std::sync::Arc;
  2
  3use assistant_settings::{
  4    AgentProfile, AssistantSettings, AssistantSettingsContent, VersionedAssistantSettingsContent,
  5};
  6use assistant_tool::{ToolSource, ToolWorkingSet};
  7use fs::Fs;
  8use fuzzy::{match_strings, StringMatch, StringMatchCandidate};
  9use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
 10use picker::{Picker, PickerDelegate};
 11use settings::update_settings_file;
 12use ui::{prelude::*, HighlightedLabel, ListItem, ListItemSpacing};
 13use util::ResultExt as _;
 14
 15pub struct ToolPicker {
 16    picker: Entity<Picker<ToolPickerDelegate>>,
 17}
 18
 19impl ToolPicker {
 20    pub fn new(delegate: ToolPickerDelegate, window: &mut Window, cx: &mut Context<Self>) -> Self {
 21        let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
 22        Self { picker }
 23    }
 24}
 25
 26impl EventEmitter<DismissEvent> for ToolPicker {}
 27
 28impl Focusable for ToolPicker {
 29    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 30        self.picker.focus_handle(cx)
 31    }
 32}
 33
 34impl Render for ToolPicker {
 35    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
 36        v_flex().w(rems(34.)).child(self.picker.clone())
 37    }
 38}
 39
 40#[derive(Debug, Clone)]
 41pub struct ToolEntry {
 42    pub name: Arc<str>,
 43    pub source: ToolSource,
 44}
 45
 46pub struct ToolPickerDelegate {
 47    tool_picker: WeakEntity<ToolPicker>,
 48    fs: Arc<dyn Fs>,
 49    tools: Vec<ToolEntry>,
 50    profile_id: Arc<str>,
 51    profile: AgentProfile,
 52    matches: Vec<StringMatch>,
 53    selected_index: usize,
 54}
 55
 56impl ToolPickerDelegate {
 57    pub fn new(
 58        fs: Arc<dyn Fs>,
 59        tool_set: Arc<ToolWorkingSet>,
 60        profile_id: Arc<str>,
 61        profile: AgentProfile,
 62        cx: &mut Context<ToolPicker>,
 63    ) -> Self {
 64        let mut tool_entries = Vec::new();
 65
 66        for (source, tools) in tool_set.tools_by_source(cx) {
 67            tool_entries.extend(tools.into_iter().map(|tool| ToolEntry {
 68                name: tool.name().into(),
 69                source: source.clone(),
 70            }));
 71        }
 72
 73        Self {
 74            tool_picker: cx.entity().downgrade(),
 75            fs,
 76            tools: tool_entries,
 77            profile_id,
 78            profile,
 79            matches: Vec::new(),
 80            selected_index: 0,
 81        }
 82    }
 83}
 84
 85impl PickerDelegate for ToolPickerDelegate {
 86    type ListItem = ListItem;
 87
 88    fn match_count(&self) -> usize {
 89        self.matches.len()
 90    }
 91
 92    fn selected_index(&self) -> usize {
 93        self.selected_index
 94    }
 95
 96    fn set_selected_index(
 97        &mut self,
 98        ix: usize,
 99        _window: &mut Window,
100        _cx: &mut Context<Picker<Self>>,
101    ) {
102        self.selected_index = ix;
103    }
104
105    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
106        "Search tools…".into()
107    }
108
109    fn update_matches(
110        &mut self,
111        query: String,
112        window: &mut Window,
113        cx: &mut Context<Picker<Self>>,
114    ) -> Task<()> {
115        let background = cx.background_executor().clone();
116        let candidates = self
117            .tools
118            .iter()
119            .enumerate()
120            .map(|(id, profile)| StringMatchCandidate::new(id, profile.name.as_ref()))
121            .collect::<Vec<_>>();
122
123        cx.spawn_in(window, async move |this, cx| {
124            let matches = if query.is_empty() {
125                candidates
126                    .into_iter()
127                    .enumerate()
128                    .map(|(index, candidate)| StringMatch {
129                        candidate_id: index,
130                        string: candidate.string,
131                        positions: Vec::new(),
132                        score: 0.,
133                    })
134                    .collect()
135            } else {
136                match_strings(
137                    &candidates,
138                    &query,
139                    false,
140                    100,
141                    &Default::default(),
142                    background,
143                )
144                .await
145            };
146
147            this.update(cx, |this, _cx| {
148                this.delegate.matches = matches;
149                this.delegate.selected_index = this
150                    .delegate
151                    .selected_index
152                    .min(this.delegate.matches.len().saturating_sub(1));
153            })
154            .log_err();
155        })
156    }
157
158    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
159        if self.matches.is_empty() {
160            self.dismissed(window, cx);
161            return;
162        }
163
164        let candidate_id = self.matches[self.selected_index].candidate_id;
165        let tool = &self.tools[candidate_id];
166
167        let is_enabled = match &tool.source {
168            ToolSource::Native => {
169                let is_enabled = self.profile.tools.entry(tool.name.clone()).or_default();
170                *is_enabled = !*is_enabled;
171                *is_enabled
172            }
173            ToolSource::ContextServer { id } => {
174                let preset = self
175                    .profile
176                    .context_servers
177                    .entry(id.clone().into())
178                    .or_default();
179                let is_enabled = preset.tools.entry(tool.name.clone()).or_default();
180                *is_enabled = !*is_enabled;
181                *is_enabled
182            }
183        };
184
185        update_settings_file::<AssistantSettings>(self.fs.clone(), cx, {
186            let profile_id = self.profile_id.clone();
187            let tool = tool.clone();
188            move |settings, _cx| match settings {
189                AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
190                    settings,
191                )) => {
192                    if let Some(profiles) = &mut settings.profiles {
193                        if let Some(profile) = profiles.get_mut(&profile_id) {
194                            match tool.source {
195                                ToolSource::Native => {
196                                    *profile.tools.entry(tool.name).or_default() = is_enabled;
197                                }
198                                ToolSource::ContextServer { id } => {
199                                    let preset = profile
200                                        .context_servers
201                                        .entry(id.clone().into())
202                                        .or_default();
203                                    *preset.tools.entry(tool.name.clone()).or_default() =
204                                        is_enabled;
205                                }
206                            }
207                        }
208                    }
209                }
210                _ => {}
211            }
212        });
213    }
214
215    fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
216        self.tool_picker
217            .update(cx, |_this, cx| cx.emit(DismissEvent))
218            .log_err();
219    }
220
221    fn render_match(
222        &self,
223        ix: usize,
224        selected: bool,
225        _window: &mut Window,
226        _cx: &mut Context<Picker<Self>>,
227    ) -> Option<Self::ListItem> {
228        let tool_match = &self.matches[ix];
229        let tool = &self.tools[tool_match.candidate_id];
230
231        let is_enabled = match &tool.source {
232            ToolSource::Native => self.profile.tools.get(&tool.name).copied().unwrap_or(false),
233            ToolSource::ContextServer { id } => self
234                .profile
235                .context_servers
236                .get(id.as_ref())
237                .and_then(|preset| preset.tools.get(&tool.name))
238                .copied()
239                .unwrap_or(false),
240        };
241
242        Some(
243            ListItem::new(ix)
244                .inset(true)
245                .spacing(ListItemSpacing::Sparse)
246                .toggle_state(selected)
247                .child(
248                    h_flex()
249                        .gap_2()
250                        .child(HighlightedLabel::new(
251                            tool_match.string.clone(),
252                            tool_match.positions.clone(),
253                        ))
254                        .map(|parent| match &tool.source {
255                            ToolSource::Native => parent,
256                            ToolSource::ContextServer { id } => parent
257                                .child(Label::new(id).size(LabelSize::XSmall).color(Color::Muted)),
258                        }),
259                )
260                .end_slot::<Icon>(is_enabled.then(|| {
261                    Icon::new(IconName::Check)
262                        .size(IconSize::Small)
263                        .color(Color::Success)
264                })),
265        )
266    }
267}