path_search_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{App, AppContext, Entity, Task};
  5use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::{cmp, fmt::Write as _, path::PathBuf, sync::Arc};
 10use ui::IconName;
 11use util::paths::PathMatcher;
 12use worktree::Snapshot;
 13
 14#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 15pub struct PathSearchToolInput {
 16    /// The glob to match against every path in the project.
 17    ///
 18    /// <example>
 19    /// If the project has the following root directories:
 20    ///
 21    /// - directory1/a/something.txt
 22    /// - directory2/a/things.txt
 23    /// - directory3/a/other.txt
 24    ///
 25    /// You can get back the first two paths by providing a glob of "*thing*.txt"
 26    /// </example>
 27    pub glob: String,
 28
 29    /// Optional starting position for paginated results (0-based).
 30    /// When not provided, starts from the beginning.
 31    #[serde(default)]
 32    pub offset: u32,
 33}
 34
 35const RESULTS_PER_PAGE: usize = 50;
 36
 37pub struct PathSearchTool;
 38
 39impl Tool for PathSearchTool {
 40    fn name(&self) -> String {
 41        "path_search".into()
 42    }
 43
 44    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 45        false
 46    }
 47
 48    fn description(&self) -> String {
 49        include_str!("./path_search_tool/description.md").into()
 50    }
 51
 52    fn icon(&self) -> IconName {
 53        IconName::SearchCode
 54    }
 55
 56    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 57        json_schema_for::<PathSearchToolInput>(format)
 58    }
 59
 60    fn ui_text(&self, input: &serde_json::Value) -> String {
 61        match serde_json::from_value::<PathSearchToolInput>(input.clone()) {
 62            Ok(input) => format!("Find paths matching “`{}`”", input.glob),
 63            Err(_) => "Search paths".to_string(),
 64        }
 65    }
 66
 67    fn run(
 68        self: Arc<Self>,
 69        input: serde_json::Value,
 70        _messages: &[LanguageModelRequestMessage],
 71        project: Entity<Project>,
 72        _action_log: Entity<ActionLog>,
 73        cx: &mut App,
 74    ) -> ToolResult {
 75        let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
 76            Ok(input) => (input.offset, input.glob),
 77            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 78        };
 79        let offset = offset as usize;
 80        let task = search_paths(&glob, project, cx);
 81        cx.background_spawn(async move {
 82            let matches = task.await?;
 83            let paginated_matches = &matches[cmp::min(offset, matches.len())
 84                ..cmp::min(offset + RESULTS_PER_PAGE, matches.len())];
 85
 86            if matches.is_empty() {
 87                Ok("No matches found".to_string())
 88            } else {
 89                let mut message = format!("Found {} total matches.", matches.len());
 90                if matches.len() > RESULTS_PER_PAGE {
 91                    write!(
 92                        &mut message,
 93                        "\nShowing results {}-{} (provide 'offset' parameter for more results):",
 94                        offset + 1,
 95                        offset + paginated_matches.len()
 96                    )
 97                    .unwrap();
 98                }
 99                for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
100                    write!(&mut message, "\n{}", mat.display()).unwrap();
101                }
102                Ok(message)
103            }
104        })
105        .into()
106    }
107}
108
109fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
110    let path_matcher = match PathMatcher::new([
111        // Sometimes models try to search for "". In this case, return all paths in the project.
112        if glob.is_empty() { "*" } else { glob },
113    ]) {
114        Ok(matcher) => matcher,
115        Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
116    };
117    let snapshots: Vec<Snapshot> = project
118        .read(cx)
119        .worktrees(cx)
120        .map(|worktree| worktree.read(cx).snapshot())
121        .collect();
122
123    cx.background_spawn(async move {
124        Ok(snapshots
125            .iter()
126            .flat_map(|snapshot| {
127                let root_name = PathBuf::from(snapshot.root_name());
128                snapshot
129                    .entries(false, 0)
130                    .map(move |entry| root_name.join(&entry.path))
131                    .filter(|path| path_matcher.is_match(&path))
132            })
133            .collect())
134    })
135}
136
137#[cfg(test)]
138mod test {
139    use super::*;
140    use gpui::TestAppContext;
141    use project::{FakeFs, Project};
142    use settings::SettingsStore;
143    use util::path;
144
145    #[gpui::test]
146    async fn test_path_search_tool(cx: &mut TestAppContext) {
147        init_test(cx);
148
149        let fs = FakeFs::new(cx.executor());
150        fs.insert_tree(
151            "/root",
152            serde_json::json!({
153                "apple": {
154                    "banana": {
155                        "carrot": "1",
156                    },
157                    "bandana": {
158                        "carbonara": "2",
159                    },
160                    "endive": "3"
161                }
162            }),
163        )
164        .await;
165        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
166
167        let matches = cx
168            .update(|cx| search_paths("root/**/car*", project.clone(), cx))
169            .await
170            .unwrap();
171        assert_eq!(
172            matches,
173            &[
174                PathBuf::from("root/apple/banana/carrot"),
175                PathBuf::from("root/apple/bandana/carbonara")
176            ]
177        );
178
179        let matches = cx
180            .update(|cx| search_paths("**/car*", project.clone(), cx))
181            .await
182            .unwrap();
183        assert_eq!(
184            matches,
185            &[
186                PathBuf::from("root/apple/banana/carrot"),
187                PathBuf::from("root/apple/bandana/carbonara")
188            ]
189        );
190    }
191
192    fn init_test(cx: &mut TestAppContext) {
193        cx.update(|cx| {
194            let settings_store = SettingsStore::test(cx);
195            cx.set_global(settings_store);
196            language::init(cx);
197            Project::init_settings(cx);
198        });
199    }
200}