find_path_tool.rs

  1use crate::{schema::json_schema_for, ui::ToolCallCardHeader};
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{
  4    ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
  5};
  6use editor::Editor;
  7use futures::channel::oneshot::{self, Receiver};
  8use gpui::{
  9    AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
 10};
 11use language;
 12use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 13use project::Project;
 14use schemars::JsonSchema;
 15use serde::{Deserialize, Serialize};
 16use std::fmt::Write;
 17use std::{cmp, path::PathBuf, sync::Arc};
 18use ui::{Disclosure, Tooltip, prelude::*};
 19use util::{ResultExt, paths::PathMatcher};
 20use workspace::Workspace;
 21
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct FindPathToolInput {
 24    /// The glob to match against every path in the project.
 25    ///
 26    /// <example>
 27    /// If the project has the following root directories:
 28    ///
 29    /// - directory1/a/something.txt
 30    /// - directory2/a/things.txt
 31    /// - directory3/a/other.txt
 32    ///
 33    /// You can get back the first two paths by providing a glob of "*thing*.txt"
 34    /// </example>
 35    pub glob: String,
 36
 37    /// Optional starting position for paginated results (0-based).
 38    /// When not provided, starts from the beginning.
 39    #[serde(default)]
 40    pub offset: usize,
 41}
 42
 43#[derive(Debug, Serialize, Deserialize)]
 44struct FindPathToolOutput {
 45    glob: String,
 46    paths: Vec<PathBuf>,
 47}
 48
 49const RESULTS_PER_PAGE: usize = 50;
 50
 51pub struct FindPathTool;
 52
 53impl Tool for FindPathTool {
 54    fn name(&self) -> String {
 55        "find_path".into()
 56    }
 57
 58    fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
 59        false
 60    }
 61
 62    fn may_perform_edits(&self) -> bool {
 63        false
 64    }
 65
 66    fn description(&self) -> String {
 67        include_str!("./find_path_tool/description.md").into()
 68    }
 69
 70    fn icon(&self) -> IconName {
 71        IconName::ToolSearch
 72    }
 73
 74    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 75        json_schema_for::<FindPathToolInput>(format)
 76    }
 77
 78    fn ui_text(&self, input: &serde_json::Value) -> String {
 79        match serde_json::from_value::<FindPathToolInput>(input.clone()) {
 80            Ok(input) => format!("Find paths matching “`{}`”", input.glob),
 81            Err(_) => "Search paths".to_string(),
 82        }
 83    }
 84
 85    fn run(
 86        self: Arc<Self>,
 87        input: serde_json::Value,
 88        _request: Arc<LanguageModelRequest>,
 89        project: Entity<Project>,
 90        _action_log: Entity<ActionLog>,
 91        _model: Arc<dyn LanguageModel>,
 92        _window: Option<AnyWindowHandle>,
 93        cx: &mut App,
 94    ) -> ToolResult {
 95        let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) {
 96            Ok(input) => (input.offset, input.glob),
 97            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 98        };
 99
