find_path_tool.rs

  1use crate::{schema::json_schema_for, ui::ToolCallCardHeader};
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus};
  4use editor::Editor;
  5use futures::channel::oneshot::{self, Receiver};
  6use gpui::{
  7    AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
  8};
  9use language;
 10use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 11use project::Project;
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use std::fmt::Write;
 15use std::{cmp, path::PathBuf, sync::Arc};
 16use ui::{Disclosure, Tooltip, prelude::*};
 17use util::{ResultExt, paths::PathMatcher};
 18use workspace::Workspace;
 19
 20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 21pub struct FindPathToolInput {
 22    /// The glob to match against every path in the project.
 23    ///
 24    /// <example>
 25    /// If the project has the following root directories:
 26    ///
 27    /// - directory1/a/something.txt
 28    /// - directory2/a/things.txt
 29    /// - directory3/a/other.txt
 30    ///
 31    /// You can get back the first two paths by providing a glob of "*thing*.txt"
 32    /// </example>
 33    pub glob: String,
 34
 35    /// Optional starting position for paginated results (0-based).
 36    /// When not provided, starts from the beginning.
 37    #[serde(default)]
 38    pub offset: usize,
 39}
 40
 41#[derive(Debug, Serialize, Deserialize)]
 42struct FindPathToolOutput {
 43    glob: String,
 44    paths: Vec<PathBuf>,
 45}
 46
 47const RESULTS_PER_PAGE: usize = 50;
 48
 49pub struct FindPathTool;
 50
 51impl Tool for FindPathTool {
 52    fn name(&self) -> String {
 53        "find_path".into()
 54    }
 55
 56    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 57        false
 58    }
 59
 60    fn description(&self) -> String {
 61        include_str!("./find_path_tool/description.md").into()
 62    }
 63
 64    fn icon(&self) -> IconName {
 65        IconName::SearchCode
 66    }
 67
 68    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 69        json_schema_for::<FindPathToolInput>(format)
 70    }
 71
 72    fn ui_text(&self, input: &serde_json::Value) -> String {
 73        match serde_json::from_value::<FindPathToolInput>(input.clone()) {
 74            Ok(input) => format!("Find paths matching “`{}`”", input.glob),
 75            Err(_) => "Search paths".to_string(),
 76        }
 77    }
 78
 79    fn run(
 80        self: Arc<Self>,
 81        input: serde_json::Value,
 82        _request: Arc<LanguageModelRequest>,
 83        project: Entity<Project>,
 84        _action_log: Entity<ActionLog>,
 85        _model: Arc<dyn LanguageModel>,
 86        _window: Option<AnyWindowHandle>,
 87        cx: &mut App,
 88    ) -> ToolResult {
 89        let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) {
 90            Ok(input) => (input.offset, input.glob),
 91            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 92        };
 93
 94        let (sender, receiver) = oneshot::channel();
 95
 96        let card = cx.new(|cx| FindPathToolCard::new(glob.clone(), receiver, cx));
 97
 98        let search_paths_task = search_paths(&glob, project, cx);
 99
