path_search_tool.rs

  1use anyhow::{anyhow, Result};
  2use assistant_tool::{ActionLog, Tool};
  3use gpui::{App, AppContext, Entity, Task};
  4use language_model::LanguageModelRequestMessage;
  5use project::Project;
  6use schemars::JsonSchema;
  7use serde::{Deserialize, Serialize};
  8use std::{path::PathBuf, sync::Arc};
  9use ui::IconName;
 10use util::paths::PathMatcher;
 11use worktree::Snapshot;
 12
 13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 14pub struct PathSearchToolInput {
 15    /// The glob to search all project paths for.
 16    ///
 17    /// <example>
 18    /// If the project has the following root directories:
 19    ///
 20    /// - directory1/a/something.txt
 21    /// - directory2/a/things.txt
 22    /// - directory3/a/other.txt
 23    ///
 24    /// You can get back the first two paths by providing a glob of "*thing*.txt"
 25    /// </example>
 26    pub glob: String,
 27
 28    /// Optional starting position for paginated results (0-based).
 29    /// When not provided, starts from the beginning.
 30    #[serde(default)]
 31    pub offset: u32,
 32}
 33
 34const RESULTS_PER_PAGE: usize = 50;
 35
 36pub struct PathSearchTool;
 37
 38impl Tool for PathSearchTool {
 39    fn name(&self) -> String {
 40        "path-search".into()
 41    }
 42
 43    fn needs_confirmation(&self) -> bool {
 44        false
 45    }
 46
 47    fn description(&self) -> String {
 48        include_str!("./path_search_tool/description.md").into()
 49    }
 50
 51    fn icon(&self) -> IconName {
 52        IconName::SearchCode
 53    }
 54
 55    fn input_schema(&self) -> serde_json::Value {
 56        let schema = schemars::schema_for!(PathSearchToolInput);
 57        serde_json::to_value(&schema).unwrap()
 58    }
 59
 60    fn ui_text(&self, input: &serde_json::Value) -> String {
 61        match serde_json::from_value::<PathSearchToolInput>(input.clone()) {
 62            Ok(input) => format!("Find paths matching “`{}`”", input.glob),
 63            Err(_) => "Search paths".to_string(),
 64        }
 65    }
 66
 67    fn run(
 68        self: Arc<Self>,
 69        input: serde_json::Value,
 70        _messages: &[LanguageModelRequestMessage],
 71        project: Entity<Project>,
 72        _action_log: Entity<ActionLog>,
 73        cx: &mut App,
 74    ) -> Task<Result<String>> {
 75        let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
 76            Ok(input) => (input.offset, input.glob),
 77            Err(err) => return Task::ready(Err(anyhow!(err))),
 78        };
 79
 80        let path_matcher = match PathMatcher::new([
 81            // Sometimes models try to search for "". In this case, return all paths in the project.
 82            if glob.is_empty() { "*" } else { &glob },
 83        ]) {
 84            Ok(matcher) => matcher,
 85            Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
 86        };
 87        let snapshots: Vec<Snapshot> = project
 88            .read(cx)
 89            .worktrees(cx)
 90            .map(|worktree| worktree.read(cx).snapshot())
 91            .collect();
 92
 93        cx.background_spawn(async move {
 94            let mut matches = Vec::new();
 95
 96            for worktree in snapshots {
 97                let root_name = worktree.root_name();
 98
 99                // Don't consider ignored entries.
100                for entry in worktree.entries(false, 0) {
101                    if path_matcher.is_match(&entry.path) {
102                        matches.push(
103                            PathBuf::from(root_name)
104                                .join(&entry.path)
105                                .to_string_lossy()
106                                .to_string(),
107                        );
108                    }
109                }
110            }
111
112            if matches.is_empty() {
113                Ok(format!("No paths in the project matched the glob {glob:?}"))
114            } else {
115                // Sort to group entries in the same directory together.
116                matches.sort();
117
118                let total_matches = matches.len();
119                let response = if total_matches > RESULTS_PER_PAGE + offset as usize {
120                let paginated_matches: Vec<_> = matches
121                      .into_iter()
122                      .skip(offset as usize)
123                      .take(RESULTS_PER_PAGE)
124                      .collect();
125
126                    format!(
127                        "Found {} total matches. Showing results {}-{} (provide 'offset' parameter for more results):\n\n{}",
128                        total_matches,
129                        offset + 1,
130                        offset as usize + paginated_matches.len(),
131                        paginated_matches.join("\n")
132                    )
133                } else {
134                    matches.join("\n")
135                };
136
137                Ok(response)
138            }
139        })
140    }
141}