path_search_tool.rs

  1use anyhow::{anyhow, Result};
  2use assistant_tool::{ActionLog, Tool};
  3use gpui::{App, AppContext, Entity, Task};
  4use language_model::LanguageModelRequestMessage;
  5use project::Project;
  6use schemars::JsonSchema;
  7use serde::{Deserialize, Serialize};
  8use std::{path::PathBuf, sync::Arc};
  9use util::paths::PathMatcher;
 10use worktree::Snapshot;
 11
 12#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 13pub struct PathSearchToolInput {
 14    /// The glob to search all project paths for.
 15    ///
 16    /// <example>
 17    /// If the project has the following root directories:
 18    ///
 19    /// - directory1/a/something.txt
 20    /// - directory2/a/things.txt
 21    /// - directory3/a/other.txt
 22    ///
 23    /// You can get back the first two paths by providing a glob of "*thing*.txt"
 24    /// </example>
 25    pub glob: String,
 26
 27    /// Optional starting position for paginated results (0-based).
 28    /// When not provided, starts from the beginning.
 29    #[serde(default)]
 30    pub offset: Option<usize>,
 31}
 32
 33const RESULTS_PER_PAGE: usize = 50;
 34
 35pub struct PathSearchTool;
 36
 37impl Tool for PathSearchTool {
 38    fn name(&self) -> String {
 39        "path-search".into()
 40    }
 41
 42    fn needs_confirmation(&self) -> bool {
 43        false
 44    }
 45
 46    fn description(&self) -> String {
 47        include_str!("./path_search_tool/description.md").into()
 48    }
 49
 50    fn input_schema(&self) -> serde_json::Value {
 51        let schema = schemars::schema_for!(PathSearchToolInput);
 52        serde_json::to_value(&schema).unwrap()
 53    }
 54
 55    fn ui_text(&self, input: &serde_json::Value) -> String {
 56        match serde_json::from_value::<PathSearchToolInput>(input.clone()) {
 57            Ok(input) => format!("Find paths matching “`{}`”", input.glob),
 58            Err(_) => "Search paths".to_string(),
 59        }
 60    }
 61
 62    fn run(
 63        self: Arc<Self>,
 64        input: serde_json::Value,
 65        _messages: &[LanguageModelRequestMessage],
 66        project: Entity<Project>,
 67        _action_log: Entity<ActionLog>,
 68        cx: &mut App,
 69    ) -> Task<Result<String>> {
 70        let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
 71            Ok(input) => (input.offset.unwrap_or(0), input.glob),
 72            Err(err) => return Task::ready(Err(anyhow!(err))),
 73        };
 74
 75        let path_matcher = match PathMatcher::new([
 76            // Sometimes models try to search for "". In this case, return all paths in the project.
 77            if glob.is_empty() { "*" } else { &glob },
 78        ]) {
 79            Ok(matcher) => matcher,
 80            Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
 81        };
 82        let snapshots: Vec<Snapshot> = project
 83            .read(cx)
 84            .worktrees(cx)
 85            .map(|worktree| worktree.read(cx).snapshot())
 86            .collect();
 87
 88        cx.background_spawn(async move {
 89            let mut matches = Vec::new();
 90
 91            for worktree in snapshots {
 92                let root_name = worktree.root_name();
 93
 94                // Don't consider ignored entries.
 95                for entry in worktree.entries(false, 0) {
 96                    if path_matcher.is_match(&entry.path) {
 97                        matches.push(
 98                            PathBuf::from(root_name)
 99                                .join(&entry.path)
100                                .to_string_lossy()
101                                .to_string(),
102                        );
103                    }
104                }
105            }
106
107            if matches.is_empty() {
108                Ok(format!("No paths in the project matched the glob {glob:?}"))
109            } else {
110                // Sort to group entries in the same directory together.
111                matches.sort();
112
113                let total_matches = matches.len();
114                let response = if total_matches > offset + RESULTS_PER_PAGE {
115                  let paginated_matches: Vec<_> = matches
116                      .into_iter()
117                      .skip(offset)
118                      .take(RESULTS_PER_PAGE)
119                      .collect();
120
121                    format!(
122                        "Found {} total matches. Showing results {}-{} (provide 'offset' parameter for more results):\n\n{}",
123                        total_matches,
124                        offset + 1,
125                        offset + paginated_matches.len(),
126                        paginated_matches.join("\n")
127                    )
128                } else {
129                    matches.join("\n")
130                };
131
132                Ok(response)
133            }
134        })
135    }
136}