read_file_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::outline;
  4use assistant_tool::{ActionLog, Tool, ToolResult};
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6
  7use indoc::formatdoc;
  8use itertools::Itertools;
  9use language::{Anchor, Point};
 10use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 11use project::{AgentLocation, Project};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use std::sync::Arc;
 15use ui::IconName;
 16use util::markdown::MarkdownInlineCode;
 17
 18/// If the model requests to read a file whose size exceeds this, then
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct ReadFileToolInput {
 21    /// The relative path of the file to read.
 22    ///
 23    /// This path should never be absolute, and the first component
 24    /// of the path should always be a root directory in a project.
 25    ///
 26    /// <example>
 27    /// If the project has the following root directories:
 28    ///
 29    /// - directory1
 30    /// - directory2
 31    ///
 32    /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
 33    /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
 34    /// </example>
 35    pub path: String,
 36
 37    /// Optional line number to start reading on (1-based index)
 38    #[serde(default)]
 39    pub start_line: Option<u32>,
 40
 41    /// Optional line number to end reading on (1-based index, inclusive)
 42    #[serde(default)]
 43    pub end_line: Option<u32>,
 44}
 45
 46pub struct ReadFileTool;
 47
 48impl Tool for ReadFileTool {
 49    fn name(&self) -> String {
 50        "read_file".into()
 51    }
 52
 53    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 54        false
 55    }
 56
 57    fn description(&self) -> String {
 58        include_str!("./read_file_tool/description.md").into()
 59    }
 60
 61    fn icon(&self) -> IconName {
 62        IconName::FileSearch
 63    }
 64
 65    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 66        json_schema_for::<ReadFileToolInput>(format)
 67    }
 68
 69    fn ui_text(&self, input: &serde_json::Value) -> String {
 70        match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
 71            Ok(input) => {
 72                let path = MarkdownInlineCode(&input.path);
 73                match (input.start_line, input.end_line) {
 74                    (Some(start), None) => format!("Read file {path} (from line {start})"),
 75                    (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
 76                    _ => format!("Read file {path}"),
 77                }
 78            }
 79            Err(_) => "Read file".to_string(),
 80        }
 81    }
 82
 83    fn run(
 84        self: Arc<Self>,
 85        input: serde_json::Value,
 86        _messages: &[LanguageModelRequestMessage],
 87        project: Entity<Project>,
 88        action_log: Entity<ActionLog>,
 89        _model: Arc<dyn LanguageModel>,
 90        _window: Option<AnyWindowHandle>,
 91        cx: &mut App,
 92    ) -> ToolResult {
 93        let input = match serde_json::from_value::<ReadFileToolInput>(input) {
 94            Ok(input) => input,
 95            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 96        };
 97
 98        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
 99            return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
100        };
101
102        let file_path = input.path.clone();
103        cx.spawn(async move |cx| {
104            let buffer = cx
105                .update(|cx| {
106                    project.update(cx, |project, cx| project.open_buffer(project_path, cx))
107                })?
108                .await?;
109            if buffer.read_with(cx, |buffer, _| {
110                buffer
111                    .file()
112                    .as_ref()
113                    .map_or(true, |file| !file.disk_state().exists())
114            })? {
115                return Err(anyhow!("{} not found", file_path));
116            }
117
118            project.update(cx, |project, cx| {
119                project.set_agent_location(
120                    Some(AgentLocation {
121                        buffer: buffer.downgrade(),
122                        position: Anchor::MIN,
123                    }),
124                    cx,
125                );
126            })?;
127
128            // Check if specific line ranges are provided
129            if input.start_line.is_some() || input.end_line.is_some() {
130                let mut anchor = None;
131                let result = buffer.read_with(cx, |buffer, _cx| {
132                    let text = buffer.text();
133                    // .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
134                    let start = input.start_line.unwrap_or(1).max(1);
135                    let start_row = start - 1;
136                    if start_row <= buffer.max_point().row {
137                        let column = buffer.line_indent_for_row(start_row).raw_len();
138                        anchor = Some(buffer.anchor_before(Point::new(start_row, column)));
139                    }
140
141                    let lines = text.split('\n').skip(start_row as usize);
142                    if let Some(end) = input.end_line {
143                        let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
144                        Itertools::intersperse(lines.take(count as usize), "\n")
145                            .collect::<String>()
146                            .into()
147                    } else {
148                        Itertools::intersperse(lines, "\n")
149                            .collect::<String>()
150                            .into()
151                    }
152                })?;
153
154                action_log.update(cx, |log, cx| {
155                    log.buffer_read(buffer.clone(), cx);
156                })?;
157
158                if let Some(anchor) = anchor {
159                    project.update(cx, |project, cx| {
160                        project.set_agent_location(
161                            Some(AgentLocation {
162                                buffer: buffer.downgrade(),
163                                position: anchor,
164                            }),
165                            cx,
166                        );
167                    })?;
168                }
169
170                Ok(result)
171            } else {
172                // No line ranges specified, so check file size to see if it's too big.
173                let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
174
175                if file_size <= outline::AUTO_OUTLINE_SIZE {
176                    // File is small enough, so return its contents.
177                    let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
178
179                    action_log.update(cx, |log, cx| {
180                        log.buffer_read(buffer, cx);
181                    })?;
182
183                    Ok(result.into())
184                } else {
185                    // File is too big, so return the outline
186                    // and a suggestion to read again with line numbers.
187                    let outline =
188                        outline::file_outline(project, file_path, action_log, None, cx).await?;
189                    Ok(formatdoc! {"
190                        This file was too big to read all at once.
191
192                        Here is an outline of its symbols:
193
194                        {outline}
195
196                        Using the line numbers in this outline, you can call this tool again
197                        while specifying the start_line and end_line fields to see the
198                        implementations of symbols in the outline."
199                    }
200                    .into())
201                }
202            }
203        })
204        .into()
205    }
206}
207
208#[cfg(test)]
209mod test {
210    use super::*;
211    use gpui::{AppContext, TestAppContext};
212    use language::{Language, LanguageConfig, LanguageMatcher};
213    use language_model::fake_provider::FakeLanguageModel;
214    use project::{FakeFs, Project};
215    use serde_json::json;
216    use settings::SettingsStore;
217    use util::path;
218
219    #[gpui::test]
220    async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
221        init_test(cx);
222
223        let fs = FakeFs::new(cx.executor());
224        fs.insert_tree("/root", json!({})).await;
225        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
226        let action_log = cx.new(|_| ActionLog::new(project.clone()));
227        let model = Arc::new(FakeLanguageModel::default());
228        let result = cx
229            .update(|cx| {
230                let input = json!({
231                    "path": "root/nonexistent_file.txt"
232                });
233                Arc::new(ReadFileTool)
234                    .run(input, &[], project.clone(), action_log, model, None, cx)
235                    .output
236            })
237            .await;
238        assert_eq!(
239            result.unwrap_err().to_string(),
240            "root/nonexistent_file.txt not found"
241        );
242    }
243
244    #[gpui::test]
245    async fn test_read_small_file(cx: &mut TestAppContext) {
246        init_test(cx);
247
248        let fs = FakeFs::new(cx.executor());
249        fs.insert_tree(
250            "/root",
251            json!({
252                "small_file.txt": "This is a small file content"
253            }),
254        )
255        .await;
256        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
257        let action_log = cx.new(|_| ActionLog::new(project.clone()));
258        let model = Arc::new(FakeLanguageModel::default());
259        let result = cx
260            .update(|cx| {
261                let input = json!({
262                    "path": "root/small_file.txt"
263                });
264                Arc::new(ReadFileTool)
265                    .run(input, &[], project.clone(), action_log, model, None, cx)
266                    .output
267            })
268            .await;
269        assert_eq!(result.unwrap().content, "This is a small file content");
270    }
271
272    #[gpui::test]
273    async fn test_read_large_file(cx: &mut TestAppContext) {
274        init_test(cx);
275
276        let fs = FakeFs::new(cx.executor());
277        fs.insert_tree(
278            "/root",
279            json!({
280                "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n    a: u32,\n    b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
281            }),
282        )
283        .await;
284        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
285        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
286        language_registry.add(Arc::new(rust_lang()));
287        let action_log = cx.new(|_| ActionLog::new(project.clone()));
288        let model = Arc::new(FakeLanguageModel::default());
289
290        let result = cx
291            .update(|cx| {
292                let input = json!({
293                    "path": "root/large_file.rs"
294                });
295                Arc::new(ReadFileTool)
296                    .run(
297                        input,
298                        &[],
299                        project.clone(),
300                        action_log.clone(),
301                        model.clone(),
302                        None,
303                        cx,
304                    )
305                    .output
306            })
307            .await;
308        let content = result.unwrap();
309        assert_eq!(
310            content.lines().skip(4).take(6).collect::<Vec<_>>(),
311            vec![
312                "struct Test0 [L1-4]",
313                " a [L2]",
314                " b [L3]",
315                "struct Test1 [L5-8]",
316                " a [L6]",
317                " b [L7]",
318            ]
319        );
320
321        let result = cx
322            .update(|cx| {
323                let input = json!({
324                    "path": "root/large_file.rs",
325                    "offset": 1
326                });
327                Arc::new(ReadFileTool)
328                    .run(input, &[], project.clone(), action_log, model, None, cx)
329                    .output
330            })
331            .await;
332        let content = result.unwrap();
333        let expected_content = (0..1000)
334            .flat_map(|i| {
335                vec![
336                    format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
337                    format!(" a [L{}]", i * 4 + 2),
338                    format!(" b [L{}]", i * 4 + 3),
339                ]
340            })
341            .collect::<Vec<_>>();
342        pretty_assertions::assert_eq!(
343            content
344                .lines()
345                .skip(4)
346                .take(expected_content.len())
347                .collect::<Vec<_>>(),
348            expected_content
349        );
350    }
351
352    #[gpui::test]
353    async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
354        init_test(cx);
355
356        let fs = FakeFs::new(cx.executor());
357        fs.insert_tree(
358            "/root",
359            json!({
360                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
361            }),
362        )
363        .await;
364        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
365        let action_log = cx.new(|_| ActionLog::new(project.clone()));
366        let model = Arc::new(FakeLanguageModel::default());
367        let result = cx
368            .update(|cx| {
369                let input = json!({
370                    "path": "root/multiline.txt",
371                    "start_line": 2,
372                    "end_line": 4
373                });
374                Arc::new(ReadFileTool)
375                    .run(input, &[], project.clone(), action_log, model, None, cx)
376                    .output
377            })
378            .await;
379        assert_eq!(result.unwrap().content, "Line 2\nLine 3\nLine 4");
380    }
381
382    #[gpui::test]
383    async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) {
384        init_test(cx);
385
386        let fs = FakeFs::new(cx.executor());
387        fs.insert_tree(
388            "/root",
389            json!({
390                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
391            }),
392        )
393        .await;
394        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
395        let action_log = cx.new(|_| ActionLog::new(project.clone()));
396        let model = Arc::new(FakeLanguageModel::default());
397
398        // start_line of 0 should be treated as 1
399        let result = cx
400            .update(|cx| {
401                let input = json!({
402                    "path": "root/multiline.txt",
403                    "start_line": 0,
404                    "end_line": 2
405                });
406                Arc::new(ReadFileTool)
407                    .run(
408                        input,
409                        &[],
410                        project.clone(),
411                        action_log.clone(),
412                        model.clone(),
413                        None,
414                        cx,
415                    )
416                    .output
417            })
418            .await;
419        assert_eq!(result.unwrap().content, "Line 1\nLine 2");
420
421        // end_line of 0 should result in at least 1 line
422        let result = cx
423            .update(|cx| {
424                let input = json!({
425                    "path": "root/multiline.txt",
426                    "start_line": 1,
427                    "end_line": 0
428                });
429                Arc::new(ReadFileTool)
430                    .run(
431                        input,
432                        &[],
433                        project.clone(),
434                        action_log.clone(),
435                        model.clone(),
436                        None,
437                        cx,
438                    )
439                    .output
440            })
441            .await;
442        assert_eq!(result.unwrap().content, "Line 1");
443
444        // when start_line > end_line, should still return at least 1 line
445        let result = cx
446            .update(|cx| {
447                let input = json!({
448                    "path": "root/multiline.txt",
449                    "start_line": 3,
450                    "end_line": 2
451                });
452                Arc::new(ReadFileTool)
453                    .run(input, &[], project.clone(), action_log, model, None, cx)
454                    .output
455            })
456            .await;
457        assert_eq!(result.unwrap().content, "Line 3");
458    }
459
460    fn init_test(cx: &mut TestAppContext) {
461        cx.update(|cx| {
462            let settings_store = SettingsStore::test(cx);
463            cx.set_global(settings_store);
464            language::init(cx);
465            Project::init_settings(cx);
466        });
467    }
468
469    fn rust_lang() -> Language {
470        Language::new(
471            LanguageConfig {
472                name: "Rust".into(),
473                matcher: LanguageMatcher {
474                    path_suffixes: vec!["rs".to_string()],
475                    ..Default::default()
476                },
477                ..Default::default()
478            },
479            Some(tree_sitter_rust::LANGUAGE.into()),
480        )
481        .with_outline_query(
482            r#"
483            (line_comment) @annotation
484
485            (struct_item
486                "struct" @context
487                name: (_) @name) @item
488            (enum_item
489                "enum" @context
490                name: (_) @name) @item
491            (enum_variant
492                name: (_) @name) @item
493            (field_declaration
494                name: (_) @name) @item
495            (impl_item
496                "impl" @context
497                trait: (_)? @name
498                "for"? @context
499                type: (_) @name
500                body: (_ "{" (_)* "}")) @item
501            (function_item
502                "fn" @context
503                name: (_) @name) @item
504            (mod_item
505                "mod" @context
506                name: (_) @name) @item
507            "#,
508        )
509        .unwrap()
510    }
511}