find_path_tool.rs

  1use crate::{schema::json_schema_for, ui::ToolCallCardHeader};
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{
  4    ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
  5};
  6use editor::Editor;
  7use futures::channel::oneshot::{self, Receiver};
  8use gpui::{
  9    AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
 10};
 11use language;
 12use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 13use project::Project;
 14use schemars::JsonSchema;
 15use serde::{Deserialize, Serialize};
 16use std::fmt::Write;
 17use std::{cmp, path::PathBuf, sync::Arc};
 18use ui::{Disclosure, Tooltip, prelude::*};
 19use util::{ResultExt, paths::PathMatcher};
 20use workspace::Workspace;
 21
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct FindPathToolInput {
 24    /// The glob to match against every path in the project.
 25    ///
 26    /// <example>
 27    /// If the project has the following root directories:
 28    ///
 29    /// - directory1/a/something.txt
 30    /// - directory2/a/things.txt
 31    /// - directory3/a/other.txt
 32    ///
 33    /// You can get back the first two paths by providing a glob of "*thing*.txt"
 34    /// </example>
 35    pub glob: String,
 36
 37    /// Optional starting position for paginated results (0-based).
 38    /// When not provided, starts from the beginning.
 39    #[serde(default)]
 40    pub offset: usize,
 41}
 42
 43#[derive(Debug, Serialize, Deserialize)]
 44struct FindPathToolOutput {
 45    glob: String,
 46    paths: Vec<PathBuf>,
 47}
 48
 49const RESULTS_PER_PAGE: usize = 50;
 50
 51pub struct FindPathTool;
 52
 53impl Tool for FindPathTool {
 54    fn name(&self) -> String {
 55        "find_path".into()
 56    }
 57
 58    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 59        false
 60    }
 61
 62    fn description(&self) -> String {
 63        include_str!("./find_path_tool/description.md").into()
 64    }
 65
 66    fn icon(&self) -> IconName {
 67        IconName::SearchCode
 68    }
 69
 70    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 71        json_schema_for::<FindPathToolInput>(format)
 72    }
 73
 74    fn ui_text(&self, input: &serde_json::Value) -> String {
 75        match serde_json::from_value::<FindPathToolInput>(input.clone()) {
 76            Ok(input) => format!("Find paths matching “`{}`”", input.glob),
 77            Err(_) => "Search paths".to_string(),
 78        }
 79    }
 80
 81    fn run(
 82        self: Arc<Self>,
 83        input: serde_json::Value,
 84        _request: Arc<LanguageModelRequest>,
 85        project: Entity<Project>,
 86        _action_log: Entity<ActionLog>,
 87        _model: Arc<dyn LanguageModel>,
 88        _window: Option<AnyWindowHandle>,
 89        cx: &mut App,
 90    ) -> ToolResult {
 91        let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) {
 92            Ok(input) => (input.offset, input.glob),
 93            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 94        };
 95
 96        let (sender, receiver) = oneshot::channel();
 97
 98        let card = cx.new(|cx| FindPathToolCard::new(glob.clone(), receiver, cx));
 99
