read_file_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::outline;
  4use assistant_tool::{ActionLog, Tool, ToolResult};
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6
  7use indoc::formatdoc;
  8use itertools::Itertools;
  9use language::{Anchor, Point};
 10use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 11use project::{AgentLocation, Project};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use std::sync::Arc;
 15use ui::IconName;
 16use util::markdown::MarkdownInlineCode;
 17
 18/// If the model requests to read a file whose size exceeds this, then
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct ReadFileToolInput {
 21    /// The relative path of the file to read.
 22    ///
 23    /// This path should never be absolute, and the first component
 24    /// of the path should always be a root directory in a project.
 25    ///
 26    /// <example>
 27    /// If the project has the following root directories:
 28    ///
 29    /// - directory1
 30    /// - directory2
 31    ///
 32    /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
 33    /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
 34    /// </example>
 35    pub path: String,
 36
 37    /// Optional line number to start reading on (1-based index)
 38    #[serde(default)]
 39    pub start_line: Option<u32>,
 40
 41    /// Optional line number to end reading on (1-based index, inclusive)
 42    #[serde(default)]
 43    pub end_line: Option<u32>,
 44}
 45
 46pub struct ReadFileTool;
 47
 48impl Tool for ReadFileTool {
 49    fn name(&self) -> String {
 50        "read_file".into()
 51    }
 52
 53    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 54        false
 55    }
 56
 57    fn description(&self) -> String {
 58        include_str!("./read_file_tool/description.md").into()
 59    }
 60
 61    fn icon(&self) -> IconName {
 62        IconName::FileSearch
 63    }
 64
 65    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 66        json_schema_for::<ReadFileToolInput>(format)
 67    }
 68
 69    fn ui_text(&self, input: &serde_json::Value) -> String {
 70        match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
 71            Ok(input) => {
 72                let path = MarkdownInlineCode(&input.path);
 73                match (input.start_line, input.end_line) {
 74                    (Some(start), None) => format!("Read file {path} (from line {start})"),
 75                    (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
 76                    _ => format!("Read file {path}"),
 77                }
 78            }
 79            Err(_) => "Read file".to_string(),
 80        }
 81    }
 82
 83    fn run(
 84        self: Arc<Self>,
 85        input: serde_json::Value,
 86        _messages: &[LanguageModelRequestMessage],
 87        project: Entity<Project>,
 88        action_log: Entity<ActionLog>,
 89        _window: Option<AnyWindowHandle>,
 90        cx: &mut App,
 91    ) -> ToolResult {
 92        let input = match serde_json::from_value::<ReadFileToolInput>(input) {
 93            Ok(input) => input,
 94            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 95        };
 96
 97        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
 98            return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
 99        };
