path_search_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
  5use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::{cmp, fmt::Write as _, path::PathBuf, sync::Arc};
 10use ui::IconName;
 11use util::paths::PathMatcher;
 12use worktree::Snapshot;
 13
 14#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 15pub struct PathSearchToolInput {
 16    /// The glob to match against every path in the project.
 17    ///
 18    /// <example>
 19    /// If the project has the following root directories:
 20    ///
 21    /// - directory1/a/something.txt
 22    /// - directory2/a/things.txt
 23    /// - directory3/a/other.txt
 24    ///
 25    /// You can get back the first two paths by providing a glob of "*thing*.txt"
 26    /// </example>
 27    pub glob: String,
 28
 29    /// Optional starting position for paginated results (0-based).
 30    /// When not provided, starts from the beginning.
 31    #[serde(default)]
 32    pub offset: u32,
 33}
 34
 35const RESULTS_PER_PAGE: usize = 50;
 36
 37pub struct PathSearchTool;
 38
 39impl Tool for PathSearchTool {
 40    fn name(&self) -> String {
 41        "path_search".into()
 42    }
 43
 44    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 45        false
 46    }
 47
 48    fn description(&self) -> String {
 49        include_str!("./path_search_tool/description.md").into()
 50    }
 51
 52    fn icon(&self) -> IconName {
 53        IconName::SearchCode
 54    }
 55
 56    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 57        json_schema_for::<PathSearchToolInput>(format)
 58    }
 59
 60    fn ui_text(&self, input: &serde_json::Value) -> String {
 61        match serde_json::from_value::<PathSearchToolInput>(input.clone()) {
 62            Ok(input) => format!("Find paths matching “`{}`”", input.glob),
 63            Err(_) => "Search paths".to_string(),
 64        }
 65    }
 66
 67    fn run(
 68        self: Arc<Self>,
 69        input: serde_json::Value,
 70        _messages: &[LanguageModelRequestMessage],
 71        project: Entity<Project>,
 72        _action_log: Entity<ActionLog>,
 73        _window: Option<AnyWindowHandle>,
 74        cx: &mut App,
 75    ) -> ToolResult {
 76        let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
 77            Ok(input) => (input.offset, input.glob),
 78            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 79        };
 80        let offset = offset as usize;
 81        let task = search_paths(&glob, project, cx);
 82        cx.background_spawn(async move {
 83            let matches = task.await?;
 84            let paginated_matches = &matches[cmp::min(offset, matches.len())
 85                ..cmp::min(offset + RESULTS_PER_PAGE, matches.len())];
 86
 87            if matches.is_empty() {
 88                Ok("No matches found".to_string())
 89            } else {
 90                let mut message = format!("Found {} total matches.", matches.len());
 91                if matches.len() > RESULTS_PER_PAGE {
 92                    write!(
 93                        &mut message,
 94                        "\nShowing results {}-{} (provide 'offset' parameter for more results):",
 95                        offset + 1,
 96                        offset + paginated_matches.len()
 97                    )
 98                    .unwrap();
 99                }
100                for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
101                    write!(&mut message, "\n{}", mat.display()).unwrap();
102                }
103                Ok(message)
104            }
105        })
106        .into()
107    }
108}
109
110fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
111    let path_matcher = match PathMatcher::new([
112        // Sometimes models try to search for "". In this case, return all paths in the project.
113        if glob.is_empty() { "*" } else { glob },
114    ]) {
115        Ok(matcher) => matcher,
116        Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
117    };
118    let snapshots: Vec<Snapshot> = project
119        .read(cx)
120        .worktrees(cx)
121        .map(|worktree| worktree.read(cx).snapshot())
122        .collect();
123
124    cx.background_spawn(async move {
125        Ok(snapshots
126            .iter()
127            .flat_map(|snapshot| {
128                let root_name = PathBuf::from(snapshot.root_name());
129                snapshot
130                    .entries(false, 0)
131                    .map(move |entry| root_name.join(&entry.path))
132                    .filter(|path| path_matcher.is_match(&path))
133            })
134            .collect())
135    })
136}
137
138#[cfg(test)]
139mod test {
140    use super::*;
141    use gpui::TestAppContext;
142    use project::{FakeFs, Project};
143    use settings::SettingsStore;
144    use util::path;
145
146    #[gpui::test]
147    async fn test_path_search_tool(cx: &mut TestAppContext) {
148        init_test(cx);
149
150        let fs = FakeFs::new(cx.executor());
151        fs.insert_tree(
152            "/root",
153            serde_json::json!({
154                "apple": {
155                    "banana": {
156                        "carrot": "1",
157                    },
158                    "bandana": {
159                        "carbonara": "2",
160                    },
161                    "endive": "3"
162                }
163            }),
164        )
165        .await;
166        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
167
168        let matches = cx
169            .update(|cx| search_paths("root/**/car*", project.clone(), cx))
170            .await
171            .unwrap();
172        assert_eq!(
173            matches,
174            &[
175                PathBuf::from("root/apple/banana/carrot"),
176                PathBuf::from("root/apple/bandana/carbonara")
177            ]
178        );
179
180        let matches = cx
181            .update(|cx| search_paths("**/car*", project.clone(), cx))
182            .await
183            .unwrap();
184        assert_eq!(
185            matches,
186            &[
187                PathBuf::from("root/apple/banana/carrot"),
188                PathBuf::from("root/apple/bandana/carbonara")
189            ]
190        );
191    }
192
193    fn init_test(cx: &mut TestAppContext) {
194        cx.update(|cx| {
195            let settings_store = SettingsStore::test(cx);
196            cx.set_global(settings_store);
197            language::init(cx);
198            Project::init_settings(cx);
199        });
200    }
201}