read_file_tool.rs

  1use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{AnyWindowHandle, App, Entity, Task};
  5
  6use indoc::formatdoc;
  7use itertools::Itertools;
  8use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  9use project::Project;
 10use schemars::JsonSchema;
 11use serde::{Deserialize, Serialize};
 12use std::sync::Arc;
 13use ui::IconName;
 14use util::markdown::MarkdownString;
 15
 16/// If the model requests to read a file whose size exceeds this, then
 17/// the tool will return an error along with the model's symbol outline,
 18/// and suggest trying again using line ranges from the outline.
 19const MAX_FILE_SIZE_TO_READ: usize = 16384;
 20
 21#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 22pub struct ReadFileToolInput {
 23    /// The relative path of the file to read.
 24    ///
 25    /// This path should never be absolute, and the first component
 26    /// of the path should always be a root directory in a project.
 27    ///
 28    /// <example>
 29    /// If the project has the following root directories:
 30    ///
 31    /// - directory1
 32    /// - directory2
 33    ///
 34    /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
 35    /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
 36    /// </example>
 37    pub path: String,
 38
 39    /// Optional line number to start reading on (1-based index)
 40    #[serde(default)]
 41    pub start_line: Option<usize>,
 42
 43    /// Optional line number to end reading on (1-based index)
 44    #[serde(default)]
 45    pub end_line: Option<usize>,
 46}
 47
 48pub struct ReadFileTool;
 49
 50impl Tool for ReadFileTool {
 51    fn name(&self) -> String {
 52        "read_file".into()
 53    }
 54
 55    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 56        false
 57    }
 58
 59    fn description(&self) -> String {
 60        include_str!("./read_file_tool/description.md").into()
 61    }
 62
 63    fn icon(&self) -> IconName {
 64        IconName::FileSearch
 65    }
 66
 67    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 68        json_schema_for::<ReadFileToolInput>(format)
 69    }
 70
 71    fn ui_text(&self, input: &serde_json::Value) -> String {
 72        match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
 73            Ok(input) => {
 74                let path = MarkdownString::inline_code(&input.path);
 75                match (input.start_line, input.end_line) {
 76                    (Some(start), None) => format!("Read file {path} (from line {start})"),
 77                    (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
 78                    _ => format!("Read file {path}"),
 79                }
 80            }
 81            Err(_) => "Read file".to_string(),
 82        }
 83    }
 84
 85    fn run(
 86        self: Arc<Self>,
 87        input: serde_json::Value,
 88        _messages: &[LanguageModelRequestMessage],
 89        project: Entity<Project>,
 90        action_log: Entity<ActionLog>,
 91        _window: Option<AnyWindowHandle>,
 92        cx: &mut App,
 93    ) -> ToolResult {
 94        let input = match serde_json::from_value::<ReadFileToolInput>(input) {
 95            Ok(input) => input,
 96            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 97        };
 98
 99        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
100            return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
101        };
102        let Some(worktree) = project
103            .read(cx)
104            .worktree_for_id(project_path.worktree_id, cx)
105        else {
106            return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
107        };
108        let exists = worktree.update(cx, |worktree, cx| {
109            worktree.file_exists(&project_path.path, cx)
110        });
111
112        let file_path = input.path.clone();
113        cx.spawn(async move |cx| {
114            if !exists.await? {
115                return Err(anyhow!("{} not found", file_path))
116            }
117
118            let buffer = cx
119                .update(|cx| {
120                    project.update(cx, |project, cx| project.open_buffer(project_path, cx))
121                })?
122                .await?;
123
124            // Check if specific line ranges are provided
125            if input.start_line.is_some() || input.end_line.is_some() {
126                let result = buffer.read_with(cx, |buffer, _cx| {
127                    let text = buffer.text();
128                    let start = input.start_line.unwrap_or(1);
129                    let lines = text.split('\n').skip(start - 1);
130                    if let Some(end) = input.end_line {
131                        let count = end.saturating_sub(start).max(1); // Ensure at least 1 line
132                        Itertools::intersperse(lines.take(count), "\n").collect()
133                    } else {
134                        Itertools::intersperse(lines, "\n").collect()
135                    }
136                })?;
137
138                action_log.update(cx, |log, cx| {
139                    log.track_buffer(buffer, cx);
140                })?;
141
142                Ok(result)
143            } else {
144                // No line ranges specified, so check file size to see if it's too big.
145                let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
146
147                if file_size <= MAX_FILE_SIZE_TO_READ {
148                    // File is small enough, so return its contents.
149                    let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
150
151                    action_log.update(cx, |log, cx| {
152                        log.track_buffer(buffer, cx);
153                    })?;
154
155                    Ok(result)
156                } else {
157                    // File is too big, so return an error with the outline
158                    // and a suggestion to read again with line numbers.
159                    let outline = file_outline(project, file_path, action_log, None, cx).await?;
160                    Ok(formatdoc! {"
161                        This file was too big to read all at once. Here is an outline of its symbols:
162
163                        {outline}
164
165                        Using the line numbers in this outline, you can call this tool again while specifying
166                        the start_line and end_line fields to see the implementations of symbols in the outline."
167                    })
168                }
169            }
170        }).into()
171    }
172}
173
174#[cfg(test)]
175mod test {
176    use super::*;
177    use gpui::{AppContext, TestAppContext};
178    use language::{Language, LanguageConfig, LanguageMatcher};
179    use project::{FakeFs, Project};
180    use serde_json::json;
181    use settings::SettingsStore;
182    use util::path;
183
184    #[gpui::test]
185    async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
186        init_test(cx);
187
188        let fs = FakeFs::new(cx.executor());
189        fs.insert_tree("/root", json!({})).await;
190        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
191        let action_log = cx.new(|_| ActionLog::new(project.clone()));
192        let result = cx
193            .update(|cx| {
194                let input = json!({
195                    "path": "root/nonexistent_file.txt"
196                });
197                Arc::new(ReadFileTool)
198                    .run(input, &[], project.clone(), action_log, None, cx)
199                    .output
200            })
201            .await;
202        assert_eq!(
203            result.unwrap_err().to_string(),
204            "root/nonexistent_file.txt not found"
205        );
206    }
207
208    #[gpui::test]
209    async fn test_read_small_file(cx: &mut TestAppContext) {
210        init_test(cx);
211
212        let fs = FakeFs::new(cx.executor());
213        fs.insert_tree(
214            "/root",
215            json!({
216                "small_file.txt": "This is a small file content"
217            }),
218        )
219        .await;
220        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
221        let action_log = cx.new(|_| ActionLog::new(project.clone()));
222        let result = cx
223            .update(|cx| {
224                let input = json!({
225                    "path": "root/small_file.txt"
226                });
227                Arc::new(ReadFileTool)
228                    .run(input, &[], project.clone(), action_log, None, cx)
229                    .output
230            })
231            .await;
232        assert_eq!(result.unwrap(), "This is a small file content");
233    }
234
235    #[gpui::test]
236    async fn test_read_large_file(cx: &mut TestAppContext) {
237        init_test(cx);
238
239        let fs = FakeFs::new(cx.executor());
240        fs.insert_tree(
241            "/root",
242            json!({
243                "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n    a: u32,\n    b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
244            }),
245        )
246        .await;
247        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
248        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
249        language_registry.add(Arc::new(rust_lang()));
250        let action_log = cx.new(|_| ActionLog::new(project.clone()));
251
252        let result = cx
253            .update(|cx| {
254                let input = json!({
255                    "path": "root/large_file.rs"
256                });
257                Arc::new(ReadFileTool)
258                    .run(input, &[], project.clone(), action_log.clone(), None, cx)
259                    .output
260            })
261            .await;
262        let content = result.unwrap();
263        assert_eq!(
264            content.lines().skip(2).take(6).collect::<Vec<_>>(),
265            vec![
266                "struct Test0 [L1-4]",
267                " a [L2]",
268                " b [L3]",
269                "struct Test1 [L5-8]",
270                " a [L6]",
271                " b [L7]",
272            ]
273        );
274
275        let result = cx
276            .update(|cx| {
277                let input = json!({
278                    "path": "root/large_file.rs",
279                    "offset": 1
280                });
281                Arc::new(ReadFileTool)
282                    .run(input, &[], project.clone(), action_log, None, cx)
283                    .output
284            })
285            .await;
286        let content = result.unwrap();
287        let expected_content = (0..1000)
288            .flat_map(|i| {
289                vec![
290                    format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
291                    format!(" a [L{}]", i * 4 + 2),
292                    format!(" b [L{}]", i * 4 + 3),
293                ]
294            })
295            .collect::<Vec<_>>();
296        pretty_assertions::assert_eq!(
297            content
298                .lines()
299                .skip(2)
300                .take(expected_content.len())
301                .collect::<Vec<_>>(),
302            expected_content
303        );
304    }
305
306    #[gpui::test]
307    async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
308        init_test(cx);
309
310        let fs = FakeFs::new(cx.executor());
311        fs.insert_tree(
312            "/root",
313            json!({
314                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
315            }),
316        )
317        .await;
318        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
319        let action_log = cx.new(|_| ActionLog::new(project.clone()));
320        let result = cx
321            .update(|cx| {
322                let input = json!({
323                    "path": "root/multiline.txt",
324                    "start_line": 2,
325                    "end_line": 4
326                });
327                Arc::new(ReadFileTool)
328                    .run(input, &[], project.clone(), action_log, None, cx)
329                    .output
330            })
331            .await;
332        assert_eq!(result.unwrap(), "Line 2\nLine 3");
333    }
334
335    fn init_test(cx: &mut TestAppContext) {
336        cx.update(|cx| {
337            let settings_store = SettingsStore::test(cx);
338            cx.set_global(settings_store);
339            language::init(cx);
340            Project::init_settings(cx);
341        });
342    }
343
344    fn rust_lang() -> Language {
345        Language::new(
346            LanguageConfig {
347                name: "Rust".into(),
348                matcher: LanguageMatcher {
349                    path_suffixes: vec!["rs".to_string()],
350                    ..Default::default()
351                },
352                ..Default::default()
353            },
354            Some(tree_sitter_rust::LANGUAGE.into()),
355        )
356        .with_outline_query(
357            r#"
358            (line_comment) @annotation
359
360            (struct_item
361                "struct" @context
362                name: (_) @name) @item
363            (enum_item
364                "enum" @context
365                name: (_) @name) @item
366            (enum_variant
367                name: (_) @name) @item
368            (field_declaration
369                name: (_) @name) @item
370            (impl_item
371                "impl" @context
372                trait: (_)? @name
373                "for"? @context
374                type: (_) @name
375                body: (_ "{" (_)* "}")) @item
376            (function_item
377                "fn" @context
378                name: (_) @name) @item
379            (mod_item
380                "mod" @context
381                name: (_) @name) @item
382            "#,
383        )
384        .unwrap()
385    }
386}