100        let Some(worktree) = project
101            .read(cx)
102            .worktree_for_id(project_path.worktree_id, cx)
103        else {
104            return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
105        };
106        let exists = worktree.update(cx, |worktree, cx| {
107            worktree.file_exists(&project_path.path, cx)
108        });
109
110        let file_path = input.path.clone();
111        cx.spawn(async move |cx| {
112            if !exists.await? {
113                return Err(anyhow!("{} not found", file_path));
114            }
115
116            let buffer = cx
117                .update(|cx| {
118                    project.update(cx, |project, cx| project.open_buffer(project_path, cx))
119                })?
120                .await?;
121
122            project.update(cx, |project, cx| {
123                project.set_agent_location(
124                    Some(AgentLocation {
125                        buffer: buffer.downgrade(),
126                        position: Anchor::MIN,
127                    }),
128                    cx,
129                );
130            })?;
131
132            // Check if specific line ranges are provided
133            if input.start_line.is_some() || input.end_line.is_some() {
134                let mut anchor = None;
135                let result = buffer.read_with(cx, |buffer, _cx| {
136                    let text = buffer.text();
137                    // .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
138                    let start = input.start_line.unwrap_or(1).max(1);
139                    let start_row = start - 1;
140                    if start_row <= buffer.max_point().row {
141                        let column = buffer.line_indent_for_row(start_row).raw_len();
142                        anchor = Some(buffer.anchor_before(Point::new(start_row, column)));
143                    }
144
145                    let lines = text.split('\n').skip(start_row as usize);
146                    if let Some(end) = input.end_line {
147                        let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
148                        Itertools::intersperse(lines.take(count as usize), "\n").collect()
149                    } else {
150                        Itertools::intersperse(lines, "\n").collect()
151                    }
152                })?;
153
154                action_log.update(cx, |log, cx| {
155                    log.buffer_read(buffer.clone(), cx);
156                })?;
157
158                if let Some(anchor) = anchor {
159                    project.update(cx, |project, cx| {
160                        project.set_agent_location(
161                            Some(AgentLocation {
162                                buffer: buffer.downgrade(),
163                                position: anchor,
164                            }),
165                            cx,
166                        );
167                    })?;
168                }
169
170                Ok(result)
171            } else {
172                // No line ranges specified, so check file size to see if it's too big.
173                let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
174
175                if file_size <= outline::AUTO_OUTLINE_SIZE {
176                    // File is small enough, so return its contents.
177                    let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
178
179                    action_log.update(cx, |log, cx| {
180                        log.buffer_read(buffer, cx);
181                    })?;
182
183                    Ok(result)
184                } else {
185                    // File is too big, so return the outline
186                    // and a suggestion to read again with line numbers.
187                    let outline = outline::file_outline(project, file_path, action_log, None, cx).await?;
188                    Ok(formatdoc! {"
189                        This file was too big to read all at once. Here is an outline of its symbols:
190
191                        {outline}
192
193                        Using the line numbers in this outline, you can call this tool again while specifying
194                        the start_line and end_line fields to see the implementations of symbols in the outline."
195                    })
196                }
197            }
198        })
199        .into()
200    }
201}
202
203#[cfg(test)]
204mod test {
205    use super::*;
206    use gpui::{AppContext, TestAppContext};
207    use language::{Language, LanguageConfig, LanguageMatcher};
208    use project::{FakeFs, Project};
209    use serde_json::json;
210    use settings::SettingsStore;
211    use util::path;
212
213    #[gpui::test]
214    async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
215        init_test(cx);
216
217        let fs = FakeFs::new(cx.executor());
218        fs.insert_tree("/root", json!({})).await;
219        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
220        let action_log = cx.new(|_| ActionLog::new(project.clone()));
221        let result = cx
222            .update(|cx| {
223                let input = json!({
224                    "path": "root/nonexistent_file.txt"
225                });
226                Arc::new(ReadFileTool)
227                    .run(input, &[], project.clone(), action_log, None, cx)
228                    .output
229            })
230            .await;
231        assert_eq!(
232            result.unwrap_err().to_string(),
233            "root/nonexistent_file.txt not found"
234        );
235    }
236
237    #[gpui::test]
238    async fn test_read_small_file(cx: &mut TestAppContext) {
239        init_test(cx);
240
241        let fs = FakeFs::new(cx.executor());
242        fs.insert_tree(
243            "/root",
244            json!({
245                "small_file.txt": "This is a small file content"
246            }),
247        )
248        .await;
249        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
250        let action_log = cx.new(|_| ActionLog::new(project.clone()));
251        let result = cx
252            .update(|cx| {
253                let input = json!({
254                    "path": "root/small_file.txt"
255                });
256                Arc::new(ReadFileTool)
257                    .run(input, &[], project.clone(), action_log, None, cx)
258                    .output
259            })
260            .await;
261        assert_eq!(result.unwrap(), "This is a small file content");
262    }
263
264    #[gpui::test]
265    async fn test_read_large_file(cx: &mut TestAppContext) {
266        init_test(cx);
267
268        let fs = FakeFs::new(cx.executor());
269        fs.insert_tree(
270            "/root",
271            json!({
272                "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n    a: u32,\n    b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
273            }),
274        )
275        .await;
276        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
277        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
278        language_registry.add(Arc::new(rust_lang()));
279        let action_log = cx.new(|_| ActionLog::new(project.clone()));
280
281        let result = cx
282            .update(|cx| {
283                let input = json!({
284                    "path": "root/large_file.rs"
285                });
286                Arc::new(ReadFileTool)
287                    .run(input, &[], project.clone(), action_log.clone(), None, cx)
288                    .output
289            })
290            .await;
291        let content = result.unwrap();
292        assert_eq!(
293            content.lines().skip(2).take(6).collect::<Vec<_>>(),
294            vec![
295                "struct Test0 [L1-4]",
296                " a [L2]",
297                " b [L3]",
298                "struct Test1 [L5-8]",
299                " a [L6]",
300                " b [L7]",
301            ]
302        );
303
304        let result = cx
305            .update(|cx| {
306                let input = json!({
307                    "path": "root/large_file.rs",
308                    "offset": 1
309                });
310                Arc::new(ReadFileTool)
311                    .run(input, &[], project.clone(), action_log, None, cx)
312                    .output
313            })
314            .await;
315        let content = result.unwrap();
316        let expected_content = (0..1000)
317            .flat_map(|i| {
318                vec![
319                    format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
320                    format!(" a [L{}]", i * 4 + 2),
321                    format!(" b [L{}]", i * 4 + 3),
322                ]
323            })
324            .collect::<Vec<_>>();
325        pretty_assertions::assert_eq!(
326            content
327                .lines()
328                .skip(2)
329                .take(expected_content.len())
330                .collect::<Vec<_>>(),
331            expected_content
332        );
333    }
334
335    #[gpui::test]
336    async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
337        init_test(cx);
338
339        let fs = FakeFs::new(cx.executor());
340        fs.insert_tree(
341            "/root",
342            json!({
343                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
344            }),
345        )
346        .await;
347        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
348        let action_log = cx.new(|_| ActionLog::new(project.clone()));
349        let result = cx
350            .update(|cx| {
351                let input = json!({
352                    "path": "root/multiline.txt",
353                    "start_line": 2,
354                    "end_line": 4
355                });
356                Arc::new(ReadFileTool)
357                    .run(input, &[], project.clone(), action_log, None, cx)
358                    .output
359            })
360            .await;
361        assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
362    }
363
364    #[gpui::test]
365    async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) {
366        init_test(cx);
367
368        let fs = FakeFs::new(cx.executor());
369        fs.insert_tree(
370            "/root",
371            json!({
372                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
373            }),
374        )
375        .await;
376        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
377        let action_log = cx.new(|_| ActionLog::new(project.clone()));
378
379        // start_line of 0 should be treated as 1
380        let result = cx
381            .update(|cx| {
382                let input = json!({
383                    "path": "root/multiline.txt",
384                    "start_line": 0,
385                    "end_line": 2
386                });
387                Arc::new(ReadFileTool)
388                    .run(input, &[], project.clone(), action_log.clone(), None, cx)
389                    .output
390            })
391            .await;
392        assert_eq!(result.unwrap(), "Line 1\nLine 2");
393
394        // end_line of 0 should result in at least 1 line
395        let result = cx
396            .update(|cx| {
397                let input = json!({
398                    "path": "root/multiline.txt",
399                    "start_line": 1,
400                    "end_line": 0
401                });
402                Arc::new(ReadFileTool)
403                    .run(input, &[], project.clone(), action_log.clone(), None, cx)
404                    .output
405            })
406            .await;
407        assert_eq!(result.unwrap(), "Line 1");
408
409        // when start_line > end_line, should still return at least 1 line
410        let result = cx
411            .update(|cx| {
412                let input = json!({
413                    "path": "root/multiline.txt",
414                    "start_line": 3,
415                    "end_line": 2
416                });
417                Arc::new(ReadFileTool)
418                    .run(input, &[], project.clone(), action_log, None, cx)
419                    .output
420            })
421            .await;
422        assert_eq!(result.unwrap(), "Line 3");
423    }
424
425    fn init_test(cx: &mut TestAppContext) {
426        cx.update(|cx| {
427            let settings_store = SettingsStore::test(cx);
428            cx.set_global(settings_store);
429            language::init(cx);
430            Project::init_settings(cx);
431        });
432    }
433
434    fn rust_lang() -> Language {
435        Language::new(
436            LanguageConfig {
437                name: "Rust".into(),
438                matcher: LanguageMatcher {
439                    path_suffixes: vec!["rs".to_string()],
440                    ..Default::default()
441                },
442                ..Default::default()
443            },
444            Some(tree_sitter_rust::LANGUAGE.into()),
445        )
446        .with_outline_query(
447            r#"
448            (line_comment) @annotation
449
450            (struct_item
451                "struct" @context
452                name: (_) @name) @item
453            (enum_item
454                "enum" @context
455                name: (_) @name) @item
456            (enum_variant
457                name: (_) @name) @item
458            (field_declaration
459                name: (_) @name) @item
460            (impl_item
461                "impl" @context
462                trait: (_)? @name
463                "for"? @context
464                type: (_) @name
465                body: (_ "{" (_)* "}")) @item
466            (function_item
467                "fn" @context
468                name: (_) @name) @item
469            (mod_item
470                "mod" @context
471                name: (_) @name) @item
472            "#,
473        )
474        .unwrap()
475    }
476}