read_file_tool.rs

  1use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{App, Entity, Task};
  5use indoc::formatdoc;
  6use itertools::Itertools;
  7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use std::sync::Arc;
 12use ui::IconName;
 13use util::markdown::MarkdownString;
 14
 15/// If the model requests to read a file whose size exceeds this, then
 16/// the tool will return an error along with the model's symbol outline,
 17/// and suggest trying again using line ranges from the outline.
 18const MAX_FILE_SIZE_TO_READ: usize = 16384;
 19
 20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 21pub struct ReadFileToolInput {
 22    /// The relative path of the file to read.
 23    ///
 24    /// This path should never be absolute, and the first component
 25    /// of the path should always be a root directory in a project.
 26    ///
 27    /// <example>
 28    /// If the project has the following root directories:
 29    ///
 30    /// - directory1
 31    /// - directory2
 32    ///
 33    /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
 34    /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
 35    /// </example>
 36    pub path: String,
 37
 38    /// Optional line number to start reading on (1-based index)
 39    #[serde(default)]
 40    pub start_line: Option<usize>,
 41
 42    /// Optional line number to end reading on (1-based index)
 43    #[serde(default)]
 44    pub end_line: Option<usize>,
 45}
 46
 47pub struct ReadFileTool;
 48
 49impl Tool for ReadFileTool {
 50    fn name(&self) -> String {
 51        "read_file".into()
 52    }
 53
 54    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 55        false
 56    }
 57
 58    fn description(&self) -> String {
 59        include_str!("./read_file_tool/description.md").into()
 60    }
 61
 62    fn icon(&self) -> IconName {
 63        IconName::FileSearch
 64    }
 65
 66    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 67        json_schema_for::<ReadFileToolInput>(format)
 68    }
 69
 70    fn ui_text(&self, input: &serde_json::Value) -> String {
 71        match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
 72            Ok(input) => {
 73                let path = MarkdownString::inline_code(&input.path);
 74                match (input.start_line, input.end_line) {
 75                    (Some(start), None) => format!("Read file {path} (from line {start})"),
 76                    (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
 77                    _ => format!("Read file {path}"),
 78                }
 79            }
 80            Err(_) => "Read file".to_string(),
 81        }
 82    }
 83
 84    fn run(
 85        self: Arc<Self>,
 86        input: serde_json::Value,
 87        _messages: &[LanguageModelRequestMessage],
 88        project: Entity<Project>,
 89        action_log: Entity<ActionLog>,
 90        cx: &mut App,
 91    ) -> ToolResult {
 92        let input = match serde_json::from_value::<ReadFileToolInput>(input) {
 93            Ok(input) => input,
 94            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 95        };
 96
 97        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
 98            return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
 99        };
