find_path_tool.rs

  1use agent_client_protocol as acp;
  2use anyhow::{anyhow, Result};
  3use gpui::{App, AppContext, Entity, SharedString, Task};
  4use project::Project;
  5use schemars::JsonSchema;
  6use serde::{Deserialize, Serialize};
  7use std::fmt::Write;
  8use std::{cmp, path::PathBuf, sync::Arc};
  9use util::paths::PathMatcher;
 10
 11use crate::{AgentTool, ToolCallEventStream};
 12
 13/// Fast file path pattern matching tool that works with any codebase size
 14///
 15/// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
 16/// - Returns matching file paths sorted alphabetically
 17/// - Prefer the `grep` tool to this tool when searching for symbols unless you have specific information about paths.
 18/// - Use this tool when you need to find files by name patterns
 19/// - Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages.
 20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 21pub struct FindPathToolInput {
 22    /// The glob to match against every path in the project.
 23    ///
 24    /// <example>
 25    /// If the project has the following root directories:
 26    ///
 27    /// - directory1/a/something.txt
 28    /// - directory2/a/things.txt
 29    /// - directory3/a/other.txt
 30    ///
 31    /// You can get back the first two paths by providing a glob of "*thing*.txt"
 32    /// </example>
 33    pub glob: String,
 34
 35    /// Optional starting position for paginated results (0-based).
 36    /// When not provided, starts from the beginning.
 37    #[serde(default)]
 38    pub offset: usize,
 39}
 40
 41#[derive(Debug, Serialize, Deserialize)]
 42struct FindPathToolOutput {
 43    paths: Vec<PathBuf>,
 44}
 45
 46const RESULTS_PER_PAGE: usize = 50;
 47
 48pub struct FindPathTool {
 49    project: Entity<Project>,
 50}
 51
 52impl FindPathTool {
 53    pub fn new(project: Entity<Project>) -> Self {
 54        Self { project }
 55    }
 56}
 57
 58impl AgentTool for FindPathTool {
 59    type Input = FindPathToolInput;
 60
 61    fn name(&self) -> SharedString {
 62        "find_path".into()
 63    }
 64
 65    fn kind(&self) -> acp::ToolKind {
 66        acp::ToolKind::Search
 67    }
 68
 69    fn initial_title(&self, input: Self::Input) -> SharedString {
 70        format!("Find paths matching “`{}`”", input.glob).into()
 71    }
 72
 73    fn run(
 74        self: Arc<Self>,
 75        input: Self::Input,
 76        event_stream: ToolCallEventStream,
 77        cx: &mut App,
 78    ) -> Task<Result<String>> {
 79        let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
 80
 81        cx.background_spawn(async move {
 82            let matches = search_paths_task.await?;
 83            let paginated_matches: &[PathBuf] = &matches[cmp::min(input.offset, matches.len())
 84                ..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())];
 85
 86            event_stream.send_update(acp::ToolCallUpdateFields {
 87                title: Some(if paginated_matches.len() == 0 {
 88                    "No matches".into()
 89                } else if paginated_matches.len() == 1 {
 90                    "1 match".into()
 91                } else {
 92                    format!("{} matches", paginated_matches.len())
 93                }),
 94                content: Some(
 95                    paginated_matches
 96                        .iter()
 97                        .map(|path| acp::ToolCallContent::Content {
 98                            content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
 99                                uri: format!("file://{}", path.display()),
100                                name: path.to_string_lossy().into(),
101                                annotations: None,
102                                description: None,
103                                mime_type: None,
104                                size: None,
105                                title: None,
106                            }),
107                        })
108                        .collect(),
109                ),
110                raw_output: Some(serde_json::json!({
111                    "paths": &matches,
112                })),
113                ..Default::default()
114            });
115
116            if matches.is_empty() {
117                Ok("No matches found".into())
118            } else {
119                let mut message = format!("Found {} total matches.", matches.len());
120                if matches.len() > RESULTS_PER_PAGE {
121                    write!(
122                        &mut message,
123                        "\nShowing results {}-{} (provide 'offset' parameter for more results):",
124                        input.offset + 1,
125                        input.offset + paginated_matches.len()
126                    )
127                    .unwrap();
128                }
129
130                for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) {
131                    write!(&mut message, "\n{}", mat.display()).unwrap();
132                }
133
134                Ok(message)
135            }
136        })
137    }
138}
139
140fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
141    let path_matcher = match PathMatcher::new([
142        // Sometimes models try to search for "". In this case, return all paths in the project.
143        if glob.is_empty() { "*" } else { glob },
144    ]) {
145        Ok(matcher) => matcher,
146        Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
147    };
148    let snapshots: Vec<_> = project
149        .read(cx)
150        .worktrees(cx)
151        .map(|worktree| worktree.read(cx).snapshot())
152        .collect();
153
154    cx.background_spawn(async move {
155        Ok(snapshots
156            .iter()
157            .flat_map(|snapshot| {
158                let root_name = PathBuf::from(snapshot.root_name());
159                snapshot
160                    .entries(false, 0)
161                    .map(move |entry| root_name.join(&entry.path))
162                    .filter(|path| path_matcher.is_match(&path))
163            })
164            .collect())
165    })
166}
167
168#[cfg(test)]
169mod test {
170    use super::*;
171    use gpui::TestAppContext;
172    use project::{FakeFs, Project};
173    use settings::SettingsStore;
174    use util::path;
175
176    #[gpui::test]
177    async fn test_find_path_tool(cx: &mut TestAppContext) {
178        init_test(cx);
179
180        let fs = FakeFs::new(cx.executor());
181        fs.insert_tree(
182            "/root",
183            serde_json::json!({
184                "apple": {
185                    "banana": {
186                        "carrot": "1",
187                    },
188                    "bandana": {
189                        "carbonara": "2",
190                    },
191                    "endive": "3"
192                }
193            }),
194        )
195        .await;
196        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
197
198        let matches = cx
199            .update(|cx| search_paths("root/**/car*", project.clone(), cx))
200            .await
201            .unwrap();
202        assert_eq!(
203            matches,
204            &[
205                PathBuf::from("root/apple/banana/carrot"),
206                PathBuf::from("root/apple/bandana/carbonara")
207            ]
208        );
209
210        let matches = cx
211            .update(|cx| search_paths("**/car*", project.clone(), cx))
212            .await
213            .unwrap();
214        assert_eq!(
215            matches,
216            &[
217                PathBuf::from("root/apple/banana/carrot"),
218                PathBuf::from("root/apple/bandana/carbonara")
219            ]
220        );
221    }
222
223    fn init_test(cx: &mut TestAppContext) {
224        cx.update(|cx| {
225            let settings_store = SettingsStore::test(cx);
226            cx.set_global(settings_store);
227            language::init(cx);
228            Project::init_settings(cx);
229        });
230    }
231}