100        let (sender, receiver) = oneshot::channel();
101
102        let card = cx.new(|cx| FindPathToolCard::new(glob.clone(), receiver, cx));
103
104        let search_paths_task = search_paths(&glob, project, cx);
105
106        let task = cx.background_spawn(async move {
107            let matches = search_paths_task.await?;
108            let paginated_matches: &[PathBuf] = &matches[cmp::min(offset, matches.len())
109                ..cmp::min(offset + RESULTS_PER_PAGE, matches.len())];
110
111            sender.send(paginated_matches.to_vec()).log_err();
112
113            if matches.is_empty() {
114                Ok("No matches found".to_string().into())
115            } else {
116                let mut message = format!("Found {} total matches.", matches.len());
117                if matches.len() > RESULTS_PER_PAGE {
118                    write!(
119                        &mut message,
120                        "\nShowing results {}-{} (provide 'offset' parameter for more results):",
121                        offset + 1,
122                        offset + paginated_matches.len()
123                    )
124                    .unwrap();
125                }
126
127                for mat in matches.iter().skip(offset).take(RESULTS_PER_PAGE) {
128                    write!(&mut message, "\n{}", mat.display()).unwrap();
129                }
130
131                let output = FindPathToolOutput {
132                    glob,
133                    paths: matches,
134                };
135
136                Ok(ToolResultOutput {
137                    content: ToolResultContent::Text(message),
138                    output: Some(serde_json::to_value(output)?),
139                })
140            }
141        });
142
143        ToolResult {
144            output: task,
145            card: Some(card.into()),
146        }
147    }
148
149    fn deserialize_card(
150        self: Arc<Self>,
151        output: serde_json::Value,
152        _project: Entity<Project>,
153        _window: &mut Window,
154        cx: &mut App,
155    ) -> Option<assistant_tool::AnyToolCard> {
156        let output = serde_json::from_value::<FindPathToolOutput>(output).ok()?;
157        let card = cx.new(|_| FindPathToolCard::from_output(output));
158        Some(card.into())
159    }
160}
161
162fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
163    let path_matcher = match PathMatcher::new([
164        // Sometimes models try to search for "". In this case, return all paths in the project.
165        if glob.is_empty() { "*" } else { glob },
166    ]) {
167        Ok(matcher) => matcher,
168        Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
169    };
170    let snapshots: Vec<_> = project
171        .read(cx)
172        .worktrees(cx)
173        .map(|worktree| worktree.read(cx).snapshot())
174        .collect();
175
176    cx.background_spawn(async move {
177        Ok(snapshots
178            .iter()
179            .flat_map(|snapshot| {
180                let root_name = PathBuf::from(snapshot.root_name());
181                snapshot
182                    .entries(false, 0)
183                    .map(move |entry| root_name.join(&entry.path))
184                    .filter(|path| path_matcher.is_match(&path))
185            })
186            .collect())
187    })
188}
189
190struct FindPathToolCard {
191    paths: Vec<PathBuf>,
192    expanded: bool,
193    glob: String,
194    _receiver_task: Option<Task<Result<()>>>,
195}
196
197impl FindPathToolCard {
198    fn new(glob: String, receiver: Receiver<Vec<PathBuf>>, cx: &mut Context<Self>) -> Self {
199        let _receiver_task = cx.spawn(async move |this, cx| {
200            let paths = receiver.await?;
201
202            this.update(cx, |this, _cx| {
203                this.paths = paths;
204            })
205            .log_err();
206
207            Ok(())
208        });
209
210        Self {
211            paths: Vec::new(),
212            expanded: false,
213            glob,
214            _receiver_task: Some(_receiver_task),
215        }
216    }
217
218    fn from_output(output: FindPathToolOutput) -> Self {
219        Self {
220            glob: output.glob,
221            paths: output.paths,
222            expanded: false,
223            _receiver_task: None,
224        }
225    }
226}
227
228impl ToolCard for FindPathToolCard {
229    fn render(
230        &mut self,
231        _status: &ToolUseStatus,
232        _window: &mut Window,
233        workspace: WeakEntity<Workspace>,
234        cx: &mut Context<Self>,
235    ) -> impl IntoElement {
236        let matches_label: SharedString = if self.paths.len() == 0 {
237            "No matches".into()
238        } else if self.paths.len() == 1 {
239            "1 match".into()
240        } else {
241            format!("{} matches", self.paths.len()).into()
242        };
243
244        let content = if !self.paths.is_empty() && self.expanded {
245            Some(
246                v_flex()
247                    .relative()
248                    .ml_1p5()
249                    .px_1p5()
250                    .gap_0p5()
251                    .border_l_1()
252                    .border_color(cx.theme().colors().border_variant)
253                    .children(self.paths.iter().enumerate().map(|(index, path)| {
254                        let path_clone = path.clone();
255                        let workspace_clone = workspace.clone();
256                        let button_label = path.to_string_lossy().to_string();
257
258                        Button::new(("path", index), button_label)
259                            .icon(IconName::ArrowUpRight)
260                            .icon_size(IconSize::Small)
261                            .icon_position(IconPosition::End)
262                            .label_size(LabelSize::Small)
263                            .color(Color::Muted)
264                            .tooltip(Tooltip::text("Jump to File"))
265                            .on_click(move |_, window, cx| {
266                                workspace_clone
267                                    .update(cx, |workspace, cx| {
268                                        let path = PathBuf::from(&path_clone);
269                                        let Some(project_path) = workspace
270                                            .project()
271                                            .read(cx)
272                                            .find_project_path(&path, cx)
273                                        else {
274                                            return;
275                                        };
276                                        let open_task = workspace.open_path(
277                                            project_path,
278                                            None,
279                                            true,
280                                            window,
281                                            cx,
282                                        );
283                                        window
284                                            .spawn(cx, async move |cx| {
285                                                let item = open_task.await?;
286                                                if let Some(active_editor) =
287                                                    item.downcast::<Editor>()
288                                                {
289                                                    active_editor
290                                                        .update_in(cx, |editor, window, cx| {
291                                                            editor.go_to_singleton_buffer_point(
292                                                                language::Point::new(0, 0),
293                                                                window,
294                                                                cx,
295                                                            );
296                                                        })
297                                                        .log_err();
298                                                }
299                                                anyhow::Ok(())
300                                            })
301                                            .detach_and_log_err(cx);
302                                    })
303                                    .ok();
304                            })
305                    }))
306                    .into_any(),
307            )
308        } else {
309            None
310        };
311
312        v_flex()
313            .mb_2()
314            .gap_1()
315            .child(
316                ToolCallCardHeader::new(IconName::ToolSearch, matches_label)
317                    .with_code_path(&self.glob)
318                    .disclosure_slot(
319                        Disclosure::new("path-search-disclosure", self.expanded)
320                            .opened_icon(IconName::ChevronUp)
321                            .closed_icon(IconName::ChevronDown)
322                            .disabled(self.paths.is_empty())
323                            .on_click(cx.listener(move |this, _, _, _cx| {
324                                this.expanded = !this.expanded;
325                            })),
326                    ),
327            )
328            .children(content)
329    }
330}
331
332impl Component for FindPathTool {
333    fn scope() -> ComponentScope {
334        ComponentScope::Agent
335    }
336
337    fn sort_name() -> &'static str {
338        "FindPathTool"
339    }
340
341    fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
342        let successful_card = cx.new(|_| FindPathToolCard {
343            paths: vec![
344                PathBuf::from("src/main.rs"),
345                PathBuf::from("src/lib.rs"),
346                PathBuf::from("tests/test.rs"),
347            ],
348            expanded: true,
349            glob: "*.rs".to_string(),
350            _receiver_task: None,
351        });
352
353        let empty_card = cx.new(|_| FindPathToolCard {
354            paths: Vec::new(),
355            expanded: false,
356            glob: "*.nonexistent".to_string(),
357            _receiver_task: None,
358        });
359
360        Some(
361            v_flex()
362                .gap_6()
363                .children(vec![example_group(vec![
364                    single_example(
365                        "With Paths",
366                        div()
367                            .size_full()
368                            .child(successful_card.update(cx, |tool, cx| {
369                                tool.render(
370                                    &ToolUseStatus::Finished("".into()),
371                                    window,
372                                    WeakEntity::new_invalid(),
373                                    cx,
374                                )
375                                .into_any_element()
376                            }))
377                            .into_any_element(),
378                    ),
379                    single_example(
380                        "No Paths",
381                        div()
382                            .size_full()
383                            .child(empty_card.update(cx, |tool, cx| {
384                                tool.render(
385                                    &ToolUseStatus::Finished("".into()),
386                                    window,
387                                    WeakEntity::new_invalid(),
388                                    cx,
389                                )
390                                .into_any_element()
391                            }))
392                            .into_any_element(),
393                    ),
394                ])])
395                .into_any_element(),
396        )
397    }
398}
399
400#[cfg(test)]
401mod test {
402    use super::*;
403    use gpui::TestAppContext;
404    use project::{FakeFs, Project};
405    use settings::SettingsStore;
406    use util::path;
407
408    #[gpui::test]
409    async fn test_find_path_tool(cx: &mut TestAppContext) {
410        init_test(cx);
411
412        let fs = FakeFs::new(cx.executor());
413        fs.insert_tree(
414            "/root",
415            serde_json::json!({
416                "apple": {
417                    "banana": {
418                        "carrot": "1",
419                    },
420                    "bandana": {
421                        "carbonara": "2",
422                    },
423                    "endive": "3"
424                }
425            }),
426        )
427        .await;
428        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
429
430        let matches = cx
431            .update(|cx| search_paths("root/**/car*", project.clone(), cx))
432            .await
433            .unwrap();
434        assert_eq!(
435            matches,
436            &[
437                PathBuf::from("root/apple/banana/carrot"),
438                PathBuf::from("root/apple/bandana/carbonara")
439            ]
440        );
441
442        let matches = cx
443            .update(|cx| search_paths("**/car*", project.clone(), cx))
444            .await
445            .unwrap();
446        assert_eq!(
447            matches,
448            &[
449                PathBuf::from("root/apple/banana/carrot"),
450                PathBuf::from("root/apple/bandana/carbonara")
451            ]
452        );
453    }
454
455    fn init_test(cx: &mut TestAppContext) {
456        cx.update(|cx| {
457            let settings_store = SettingsStore::test(cx);
458            cx.set_global(settings_store);
459            language::init(cx);
460            Project::init_settings(cx);
461        });
462    }
463}