path_search_tool.rs

  1use anyhow::{anyhow, Result};
  2use assistant_tool::{ActionLog, Tool};
  3use gpui::{App, AppContext, Entity, Task};
  4use language_model::LanguageModelRequestMessage;
  5use project::Project;
  6use schemars::JsonSchema;
  7use serde::{Deserialize, Serialize};
  8use std::{path::PathBuf, sync::Arc};
  9use util::paths::PathMatcher;
 10use worktree::Snapshot;
 11
 12#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 13pub struct PathSearchToolInput {
 14    /// The glob to search all project paths for.
 15    ///
 16    /// <example>
 17    /// If the project has the following root directories:
 18    ///
 19    /// - directory1/a/something.txt
 20    /// - directory2/a/things.txt
 21    /// - directory3/a/other.txt
 22    ///
 23    /// You can get back the first two paths by providing a glob of "*thing*.txt"
 24    /// </example>
 25    pub glob: String,
 26
 27    /// Optional starting position for paginated results (0-based).
 28    /// When not provided, starts from the beginning.
 29    #[serde(default)]
 30    pub offset: Option<usize>,
 31}
 32
 33const RESULTS_PER_PAGE: usize = 50;
 34
 35pub struct PathSearchTool;
 36
 37impl Tool for PathSearchTool {
 38    fn name(&self) -> String {
 39        "path-search".into()
 40    }
 41
 42    fn description(&self) -> String {
 43        include_str!("./path_search_tool/description.md").into()
 44    }
 45
 46    fn input_schema(&self) -> serde_json::Value {
 47        let schema = schemars::schema_for!(PathSearchToolInput);
 48        serde_json::to_value(&schema).unwrap()
 49    }
 50
 51    fn run(
 52        self: Arc<Self>,
 53        input: serde_json::Value,
 54        _messages: &[LanguageModelRequestMessage],
 55        project: Entity<Project>,
 56        _action_log: Entity<ActionLog>,
 57        cx: &mut App,
 58    ) -> Task<Result<String>> {
 59        let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
 60            Ok(input) => (input.offset.unwrap_or(0), input.glob),
 61            Err(err) => return Task::ready(Err(anyhow!(err))),
 62        };
 63        let path_matcher = match PathMatcher::new(&[glob.clone()]) {
 64            Ok(matcher) => matcher,
 65            Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {}", err))),
 66        };
 67        let snapshots: Vec<Snapshot> = project
 68            .read(cx)
 69            .worktrees(cx)
 70            .map(|worktree| worktree.read(cx).snapshot())
 71            .collect();
 72
 73        cx.background_spawn(async move {
 74            let mut matches = Vec::new();
 75
 76            for worktree in snapshots {
 77                let root_name = worktree.root_name();
 78
 79                // Don't consider ignored entries.
 80                for entry in worktree.entries(false, 0) {
 81                    if path_matcher.is_match(&entry.path) {
 82                        matches.push(
 83                            PathBuf::from(root_name)
 84                                .join(&entry.path)
 85                                .to_string_lossy()
 86                                .to_string(),
 87                        );
 88                    }
 89                }
 90            }
 91
 92            if matches.is_empty() {
 93                Ok(format!("No paths in the project matched the glob {glob:?}"))
 94            } else {
 95                // Sort to group entries in the same directory together.
 96                matches.sort();
 97
 98                let total_matches = matches.len();
 99                let response = if total_matches > offset + RESULTS_PER_PAGE {
100                  let paginated_matches: Vec<_> = matches
101                      .into_iter()
102                      .skip(offset)
103                      .take(RESULTS_PER_PAGE)
104                      .collect();
105
106                    format!(
107                        "Found {} total matches. Showing results {}-{} (provide 'offset' parameter for more results):\n\n{}",
108                        total_matches,
109                        offset + 1,
110                        offset + paginated_matches.len(),
111                        paginated_matches.join("\n")
112                    )
113                } else {
114                    matches.join("\n")
115                };
116
117                Ok(response)
118            }
119        })
120    }
121}