100        let task = cx.background_spawn(async move {
101            let matches = search_paths_task.await?;
102            let paginated_matches: &[PathBuf] = &matches[cmp::min(offset, matches.len())
103                ..cmp::min(offset + RESULTS_PER_PAGE, matches.len())];
104
105            sender.send(paginated_matches.to_vec()).log_err();
106
107            if matches.is_empty() {
108                Ok("No matches found".to_string().into())
109            } else {
110                let mut message = format!("Found {} total matches.", matches.len());
111                if matches.len() > RESULTS_PER_PAGE {
112                    write!(
113                        &mut message,
114                        "\nShowing results {}-{} (provide 'offset' parameter for more results):",
115                        offset + 1,
116                        offset + paginated_matches.len()
117                    )
118                    .unwrap();
119                }
120                let output = FindPathToolOutput {
121                    glob,
122                    paths: matches.clone(),
123                };
124
125                for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
126                    write!(&mut message, "\n{}", mat.display()).unwrap();
127                }
128                Ok(ToolResultOutput {
129                    content: message,
130                    output: Some(serde_json::to_value(output)?),
131                })
132            }
133        });
134
135        ToolResult {
136            output: task,
137            card: Some(card.into()),
138        }
139    }
140
141    fn deserialize_card(
142        self: Arc<Self>,
143        output: serde_json::Value,
144        _project: Entity<Project>,
145        _window: &mut Window,
146        cx: &mut App,
147    ) -> Option<assistant_tool::AnyToolCard> {
148        let output = serde_json::from_value::<FindPathToolOutput>(output).ok()?;
149        let card = cx.new(|_| FindPathToolCard::from_output(output));
150        Some(card.into())
151    }
152}
153
154fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
155    let path_matcher = match PathMatcher::new([
156        // Sometimes models try to search for "". In this case, return all paths in the project.
157        if glob.is_empty() { "*" } else { glob },
158    ]) {
159        Ok(matcher) => matcher,
160        Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
161    };
162    let snapshots: Vec<_> = project
163        .read(cx)
164        .worktrees(cx)
165        .map(|worktree| worktree.read(cx).snapshot())
166        .collect();
167
168    cx.background_spawn(async move {
169        Ok(snapshots
170            .iter()
171            .flat_map(|snapshot| {
172                let root_name = PathBuf::from(snapshot.root_name());
173                snapshot
174                    .entries(false, 0)
175                    .map(move |entry| root_name.join(&entry.path))
176                    .filter(|path| path_matcher.is_match(&path))
177            })
178            .collect())
179    })
180}
181
182struct FindPathToolCard {
183    paths: Vec<PathBuf>,
184    expanded: bool,
185    glob: String,
186    _receiver_task: Option<Task<Result<()>>>,
187}
188
189impl FindPathToolCard {
190    fn new(glob: String, receiver: Receiver<Vec<PathBuf>>, cx: &mut Context<Self>) -> Self {
191        let _receiver_task = cx.spawn(async move |this, cx| {
192            let paths = receiver.await?;
193
194            this.update(cx, |this, _cx| {
195                this.paths = paths;
196            })
197            .log_err();
198
199            Ok(())
200        });
201
202        Self {
203            paths: Vec::new(),
204            expanded: false,
205            glob,
206            _receiver_task: Some(_receiver_task),
207        }
208    }
209
210    fn from_output(output: FindPathToolOutput) -> Self {
211        Self {
212            glob: output.glob,
213            paths: output.paths,
214            expanded: false,
215            _receiver_task: None,
216        }
217    }
218}
219
220impl ToolCard for FindPathToolCard {
221    fn render(
222        &mut self,
223        _status: &ToolUseStatus,
224        _window: &mut Window,
225        workspace: WeakEntity<Workspace>,
226        cx: &mut Context<Self>,
227    ) -> impl IntoElement {
228        let matches_label: SharedString = if self.paths.len() == 0 {
229            "No matches".into()
230        } else if self.paths.len() == 1 {
231            "1 match".into()
232        } else {
233            format!("{} matches", self.paths.len()).into()
234        };
235
236        let glob_label = self.glob.to_string();
237
238        let content = if !self.paths.is_empty() && self.expanded {
239            Some(
240                v_flex()
241                    .relative()
242                    .ml_1p5()
243                    .px_1p5()
244                    .gap_0p5()
245                    .border_l_1()
246                    .border_color(cx.theme().colors().border_variant)
247                    .children(self.paths.iter().enumerate().map(|(index, path)| {
248                        let path_clone = path.clone();
249                        let workspace_clone = workspace.clone();
250                        let button_label = path.to_string_lossy().to_string();
251
252                        Button::new(("path", index), button_label)
253                            .icon(IconName::ArrowUpRight)
254                            .icon_size(IconSize::XSmall)
255                            .icon_position(IconPosition::End)
256                            .label_size(LabelSize::Small)
257                            .color(Color::Muted)
258                            .tooltip(Tooltip::text("Jump to File"))
259                            .on_click(move |_, window, cx| {
260                                workspace_clone
261                                    .update(cx, |workspace, cx| {
262                                        let path = PathBuf::from(&path_clone);
263                                        let Some(project_path) = workspace
264                                            .project()
265                                            .read(cx)
266                                            .find_project_path(&path, cx)
267                                        else {
268                                            return;
269                                        };
270                                        let open_task = workspace.open_path(
271                                            project_path,
272                                            None,
273                                            true,
274                                            window,
275                                            cx,
276                                        );
277                                        window
278                                            .spawn(cx, async move |cx| {
279                                                let item = open_task.await?;
280                                                if let Some(active_editor) =
281                                                    item.downcast::<Editor>()
282                                                {
283                                                    active_editor
284                                                        .update_in(cx, |editor, window, cx| {
285                                                            editor.go_to_singleton_buffer_point(
286                                                                language::Point::new(0, 0),
287                                                                window,
288                                                                cx,
289                                                            );
290                                                        })
291                                                        .log_err();
292                                                }
293                                                anyhow::Ok(())
294                                            })
295                                            .detach_and_log_err(cx);
296                                    })
297                                    .ok();
298                            })
299                    }))
300                    .into_any(),
301            )
302        } else {
303            None
304        };
305
306        v_flex()
307            .mb_2()
308            .gap_1()
309            .child(
310                ToolCallCardHeader::new(IconName::SearchCode, matches_label)
311                    .with_code_path(glob_label)
312                    .disclosure_slot(
313                        Disclosure::new("path-search-disclosure", self.expanded)
314                            .opened_icon(IconName::ChevronUp)
315                            .closed_icon(IconName::ChevronDown)
316                            .disabled(self.paths.is_empty())
317                            .on_click(cx.listener(move |this, _, _, _cx| {
318                                this.expanded = !this.expanded;
319                            })),
320                    ),
321            )
322            .children(content)
323    }
324}
325
326impl Component for FindPathTool {
327    fn scope() -> ComponentScope {
328        ComponentScope::Agent
329    }
330
331    fn sort_name() -> &'static str {
332        "FindPathTool"
333    }
334
335    fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
336        let successful_card = cx.new(|_| FindPathToolCard {
337            paths: vec![
338                PathBuf::from("src/main.rs"),
339                PathBuf::from("src/lib.rs"),
340                PathBuf::from("tests/test.rs"),
341            ],
342            expanded: true,
343            glob: "*.rs".to_string(),
344            _receiver_task: None,
345        });
346
347        let empty_card = cx.new(|_| FindPathToolCard {
348            paths: Vec::new(),
349            expanded: false,
350            glob: "*.nonexistent".to_string(),
351            _receiver_task: None,
352        });
353
354        Some(
355            v_flex()
356                .gap_6()
357                .children(vec![example_group(vec![
358                    single_example(
359                        "With Paths",
360                        div()
361                            .size_full()
362                            .child(successful_card.update(cx, |tool, cx| {
363                                tool.render(
364                                    &ToolUseStatus::Finished("".into()),
365                                    window,
366                                    WeakEntity::new_invalid(),
367                                    cx,
368                                )
369                                .into_any_element()
370                            }))
371                            .into_any_element(),
372                    ),
373                    single_example(
374                        "No Paths",
375                        div()
376                            .size_full()
377                            .child(empty_card.update(cx, |tool, cx| {
378                                tool.render(
379                                    &ToolUseStatus::Finished("".into()),
380                                    window,
381                                    WeakEntity::new_invalid(),
382                                    cx,
383                                )
384                                .into_any_element()
385                            }))
386                            .into_any_element(),
387                    ),
388                ])])
389                .into_any_element(),
390        )
391    }
392}
393
394#[cfg(test)]
395mod test {
396    use super::*;
397    use gpui::TestAppContext;
398    use project::{FakeFs, Project};
399    use settings::SettingsStore;
400    use util::path;
401
402    #[gpui::test]
403    async fn test_find_path_tool(cx: &mut TestAppContext) {
404        init_test(cx);
405
406        let fs = FakeFs::new(cx.executor());
407        fs.insert_tree(
408            "/root",
409            serde_json::json!({
410                "apple": {
411                    "banana": {
412                        "carrot": "1",
413                    },
414                    "bandana": {
415                        "carbonara": "2",
416                    },
417                    "endive": "3"
418                }
419            }),
420        )
421        .await;
422        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
423
424        let matches = cx
425            .update(|cx| search_paths("root/**/car*", project.clone(), cx))
426            .await
427            .unwrap();
428        assert_eq!(
429            matches,
430            &[
431                PathBuf::from("root/apple/banana/carrot"),
432                PathBuf::from("root/apple/bandana/carbonara")
433            ]
434        );
435
436        let matches = cx
437            .update(|cx| search_paths("**/car*", project.clone(), cx))
438            .await
439            .unwrap();
440        assert_eq!(
441            matches,
442            &[
443                PathBuf::from("root/apple/banana/carrot"),
444                PathBuf::from("root/apple/bandana/carbonara")
445            ]
446        );
447    }
448
449    fn init_test(cx: &mut TestAppContext) {
450        cx.update(|cx| {
451            let settings_store = SettingsStore::test(cx);
452            cx.set_global(settings_store);
453            language::init(cx);
454            Project::init_settings(cx);
455        });
456    }
457}