read_file_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use assistant_tool::{ToolResultContent, outline};
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6use project::{ImageItem, image_store};
  7
  8use assistant_tool::ToolResultOutput;
  9use indoc::formatdoc;
 10use itertools::Itertools;
 11use language::{Anchor, Point};
 12use language_model::{
 13    LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
 14};
 15use project::{AgentLocation, Project};
 16use schemars::JsonSchema;
 17use serde::{Deserialize, Serialize};
 18use std::sync::Arc;
 19use ui::IconName;
 20use util::markdown::MarkdownInlineCode;
 21
 22/// If the model requests to read a file whose size exceeds this, then
 23#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 24pub struct ReadFileToolInput {
 25    /// The relative path of the file to read.
 26    ///
 27    /// This path should never be absolute, and the first component
 28    /// of the path should always be a root directory in a project.
 29    ///
 30    /// <example>
 31    /// If the project has the following root directories:
 32    ///
 33    /// - directory1
 34    /// - directory2
 35    ///
 36    /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
 37    /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
 38    /// </example>
 39    pub path: String,
 40
 41    /// Optional line number to start reading on (1-based index)
 42    #[serde(default)]
 43    pub start_line: Option<u32>,
 44
 45    /// Optional line number to end reading on (1-based index, inclusive)
 46    #[serde(default)]
 47    pub end_line: Option<u32>,
 48}
 49
 50pub struct ReadFileTool;
 51
 52impl Tool for ReadFileTool {
 53    fn name(&self) -> String {
 54        "read_file".into()
 55    }
 56
 57    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 58        false
 59    }
 60
 61    fn may_perform_edits(&self) -> bool {
 62        false
 63    }
 64
 65    fn description(&self) -> String {
 66        include_str!("./read_file_tool/description.md").into()
 67    }
 68
 69    fn icon(&self) -> IconName {
 70        IconName::FileSearch
 71    }
 72
 73    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 74        json_schema_for::<ReadFileToolInput>(format)
 75    }
 76
 77    fn ui_text(&self, input: &serde_json::Value) -> String {
 78        match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
 79            Ok(input) => {
 80                let path = MarkdownInlineCode(&input.path);
 81                match (input.start_line, input.end_line) {
 82                    (Some(start), None) => format!("Read file {path} (from line {start})"),
 83                    (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
 84                    _ => format!("Read file {path}"),
 85                }
 86            }
 87            Err(_) => "Read file".to_string(),
 88        }
 89    }
 90
 91    fn run(
 92        self: Arc<Self>,
 93        input: serde_json::Value,
 94        _request: Arc<LanguageModelRequest>,
 95        project: Entity<Project>,
 96        action_log: Entity<ActionLog>,
 97        model: Arc<dyn LanguageModel>,
 98        _window: Option<AnyWindowHandle>,
 99        cx: &mut App,
100    ) -> ToolResult {
101        let input = match serde_json::from_value::<ReadFileToolInput>(input) {
102            Ok(input) => input,
103            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
104        };
105
106        let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
107            return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into();
108        };
109
110        let file_path = input.path.clone();
111
112        if image_store::is_image_file(&project, &project_path, cx) {
113            if !model.supports_images() {
114                return Task::ready(Err(anyhow!(
115                    "Attempted to read an image, but Zed doesn't currently sending images to {}.",
116                    model.name().0
117                )))
118                .into();
119            }
120
121            let task = cx.spawn(async move |cx| -> Result<ToolResultOutput> {
122                let image_entity: Entity<ImageItem> = cx
123                    .update(|cx| {
124                        project.update(cx, |project, cx| {
125                            project.open_image(project_path.clone(), cx)
126                        })
127                    })?
128                    .await?;
129
130                let image =
131                    image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
132
133                let language_model_image = cx
134                    .update(|cx| LanguageModelImage::from_image(image, cx))?
135                    .await
136                    .context("processing image")?;
137
138                Ok(ToolResultOutput {
139                    content: ToolResultContent::Image(language_model_image),
140                    output: None,
141                })
142            });
143
144            return task.into();
145        }
146
147        cx.spawn(async move |cx| {
148            let buffer = cx
149                .update(|cx| {
150                    project.update(cx, |project, cx| project.open_buffer(project_path, cx))
151                })?
152                .await?;
153            if buffer.read_with(cx, |buffer, _| {
154                buffer
155                    .file()
156                    .as_ref()
157                    .map_or(true, |file| !file.disk_state().exists())
158            })? {
159                anyhow::bail!("{file_path} not found");
160            }
161
162            project.update(cx, |project, cx| {
163                project.set_agent_location(
164                    Some(AgentLocation {
165                        buffer: buffer.downgrade(),
166                        position: Anchor::MIN,
167                    }),
168                    cx,
169                );
170            })?;
171
172            // Check if specific line ranges are provided
173            if input.start_line.is_some() || input.end_line.is_some() {
174                let mut anchor = None;
175                let result = buffer.read_with(cx, |buffer, _cx| {
176                    let text = buffer.text();
177                    // .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
178                    let start = input.start_line.unwrap_or(1).max(1);
179                    let start_row = start - 1;
180                    if start_row <= buffer.max_point().row {
181                        let column = buffer.line_indent_for_row(start_row).raw_len();
182                        anchor = Some(buffer.anchor_before(Point::new(start_row, column)));
183                    }
184
185                    let lines = text.split('\n').skip(start_row as usize);
186                    if let Some(end) = input.end_line {
187                        let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
188                        Itertools::intersperse(lines.take(count as usize), "\n")
189                            .collect::<String>()
190                            .into()
191                    } else {
192                        Itertools::intersperse(lines, "\n")
193                            .collect::<String>()
194                            .into()
195                    }
196                })?;
197
198                action_log.update(cx, |log, cx| {
199                    log.buffer_read(buffer.clone(), cx);
200                })?;
201
202                if let Some(anchor) = anchor {
203                    project.update(cx, |project, cx| {
204                        project.set_agent_location(
205                            Some(AgentLocation {
206                                buffer: buffer.downgrade(),
207                                position: anchor,
208                            }),
209                            cx,
210                        );
211                    })?;
212                }
213
214                Ok(result)
215            } else {
216                // No line ranges specified, so check file size to see if it's too big.
217                let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
218
219                if file_size <= outline::AUTO_OUTLINE_SIZE {
220                    // File is small enough, so return its contents.
221                    let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
222
223                    action_log.update(cx, |log, cx| {
224                        log.buffer_read(buffer, cx);
225                    })?;
226
227                    Ok(result.into())
228                } else {
229                    // File is too big, so return the outline
230                    // and a suggestion to read again with line numbers.
231                    let outline =
232                        outline::file_outline(project, file_path, action_log, None, cx).await?;
233                    Ok(formatdoc! {"
234                        This file was too big to read all at once.
235
236                        Here is an outline of its symbols:
237
238                        {outline}
239
240                        Using the line numbers in this outline, you can call this tool again
241                        while specifying the start_line and end_line fields to see the
242                        implementations of symbols in the outline."
243                    }
244                    .into())
245                }
246            }
247        })
248        .into()
249    }
250}
251
252#[cfg(test)]
253mod test {
254    use super::*;
255    use gpui::{AppContext, TestAppContext};
256    use language::{Language, LanguageConfig, LanguageMatcher};
257    use language_model::fake_provider::FakeLanguageModel;
258    use project::{FakeFs, Project};
259    use serde_json::json;
260    use settings::SettingsStore;
261    use util::path;
262
263    #[gpui::test]
264    async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
265        init_test(cx);
266
267        let fs = FakeFs::new(cx.executor());
268        fs.insert_tree("/root", json!({})).await;
269        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
270        let action_log = cx.new(|_| ActionLog::new(project.clone()));
271        let model = Arc::new(FakeLanguageModel::default());
272        let result = cx
273            .update(|cx| {
274                let input = json!({
275                    "path": "root/nonexistent_file.txt"
276                });
277                Arc::new(ReadFileTool)
278                    .run(
279                        input,
280                        Arc::default(),
281                        project.clone(),
282                        action_log,
283                        model,
284                        None,
285                        cx,
286                    )
287                    .output
288            })
289            .await;
290        assert_eq!(
291            result.unwrap_err().to_string(),
292            "root/nonexistent_file.txt not found"
293        );
294    }
295
296    #[gpui::test]
297    async fn test_read_small_file(cx: &mut TestAppContext) {
298        init_test(cx);
299
300        let fs = FakeFs::new(cx.executor());
301        fs.insert_tree(
302            "/root",
303            json!({
304                "small_file.txt": "This is a small file content"
305            }),
306        )
307        .await;
308        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
309        let action_log = cx.new(|_| ActionLog::new(project.clone()));
310        let model = Arc::new(FakeLanguageModel::default());
311        let result = cx
312            .update(|cx| {
313                let input = json!({
314                    "path": "root/small_file.txt"
315                });
316                Arc::new(ReadFileTool)
317                    .run(
318                        input,
319                        Arc::default(),
320                        project.clone(),
321                        action_log,
322                        model,
323                        None,
324                        cx,
325                    )
326                    .output
327            })
328            .await;
329        assert_eq!(
330            result.unwrap().content.as_str(),
331            Some("This is a small file content")
332        );
333    }
334
335    #[gpui::test]
336    async fn test_read_large_file(cx: &mut TestAppContext) {
337        init_test(cx);
338
339        let fs = FakeFs::new(cx.executor());
340        fs.insert_tree(
341            "/root",
342            json!({
343                "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n    a: u32,\n    b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
344            }),
345        )
346        .await;
347        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
348        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
349        language_registry.add(Arc::new(rust_lang()));
350        let action_log = cx.new(|_| ActionLog::new(project.clone()));
351        let model = Arc::new(FakeLanguageModel::default());
352
353        let result = cx
354            .update(|cx| {
355                let input = json!({
356                    "path": "root/large_file.rs"
357                });
358                Arc::new(ReadFileTool)
359                    .run(
360                        input,
361                        Arc::default(),
362                        project.clone(),
363                        action_log.clone(),
364                        model.clone(),
365                        None,
366                        cx,
367                    )
368                    .output
369            })
370            .await;
371        let content = result.unwrap();
372        let content = content.as_str().unwrap();
373        assert_eq!(
374            content.lines().skip(4).take(6).collect::<Vec<_>>(),
375            vec![
376                "struct Test0 [L1-4]",
377                " a [L2]",
378                " b [L3]",
379                "struct Test1 [L5-8]",
380                " a [L6]",
381                " b [L7]",
382            ]
383        );
384
385        let result = cx
386            .update(|cx| {
387                let input = json!({
388                    "path": "root/large_file.rs",
389                    "offset": 1
390                });
391                Arc::new(ReadFileTool)
392                    .run(
393                        input,
394                        Arc::default(),
395                        project.clone(),
396                        action_log,
397                        model,
398                        None,
399                        cx,
400                    )
401                    .output
402            })
403            .await;
404        let content = result.unwrap();
405        let expected_content = (0..1000)
406            .flat_map(|i| {
407                vec![
408                    format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
409                    format!(" a [L{}]", i * 4 + 2),
410                    format!(" b [L{}]", i * 4 + 3),
411                ]
412            })
413            .collect::<Vec<_>>();
414        pretty_assertions::assert_eq!(
415            content
416                .as_str()
417                .unwrap()
418                .lines()
419                .skip(4)
420                .take(expected_content.len())
421                .collect::<Vec<_>>(),
422            expected_content
423        );
424    }
425
426    #[gpui::test]
427    async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
428        init_test(cx);
429
430        let fs = FakeFs::new(cx.executor());
431        fs.insert_tree(
432            "/root",
433            json!({
434                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
435            }),
436        )
437        .await;
438        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
439        let action_log = cx.new(|_| ActionLog::new(project.clone()));
440        let model = Arc::new(FakeLanguageModel::default());
441        let result = cx
442            .update(|cx| {
443                let input = json!({
444                    "path": "root/multiline.txt",
445                    "start_line": 2,
446                    "end_line": 4
447                });
448                Arc::new(ReadFileTool)
449                    .run(
450                        input,
451                        Arc::default(),
452                        project.clone(),
453                        action_log,
454                        model,
455                        None,
456                        cx,
457                    )
458                    .output
459            })
460            .await;
461        assert_eq!(
462            result.unwrap().content.as_str(),
463            Some("Line 2\nLine 3\nLine 4")
464        );
465    }
466
467    #[gpui::test]
468    async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) {
469        init_test(cx);
470
471        let fs = FakeFs::new(cx.executor());
472        fs.insert_tree(
473            "/root",
474            json!({
475                "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
476            }),
477        )
478        .await;
479        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
480        let action_log = cx.new(|_| ActionLog::new(project.clone()));
481        let model = Arc::new(FakeLanguageModel::default());
482
483        // start_line of 0 should be treated as 1
484        let result = cx
485            .update(|cx| {
486                let input = json!({
487                    "path": "root/multiline.txt",
488                    "start_line": 0,
489                    "end_line": 2
490                });
491                Arc::new(ReadFileTool)
492                    .run(
493                        input,
494                        Arc::default(),
495                        project.clone(),
496                        action_log.clone(),
497                        model.clone(),
498                        None,
499                        cx,
500                    )
501                    .output
502            })
503            .await;
504        assert_eq!(result.unwrap().content.as_str(), Some("Line 1\nLine 2"));
505
506        // end_line of 0 should result in at least 1 line
507        let result = cx
508            .update(|cx| {
509                let input = json!({
510                    "path": "root/multiline.txt",
511                    "start_line": 1,
512                    "end_line": 0
513                });
514                Arc::new(ReadFileTool)
515                    .run(
516                        input,
517                        Arc::default(),
518                        project.clone(),
519                        action_log.clone(),
520                        model.clone(),
521                        None,
522                        cx,
523                    )
524                    .output
525            })
526            .await;
527        assert_eq!(result.unwrap().content.as_str(), Some("Line 1"));
528
529        // when start_line > end_line, should still return at least 1 line
530        let result = cx
531            .update(|cx| {
532                let input = json!({
533                    "path": "root/multiline.txt",
534                    "start_line": 3,
535                    "end_line": 2
536                });
537                Arc::new(ReadFileTool)
538                    .run(
539                        input,
540                        Arc::default(),
541                        project.clone(),
542                        action_log,
543                        model,
544                        None,
545                        cx,
546                    )
547                    .output
548            })
549            .await;
550        assert_eq!(result.unwrap().content.as_str(), Some("Line 3"));
551    }
552
553    fn init_test(cx: &mut TestAppContext) {
554        cx.update(|cx| {
555            let settings_store = SettingsStore::test(cx);
556            cx.set_global(settings_store);
557            language::init(cx);
558            Project::init_settings(cx);
559        });
560    }
561
562    fn rust_lang() -> Language {
563        Language::new(
564            LanguageConfig {
565                name: "Rust".into(),
566                matcher: LanguageMatcher {
567                    path_suffixes: vec!["rs".to_string()],
568                    ..Default::default()
569                },
570                ..Default::default()
571            },
572            Some(tree_sitter_rust::LANGUAGE.into()),
573        )
574        .with_outline_query(
575            r#"
576            (line_comment) @annotation
577
578            (struct_item
579                "struct" @context
580                name: (_) @name) @item
581            (enum_item
582                "enum" @context
583                name: (_) @name) @item
584            (enum_variant
585                name: (_) @name) @item
586            (field_declaration
587                name: (_) @name) @item
588            (impl_item
589                "impl" @context
590                trait: (_)? @name
591                "for"? @context
592                type: (_) @name
593                body: (_ "{" (_)* "}")) @item
594            (function_item
595                "fn" @context
596                name: (_) @name) @item
597            (mod_item
598                "mod" @context
599                name: (_) @name) @item
600            "#,
601        )
602        .unwrap()
603    }
604}