read_file_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::outline;
  4use assistant_tool::{ActionLog, Tool, ToolResult};
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6
  7use indoc::formatdoc;
  8use itertools::Itertools;
  9use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 10use project::Project;
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use std::sync::Arc;
 14use ui::IconName;
 15use util::markdown::MarkdownInlineCode;
 16
 17/// If the model requests to read a file whose size exceeds this, then
 18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 19pub struct ReadFileToolInput {
 20    /// The relative path of the file to read.
 21    ///
 22    /// This path should never be absolute, and the first component
 23    /// of the path should always be a root directory in a project.
 24    ///
 25    /// <example>
 26    /// If the project has the following root directories:
 27    ///
 28    /// - directory1
 29    /// - directory2
 30    ///
 31    /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
 32    /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
 33    /// </example>
 34    pub path: String,
 35
 36    /// Optional line number to start reading on (1-based index)
 37    #[serde(default)]
 38    pub start_line: Option<usize>,
 39
 40    /// Optional line number to end reading on (1-based index, inclusive)
 41    #[serde(default)]
 42    pub end_line: Option<usize>,
 43}
 44
 45pub struct ReadFileTool;
 46
 47impl Tool for ReadFileTool {
 48    fn name(&self) -> String {
 49        "read_file".into()
 50    }
 51
 52    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 53        false
 54    }
 55
 56    fn description(&self) -> String {
 57        include_str!("./read_file_tool/description.md").into()
 58    }
 59
 60    fn icon(&self) -> IconName {
 61        IconName::FileSearch
 62    }
 63
 64    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 65        json_schema_for::<ReadFileToolInput>(format)
 66    }
 67
 68    fn ui_text(&self, input: &serde_json::Value) -> String {
 69        match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
 70            Ok(input) => {
 71                let path = MarkdownInlineCode(&input.path);
 72                match (input.start_line, input.end_line) {
 73                    (Some(start), None) => format!("Read file {path} (from line {start})"),
 74                    (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
 75                    _ => format!("Read file {path}"),
 76                }
 77            }
 78            Err(_) => "Read file".to_string(),
 79        }
 80    }
 81
 82    fn run(
 83        self: Arc<Self>,
 84        input: serde_json::Value,
 85        _messages: &[LanguageModelRequestMessage],
 86        project: Entity<Project>,
 87        action_log: Entity<ActionLog>,
 88        _window: Option<AnyWindowHandle>,
 89        cx: &mut App,
 90    ) -> ToolResult {
 91        let input = match serde_json::from_value::<ReadFileToolInput>(input) {
 92            Ok(input) => input,
 93            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 94        };
 95
 96        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
 97            return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
 98        };
 99        let Some(worktree) = project