100        let search_paths_task = search_paths(&glob, project, cx);
101
102        let task = cx.background_spawn(async move {
103            let matches = search_paths_task.await?;
104            let paginated_matches: &[PathBuf] = &matches[cmp::min(offset, matches.len())
105                ..cmp::min(offset + RESULTS_PER_PAGE, matches.len())];
106
107            sender.send(paginated_matches.to_vec()).log_err();
108
109            if matches.is_empty() {
110                Ok("No matches found".to_string().into())
111            } else {
112                let mut message = format!("Found {} total matches.", matches.len());
113                if matches.len() > RESULTS_PER_PAGE {
114                    write!(
115                        &mut message,
116                        "\nShowing results {}-{} (provide 'offset' parameter for more results):",
117                        offset + 1,
118                        offset + paginated_matches.len()
119                    )
120                    .unwrap();
121                }
122
123                for mat in matches.iter().skip(offset).take(RESULTS_PER_PAGE) {
124                    write!(&mut message, "\n{}", mat.display()).unwrap();
125                }
126
127                let output = FindPathToolOutput {
128                    glob,
129                    paths: matches,
130                };
131
132                Ok(ToolResultOutput {
133                    content: ToolResultContent::Text(message),
134                    output: Some(serde_json::to_value(output)?),
135                })
136            }
137        });
138
139        ToolResult {
140            output: task,
141            card: Some(card.into()),
142        }
143    }
144
145    fn deserialize_card(
146        self: Arc<Self>,
147        output: serde_json::Value,
148        _project: Entity<Project>,
149        _window: &mut Window,
150        cx: &mut App,
151    ) -> Option<assistant_tool::AnyToolCard> {
152        let output = serde_json::from_value::<FindPathToolOutput>(output).ok()?;
153        let card = cx.new(|_| FindPathToolCard::from_output(output));
154        Some(card.into())
155    }
156}
157
158fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
159    let path_matcher = match PathMatcher::new([
160        // Sometimes models try to search for "". In this case, return all paths in the project.
161        if glob.is_empty() { "*" } else { glob },
162    ]) {
163        Ok(matcher) => matcher,
164        Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
165    };
166    let snapshots: Vec<_> = project
167        .read(cx)
168        .worktrees(cx)
169        .map(|worktree| worktree.read(cx).snapshot())
170        .collect();
171
172    cx.background_spawn(async move {
173        Ok(snapshots
174            .iter()
175            .flat_map(|snapshot| {
176                let root_name = PathBuf::from(snapshot.root_name());
177                snapshot
178                    .entries(false, 0)
179                    .map(move |entry| root_name.join(&entry.path))
180                    .filter(|path| path_matcher.is_match(&path))
181            })
182            .collect())
183    })
184}
185
186struct FindPathToolCard {
187    paths: Vec<PathBuf>,
188    expanded: bool,
189    glob: String,
190    _receiver_task: Option<Task<Result<()>>>,
191}
192
193impl FindPathToolCard {
194    fn new(glob: String, receiver: Receiver<Vec<PathBuf>>, cx: &mut Context<Self>) -> Self {
195        let _receiver_task = cx.spawn(async move |this, cx| {
196            let paths = receiver.await?;
197
198            this.update(cx, |this, _cx| {
199                this.paths = paths;
200            })
201            .log_err();
202
203            Ok(())
204        });
205
206        Self {
207            paths: Vec::new(),
208            expanded: false,
209            glob,
210            _receiver_task: Some(_receiver_task),
211        }
212    }
213
214    fn from_output(output: FindPathToolOutput) -> Self {
215        Self {
216            glob: output.glob,
217            paths: output.paths,
218            expanded: false,
219            _receiver_task: None,
220        }
221    }
222}
223
224impl ToolCard for FindPathToolCard {
225    fn render(
226        &mut self,
227        _status: &ToolUseStatus,
228        _window: &mut Window,
229        workspace: WeakEntity<Workspace>,
230        cx: &mut Context<Self>,
231    ) -> impl IntoElement {
232        let matches_label: SharedString = if self.paths.len() == 0 {
233            "No matches".into()
234        } else if self.paths.len() == 1 {
235            "1 match".into()
236        } else {
237            format!("{} matches", self.paths.len()).into()
238        };
239
240        let content = if !self.paths.is_empty() && self.expanded {
241            Some(
242                v_flex()
243                    .relative()
244                    .ml_1p5()
245                    .px_1p5()
246                    .gap_0p5()
247                    .border_l_1()
248                    .border_color(cx.theme().colors().border_variant)
249                    .children(self.paths.iter().enumerate().map(|(index, path)| {
250                        let path_clone = path.clone();
251                        let workspace_clone = workspace.clone();
252                        let button_label = path.to_string_lossy().to_string();
253
254                        Button::new(("path", index), button_label)
255                            .icon(IconName::ArrowUpRight)
256                            .icon_size(IconSize::XSmall)
257                            .icon_position(IconPosition::End)
258                            .label_size(LabelSize::Small)
259                            .color(Color::Muted)
260                            .tooltip(Tooltip::text("Jump to File"))
261                            .on_click(move |_, window, cx| {
262                                workspace_clone
263                                    .update(cx, |workspace, cx| {
264                                        let path = PathBuf::from(&path_clone);
265                                        let Some(project_path) = workspace
266                                            .project()
267                                            .read(cx)
268                                            .find_project_path(&path, cx)
269                                        else {
270                                            return;
271                                        };
272                                        let open_task = workspace.open_path(
273                                            project_path,
274                                            None,
275                                            true,
276                                            window,
277                                            cx,
278                                        );
279                                        window
280                                            .spawn(cx, async move |cx| {
281                                                let item = open_task.await?;
282                                                if let Some(active_editor) =
283                                                    item.downcast::<Editor>()
284                                                {
285                                                    active_editor
286                                                        .update_in(cx, |editor, window, cx| {
287                                                            editor.go_to_singleton_buffer_point(
288                                                                language::Point::new(0, 0),
289                                                                window,
290                                                                cx,
291                                                            );
292                                                        })
293                                                        .log_err();
294                                                }
295                                                anyhow::Ok(())
296                                            })
297                                            .detach_and_log_err(cx);
298                                    })
299                                    .ok();
300                            })
301                    }))
302                    .into_any(),
303            )
304        } else {
305            None
306        };
307
308        v_flex()
309            .mb_2()
310            .gap_1()
311            .child(
312                ToolCallCardHeader::new(IconName::SearchCode, matches_label)
313                    .with_code_path(&self.glob)
314                    .disclosure_slot(
315                        Disclosure::new("path-search-disclosure", self.expanded)
316                            .opened_icon(IconName::ChevronUp)
317                            .closed_icon(IconName::ChevronDown)
318                            .disabled(self.paths.is_empty())
319                            .on_click(cx.listener(move |this, _, _, _cx| {
320                                this.expanded = !this.expanded;
321                            })),
322                    ),
323            )
324            .children(content)
325    }
326}
327
328impl Component for FindPathTool {
329    fn scope() -> ComponentScope {
330        ComponentScope::Agent
331    }
332
333    fn sort_name() -> &'static str {
334        "FindPathTool"
335    }
336
337    fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
338        let successful_card = cx.new(|_| FindPathToolCard {
339            paths: vec![
340                PathBuf::from("src/main.rs"),
341                PathBuf::from("src/lib.rs"),
342                PathBuf::from("tests/test.rs"),
343            ],
344            expanded: true,
345            glob: "*.rs".to_string(),
346            _receiver_task: None,
347        });
348
349        let empty_card = cx.new(|_| FindPathToolCard {
350            paths: Vec::new(),
351            expanded: false,
352            glob: "*.nonexistent".to_string(),
353            _receiver_task: None,
354        });
355
356        Some(
357            v_flex()
358                .gap_6()
359                .children(vec![example_group(vec![
360                    single_example(
361                        "With Paths",
362                        div()
363                            .size_full()
364                            .child(successful_card.update(cx, |tool, cx| {
365                                tool.render(
366                                    &ToolUseStatus::Finished("".into()),
367                                    window,
368                                    WeakEntity::new_invalid(),
369                                    cx,
370                                )
371                                .into_any_element()
372                            }))
373                            .into_any_element(),
374                    ),
375                    single_example(
376                        "No Paths",
377                        div()
378                            .size_full()
379                            .child(empty_card.update(cx, |tool, cx| {
380                                tool.render(
381                                    &ToolUseStatus::Finished("".into()),
382                                    window,
383                                    WeakEntity::new_invalid(),
384                                    cx,
385                                )
386                                .into_any_element()
387                            }))
388                            .into_any_element(),
389                    ),
390                ])])
391                .into_any_element(),
392        )
393    }
394}
395
396#[cfg(test)]
397mod test {
398    use super::*;
399    use gpui::TestAppContext;
400    use project::{FakeFs, Project};
401    use settings::SettingsStore;
402    use util::path;
403
404    #[gpui::test]
405    async fn test_find_path_tool(cx: &mut TestAppContext) {
406        init_test(cx);
407
408        let fs = FakeFs::new(cx.executor());
409        fs.insert_tree(
410            "/root",
411            serde_json::json!({
412                "apple": {
413                    "banana": {
414                        "carrot": "1",
415                    },
416                    "bandana": {
417                        "carbonara": "2",
418                    },
419                    "endive": "3"
420                }
421            }),
422        )
423        .await;
424        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
425
426        let matches = cx
427            .update(|cx| search_paths("root/**/car*", project.clone(), cx))
428            .await
429            .unwrap();
430        assert_eq!(
431            matches,
432            &[
433                PathBuf::from("root/apple/banana/carrot"),
434                PathBuf::from("root/apple/bandana/carbonara")
435            ]
436        );
437
438        let matches = cx
439            .update(|cx| search_paths("**/car*", project.clone(), cx))
440            .await
441            .unwrap();
442        assert_eq!(
443            matches,
444            &[
445                PathBuf::from("root/apple/banana/carrot"),
446                PathBuf::from("root/apple/bandana/carbonara")
447            ]
448        );
449    }
450
451    fn init_test(cx: &mut TestAppContext) {
452        cx.update(|cx| {
453            let settings_store = SettingsStore::test(cx);
454            cx.set_global(settings_store);
455            language::init(cx);
456            Project::init_settings(cx);
457        });
458    }
459}