list_directory_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{AnyWindowHandle, App, Entity, Task};
  5use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::{fmt::Write, path::Path, sync::Arc};
 10use ui::IconName;
 11use util::markdown::MarkdownInlineCode;
 12
 13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 14pub struct ListDirectoryToolInput {
 15    /// The fully-qualified path of the directory to list in the project.
 16    ///
 17    /// This path should never be absolute, and the first component
 18    /// of the path should always be a root directory in a project.
 19    ///
 20    /// <example>
 21    /// If the project has the following root directories:
 22    ///
 23    /// - directory1
 24    /// - directory2
 25    ///
 26    /// You can list the contents of `directory1` by using the path `directory1`.
 27    /// </example>
 28    ///
 29    /// <example>
 30    /// If the project has the following root directories:
 31    ///
 32    /// - foo
 33    /// - bar
 34    ///
 35    /// If you wanna list contents in the directory `foo/baz`, you should use the path `foo/baz`.
 36    /// </example>
 37    pub path: String,
 38}
 39
 40pub struct ListDirectoryTool;
 41
 42impl Tool for ListDirectoryTool {
 43    fn name(&self) -> String {
 44        "list_directory".into()
 45    }
 46
 47    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 48        false
 49    }
 50
 51    fn may_perform_edits(&self) -> bool {
 52        false
 53    }
 54
 55    fn description(&self) -> String {
 56        include_str!("./list_directory_tool/description.md").into()
 57    }
 58
 59    fn icon(&self) -> IconName {
 60        IconName::Folder
 61    }
 62
 63    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 64        json_schema_for::<ListDirectoryToolInput>(format)
 65    }
 66
 67    fn ui_text(&self, input: &serde_json::Value) -> String {
 68        match serde_json::from_value::<ListDirectoryToolInput>(input.clone()) {
 69            Ok(input) => {
 70                let path = MarkdownInlineCode(&input.path);
 71                format!("List the {path} directory's contents")
 72            }
 73            Err(_) => "List directory".to_string(),
 74        }
 75    }
 76
 77    fn run(
 78        self: Arc<Self>,
 79        input: serde_json::Value,
 80        _request: Arc<LanguageModelRequest>,
 81        project: Entity<Project>,
 82        _action_log: Entity<ActionLog>,
 83        _model: Arc<dyn LanguageModel>,
 84        _window: Option<AnyWindowHandle>,
 85        cx: &mut App,
 86    ) -> ToolResult {
 87        let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {
 88            Ok(input) => input,
 89            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 90        };
 91
 92        // Sometimes models will return these even though we tell it to give a path and not a glob.
 93        // When this happens, just list the root worktree directories.
 94        if matches!(input.path.as_str(), "." | "" | "./" | "*") {
 95            let output = project
 96                .read(cx)
 97                .worktrees(cx)
 98                .filter_map(|worktree| {
 99                    worktree.read(cx).root_entry().and_then(|entry| {
100                        if entry.is_dir() {
101                            entry.path.to_str()
102                        } else {
103                            None
104                        }
105                    })
106                })
107                .collect::<Vec<_>>()
108                .join("\n");
109
110            return Task::ready(Ok(output.into())).into();
111        }
112
113        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
114            return Task::ready(Err(anyhow!("Path {} not found in project", input.path))).into();
115        };
116        let Some(worktree) = project
117            .read(cx)
118            .worktree_for_id(project_path.worktree_id, cx)
119        else {
120            return Task::ready(Err(anyhow!("Worktree not found"))).into();
121        };
122        let worktree = worktree.read(cx);
123
124        let Some(entry) = worktree.entry_for_path(&project_path.path) else {
125            return Task::ready(Err(anyhow!("Path not found: {}", input.path))).into();
126        };
127
128        if !entry.is_dir() {
129            return Task::ready(Err(anyhow!("{} is not a directory.", input.path))).into();
130        }
131
132        let mut folders = Vec::new();
133        let mut files = Vec::new();
134
135        for entry in worktree.child_entries(&project_path.path) {
136            let full_path = Path::new(worktree.root_name())
137                .join(&entry.path)
138                .display()
139                .to_string();
140            if entry.is_dir() {
141                folders.push(full_path);
142            } else {
143                files.push(full_path);
144            }
145        }
146
147        let mut output = String::new();
148
149        if !folders.is_empty() {
150            writeln!(output, "# Folders:\n{}", folders.join("\n")).unwrap();
151        }
152
153        if !files.is_empty() {
154            writeln!(output, "\n# Files:\n{}", files.join("\n")).unwrap();
155        }
156
157        if output.is_empty() {
158            writeln!(output, "{} is empty.", input.path).unwrap();
159        }
160
161        Task::ready(Ok(output.into())).into()
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use assistant_tool::Tool;
169    use gpui::{AppContext, TestAppContext};
170    use indoc::indoc;
171    use language_model::fake_provider::FakeLanguageModel;
172    use project::{FakeFs, Project};
173    use serde_json::json;
174    use settings::SettingsStore;
175    use util::path;
176
177    fn platform_paths(path_str: &str) -> String {
178        if cfg!(target_os = "windows") {
179            path_str.replace("/", "\\")
180        } else {
181            path_str.to_string()
182        }
183    }
184
185    fn init_test(cx: &mut TestAppContext) {
186        cx.update(|cx| {
187            let settings_store = SettingsStore::test(cx);
188            cx.set_global(settings_store);
189            language::init(cx);
190            Project::init_settings(cx);
191        });
192    }
193
194    #[gpui::test]
195    async fn test_list_directory_separates_files_and_dirs(cx: &mut TestAppContext) {
196        init_test(cx);
197
198        let fs = FakeFs::new(cx.executor());
199        fs.insert_tree(
200            "/project",
201            json!({
202                "src": {
203                    "main.rs": "fn main() {}",
204                    "lib.rs": "pub fn hello() {}",
205                    "models": {
206                        "user.rs": "struct User {}",
207                        "post.rs": "struct Post {}"
208                    },
209                    "utils": {
210                        "helper.rs": "pub fn help() {}"
211                    }
212                },
213                "tests": {
214                    "integration_test.rs": "#[test] fn test() {}"
215                },
216                "README.md": "# Project",
217                "Cargo.toml": "[package]"
218            }),
219        )
220        .await;
221
222        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
223        let action_log = cx.new(|_| ActionLog::new(project.clone()));
224        let model = Arc::new(FakeLanguageModel::default());
225        let tool = Arc::new(ListDirectoryTool);
226
227        // Test listing root directory
228        let input = json!({
229            "path": "project"
230        });
231
232        let result = cx
233            .update(|cx| {
234                tool.clone().run(
235                    input,
236                    Arc::default(),
237                    project.clone(),
238                    action_log.clone(),
239                    model.clone(),
240                    None,
241                    cx,
242                )
243            })
244            .output
245            .await
246            .unwrap();
247
248        let content = result.content.as_str().unwrap();
249        assert_eq!(
250            content,
251            platform_paths(indoc! {"
252                # Folders:
253                project/src
254                project/tests
255
256                # Files:
257                project/Cargo.toml
258                project/README.md
259            "})
260        );
261
262        // Test listing src directory
263        let input = json!({
264            "path": "project/src"
265        });
266
267        let result = cx
268            .update(|cx| {
269                tool.clone().run(
270                    input,
271                    Arc::default(),
272                    project.clone(),
273                    action_log.clone(),
274                    model.clone(),
275                    None,
276                    cx,
277                )
278            })
279            .output
280            .await
281            .unwrap();
282
283        let content = result.content.as_str().unwrap();
284        assert_eq!(
285            content,
286            platform_paths(indoc! {"
287                # Folders:
288                project/src/models
289                project/src/utils
290
291                # Files:
292                project/src/lib.rs
293                project/src/main.rs
294            "})
295        );
296
297        // Test listing directory with only files
298        let input = json!({
299            "path": "project/tests"
300        });
301
302        let result = cx
303            .update(|cx| {
304                tool.clone().run(
305                    input,
306                    Arc::default(),
307                    project.clone(),
308                    action_log.clone(),
309                    model.clone(),
310                    None,
311                    cx,
312                )
313            })
314            .output
315            .await
316            .unwrap();
317
318        let content = result.content.as_str().unwrap();
319        assert!(!content.contains("# Folders:"));
320        assert!(content.contains("# Files:"));
321        assert!(content.contains(&platform_paths("project/tests/integration_test.rs")));
322    }
323
324    #[gpui::test]
325    async fn test_list_directory_empty_directory(cx: &mut TestAppContext) {
326        init_test(cx);
327
328        let fs = FakeFs::new(cx.executor());
329        fs.insert_tree(
330            "/project",
331            json!({
332                "empty_dir": {}
333            }),
334        )
335        .await;
336
337        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
338        let action_log = cx.new(|_| ActionLog::new(project.clone()));
339        let model = Arc::new(FakeLanguageModel::default());
340        let tool = Arc::new(ListDirectoryTool);
341
342        let input = json!({
343            "path": "project/empty_dir"
344        });
345
346        let result = cx
347            .update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx))
348            .output
349            .await
350            .unwrap();
351
352        let content = result.content.as_str().unwrap();
353        assert_eq!(content, "project/empty_dir is empty.\n");
354    }
355
356    #[gpui::test]
357    async fn test_list_directory_error_cases(cx: &mut TestAppContext) {
358        init_test(cx);
359
360        let fs = FakeFs::new(cx.executor());
361        fs.insert_tree(
362            "/project",
363            json!({
364                "file.txt": "content"
365            }),
366        )
367        .await;
368
369        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
370        let action_log = cx.new(|_| ActionLog::new(project.clone()));
371        let model = Arc::new(FakeLanguageModel::default());
372        let tool = Arc::new(ListDirectoryTool);
373
374        // Test non-existent path
375        let input = json!({
376            "path": "project/nonexistent"
377        });
378
379        let result = cx
380            .update(|cx| {
381                tool.clone().run(
382                    input,
383                    Arc::default(),
384                    project.clone(),
385                    action_log.clone(),
386                    model.clone(),
387                    None,
388                    cx,
389                )
390            })
391            .output
392            .await;
393
394        assert!(result.is_err());
395        assert!(result.unwrap_err().to_string().contains("Path not found"));
396
397        // Test trying to list a file instead of directory
398        let input = json!({
399            "path": "project/file.txt"
400        });
401
402        let result = cx
403            .update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx))
404            .output
405            .await;
406
407        assert!(result.is_err());
408        assert!(
409            result
410                .unwrap_err()
411                .to_string()
412                .contains("is not a directory")
413        );
414    }
415}