read_file_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::outline;
  4use assistant_tool::{ActionLog, Tool, ToolResult};
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6
  7use indoc::formatdoc;
  8use itertools::Itertools;
  9use language::{Anchor, Point};
 10use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 11use project::{AgentLocation, Project};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use std::sync::Arc;
 15use ui::IconName;
 16use util::markdown::MarkdownInlineCode;
 17
 18/// If the model requests to read a file whose size exceeds this, then
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct ReadFileToolInput {
 21    /// The relative path of the file to read.
 22    ///
 23    /// This path should never be absolute, and the first component
 24    /// of the path should always be a root directory in a project.
 25    ///
 26    /// <example>
 27    /// If the project has the following root directories:
 28    ///
 29    /// - directory1
 30    /// - directory2
 31    ///
 32    /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
 33    /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
 34    /// </example>
 35    pub path: String,
 36
 37    /// Optional line number to start reading on (1-based index)
 38    #[serde(default)]
 39    pub start_line: Option<u32>,
 40
 41    /// Optional line number to end reading on (1-based index, inclusive)
 42    #[serde(default)]
 43    pub end_line: Option<u32>,
 44}
 45
 46pub struct ReadFileTool;
 47
 48impl Tool for ReadFileTool {
 49    fn name(&self) -> String {
 50        "read_file".into()
 51    }
 52
 53    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 54        false
 55    }
 56
 57    fn description(&self) -> String {
 58        include_str!("./read_file_tool/description.md").into()
 59    }
 60
 61    fn icon(&self) -> IconName {
 62        IconName::FileSearch
 63    }
 64
 65    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 66        json_schema_for::<ReadFileToolInput>(format)
 67    }
 68
 69    fn ui_text(&self, input: &serde_json::Value) -> String {
 70        match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
 71            Ok(input) => {
 72                let path = MarkdownInlineCode(&input.path);
 73                match (input.start_line, input.end_line) {
 74                    (Some(start), None) => format!("Read file {path} (from line {start})"),
 75                    (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
 76                    _ => format!("Read file {path}"),
 77                }
 78            }
 79            Err(_) => "Read file".to_string(),
 80        }
 81    }
 82
 83    fn run(
 84        self: Arc<Self>,
 85        input: serde_json::Value,
 86        _request: Arc<LanguageModelRequest>,
 87        project: Entity<Project>,
 88        action_log: Entity<ActionLog>,
 89        _model: Arc<dyn LanguageModel>,
 90        _window: Option<AnyWindowHandle>,
 91        cx: &mut App,
 92    ) -> ToolResult {
 93        let input = match serde_json::from_value::<ReadFileToolInput>(input) {
 94            Ok(input) => input,
 95            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 96        };
 97
 98        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
 99            return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
100        };
101
102        let file_path = input.path.clone();
103        cx.spawn(async move |cx| {
104            let buffer = cx
105                .update(|cx| {
106                    project.update(cx, |project, cx| project.open_buffer(project_path, cx))
107                })?
108                .await?;
109            if buffer.read_with(cx, |buffer, _| {
110                buffer
111                    .file()
112                    .as_ref()
113                    .map_or(true, |file| !file.disk_state().exists())
114            })? {
115                return Err(anyhow!("{} not found", file_path));
116            }
117
118            project.update(cx, |project, cx| {
119                project.set_agent_location(
120                    Some(AgentLocation {
121                        buffer: buffer.downgrade(),
122                        position: Anchor::MIN,
123                    }),
124                    cx,
125                );
126            })?;
127
128            // Check if specific line ranges are provided
129            if input.start_line.is_some() || input.end_line.is_some() {
130                let mut anchor = None;
131                let result = buffer.read_with(cx, |buffer, _cx| {
132                    let text = buffer.text();
133                    // .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
134                    let start = input.start_line.unwrap_or(1).max(1);
135                    let start_row = start - 1;
136                    if start_row <= buffer.max_point().row {
137                        let column = buffer.line_indent_for_row(start_row).raw_len();
138                        anchor = Some(buffer.anchor_before(Point::new(start_row, column)));
139                    }
140
141                    let lines = text.split('\n').skip(start_row as usize);
142                    if let Some(end) = input.end_line {
143                        let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
144                        Itertools::intersperse(lines.take(count as usize), "\n")
145                            .collect::<String>()
146                            .into()
147                    } else {
148                        Itertools::intersperse(lines, "\n")
149                            .collect::<String>()
150                            .into()
151                    }
152                })?;
153
154                action_log.update(cx, |log, cx| {
155                    log.buffer_read(buffer.clone(), cx);
156                })?;
157
158                if let Some(anchor) = anchor {
159                    project.update(cx, |project, cx| {
160                        project.set_agent_location(
161                            Some(AgentLocation {
162                                buffer: buffer.downgrade(),
163                                position: anchor,
164                            }),
165                            cx,
166                        );
167                    })?;
168                }
169
170                Ok(result)
171            } else {
172                // No line ranges specified, so check file size to see if it's too big.
173                let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
174
175                if file_size <= outline::AUTO_OUTLINE_SIZE {
176                    // File is small enough, so return its contents.
177                    let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
178
179                    action_log.update(cx, |log, cx| {
180                        log.buffer_read(buffer, cx);
181                    })?;
182
183                    Ok(result.into())
184                } else {
185                    // File is too big, so return the outline
186                    // and a suggestion to read again with line numbers.
187                    let outline =
188                        outline::file_outline(project, file_path, action_log, None, cx).await?;
189                    Ok(formatdoc! {"
190                        This file was too big to read all at once.
191
192                        Here is an outline of its symbols:
193
194                        {outline}
195
196                        Using the line numbers in this outline, you can call this tool again
197                        while specifying the start_line and end_line fields to see the
198                        implementations of symbols in the outline."
199                    }
200                    .into())
201                }
202            }
203        })
204        .into()
205    }
206}
207
208#[cfg(test)]
209mod test {
210    use super::*;
211    use gpui::{AppContext, TestAppContext};
212    use language::{Language, LanguageConfig, LanguageMatcher};
213    use language_model::fake_provider::FakeLanguageModel;
214    use project::{FakeFs, Project};
215    use serde_json::json;
216    use settings::SettingsStore;
217    use util::path;
218
219    #[gpui::test]
220    async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
221        init_test(cx);
222
223        let fs = FakeFs::new(cx.executor());
224        fs.insert_tree("/root", json!({})).await;
225        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
226        let action_log = cx.new(|_| ActionLog::new(project.clone()));
227        let model = Arc::new(FakeLanguageModel::default());
228        let result = cx
229            .update(|cx| {
230                let input = json!({
231                    "path": "root/nonexistent_file.txt"
232                });
233                Arc::new(ReadFileTool)
234                    .run(
235                        input,
236                        Arc::default(),
237                        project.clone(),
238                        action_log,
239                        model,
240                        None,
241                        cx,
242                    )
243                    .output
244            })
245            .await;
246        assert_eq!(
247            result.unwrap_err().to_string(),
248            "root/nonexistent_file.txt not found"
249        );
250    }
251
252    #[gpui::test]
253    async fn test_read_small_file(cx: &mut TestAppContext) {
254        init_test(cx);
255
256        let fs = FakeFs::new(cx.executor());
257        fs.insert_tree(
258            "/root",
259            json!({
260                "small_file.txt": "This is a small file content"
261            }),
262        )
263        .await;
264        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
265        let action_log = cx.new(|_| ActionLog::new(project.clone()));
266        let model = Arc::new(FakeLanguageModel::default());
267        let result = cx
268            .update(|cx| {
269                let input = json!({
270                    "path": "root/small_file.txt"
271                });
272                Arc::new(ReadFileTool)
273                    .run(
274                        input,
275                        Arc::default(),
276                        project.clone(),
277                        action_log,
278                        model,
279                        None,
280                        cx,
281                    )
282                    .output
283            })
284            .await;
285        assert_eq!(result.unwrap().content, "This is a small file content");
286    }
287
288    #[gpui::test]
289    async fn test_read_large_file(cx: &mut TestAppContext) {
290        init_test(cx);
291
292        let fs = FakeFs::new(cx.executor());
293        fs.insert_tree(
294            "/root",
295            json!({
296                "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n    a: u32,\n    b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
297            }),
298        )
299        .await;
300        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
301        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
302        language_registry.add(Arc::new(rust_lang()));
303        let action_log = cx.new(|_| ActionLog::new(project.clone()));
304        let model = Arc::new(FakeLanguageModel::default());
305
306        let result = cx
307            .update(|cx| {
308                let input = json!({
309                    "path": "root/large_file.rs"
310                });
311                Arc::new(ReadFileTool)
312                    .run(
313                        input,
314                        Arc::default(),
315                        project.clone(),
316                        action_log.clone(),
317                        model.clone(),
318                        None,
319                        cx,
320                    )
321                    .output
322            })
323            .await;
324        let content = result.unwrap();
325        assert_eq!(
326            content.lines().skip(4).take(6).collect::<Vec<_>>(),
327            vec![
328                "struct Test0 [L1-4]",
329                " a [L2]",
330                " b [L3]",
331                "struct Test1 [L5-8]",
332                " a [L6]",
333                " b [L7]",
334            ]
335        );
336
337        let result = cx
338            .update(|cx| {
339                let input = json!({
340                    "path": "root/large_file.rs",
341                    "offset": 1
342                });
343                Arc::new(ReadFileTool)
344                    .run(
345                        input,
346                        Arc::default(),
347                        project.clone(),
348                        action_log,
349                        model,
350                        None,
351                        cx,
352                    )
353                    .output
354            })
355            .await;
356        let content = result.unwrap();
357        let expected_content = (0..1000)
358            .flat_map(|i| {
359                vec![
360                    format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
361                    format!(" a [L{}]", i * 4 + 2),
362                    format!(" b [L{}]", i * 4 + 3),
363                ]
364            })
365            .collect::<Vec<_>>();
366        pretty_assertions::assert_eq!(
367            content
368                .lines()
369                .skip(4)
370                .take(expected_content.len())
371                .collect::<Vec<_>>(),
372            expected_content
373        );
374    }
375
376    #[gpui::test]
377    async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
378        init_test(cx);
379
380        let fs = FakeFs::new(cx.executor());
381        fs.insert_tree(
382            "/root",
383            json!({
384                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
385            }),
386        )
387        .await;
388        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
389        let action_log = cx.new(|_| ActionLog::new(project.clone()));
390        let model = Arc::new(FakeLanguageModel::default());
391        let result = cx
392            .update(|cx| {
393                let input = json!({
394                    "path": "root/multiline.txt",
395                    "start_line": 2,
396                    "end_line": 4
397                });
398                Arc::new(ReadFileTool)
399                    .run(
400                        input,
401                        Arc::default(),
402                        project.clone(),
403                        action_log,
404                        model,
405                        None,
406                        cx,
407                    )
408                    .output
409            })
410            .await;
411        assert_eq!(result.unwrap().content, "Line 2\nLine 3\nLine 4");
412    }
413
414    #[gpui::test]
415    async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) {
416        init_test(cx);
417
418        let fs = FakeFs::new(cx.executor());
419        fs.insert_tree(
420            "/root",
421            json!({
422                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
423            }),
424        )
425        .await;
426        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
427        let action_log = cx.new(|_| ActionLog::new(project.clone()));
428        let model = Arc::new(FakeLanguageModel::default());
429
430        // start_line of 0 should be treated as 1
431        let result = cx
432            .update(|cx| {
433                let input = json!({
434                    "path": "root/multiline.txt",
435                    "start_line": 0,
436                    "end_line": 2
437                });
438                Arc::new(ReadFileTool)
439                    .run(
440                        input,
441                        Arc::default(),
442                        project.clone(),
443                        action_log.clone(),
444                        model.clone(),
445                        None,
446                        cx,
447                    )
448                    .output
449            })
450            .await;
451        assert_eq!(result.unwrap().content, "Line 1\nLine 2");
452
453        // end_line of 0 should result in at least 1 line
454        let result = cx
455            .update(|cx| {
456                let input = json!({
457                    "path": "root/multiline.txt",
458                    "start_line": 1,
459                    "end_line": 0
460                });
461                Arc::new(ReadFileTool)
462                    .run(
463                        input,
464                        Arc::default(),
465                        project.clone(),
466                        action_log.clone(),
467                        model.clone(),
468                        None,
469                        cx,
470                    )
471                    .output
472            })
473            .await;
474        assert_eq!(result.unwrap().content, "Line 1");
475
476        // when start_line > end_line, should still return at least 1 line
477        let result = cx
478            .update(|cx| {
479                let input = json!({
480                    "path": "root/multiline.txt",
481                    "start_line": 3,
482                    "end_line": 2
483                });
484                Arc::new(ReadFileTool)
485                    .run(
486                        input,
487                        Arc::default(),
488                        project.clone(),
489                        action_log,
490                        model,
491                        None,
492                        cx,
493                    )
494                    .output
495            })
496            .await;
497        assert_eq!(result.unwrap().content, "Line 3");
498    }
499
500    fn init_test(cx: &mut TestAppContext) {
501        cx.update(|cx| {
502            let settings_store = SettingsStore::test(cx);
503            cx.set_global(settings_store);
504            language::init(cx);
505            Project::init_settings(cx);
506        });
507    }
508
509    fn rust_lang() -> Language {
510        Language::new(
511            LanguageConfig {
512                name: "Rust".into(),
513                matcher: LanguageMatcher {
514                    path_suffixes: vec!["rs".to_string()],
515                    ..Default::default()
516                },
517                ..Default::default()
518            },
519            Some(tree_sitter_rust::LANGUAGE.into()),
520        )
521        .with_outline_query(
522            r#"
523            (line_comment) @annotation
524
525            (struct_item
526                "struct" @context
527                name: (_) @name) @item
528            (enum_item
529                "enum" @context
530                name: (_) @name) @item
531            (enum_variant
532                name: (_) @name) @item
533            (field_declaration
534                name: (_) @name) @item
535            (impl_item
536                "impl" @context
537                trait: (_)? @name
538                "for"? @context
539                type: (_) @name
540                body: (_ "{" (_)* "}")) @item
541            (function_item
542                "fn" @context
543                name: (_) @name) @item
544            (mod_item
545                "mod" @context
546                name: (_) @name) @item
547            "#,
548        )
549        .unwrap()
550    }
551}