100            .read(cx)
101            .worktree_for_id(project_path.worktree_id, cx)
102        else {
103            return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
104        };
105        let exists = worktree.update(cx, |worktree, cx| {
106            worktree.file_exists(&project_path.path, cx)
107        });
108
109        let file_path = input.path.clone();
110        cx.spawn(async move |cx| {
111            if !exists.await? {
112                return Err(anyhow!("{} not found", file_path))
113            }
114
115            let buffer = cx
116                .update(|cx| {
117                    project.update(cx, |project, cx| project.open_buffer(project_path, cx))
118                })?
119                .await?;
120
121            // Check if specific line ranges are provided
122            if input.start_line.is_some() || input.end_line.is_some() {
123                let result = buffer.read_with(cx, |buffer, _cx| {
124                    let text = buffer.text();
125                    // .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
126                    let start = input.start_line.unwrap_or(1).max(1);
127                    let lines = text.split('\n').skip(start - 1);
128                    if let Some(end) = input.end_line {
129                        let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
130                        Itertools::intersperse(lines.take(count), "\n").collect()
131                    } else {
132                        Itertools::intersperse(lines, "\n").collect()
133                    }
134                })?;
135
136                action_log.update(cx, |log, cx| {
137                    log.track_buffer(buffer, cx);
138                })?;
139
140                Ok(result)
141            } else {
142                // No line ranges specified, so check file size to see if it's too big.
143                let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
144
145                if file_size <= outline::AUTO_OUTLINE_SIZE {
146                    // File is small enough, so return its contents.
147                    let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
148
149                    action_log.update(cx, |log, cx| {
150                        log.track_buffer(buffer, cx);
151                    })?;
152
153                    Ok(result)
154                } else {
155                    // File is too big, so return the outline
156                    // and a suggestion to read again with line numbers.
157                    let outline = outline::file_outline(project, file_path, action_log, None, cx).await?;
158                    Ok(formatdoc! {"
159                        This file was too big to read all at once. Here is an outline of its symbols:
160
161                        {outline}
162
163                        Using the line numbers in this outline, you can call this tool again while specifying
164                        the start_line and end_line fields to see the implementations of symbols in the outline."
165                    })
166                }
167            }
168        }).into()
169    }
170}
171
172#[cfg(test)]
173mod test {
174    use super::*;
175    use gpui::{AppContext, TestAppContext};
176    use language::{Language, LanguageConfig, LanguageMatcher};
177    use project::{FakeFs, Project};
178    use serde_json::json;
179    use settings::SettingsStore;
180    use util::path;
181
182    #[gpui::test]
183    async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
184        init_test(cx);
185
186        let fs = FakeFs::new(cx.executor());
187        fs.insert_tree("/root", json!({})).await;
188        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
189        let action_log = cx.new(|_| ActionLog::new(project.clone()));
190        let result = cx
191            .update(|cx| {
192                let input = json!({
193                    "path": "root/nonexistent_file.txt"
194                });
195                Arc::new(ReadFileTool)
196                    .run(input, &[], project.clone(), action_log, None, cx)
197                    .output
198            })
199            .await;
200        assert_eq!(
201            result.unwrap_err().to_string(),
202            "root/nonexistent_file.txt not found"
203        );
204    }
205
206    #[gpui::test]
207    async fn test_read_small_file(cx: &mut TestAppContext) {
208        init_test(cx);
209
210        let fs = FakeFs::new(cx.executor());
211        fs.insert_tree(
212            "/root",
213            json!({
214                "small_file.txt": "This is a small file content"
215            }),
216        )
217        .await;
218        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
219        let action_log = cx.new(|_| ActionLog::new(project.clone()));
220        let result = cx
221            .update(|cx| {
222                let input = json!({
223                    "path": "root/small_file.txt"
224                });
225                Arc::new(ReadFileTool)
226                    .run(input, &[], project.clone(), action_log, None, cx)
227                    .output
228            })
229            .await;
230        assert_eq!(result.unwrap(), "This is a small file content");
231    }
232
233    #[gpui::test]
234    async fn test_read_large_file(cx: &mut TestAppContext) {
235        init_test(cx);
236
237        let fs = FakeFs::new(cx.executor());
238        fs.insert_tree(
239            "/root",
240            json!({
241                "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n    a: u32,\n    b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
242            }),
243        )
244        .await;
245        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
246        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
247        language_registry.add(Arc::new(rust_lang()));
248        let action_log = cx.new(|_| ActionLog::new(project.clone()));
249
250        let result = cx
251            .update(|cx| {
252                let input = json!({
253                    "path": "root/large_file.rs"
254                });
255                Arc::new(ReadFileTool)
256                    .run(input, &[], project.clone(), action_log.clone(), None, cx)
257                    .output
258            })
259            .await;
260        let content = result.unwrap();
261        assert_eq!(
262            content.lines().skip(2).take(6).collect::<Vec<_>>(),
263            vec![
264                "struct Test0 [L1-4]",
265                " a [L2]",
266                " b [L3]",
267                "struct Test1 [L5-8]",
268                " a [L6]",
269                " b [L7]",
270            ]
271        );
272
273        let result = cx
274            .update(|cx| {
275                let input = json!({
276                    "path": "root/large_file.rs",
277                    "offset": 1
278                });
279                Arc::new(ReadFileTool)
280                    .run(input, &[], project.clone(), action_log, None, cx)
281                    .output
282            })
283            .await;
284        let content = result.unwrap();
285        let expected_content = (0..1000)
286            .flat_map(|i| {
287                vec![
288                    format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
289                    format!(" a [L{}]", i * 4 + 2),
290                    format!(" b [L{}]", i * 4 + 3),
291                ]
292            })
293            .collect::<Vec<_>>();
294        pretty_assertions::assert_eq!(
295            content
296                .lines()
297                .skip(2)
298                .take(expected_content.len())
299                .collect::<Vec<_>>(),
300            expected_content
301        );
302    }
303
304    #[gpui::test]
305    async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
306        init_test(cx);
307
308        let fs = FakeFs::new(cx.executor());
309        fs.insert_tree(
310            "/root",
311            json!({
312                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
313            }),
314        )
315        .await;
316        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
317        let action_log = cx.new(|_| ActionLog::new(project.clone()));
318        let result = cx
319            .update(|cx| {
320                let input = json!({
321                    "path": "root/multiline.txt",
322                    "start_line": 2,
323                    "end_line": 4
324                });
325                Arc::new(ReadFileTool)
326                    .run(input, &[], project.clone(), action_log, None, cx)
327                    .output
328            })
329            .await;
330        assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
331    }
332
333    #[gpui::test]
334    async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) {
335        init_test(cx);
336
337        let fs = FakeFs::new(cx.executor());
338        fs.insert_tree(
339            "/root",
340            json!({
341                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
342            }),
343        )
344        .await;
345        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
346        let action_log = cx.new(|_| ActionLog::new(project.clone()));
347
348        // start_line of 0 should be treated as 1
349        let result = cx
350            .update(|cx| {
351                let input = json!({
352                    "path": "root/multiline.txt",
353                    "start_line": 0,
354                    "end_line": 2
355                });
356                Arc::new(ReadFileTool)
357                    .run(input, &[], project.clone(), action_log.clone(), None, cx)
358                    .output
359            })
360            .await;
361        assert_eq!(result.unwrap(), "Line 1\nLine 2");
362
363        // end_line of 0 should result in at least 1 line
364        let result = cx
365            .update(|cx| {
366                let input = json!({
367                    "path": "root/multiline.txt",
368                    "start_line": 1,
369                    "end_line": 0
370                });
371                Arc::new(ReadFileTool)
372                    .run(input, &[], project.clone(), action_log.clone(), None, cx)
373                    .output
374            })
375            .await;
376        assert_eq!(result.unwrap(), "Line 1");
377
378        // when start_line > end_line, should still return at least 1 line
379        let result = cx
380            .update(|cx| {
381                let input = json!({
382                    "path": "root/multiline.txt",
383                    "start_line": 3,
384                    "end_line": 2
385                });
386                Arc::new(ReadFileTool)
387                    .run(input, &[], project.clone(), action_log, None, cx)
388                    .output
389            })
390            .await;
391        assert_eq!(result.unwrap(), "Line 3");
392    }
393
394    fn init_test(cx: &mut TestAppContext) {
395        cx.update(|cx| {
396            let settings_store = SettingsStore::test(cx);
397            cx.set_global(settings_store);
398            language::init(cx);
399            Project::init_settings(cx);
400        });
401    }
402
403    fn rust_lang() -> Language {
404        Language::new(
405            LanguageConfig {
406                name: "Rust".into(),
407                matcher: LanguageMatcher {
408                    path_suffixes: vec!["rs".to_string()],
409                    ..Default::default()
410                },
411                ..Default::default()
412            },
413            Some(tree_sitter_rust::LANGUAGE.into()),
414        )
415        .with_outline_query(
416            r#"
417            (line_comment) @annotation
418
419            (struct_item
420                "struct" @context
421                name: (_) @name) @item
422            (enum_item
423                "enum" @context
424                name: (_) @name) @item
425            (enum_variant
426                name: (_) @name) @item
427            (field_declaration
428                name: (_) @name) @item
429            (impl_item
430                "impl" @context
431                trait: (_)? @name
432                "for"? @context
433                type: (_) @name
434                body: (_ "{" (_)* "}")) @item
435            (function_item
436                "fn" @context
437                name: (_) @name) @item
438            (mod_item
439                "mod" @context
440                name: (_) @name) @item
441            "#,
442        )
443        .unwrap()
444    }
445}