find_path_tool.rs

  1use crate::{schema::json_schema_for, ui::ToolCallCardHeader};
  2use action_log::ActionLog;
  3use anyhow::{Result, anyhow};
  4use assistant_tool::{
  5    Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
  6};
  7use editor::Editor;
  8use futures::channel::oneshot::{self, Receiver};
  9use gpui::{
 10    AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
 11};
 12use language;
 13use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 14use project::Project;
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use std::fmt::Write;
 18use std::{cmp, path::PathBuf, sync::Arc};
 19use ui::{Disclosure, Tooltip, prelude::*};
 20use util::{ResultExt, paths::PathMatcher};
 21use workspace::Workspace;
 22
 23#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 24pub struct FindPathToolInput {
 25    /// The glob to match against every path in the project.
 26    ///
 27    /// <example>
 28    /// If the project has the following root directories:
 29    ///
 30    /// - directory1/a/something.txt
 31    /// - directory2/a/things.txt
 32    /// - directory3/a/other.txt
 33    ///
 34    /// You can get back the first two paths by providing a glob of "*thing*.txt"
 35    /// </example>
 36    pub glob: String,
 37
 38    /// Optional starting position for paginated results (0-based).
 39    /// When not provided, starts from the beginning.
 40    #[serde(default)]
 41    pub offset: usize,
 42}
 43
 44#[derive(Debug, Serialize, Deserialize)]
 45struct FindPathToolOutput {
 46    glob: String,
 47    paths: Vec<PathBuf>,
 48}
 49
 50const RESULTS_PER_PAGE: usize = 50;
 51
 52pub struct FindPathTool;
 53
 54impl Tool for FindPathTool {
 55    fn name(&self) -> String {
 56        "find_path".into()
 57    }
 58
 59    fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
 60        false
 61    }
 62
 63    fn may_perform_edits(&self) -> bool {
 64        false
 65    }
 66
 67    fn description(&self) -> String {
 68        include_str!("./find_path_tool/description.md").into()
 69    }
 70
 71    fn icon(&self) -> IconName {
 72        IconName::ToolSearch
 73    }
 74
 75    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 76        json_schema_for::<FindPathToolInput>(format)
 77    }
 78
 79    fn ui_text(&self, input: &serde_json::Value) -> String {
 80        match serde_json::from_value::<FindPathToolInput>(input.clone()) {
 81            Ok(input) => format!("Find paths matching “`{}`”", input.glob),
 82            Err(_) => "Search paths".to_string(),
 83        }
 84    }
 85
 86    fn run(
 87        self: Arc<Self>,
 88        input: serde_json::Value,
 89        _request: Arc<LanguageModelRequest>,
 90        project: Entity<Project>,
 91        _action_log: Entity<ActionLog>,
 92        _model: Arc<dyn LanguageModel>,
 93        _window: Option<AnyWindowHandle>,
 94        cx: &mut App,
 95    ) -> ToolResult {
 96        let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) {
 97            Ok(input) => (input.offset, input.glob),
 98            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 99        };
