tool_picker.rs

  1use std::sync::Arc;
  2
  3use assistant_settings::{
  4    AgentProfile, AgentProfileContent, AgentProfileId, AssistantSettings, AssistantSettingsContent,
  5    ContextServerPresetContent,
  6};
  7use assistant_tool::{ToolSource, ToolWorkingSet};
  8use fs::Fs;
  9use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
 10use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
 11use picker::{Picker, PickerDelegate};
 12use settings::{Settings as _, update_settings_file};
 13use ui::{HighlightedLabel, ListItem, ListItemSpacing, prelude::*};
 14use util::ResultExt as _;
 15
 16use crate::ThreadStore;
 17
 18pub struct ToolPicker {
 19    picker: Entity<Picker<ToolPickerDelegate>>,
 20}
 21
 22impl ToolPicker {
 23    pub fn new(delegate: ToolPickerDelegate, window: &mut Window, cx: &mut Context<Self>) -> Self {
 24        let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
 25        Self { picker }
 26    }
 27}
 28
 29impl EventEmitter<DismissEvent> for ToolPicker {}
 30
 31impl Focusable for ToolPicker {
 32    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 33        self.picker.focus_handle(cx)
 34    }
 35}
 36
 37impl Render for ToolPicker {
 38    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
 39        v_flex().w(rems(34.)).child(self.picker.clone())
 40    }
 41}
 42
 43#[derive(Debug, Clone)]
 44pub struct ToolEntry {
 45    pub name: Arc<str>,
 46    pub source: ToolSource,
 47}
 48
 49pub struct ToolPickerDelegate {
 50    tool_picker: WeakEntity<ToolPicker>,
 51    thread_store: WeakEntity<ThreadStore>,
 52    fs: Arc<dyn Fs>,
 53    tools: Vec<ToolEntry>,
 54    profile_id: AgentProfileId,
 55    profile: AgentProfile,
 56    matches: Vec<StringMatch>,
 57    selected_index: usize,
 58}
 59
 60impl ToolPickerDelegate {
 61    pub fn new(
 62        fs: Arc<dyn Fs>,
 63        tool_set: Entity<ToolWorkingSet>,
 64        thread_store: WeakEntity<ThreadStore>,
 65        profile_id: AgentProfileId,
 66        profile: AgentProfile,
 67        cx: &mut Context<ToolPicker>,
 68    ) -> Self {
 69        let mut tool_entries = Vec::new();
 70
 71        for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
 72            tool_entries.extend(tools.into_iter().map(|tool| ToolEntry {
 73                name: tool.name().into(),
 74                source: source.clone(),
 75            }));
 76        }
 77
 78        Self {
 79            tool_picker: cx.entity().downgrade(),
 80            thread_store,
 81            fs,
 82            tools: tool_entries,
 83            profile_id,
 84            profile,
 85            matches: Vec::new(),
 86            selected_index: 0,
 87        }
 88    }
 89}
 90
 91impl PickerDelegate for ToolPickerDelegate {
 92    type ListItem = ListItem;
 93
 94    fn match_count(&self) -> usize {
 95        self.matches.len()
 96    }
 97
 98    fn selected_index(&self) -> usize {
 99        self.selected_index
100    }
101
102    fn set_selected_index(
103        &mut self,
104        ix: usize,
105        _window: &mut Window,
106        _cx: &mut Context<Picker<Self>>,
107    ) {
108        self.selected_index = ix;
109    }
110
111    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
112        "Search tools…".into()
113    }
114
115    fn update_matches(
116        &mut self,
117        query: String,
118        window: &mut Window,
119        cx: &mut Context<Picker<Self>>,
120    ) -> Task<()> {
121        let background = cx.background_executor().clone();
122        let candidates = self
123            .tools
124            .iter()
125            .enumerate()
126            .map(|(id, profile)| StringMatchCandidate::new(id, profile.name.as_ref()))
127            .collect::<Vec<_>>();
128
129        cx.spawn_in(window, async move |this, cx| {
130            let matches = if query.is_empty() {
131                candidates
132                    .into_iter()
133                    .enumerate()
134                    .map(|(index, candidate)| StringMatch {
135                        candidate_id: index,
136                        string: candidate.string,
137                        positions: Vec::new(),
138                        score: 0.,
139                    })
140                    .collect()
141            } else {
142                match_strings(
143                    &candidates,
144                    &query,
145                    false,
146                    100,
147                    &Default::default(),
148                    background,
149                )
150                .await
151            };
152
153            this.update(cx, |this, _cx| {
154                this.delegate.matches = matches;
155                this.delegate.selected_index = this
156                    .delegate
157                    .selected_index
158                    .min(this.delegate.matches.len().saturating_sub(1));
159            })
160            .log_err();
161        })
162    }
163
164    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
165        if self.matches.is_empty() {
166            self.dismissed(window, cx);
167            return;
168        }
169
170        let candidate_id = self.matches[self.selected_index].candidate_id;
171        let tool = &self.tools[candidate_id];
172
173        let is_enabled = match &tool.source {
174            ToolSource::Native => {
175                let is_enabled = self.profile.tools.entry(tool.name.clone()).or_default();
176                *is_enabled = !*is_enabled;
177                *is_enabled
178            }
179            ToolSource::ContextServer { id } => {
180                let preset = self
181                    .profile
182                    .context_servers
183                    .entry(id.clone().into())
184                    .or_default();
185                let is_enabled = preset.tools.entry(tool.name.clone()).or_default();
186                *is_enabled = !*is_enabled;
187                *is_enabled
188            }
189        };
190
191        let active_profile_id = &AssistantSettings::get_global(cx).default_profile;
192        if active_profile_id == &self.profile_id {
193            self.thread_store
194                .update(cx, |this, cx| {
195                    this.load_profile(self.profile.clone(), cx);
196                })
197                .log_err();
198        }
199
200        update_settings_file::<AssistantSettings>(self.fs.clone(), cx, {
201            let profile_id = self.profile_id.clone();
202            let default_profile = self.profile.clone();
203            let tool = tool.clone();
204            move |settings: &mut AssistantSettingsContent, _cx| {
205                settings
206                    .v2_setting(|v2_settings| {
207                        let profiles = v2_settings.profiles.get_or_insert_default();
208                        let profile =
209                            profiles
210                                .entry(profile_id)
211                                .or_insert_with(|| AgentProfileContent {
212                                    name: default_profile.name.into(),
213                                    tools: default_profile.tools,
214                                    enable_all_context_servers: Some(
215                                        default_profile.enable_all_context_servers,
216                                    ),
217                                    context_servers: default_profile
218                                        .context_servers
219                                        .into_iter()
220                                        .map(|(server_id, preset)| {
221                                            (
222                                                server_id,
223                                                ContextServerPresetContent {
224                                                    tools: preset.tools,
225                                                },
226                                            )
227                                        })
228                                        .collect(),
229                                });
230
231                        match tool.source {
232                            ToolSource::Native => {
233                                *profile.tools.entry(tool.name).or_default() = is_enabled;
234                            }
235                            ToolSource::ContextServer { id } => {
236                                let preset = profile
237                                    .context_servers
238                                    .entry(id.clone().into())
239                                    .or_default();
240                                *preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
241                            }
242                        }
243
244                        Ok(())
245                    })
246                    .ok();
247            }
248        });
249    }
250
251    fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
252        self.tool_picker
253            .update(cx, |_this, cx| cx.emit(DismissEvent))
254            .log_err();
255    }
256
257    fn render_match(
258        &self,
259        ix: usize,
260        selected: bool,
261        _window: &mut Window,
262        _cx: &mut Context<Picker<Self>>,
263    ) -> Option<Self::ListItem> {
264        let tool_match = &self.matches[ix];
265        let tool = &self.tools[tool_match.candidate_id];
266
267        let is_enabled = match &tool.source {
268            ToolSource::Native => self.profile.tools.get(&tool.name).copied().unwrap_or(false),
269            ToolSource::ContextServer { id } => self
270                .profile
271                .context_servers
272                .get(id.as_ref())
273                .and_then(|preset| preset.tools.get(&tool.name))
274                .copied()
275                .unwrap_or(self.profile.enable_all_context_servers),
276        };
277
278        Some(
279            ListItem::new(ix)
280                .inset(true)
281                .spacing(ListItemSpacing::Sparse)
282                .toggle_state(selected)
283                .child(
284                    h_flex()
285                        .gap_2()
286                        .child(HighlightedLabel::new(
287                            tool_match.string.clone(),
288                            tool_match.positions.clone(),
289                        ))
290                        .map(|parent| match &tool.source {
291                            ToolSource::Native => parent,
292                            ToolSource::ContextServer { id } => parent
293                                .child(Label::new(id).size(LabelSize::XSmall).color(Color::Muted)),
294                        }),
295                )
296                .end_slot::<Icon>(is_enabled.then(|| {
297                    Icon::new(IconName::Check)
298                        .size(IconSize::Small)
299                        .color(Color::Success)
300                })),
301        )
302    }
303}