read_file_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use assistant_tool::{ToolResultContent, outline};
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6use project::{ImageItem, image_store};
  7
  8use assistant_tool::ToolResultOutput;
  9use indoc::formatdoc;
 10use itertools::Itertools;
 11use language::{Anchor, Point};
 12use language_model::{
 13    LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
 14};
 15use project::{AgentLocation, Project};
 16use schemars::JsonSchema;
 17use serde::{Deserialize, Serialize};
 18use std::sync::Arc;
 19use ui::IconName;
 20use util::markdown::MarkdownInlineCode;
 21
 22/// If the model requests to read a file whose size exceeds this, then
 23#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 24pub struct ReadFileToolInput {
 25    /// The relative path of the file to read.
 26    ///
 27    /// This path should never be absolute, and the first component
 28    /// of the path should always be a root directory in a project.
 29    ///
 30    /// <example>
 31    /// If the project has the following root directories:
 32    ///
 33    /// - directory1
 34    /// - directory2
 35    ///
 36    /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
 37    /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
 38    /// </example>
 39    pub path: String,
 40
 41    /// Optional line number to start reading on (1-based index)
 42    #[serde(default)]
 43    pub start_line: Option<u32>,
 44
 45    /// Optional line number to end reading on (1-based index, inclusive)
 46    #[serde(default)]
 47    pub end_line: Option<u32>,
 48}
 49
 50pub struct ReadFileTool;
 51
 52impl Tool for ReadFileTool {
 53    fn name(&self) -> String {
 54        "read_file".into()
 55    }
 56
 57    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 58        false
 59    }
 60
 61    fn description(&self) -> String {
 62        include_str!("./read_file_tool/description.md").into()
 63    }
 64
 65    fn icon(&self) -> IconName {
 66        IconName::FileSearch
 67    }
 68
 69    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 70        json_schema_for::<ReadFileToolInput>(format)
 71    }
 72
 73    fn ui_text(&self, input: &serde_json::Value) -> String {
 74        match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
 75            Ok(input) => {
 76                let path = MarkdownInlineCode(&input.path);
 77                match (input.start_line, input.end_line) {
 78                    (Some(start), None) => format!("Read file {path} (from line {start})"),
 79                    (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
 80                    _ => format!("Read file {path}"),
 81                }
 82            }
 83            Err(_) => "Read file".to_string(),
 84        }
 85    }
 86
 87    fn run(
 88        self: Arc<Self>,
 89        input: serde_json::Value,
 90        _request: Arc<LanguageModelRequest>,
 91        project: Entity<Project>,
 92        action_log: Entity<ActionLog>,
 93        model: Arc<dyn LanguageModel>,
 94        _window: Option<AnyWindowHandle>,
 95        cx: &mut App,
 96    ) -> ToolResult {
 97        let input = match serde_json::from_value::<ReadFileToolInput>(input) {
 98            Ok(input) => input,
 99            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
100        };
101
102        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
103            return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
104        };
105
106        let file_path = input.path.clone();
107
108        if image_store::is_image_file(&project, &project_path, cx) {
109            if !model.supports_images() {
110                return Task::ready(Err(anyhow!(
111                    "Attempted to read an image, but Zed doesn't currently sending images to {}.",
112                    model.name().0
113                )))
114                .into();
115            }
116
117            let task = cx.spawn(async move |cx| -> Result<ToolResultOutput> {
118                let image_entity: Entity<ImageItem> = cx
119                    .update(|cx| {
120                        project.update(cx, |project, cx| {
121                            project.open_image(project_path.clone(), cx)
122                        })
123                    })?
124                    .await?;
125
126                let image =
127                    image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
128
129                let language_model_image = cx
130                    .update(|cx| LanguageModelImage::from_image(image, cx))?
131                    .await
132                    .context("processing image")?;
133
134                Ok(ToolResultOutput {
135                    content: ToolResultContent::Image(language_model_image),
136                    output: None,
137                })
138            });
139
140            return task.into();
141        }
142
143        cx.spawn(async move |cx| {
144            let buffer = cx
145                .update(|cx| {
146                    project.update(cx, |project, cx| project.open_buffer(project_path, cx))
147                })?
148                .await?;
149            if buffer.read_with(cx, |buffer, _| {
150                buffer
151                    .file()
152                    .as_ref()
153                    .map_or(true, |file| !file.disk_state().exists())
154            })? {
155                anyhow::bail!("{file_path} not found");
156            }
157
158            project.update(cx, |project, cx| {
159                project.set_agent_location(
160                    Some(AgentLocation {
161                        buffer: buffer.downgrade(),
162                        position: Anchor::MIN,
163                    }),
164                    cx,
165                );
166            })?;
167
168            // Check if specific line ranges are provided
169            if input.start_line.is_some() || input.end_line.is_some() {
170                let mut anchor = None;
171                let result = buffer.read_with(cx, |buffer, _cx| {
172                    let text = buffer.text();
173                    // .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
174                    let start = input.start_line.unwrap_or(1).max(1);
175                    let start_row = start - 1;
176                    if start_row <= buffer.max_point().row {
177                        let column = buffer.line_indent_for_row(start_row).raw_len();
178                        anchor = Some(buffer.anchor_before(Point::new(start_row, column)));
179                    }
180
181                    let lines = text.split('\n').skip(start_row as usize);
182                    if let Some(end) = input.end_line {
183                        let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
184                        Itertools::intersperse(lines.take(count as usize), "\n")
185                            .collect::<String>()
186                            .into()
187                    } else {
188                        Itertools::intersperse(lines, "\n")
189                            .collect::<String>()
190                            .into()
191                    }
192                })?;
193
194                action_log.update(cx, |log, cx| {
195                    log.buffer_read(buffer.clone(), cx);
196                })?;
197
198                if let Some(anchor) = anchor {
199                    project.update(cx, |project, cx| {
200                        project.set_agent_location(
201                            Some(AgentLocation {
202                                buffer: buffer.downgrade(),
203                                position: anchor,
204                            }),
205                            cx,
206                        );
207                    })?;
208                }
209
210                Ok(result)
211            } else {
212                // No line ranges specified, so check file size to see if it's too big.
213                let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
214
215                if file_size <= outline::AUTO_OUTLINE_SIZE {
216                    // File is small enough, so return its contents.
217                    let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
218
219                    action_log.update(cx, |log, cx| {
220                        log.buffer_read(buffer, cx);
221                    })?;
222
223                    Ok(result.into())
224                } else {
225                    // File is too big, so return the outline
226                    // and a suggestion to read again with line numbers.
227                    let outline =
228                        outline::file_outline(project, file_path, action_log, None, cx).await?;
229                    Ok(formatdoc! {"
230                        This file was too big to read all at once.
231
232                        Here is an outline of its symbols:
233
234                        {outline}
235
236                        Using the line numbers in this outline, you can call this tool again
237                        while specifying the start_line and end_line fields to see the
238                        implementations of symbols in the outline."
239                    }
240                    .into())
241                }
242            }
243        })
244        .into()
245    }
246}
247
248#[cfg(test)]
249mod test {
250    use super::*;
251    use gpui::{AppContext, TestAppContext};
252    use language::{Language, LanguageConfig, LanguageMatcher};
253    use language_model::fake_provider::FakeLanguageModel;
254    use project::{FakeFs, Project};
255    use serde_json::json;
256    use settings::SettingsStore;
257    use util::path;
258
259    #[gpui::test]
260    async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
261        init_test(cx);
262
263        let fs = FakeFs::new(cx.executor());
264        fs.insert_tree("/root", json!({})).await;
265        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
266        let action_log = cx.new(|_| ActionLog::new(project.clone()));
267        let model = Arc::new(FakeLanguageModel::default());
268        let result = cx
269            .update(|cx| {
270                let input = json!({
271                    "path": "root/nonexistent_file.txt"
272                });
273                Arc::new(ReadFileTool)
274                    .run(
275                        input,
276                        Arc::default(),
277                        project.clone(),
278                        action_log,
279                        model,
280                        None,
281                        cx,
282                    )
283                    .output
284            })
285            .await;
286        assert_eq!(
287            result.unwrap_err().to_string(),
288            "root/nonexistent_file.txt not found"
289        );
290    }
291
292    #[gpui::test]
293    async fn test_read_small_file(cx: &mut TestAppContext) {
294        init_test(cx);
295
296        let fs = FakeFs::new(cx.executor());
297        fs.insert_tree(
298            "/root",
299            json!({
300                "small_file.txt": "This is a small file content"
301            }),
302        )
303        .await;
304        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
305        let action_log = cx.new(|_| ActionLog::new(project.clone()));
306        let model = Arc::new(FakeLanguageModel::default());
307        let result = cx
308            .update(|cx| {
309                let input = json!({
310                    "path": "root/small_file.txt"
311                });
312                Arc::new(ReadFileTool)
313                    .run(
314                        input,
315                        Arc::default(),
316                        project.clone(),
317                        action_log,
318                        model,
319                        None,
320                        cx,
321                    )
322                    .output
323            })
324            .await;
325        assert_eq!(
326            result.unwrap().content.as_str(),
327            Some("This is a small file content")
328        );
329    }
330
331    #[gpui::test]
332    async fn test_read_large_file(cx: &mut TestAppContext) {
333        init_test(cx);
334
335        let fs = FakeFs::new(cx.executor());
336        fs.insert_tree(
337            "/root",
338            json!({
339                "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n    a: u32,\n    b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
340            }),
341        )
342        .await;
343        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
344        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
345        language_registry.add(Arc::new(rust_lang()));
346        let action_log = cx.new(|_| ActionLog::new(project.clone()));
347        let model = Arc::new(FakeLanguageModel::default());
348
349        let result = cx
350            .update(|cx| {
351                let input = json!({
352                    "path": "root/large_file.rs"
353                });
354                Arc::new(ReadFileTool)
355                    .run(
356                        input,
357                        Arc::default(),
358                        project.clone(),
359                        action_log.clone(),
360                        model.clone(),
361                        None,
362                        cx,
363                    )
364                    .output
365            })
366            .await;
367        let content = result.unwrap();
368        let content = content.as_str().unwrap();
369        assert_eq!(
370            content.lines().skip(4).take(6).collect::<Vec<_>>(),
371            vec![
372                "struct Test0 [L1-4]",
373                " a [L2]",
374                " b [L3]",
375                "struct Test1 [L5-8]",
376                " a [L6]",
377                " b [L7]",
378            ]
379        );
380
381        let result = cx
382            .update(|cx| {
383                let input = json!({
384                    "path": "root/large_file.rs",
385                    "offset": 1
386                });
387                Arc::new(ReadFileTool)
388                    .run(
389                        input,
390                        Arc::default(),
391                        project.clone(),
392                        action_log,
393                        model,
394                        None,
395                        cx,
396                    )
397                    .output
398            })
399            .await;
400        let content = result.unwrap();
401        let expected_content = (0..1000)
402            .flat_map(|i| {
403                vec![
404                    format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
405                    format!(" a [L{}]", i * 4 + 2),
406                    format!(" b [L{}]", i * 4 + 3),
407                ]
408            })
409            .collect::<Vec<_>>();
410        pretty_assertions::assert_eq!(
411            content
412                .as_str()
413                .unwrap()
414                .lines()
415                .skip(4)
416                .take(expected_content.len())
417                .collect::<Vec<_>>(),
418            expected_content
419        );
420    }
421
422    #[gpui::test]
423    async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
424        init_test(cx);
425
426        let fs = FakeFs::new(cx.executor());
427        fs.insert_tree(
428            "/root",
429            json!({
430                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
431            }),
432        )
433        .await;
434        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
435        let action_log = cx.new(|_| ActionLog::new(project.clone()));
436        let model = Arc::new(FakeLanguageModel::default());
437        let result = cx
438            .update(|cx| {
439                let input = json!({
440                    "path": "root/multiline.txt",
441                    "start_line": 2,
442                    "end_line": 4
443                });
444                Arc::new(ReadFileTool)
445                    .run(
446                        input,
447                        Arc::default(),
448                        project.clone(),
449                        action_log,
450                        model,
451                        None,
452                        cx,
453                    )
454                    .output
455            })
456            .await;
457        assert_eq!(
458            result.unwrap().content.as_str(),
459            Some("Line 2\nLine 3\nLine 4")
460        );
461    }
462
463    #[gpui::test]
464    async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) {
465        init_test(cx);
466
467        let fs = FakeFs::new(cx.executor());
468        fs.insert_tree(
469            "/root",
470            json!({
471                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
472            }),
473        )
474        .await;
475        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
476        let action_log = cx.new(|_| ActionLog::new(project.clone()));
477        let model = Arc::new(FakeLanguageModel::default());
478
479        // start_line of 0 should be treated as 1
480        let result = cx
481            .update(|cx| {
482                let input = json!({
483                    "path": "root/multiline.txt",
484                    "start_line": 0,
485                    "end_line": 2
486                });
487                Arc::new(ReadFileTool)
488                    .run(
489                        input,
490                        Arc::default(),
491                        project.clone(),
492                        action_log.clone(),
493                        model.clone(),
494                        None,
495                        cx,
496                    )
497                    .output
498            })
499            .await;
500        assert_eq!(result.unwrap().content.as_str(), Some("Line 1\nLine 2"));
501
502        // end_line of 0 should result in at least 1 line
503        let result = cx
504            .update(|cx| {
505                let input = json!({
506                    "path": "root/multiline.txt",
507                    "start_line": 1,
508                    "end_line": 0
509                });
510                Arc::new(ReadFileTool)
511                    .run(
512                        input,
513                        Arc::default(),
514                        project.clone(),
515                        action_log.clone(),
516                        model.clone(),
517                        None,
518                        cx,
519                    )
520                    .output
521            })
522            .await;
523        assert_eq!(result.unwrap().content.as_str(), Some("Line 1"));
524
525        // when start_line > end_line, should still return at least 1 line
526        let result = cx
527            .update(|cx| {
528                let input = json!({
529                    "path": "root/multiline.txt",
530                    "start_line": 3,
531                    "end_line": 2
532                });
533                Arc::new(ReadFileTool)
534                    .run(
535                        input,
536                        Arc::default(),
537                        project.clone(),
538                        action_log,
539                        model,
540                        None,
541                        cx,
542                    )
543                    .output
544            })
545            .await;
546        assert_eq!(result.unwrap().content.as_str(), Some("Line 3"));
547    }
548
549    fn init_test(cx: &mut TestAppContext) {
550        cx.update(|cx| {
551            let settings_store = SettingsStore::test(cx);
552            cx.set_global(settings_store);
553            language::init(cx);
554            Project::init_settings(cx);
555        });
556    }
557
558    fn rust_lang() -> Language {
559        Language::new(
560            LanguageConfig {
561                name: "Rust".into(),
562                matcher: LanguageMatcher {
563                    path_suffixes: vec!["rs".to_string()],
564                    ..Default::default()
565                },
566                ..Default::default()
567            },
568            Some(tree_sitter_rust::LANGUAGE.into()),
569        )
570        .with_outline_query(
571            r#"
572            (line_comment) @annotation
573
574            (struct_item
575                "struct" @context
576                name: (_) @name) @item
577            (enum_item
578                "enum" @context
579                name: (_) @name) @item
580            (enum_variant
581                name: (_) @name) @item
582            (field_declaration
583                name: (_) @name) @item
584            (impl_item
585                "impl" @context
586                trait: (_)? @name
587                "for"? @context
588                type: (_) @name
589                body: (_ "{" (_)* "}")) @item
590            (function_item
591                "fn" @context
592                name: (_) @name) @item
593            (mod_item
594                "mod" @context
595                name: (_) @name) @item
596            "#,
597        )
598        .unwrap()
599    }
600}