100
101        let (sender, receiver) = oneshot::channel();
102
103        let card = cx.new(|cx| FindPathToolCard::new(glob.clone(), receiver, cx));
104
105        let search_paths_task = search_paths(&glob, project, cx);
106
107        let task = cx.background_spawn(async move {
108            let matches = search_paths_task.await?;
109            let paginated_matches: &[PathBuf] = &matches[cmp::min(offset, matches.len())
110                ..cmp::min(offset + RESULTS_PER_PAGE, matches.len())];
111
112            sender.send(paginated_matches.to_vec()).log_err();
113
114            if matches.is_empty() {
115                Ok("No matches found".to_string().into())
116            } else {
117                let mut message = format!("Found {} total matches.", matches.len());
118                if matches.len() > RESULTS_PER_PAGE {
119                    write!(
120                        &mut message,
121                        "\nShowing results {}-{} (provide 'offset' parameter for more results):",
122                        offset + 1,
123                        offset + paginated_matches.len()
124                    )
125                    .unwrap();
126                }
127
128                for mat in matches.iter().skip(offset).take(RESULTS_PER_PAGE) {
129                    write!(&mut message, "\n{}", mat.display()).unwrap();
130                }
131
132                let output = FindPathToolOutput {
133                    glob,
134                    paths: matches,
135                };
136
137                Ok(ToolResultOutput {
138                    content: ToolResultContent::Text(message),
139                    output: Some(serde_json::to_value(output)?),
140                })
141            }
142        });
143
144        ToolResult {
145            output: task,
146            card: Some(card.into()),
147        }
148    }
149
150    fn deserialize_card(
151        self: Arc<Self>,
152        output: serde_json::Value,
153        _project: Entity<Project>,
154        _window: &mut Window,
155        cx: &mut App,
156    ) -> Option<assistant_tool::AnyToolCard> {
157        let output = serde_json::from_value::<FindPathToolOutput>(output).ok()?;
158        let card = cx.new(|_| FindPathToolCard::from_output(output));
159        Some(card.into())
160    }
161}
162
163fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
164    let path_matcher = match PathMatcher::new(
165        [
166            // Sometimes models try to search for "". In this case, return all paths in the project.
167            if glob.is_empty() { "*" } else { glob },
168        ],
169        project.read(cx).path_style(cx),
170    ) {
171        Ok(matcher) => matcher,
172        Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
173    };
174    let snapshots: Vec<_> = project
175        .read(cx)
176        .worktrees(cx)
177        .map(|worktree| worktree.read(cx).snapshot())
178        .collect();
179
180    cx.background_spawn(async move {
181        Ok(snapshots
182            .iter()
183            .flat_map(|snapshot| {
184                snapshot
185                    .entries(false, 0)
186                    .map(move |entry| {
187                        snapshot
188                            .root_name()
189                            .join(&entry.path)
190                            .as_std_path()
191                            .to_path_buf()
192                    })
193                    .filter(|path| path_matcher.is_match(&path))
194            })
195            .collect())
196    })
197}
198
199struct FindPathToolCard {
200    paths: Vec<PathBuf>,
201    expanded: bool,
202    glob: String,
203    _receiver_task: Option<Task<Result<()>>>,
204}
205
206impl FindPathToolCard {
207    fn new(glob: String, receiver: Receiver<Vec<PathBuf>>, cx: &mut Context<Self>) -> Self {
208        let _receiver_task = cx.spawn(async move |this, cx| {
209            let paths = receiver.await?;
210
211            this.update(cx, |this, _cx| {
212                this.paths = paths;
213            })
214            .log_err();
215
216            Ok(())
217        });
218
219        Self {
220            paths: Vec::new(),
221            expanded: false,
222            glob,
223            _receiver_task: Some(_receiver_task),
224        }
225    }
226
227    fn from_output(output: FindPathToolOutput) -> Self {
228        Self {
229            glob: output.glob,
230            paths: output.paths,
231            expanded: false,
232            _receiver_task: None,
233        }
234    }
235}
236
237impl ToolCard for FindPathToolCard {
238    fn render(
239        &mut self,
240        _status: &ToolUseStatus,
241        _window: &mut Window,
242        workspace: WeakEntity<Workspace>,
243        cx: &mut Context<Self>,
244    ) -> impl IntoElement {
245        let matches_label: SharedString = if self.paths.is_empty() {
246            "No matches".into()
247        } else if self.paths.len() == 1 {
248            "1 match".into()
249        } else {
250            format!("{} matches", self.paths.len()).into()
251        };
252
253        let content = if !self.paths.is_empty() && self.expanded {
254            Some(
255                v_flex()
256                    .relative()
257                    .ml_1p5()
258                    .px_1p5()
259                    .gap_0p5()
260                    .border_l_1()
261                    .border_color(cx.theme().colors().border_variant)
262                    .children(self.paths.iter().enumerate().map(|(index, path)| {
263                        let path_clone = path.clone();
264                        let workspace_clone = workspace.clone();
265                        let button_label = path.to_string_lossy().into_owned();
266
267                        Button::new(("path", index), button_label)
268                            .icon(IconName::ArrowUpRight)
269                            .icon_size(IconSize::Small)
270                            .icon_position(IconPosition::End)
271                            .label_size(LabelSize::Small)
272                            .color(Color::Muted)
273                            .tooltip(Tooltip::text("Jump to File"))
274                            .on_click(move |_, window, cx| {
275                                workspace_clone
276                                    .update(cx, |workspace, cx| {
277                                        let path = PathBuf::from(&path_clone);
278                                        let Some(project_path) = workspace
279                                            .project()
280                                            .read(cx)
281                                            .find_project_path(&path, cx)
282                                        else {
283                                            return;
284                                        };
285                                        let open_task = workspace.open_path(
286                                            project_path,
287                                            None,
288                                            true,
289                                            window,
290                                            cx,
291                                        );
292                                        window
293                                            .spawn(cx, async move |cx| {
294                                                let item = open_task.await?;
295                                                if let Some(active_editor) =
296                                                    item.downcast::<Editor>()
297                                                {
298                                                    active_editor
299                                                        .update_in(cx, |editor, window, cx| {
300                                                            editor.go_to_singleton_buffer_point(
301                                                                language::Point::new(0, 0),
302                                                                window,
303                                                                cx,
304                                                            );
305                                                        })
306                                                        .log_err();
307                                                }
308                                                anyhow::Ok(())
309                                            })
310                                            .detach_and_log_err(cx);
311                                    })
312                                    .ok();
313                            })
314                    }))
315                    .into_any(),
316            )
317        } else {
318            None
319        };
320
321        v_flex()
322            .mb_2()
323            .gap_1()
324            .child(
325                ToolCallCardHeader::new(IconName::ToolSearch, matches_label)
326                    .with_code_path(&self.glob)
327                    .disclosure_slot(
328                        Disclosure::new("path-search-disclosure", self.expanded)
329                            .opened_icon(IconName::ChevronUp)
330                            .closed_icon(IconName::ChevronDown)
331                            .disabled(self.paths.is_empty())
332                            .on_click(cx.listener(move |this, _, _, _cx| {
333                                this.expanded = !this.expanded;
334                            })),
335                    ),
336            )
337            .children(content)
338    }
339}
340
341impl Component for FindPathTool {
342    fn scope() -> ComponentScope {
343        ComponentScope::Agent
344    }
345
346    fn sort_name() -> &'static str {
347        "FindPathTool"
348    }
349
350    fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
351        let successful_card = cx.new(|_| FindPathToolCard {
352            paths: vec![
353                PathBuf::from("src/main.rs"),
354                PathBuf::from("src/lib.rs"),
355                PathBuf::from("tests/test.rs"),
356            ],
357            expanded: true,
358            glob: "*.rs".to_string(),
359            _receiver_task: None,
360        });
361
362        let empty_card = cx.new(|_| FindPathToolCard {
363            paths: Vec::new(),
364            expanded: false,
365            glob: "*.nonexistent".to_string(),
366            _receiver_task: None,
367        });
368
369        Some(
370            v_flex()
371                .gap_6()
372                .children(vec![example_group(vec![
373                    single_example(
374                        "With Paths",
375                        div()
376                            .size_full()
377                            .child(successful_card.update(cx, |tool, cx| {
378                                tool.render(
379                                    &ToolUseStatus::Finished("".into()),
380                                    window,
381                                    WeakEntity::new_invalid(),
382                                    cx,
383                                )
384                                .into_any_element()
385                            }))
386                            .into_any_element(),
387                    ),
388                    single_example(
389                        "No Paths",
390                        div()
391                            .size_full()
392                            .child(empty_card.update(cx, |tool, cx| {
393                                tool.render(
394                                    &ToolUseStatus::Finished("".into()),
395                                    window,
396                                    WeakEntity::new_invalid(),
397                                    cx,
398                                )
399                                .into_any_element()
400                            }))
401                            .into_any_element(),
402                    ),
403                ])])
404                .into_any_element(),
405        )
406    }
407}
408
409#[cfg(test)]
410mod test {
411    use super::*;
412    use gpui::TestAppContext;
413    use project::{FakeFs, Project};
414    use settings::SettingsStore;
415    use util::path;
416
417    #[gpui::test]
418    async fn test_find_path_tool(cx: &mut TestAppContext) {
419        init_test(cx);
420
421        let fs = FakeFs::new(cx.executor());
422        fs.insert_tree(
423            "/root",
424            serde_json::json!({
425                "apple": {
426                    "banana": {
427                        "carrot": "1",
428                    },
429                    "bandana": {
430                        "carbonara": "2",
431                    },
432                    "endive": "3"
433                }
434            }),
435        )
436        .await;
437        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
438
439        let matches = cx
440            .update(|cx| search_paths("root/**/car*", project.clone(), cx))
441            .await
442            .unwrap();
443        assert_eq!(
444            matches,
445            &[
446                PathBuf::from(path!("root/apple/banana/carrot")),
447                PathBuf::from(path!("root/apple/bandana/carbonara"))
448            ]
449        );
450
451        let matches = cx
452            .update(|cx| search_paths("**/car*", project.clone(), cx))
453            .await
454            .unwrap();
455        assert_eq!(
456            matches,
457            &[
458                PathBuf::from(path!("root/apple/banana/carrot")),
459                PathBuf::from(path!("root/apple/bandana/carbonara"))
460            ]
461        );
462    }
463
464    fn init_test(cx: &mut TestAppContext) {
465        cx.update(|cx| {
466            let settings_store = SettingsStore::test(cx);
467            cx.set_global(settings_store);
468            language::init(cx);
469            Project::init_settings(cx);
470        });
471    }
472}