100        let Some(worktree) = project
101            .read(cx)
102            .worktree_for_id(project_path.worktree_id, cx)
103        else {
104            return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
105        };
106        let exists = worktree.update(cx, |worktree, cx| {
107            worktree.file_exists(&project_path.path, cx)
108        });
109
110        let file_path = input.path.clone();
111        cx.spawn(async move |cx| {
112            if !exists.await? {
113                return Err(anyhow!("{} not found", file_path))
114            }
115
116            let buffer = cx
117                .update(|cx| {
118                    project.update(cx, |project, cx| project.open_buffer(project_path, cx))
119                })?
120                .await?;
121
122            // Check if specific line ranges are provided
123            if input.start_line.is_some() || input.end_line.is_some() {
124                let result = buffer.read_with(cx, |buffer, _cx| {
125                    let text = buffer.text();
126                    let start = input.start_line.unwrap_or(1);
127                    let lines = text.split('\n').skip(start - 1);
128                    if let Some(end) = input.end_line {
129                        let count = end.saturating_sub(start).max(1); // Ensure at least 1 line
130                        Itertools::intersperse(lines.take(count), "\n").collect()
131                    } else {
132                        Itertools::intersperse(lines, "\n").collect()
133                    }
134                })?;
135
136                action_log.update(cx, |log, cx| {
137                    log.buffer_read(buffer, cx);
138                })?;
139
140                Ok(result)
141            } else {
142                // No line ranges specified, so check file size to see if it's too big.
143                let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
144
145                if file_size <= MAX_FILE_SIZE_TO_READ {
146                    // File is small enough, so return its contents.
147                    let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
148
149                    action_log.update(cx, |log, cx| {
150                        log.buffer_read(buffer, cx);
151                    })?;
152
153                    Ok(result)
154                } else {
155                    // File is too big, so return an error with the outline
156                    // and a suggestion to read again with line numbers.
157                    let outline = file_outline(project, file_path, action_log, None, cx).await?;
158                    Ok(formatdoc! {"
159                        This file was too big to read all at once. Here is an outline of its symbols:
160
161                        {outline}
162
163                        Using the line numbers in this outline, you can call this tool again while specifying
164                        the start_line and end_line fields to see the implementations of symbols in the outline."
165                    })
166                }
167            }
168        }).into()
169    }
170}
171
172#[cfg(test)]
173mod test {
174    use super::*;
175    use gpui::{AppContext, TestAppContext};
176    use language::{Language, LanguageConfig, LanguageMatcher};
177    use project::{FakeFs, Project};
178    use serde_json::json;
179    use settings::SettingsStore;
180    use util::path;
181
182    #[gpui::test]
183    async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
184        init_test(cx);
185
186        let fs = FakeFs::new(cx.executor());
187        fs.insert_tree("/root", json!({})).await;
188        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
189        let action_log = cx.new(|_| ActionLog::new(project.clone()));
190        let result = cx
191            .update(|cx| {
192                let input = json!({
193                    "path": "root/nonexistent_file.txt"
194                });
195                Arc::new(ReadFileTool)
196                    .run(input, &[], project.clone(), action_log, cx)
197                    .output
198            })
199            .await;
200        assert_eq!(
201            result.unwrap_err().to_string(),
202            "root/nonexistent_file.txt not found"
203        );
204    }
205
206    #[gpui::test]
207    async fn test_read_small_file(cx: &mut TestAppContext) {
208        init_test(cx);
209
210        let fs = FakeFs::new(cx.executor());
211        fs.insert_tree(
212            "/root",
213            json!({
214                "small_file.txt": "This is a small file content"
215            }),
216        )
217        .await;
218        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
219        let action_log = cx.new(|_| ActionLog::new(project.clone()));
220        let result = cx
221            .update(|cx| {
222                let input = json!({
223                    "path": "root/small_file.txt"
224                });
225                Arc::new(ReadFileTool)
226                    .run(input, &[], project.clone(), action_log, cx)
227                    .output
228            })
229            .await;
230        assert_eq!(result.unwrap(), "This is a small file content");
231    }
232
233    #[gpui::test]
234    async fn test_read_large_file(cx: &mut TestAppContext) {
235        init_test(cx);
236
237        let fs = FakeFs::new(cx.executor());
238        fs.insert_tree(
239            "/root",
240            json!({
241                "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n    a: u32,\n    b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
242            }),
243        )
244        .await;
245        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
246        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
247        language_registry.add(Arc::new(rust_lang()));
248        let action_log = cx.new(|_| ActionLog::new(project.clone()));
249
250        let result = cx
251            .update(|cx| {
252                let input = json!({
253                    "path": "root/large_file.rs"
254                });
255                Arc::new(ReadFileTool)
256                    .run(input, &[], project.clone(), action_log.clone(), cx)
257                    .output
258            })
259            .await;
260        let content = result.unwrap();
261        assert_eq!(
262            content.lines().skip(2).take(6).collect::<Vec<_>>(),
263            vec![
264                "struct Test0 [L1-4]",
265                " a [L2]",
266                " b [L3]",
267                "struct Test1 [L5-8]",
268                " a [L6]",
269                " b [L7]",
270            ]
271        );
272
273        let result = cx
274            .update(|cx| {
275                let input = json!({
276                    "path": "root/large_file.rs",
277                    "offset": 1
278                });
279                Arc::new(ReadFileTool)
280                    .run(input, &[], project.clone(), action_log, cx)
281                    .output
282            })
283            .await;
284        let content = result.unwrap();
285        let expected_content = (0..1000)
286            .flat_map(|i| {
287                vec![
288                    format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
289                    format!(" a [L{}]", i * 4 + 2),
290                    format!(" b [L{}]", i * 4 + 3),
291                ]
292            })
293            .collect::<Vec<_>>();
294        pretty_assertions::assert_eq!(
295            content
296                .lines()
297                .skip(2)
298                .take(expected_content.len())
299                .collect::<Vec<_>>(),
300            expected_content
301        );
302    }
303
304    #[gpui::test]
305    async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
306        init_test(cx);
307
308        let fs = FakeFs::new(cx.executor());
309        fs.insert_tree(
310            "/root",
311            json!({
312                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
313            }),
314        )
315        .await;
316        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
317        let action_log = cx.new(|_| ActionLog::new(project.clone()));
318        let result = cx
319            .update(|cx| {
320                let input = json!({
321                    "path": "root/multiline.txt",
322                    "start_line": 2,
323                    "end_line": 4
324                });
325                Arc::new(ReadFileTool)
326                    .run(input, &[], project.clone(), action_log, cx)
327                    .output
328            })
329            .await;
330        assert_eq!(result.unwrap(), "Line 2\nLine 3");
331    }
332
333    fn init_test(cx: &mut TestAppContext) {
334        cx.update(|cx| {
335            let settings_store = SettingsStore::test(cx);
336            cx.set_global(settings_store);
337            language::init(cx);
338            Project::init_settings(cx);
339        });
340    }
341
342    fn rust_lang() -> Language {
343        Language::new(
344            LanguageConfig {
345                name: "Rust".into(),
346                matcher: LanguageMatcher {
347                    path_suffixes: vec!["rs".to_string()],
348                    ..Default::default()
349                },
350                ..Default::default()
351            },
352            Some(tree_sitter_rust::LANGUAGE.into()),
353        )
354        .with_outline_query(
355            r#"
356            (line_comment) @annotation
357
358            (struct_item
359                "struct" @context
360                name: (_) @name) @item
361            (enum_item
362                "enum" @context
363                name: (_) @name) @item
364            (enum_variant
365                name: (_) @name) @item
366            (field_declaration
367                name: (_) @name) @item
368            (impl_item
369                "impl" @context
370                trait: (_)? @name
371                "for"? @context
372                type: (_) @name
373                body: (_ "{" (_)* "}")) @item
374            (function_item
375                "fn" @context
376                name: (_) @name) @item
377            (mod_item
378                "mod" @context
379                name: (_) @name) @item
380            "#,
381        )
382        .unwrap()
383    }
384}