read_file_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::outline;
  4use assistant_tool::{ActionLog, Tool, ToolResult};
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6
  7use indoc::formatdoc;
  8use itertools::Itertools;
  9use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 10use project::Project;
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use std::sync::Arc;
 14use ui::IconName;
 15use util::markdown::MarkdownInlineCode;
 16
 17/// If the model requests to read a file whose size exceeds this, then
 18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 19pub struct ReadFileToolInput {
 20    /// The relative path of the file to read.
 21    ///
 22    /// This path should never be absolute, and the first component
 23    /// of the path should always be a root directory in a project.
 24    ///
 25    /// <example>
 26    /// If the project has the following root directories:
 27    ///
 28    /// - directory1
 29    /// - directory2
 30    ///
 31    /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
 32    /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
 33    /// </example>
 34    pub path: String,
 35
 36    /// Optional line number to start reading on (1-based index)
 37    #[serde(default)]
 38    pub start_line: Option<usize>,
 39
 40    /// Optional line number to end reading on (1-based index, inclusive)
 41    #[serde(default)]
 42    pub end_line: Option<usize>,
 43}
 44
 45pub struct ReadFileTool;
 46
 47impl Tool for ReadFileTool {
 48    fn name(&self) -> String {
 49        "read_file".into()
 50    }
 51
 52    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 53        false
 54    }
 55
 56    fn description(&self) -> String {
 57        include_str!("./read_file_tool/description.md").into()
 58    }
 59
 60    fn icon(&self) -> IconName {
 61        IconName::FileSearch
 62    }
 63
 64    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 65        json_schema_for::<ReadFileToolInput>(format)
 66    }
 67
 68    fn ui_text(&self, input: &serde_json::Value) -> String {
 69        match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
 70            Ok(input) => {
 71                let path = MarkdownInlineCode(&input.path);
 72                match (input.start_line, input.end_line) {
 73                    (Some(start), None) => format!("Read file {path} (from line {start})"),
 74                    (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
 75                    _ => format!("Read file {path}"),
 76                }
 77            }
 78            Err(_) => "Read file".to_string(),
 79        }
 80    }
 81
 82    fn run(
 83        self: Arc<Self>,
 84        input: serde_json::Value,
 85        _messages: &[LanguageModelRequestMessage],
 86        project: Entity<Project>,
 87        action_log: Entity<ActionLog>,
 88        _window: Option<AnyWindowHandle>,
 89        cx: &mut App,
 90    ) -> ToolResult {
 91        let input = match serde_json::from_value::<ReadFileToolInput>(input) {
 92            Ok(input) => input,
 93            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 94        };
 95
 96        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
 97            return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
 98        };
 99        let Some(worktree) = project
