head_tool.rs

  1use crate::{AgentTool, ToolCallEventStream};
  2use agent_client_protocol as acp;
  3use anyhow::{Result, anyhow};
  4use gpui::{App, Entity, SharedString, Task};
  5use project::{Project, WorktreeSettings};
  6use schemars::JsonSchema;
  7use serde::{Deserialize, Serialize};
  8use settings::Settings;
  9use std::sync::Arc;
 10use util::markdown::MarkdownCodeBlock;
 11
 12/// Reads the first N bytes of a file in the project
 13///
 14/// - Useful for quickly previewing the beginning of files
 15/// - More efficient than reading the entire file when only the start is needed
 16/// - By default reads the first 1024 bytes
 17/// - Can be used to check file headers, magic numbers, or initial content
 18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 19pub struct HeadToolInput {
 20    /// The relative path of the file to read.
 21    ///
 22    /// This path should never be absolute, and the first component of the path should always be a root directory in a project.
 23    pub path: String,
 24    /// Number of bytes to read from the beginning of the file. Defaults to 1024.
 25    #[serde(default = "default_byte_count")]
 26    pub bytes: u32,
 27}
 28
 29fn default_byte_count() -> u32 {
 30    1024
 31}
 32
 33pub struct HeadTool {
 34    project: Entity<Project>,
 35}
 36
 37impl HeadTool {
 38    pub fn new(project: Entity<Project>) -> Self {
 39        Self { project }
 40    }
 41}
 42
 43impl AgentTool for HeadTool {
 44    type Input = HeadToolInput;
 45    type Output = String;
 46
 47    fn name() -> &'static str {
 48        "head"
 49    }
 50
 51    fn kind() -> acp::ToolKind {
 52        acp::ToolKind::Read
 53    }
 54
 55    fn initial_title(
 56        &self,
 57        input: Result<Self::Input, serde_json::Value>,
 58        cx: &mut App,
 59    ) -> SharedString {
 60        match input {
 61            Ok(input) => {
 62                if let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx)
 63                    && let Some(path) = self
 64                        .project
 65                        .read(cx)
 66                        .short_full_path_for_project_path(&project_path, cx)
 67                {
 68                    format!("Read first {} bytes of `{}`", input.bytes, path)
 69                } else {
 70                    format!("Read first {} bytes of file", input.bytes)
 71                }
 72            }
 73            Err(_) => "Read beginning of file".into(),
 74        }
 75        .into()
 76    }
 77
 78    fn run(
 79        self: Arc<Self>,
 80        input: Self::Input,
 81        event_stream: ToolCallEventStream,
 82        cx: &mut App,
 83    ) -> Task<Result<Self::Output>> {
 84        let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
 85            return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
 86        };
 87
 88        let Some(abs_path) = self.project.read(cx).absolute_path(&project_path, cx) else {
 89            return Task::ready(Err(anyhow!(
 90                "Failed to convert {} to absolute path",
 91                &input.path
 92            )));
 93        };
 94
 95        // Error out if this path is either excluded or private in global settings
 96        let global_settings = WorktreeSettings::get_global(cx);
 97        if global_settings.is_path_excluded(&project_path.path) {
 98            return Task::ready(Err(anyhow!(
 99                "Cannot read file because its path matches the global `file_scan_exclusions` setting: {}",
100                &input.path
101            )));
102        }
103
104        if global_settings.is_path_private(&project_path.path) {
105            return Task::ready(Err(anyhow!(
106                "Cannot read file because its path matches the global `private_files` setting: {}",
107                &input.path
108            )));
109        }
110
111        // Error out if this path is either excluded or private in worktree settings
112        let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx);
113        if worktree_settings.is_path_excluded(&project_path.path) {
114            return Task::ready(Err(anyhow!(
115                "Cannot read file because its path matches the worktree `file_scan_exclusions` setting: {}",
116                &input.path
117            )));
118        }
119
120        if worktree_settings.is_path_private(&project_path.path) {
121            return Task::ready(Err(anyhow!(
122                "Cannot read file because its path matches the worktree `private_files` setting: {}",
123                &input.path
124            )));
125        }
126
127        let file_path = input.path.clone();
128        let bytes_to_read = input.bytes.max(1) as usize; // Ensure at least 1 byte is read
129
130        event_stream.update_fields(acp::ToolCallUpdateFields {
131            locations: Some(vec![acp::ToolCallLocation {
132                path: abs_path.clone(),
133                line: Some(0),
134                meta: None,
135            }]),
136            ..Default::default()
137        });
138
139        let project = self.project.clone();
140
141        cx.spawn(async move |cx| {
142            let buffer = cx
143                .update(|cx| {
144                    project.update(cx, |project, cx| {
145                        project.open_buffer(project_path.clone(), cx)
146                    })
147                })?
148                .await?;
149
150            if buffer.read_with(cx, |buffer, _| {
151                buffer
152                    .file()
153                    .as_ref()
154                    .is_none_or(|file| !file.disk_state().exists())
155            })? {
156                anyhow::bail!("{file_path} not found");
157            }
158
159            let result = buffer.read_with(cx, |buffer, _cx| {
160                let full_text = buffer.text();
161                let total_bytes = full_text.len();
162                let bytes_read = bytes_to_read.min(total_bytes);
163
164                let text = if bytes_read < total_bytes {
165                    &full_text[..bytes_read]
166                } else {
167                    &full_text
168                };
169
170                if bytes_read < total_bytes {
171                    format!("{}\n\n(showing first {} of {} bytes)", text, bytes_read, total_bytes)
172                } else {
173                    format!("{}\n\n(file has only {} bytes total)", text, total_bytes)
174                }
175            })?;
176
177            // Update the event stream with formatted content
178            let markdown = MarkdownCodeBlock {
179                tag: &file_path,
180                text: &result,
181            }
182            .to_string();
183
184            event_stream.update_fields(acp::ToolCallUpdateFields {
185                content: Some(vec![acp::ToolCallContent::Content {
186                    content: markdown.into(),
187                }]),
188                ..Default::default()
189            });
190
191            Ok(result)
192        })
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::ToolCallEventStream;
200    use gpui::{TestAppContext, UpdateGlobal};
201    use project::{FakeFs, Project};
202    use serde_json::json;
203    use settings::SettingsStore;
204    use util::path;
205
206    #[gpui::test]
207    async fn test_head_tool_basic(cx: &mut TestAppContext) {
208        init_test(cx);
209
210        let fs = FakeFs::new(cx.executor());
211        fs.insert_tree(
212            path!("/root"),
213            json!({
214                "test.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10\nLine 11\nLine 12"
215            }),
216        )
217        .await;
218
219        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
220        let tool = Arc::new(HeadTool::new(project.clone()));
221
222        // Test reading first 20 bytes
223        let input = HeadToolInput {
224            path: "root/test.txt".to_string(),
225            bytes: 20,
226        };
227
228        let result = cx
229            .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx))
230            .await
231            .unwrap();
232
233        assert!(result.starts_with("Line 1\nLine 2\nLine 3"));
234        assert!(result.contains("showing first 20 of"));
235
236        // Test reading first 50 bytes
237        let input = HeadToolInput {
238            path: "root/test.txt".to_string(),
239            bytes: 50,
240        };
241
242        let result = cx
243            .update(|cx| tool.run(input, ToolCallEventStream::test().0, cx))
244            .await
245            .unwrap();
246
247        assert!(result.starts_with("Line 1\nLine 2"));
248        assert!(result.contains("showing first 50 of"));
249    }
250
251    #[gpui::test]
252    async fn test_head_tool_small_file(cx: &mut TestAppContext) {
253        init_test(cx);
254
255        let fs = FakeFs::new(cx.executor());
256        fs.insert_tree(
257            path!("/root"),
258            json!({
259                "small.txt": "Line 1\nLine 2\nLine 3"
260            }),
261        )
262        .await;
263
264        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
265        let tool = Arc::new(HeadTool::new(project));
266
267        // Request more bytes than exist
268        let input = HeadToolInput {
269            path: "root/small.txt".to_string(),
270            bytes: 1000,
271        };
272
273        let result = cx
274            .update(|cx| tool.run(input, ToolCallEventStream::test().0, cx))
275            .await
276            .unwrap();
277
278        assert!(result.contains("Line 1"));
279        assert!(result.contains("Line 2"));
280        assert!(result.contains("Line 3"));
281        assert!(result.contains("file has only"));
282    }
283
284    #[gpui::test]
285    async fn test_head_tool_nonexistent_file(cx: &mut TestAppContext) {
286        init_test(cx);
287
288        let fs = FakeFs::new(cx.executor());
289        fs.insert_tree(path!("/root"), json!({})).await;
290
291        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
292        let tool = Arc::new(HeadTool::new(project));
293
294        let input = HeadToolInput {
295            path: "root/nonexistent.txt".to_string(),
296            bytes: 1024,
297        };
298
299        let result = cx
300            .update(|cx| tool.run(input, ToolCallEventStream::test().0, cx))
301            .await;
302
303        assert!(result.is_err());
304        assert_eq!(
305            result.unwrap_err().to_string(),
306            "root/nonexistent.txt not found"
307        );
308    }
309
310    #[gpui::test]
311    async fn test_head_tool_security(cx: &mut TestAppContext) {
312        init_test(cx);
313
314        let fs = FakeFs::new(cx.executor());
315        fs.insert_tree(
316            path!("/"),
317            json!({
318                "project_root": {
319                    "allowed.txt": "This is allowed",
320                    ".secret": "SECRET_KEY=abc123",
321                    "private.key": "private key content"
322                },
323                "outside": {
324                    "sensitive.txt": "Outside project"
325                }
326            }),
327        )
328        .await;
329
330        cx.update(|cx| {
331            SettingsStore::update_global(cx, |store, cx| {
332                store.update_user_settings(cx, |settings| {
333                    settings.project.worktree.file_scan_exclusions = Some(vec!["**/.secret".to_string()]);
334                    settings.project.worktree.private_files = Some(vec!["**/*.key".to_string()].into());
335                });
336            });
337        });
338
339        let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
340        let tool = Arc::new(HeadTool::new(project));
341
342        // Reading allowed file should succeed
343        let result = cx
344            .update(|cx| {
345                tool.clone().run(
346                    HeadToolInput {
347                        path: "project_root/allowed.txt".to_string(),
348                        bytes: 1024,
349                    },
350                    ToolCallEventStream::test().0,
351                    cx,
352                )
353            })
354            .await;
355        assert!(result.is_ok());
356
357        // Reading excluded file should fail
358        let result = cx
359            .update(|cx| {
360                tool.clone().run(
361                    HeadToolInput {
362                        path: "project_root/.secret".to_string(),
363                        bytes: 1024,
364                    },
365                    ToolCallEventStream::test().0,
366                    cx,
367                )
368            })
369            .await;
370        assert!(result.is_err());
371
372        // Reading private file should fail
373        let result = cx
374            .update(|cx| {
375                tool.run(
376                    HeadToolInput {
377                        path: "project_root/private.key".to_string(),
378                        bytes: 1024,
379                    },
380                    ToolCallEventStream::test().0,
381                    cx,
382                )
383            })
384            .await;
385        assert!(result.is_err());
386    }
387
388    fn init_test(cx: &mut TestAppContext) {
389        cx.update(|cx| {
390            let settings_store = SettingsStore::test(cx);
391            cx.set_global(settings_store);
392        });
393    }
394}