100            .read(cx)
101            .worktree_for_id(project_path.worktree_id, cx)
102        else {
103            return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
104        };
105        let exists = worktree.update(cx, |worktree, cx| {
106            worktree.file_exists(&project_path.path, cx)
107        });
108
109        let file_path = input.path.clone();
110        cx.spawn(async move |cx| {
111            if !exists.await? {
112                return Err(anyhow!("{} not found", file_path))
113            }
114
115            let buffer = cx
116                .update(|cx| {
117                    project.update(cx, |project, cx| project.open_buffer(project_path, cx))
118                })?
119                .await?;
120
121            // Check if specific line ranges are provided
122            if input.start_line.is_some() || input.end_line.is_some() {
123                let result = buffer.read_with(cx, |buffer, _cx| {
124                    let text = buffer.text();
125                    let start = input.start_line.unwrap_or(1);
126                    let lines = text.split('\n').skip(start - 1);
127                    if let Some(end) = input.end_line {
128                        let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
129                        Itertools::intersperse(lines.take(count), "\n").collect()
130                    } else {
131                        Itertools::intersperse(lines, "\n").collect()
132                    }
133                })?;
134
135                action_log.update(cx, |log, cx| {
136                    log.track_buffer(buffer, cx);
137                })?;
138
139                Ok(result)
140            } else {
141                // No line ranges specified, so check file size to see if it's too big.
142                let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
143
144                if file_size <= outline::AUTO_OUTLINE_SIZE {
145                    // File is small enough, so return its contents.
146                    let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
147
148                    action_log.update(cx, |log, cx| {
149                        log.track_buffer(buffer, cx);
150                    })?;
151
152                    Ok(result)
153                } else {
154                    // File is too big, so return the outline
155                    // and a suggestion to read again with line numbers.
156                    let outline = outline::file_outline(project, file_path, action_log, None, cx).await?;
157                    Ok(formatdoc! {"
158                        This file was too big to read all at once. Here is an outline of its symbols:
159
160                        {outline}
161
162                        Using the line numbers in this outline, you can call this tool again while specifying
163                        the start_line and end_line fields to see the implementations of symbols in the outline."
164                    })
165                }
166            }
167        }).into()
168    }
169}
170
171#[cfg(test)]
172mod test {
173    use super::*;
174    use gpui::{AppContext, TestAppContext};
175    use language::{Language, LanguageConfig, LanguageMatcher};
176    use project::{FakeFs, Project};
177    use serde_json::json;
178    use settings::SettingsStore;
179    use util::path;
180
181    #[gpui::test]
182    async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
183        init_test(cx);
184
185        let fs = FakeFs::new(cx.executor());
186        fs.insert_tree("/root", json!({})).await;
187        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
188        let action_log = cx.new(|_| ActionLog::new(project.clone()));
189        let result = cx
190            .update(|cx| {
191                let input = json!({
192                    "path": "root/nonexistent_file.txt"
193                });
194                Arc::new(ReadFileTool)
195                    .run(input, &[], project.clone(), action_log, None, cx)
196                    .output
197            })
198            .await;
199        assert_eq!(
200            result.unwrap_err().to_string(),
201            "root/nonexistent_file.txt not found"
202        );
203    }
204
205    #[gpui::test]
206    async fn test_read_small_file(cx: &mut TestAppContext) {
207        init_test(cx);
208
209        let fs = FakeFs::new(cx.executor());
210        fs.insert_tree(
211            "/root",
212            json!({
213                "small_file.txt": "This is a small file content"
214            }),
215        )
216        .await;
217        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
218        let action_log = cx.new(|_| ActionLog::new(project.clone()));
219        let result = cx
220            .update(|cx| {
221                let input = json!({
222                    "path": "root/small_file.txt"
223                });
224                Arc::new(ReadFileTool)
225                    .run(input, &[], project.clone(), action_log, None, cx)
226                    .output
227            })
228            .await;
229        assert_eq!(result.unwrap(), "This is a small file content");
230    }
231
232    #[gpui::test]
233    async fn test_read_large_file(cx: &mut TestAppContext) {
234        init_test(cx);
235
236        let fs = FakeFs::new(cx.executor());
237        fs.insert_tree(
238            "/root",
239            json!({
240                "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n    a: u32,\n    b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
241            }),
242        )
243        .await;
244        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
245        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
246        language_registry.add(Arc::new(rust_lang()));
247        let action_log = cx.new(|_| ActionLog::new(project.clone()));
248
249        let result = cx
250            .update(|cx| {
251                let input = json!({
252                    "path": "root/large_file.rs"
253                });
254                Arc::new(ReadFileTool)
255                    .run(input, &[], project.clone(), action_log.clone(), None, cx)
256                    .output
257            })
258            .await;
259        let content = result.unwrap();
260        assert_eq!(
261            content.lines().skip(2).take(6).collect::<Vec<_>>(),
262            vec![
263                "struct Test0 [L1-4]",
264                " a [L2]",
265                " b [L3]",
266                "struct Test1 [L5-8]",
267                " a [L6]",
268                " b [L7]",
269            ]
270        );
271
272        let result = cx
273            .update(|cx| {
274                let input = json!({
275                    "path": "root/large_file.rs",
276                    "offset": 1
277                });
278                Arc::new(ReadFileTool)
279                    .run(input, &[], project.clone(), action_log, None, cx)
280                    .output
281            })
282            .await;
283        let content = result.unwrap();
284        let expected_content = (0..1000)
285            .flat_map(|i| {
286                vec![
287                    format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
288                    format!(" a [L{}]", i * 4 + 2),
289                    format!(" b [L{}]", i * 4 + 3),
290                ]
291            })
292            .collect::<Vec<_>>();
293        pretty_assertions::assert_eq!(
294            content
295                .lines()
296                .skip(2)
297                .take(expected_content.len())
298                .collect::<Vec<_>>(),
299            expected_content
300        );
301    }
302
303    #[gpui::test]
304    async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
305        init_test(cx);
306
307        let fs = FakeFs::new(cx.executor());
308        fs.insert_tree(
309            "/root",
310            json!({
311                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
312            }),
313        )
314        .await;
315        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
316        let action_log = cx.new(|_| ActionLog::new(project.clone()));
317        let result = cx
318            .update(|cx| {
319                let input = json!({
320                    "path": "root/multiline.txt",
321                    "start_line": 2,
322                    "end_line": 4
323                });
324                Arc::new(ReadFileTool)
325                    .run(input, &[], project.clone(), action_log, None, cx)
326                    .output
327            })
328            .await;
329        assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
330    }
331
332    fn init_test(cx: &mut TestAppContext) {
333        cx.update(|cx| {
334            let settings_store = SettingsStore::test(cx);
335            cx.set_global(settings_store);
336            language::init(cx);
337            Project::init_settings(cx);
338        });
339    }
340
341    fn rust_lang() -> Language {
342        Language::new(
343            LanguageConfig {
344                name: "Rust".into(),
345                matcher: LanguageMatcher {
346                    path_suffixes: vec!["rs".to_string()],
347                    ..Default::default()
348                },
349                ..Default::default()
350            },
351            Some(tree_sitter_rust::LANGUAGE.into()),
352        )
353        .with_outline_query(
354            r#"
355            (line_comment) @annotation
356
357            (struct_item
358                "struct" @context
359                name: (_) @name) @item
360            (enum_item
361                "enum" @context
362                name: (_) @name) @item
363            (enum_variant
364                name: (_) @name) @item
365            (field_declaration
366                name: (_) @name) @item
367            (impl_item
368                "impl" @context
369                trait: (_)? @name
370                "for"? @context
371                type: (_) @name
372                body: (_ "{" (_)* "}")) @item
373            (function_item
374                "fn" @context
375                name: (_) @name) @item
376            (mod_item
377                "mod" @context
378                name: (_) @name) @item
379            "#,
380        )
381        .unwrap()
382    }
383}