1use crate::{AgentTool, ToolCallEventStream};
2use agent_client_protocol as acp;
3use anyhow::{Result, anyhow};
4use gpui::{App, Entity, SharedString, Task};
5use project::{Project, WorktreeSettings};
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use settings::Settings;
9use std::sync::Arc;
10use util::markdown::MarkdownCodeBlock;
11
12/// Reads the first N bytes of a file in the project
13///
14/// - Useful for quickly previewing the beginning of files
15/// - More efficient than reading the entire file when only the start is needed
16/// - By default reads the first 1024 bytes
17/// - Can be used to check file headers, magic numbers, or initial content
18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
19pub struct HeadToolInput {
20 /// The relative path of the file to read.
21 ///
22 /// This path should never be absolute, and the first component of the path should always be a root directory in a project.
23 pub path: String,
24 /// Number of bytes to read from the beginning of the file. Defaults to 1024.
25 #[serde(default = "default_byte_count")]
26 pub bytes: u32,
27}
28
29fn default_byte_count() -> u32 {
30 1024
31}
32
33pub struct HeadTool {
34 project: Entity<Project>,
35}
36
37impl HeadTool {
38 pub fn new(project: Entity<Project>) -> Self {
39 Self { project }
40 }
41}
42
43impl AgentTool for HeadTool {
44 type Input = HeadToolInput;
45 type Output = String;
46
47 fn name() -> &'static str {
48 "head"
49 }
50
51 fn kind() -> acp::ToolKind {
52 acp::ToolKind::Read
53 }
54
55 fn initial_title(
56 &self,
57 input: Result<Self::Input, serde_json::Value>,
58 cx: &mut App,
59 ) -> SharedString {
60 match input {
61 Ok(input) => {
62 if let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx)
63 && let Some(path) = self
64 .project
65 .read(cx)
66 .short_full_path_for_project_path(&project_path, cx)
67 {
68 format!("Read first {} bytes of `{}`", input.bytes, path)
69 } else {
70 format!("Read first {} bytes of file", input.bytes)
71 }
72 }
73 Err(_) => "Read beginning of file".into(),
74 }
75 .into()
76 }
77
78 fn run(
79 self: Arc<Self>,
80 input: Self::Input,
81 event_stream: ToolCallEventStream,
82 cx: &mut App,
83 ) -> Task<Result<Self::Output>> {
84 let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
85 return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
86 };
87
88 let Some(abs_path) = self.project.read(cx).absolute_path(&project_path, cx) else {
89 return Task::ready(Err(anyhow!(
90 "Failed to convert {} to absolute path",
91 &input.path
92 )));
93 };
94
95 // Error out if this path is either excluded or private in global settings
96 let global_settings = WorktreeSettings::get_global(cx);
97 if global_settings.is_path_excluded(&project_path.path) {
98 return Task::ready(Err(anyhow!(
99 "Cannot read file because its path matches the global `file_scan_exclusions` setting: {}",
100 &input.path
101 )));
102 }
103
104 if global_settings.is_path_private(&project_path.path) {
105 return Task::ready(Err(anyhow!(
106 "Cannot read file because its path matches the global `private_files` setting: {}",
107 &input.path
108 )));
109 }
110
111 // Error out if this path is either excluded or private in worktree settings
112 let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx);
113 if worktree_settings.is_path_excluded(&project_path.path) {
114 return Task::ready(Err(anyhow!(
115 "Cannot read file because its path matches the worktree `file_scan_exclusions` setting: {}",
116 &input.path
117 )));
118 }
119
120 if worktree_settings.is_path_private(&project_path.path) {
121 return Task::ready(Err(anyhow!(
122 "Cannot read file because its path matches the worktree `private_files` setting: {}",
123 &input.path
124 )));
125 }
126
127 let file_path = input.path.clone();
128 let bytes_to_read = input.bytes.max(1) as usize; // Ensure at least 1 byte is read
129
130 event_stream.update_fields(acp::ToolCallUpdateFields {
131 locations: Some(vec![acp::ToolCallLocation {
132 path: abs_path.clone(),
133 line: Some(0),
134 meta: None,
135 }]),
136 ..Default::default()
137 });
138
139 let project = self.project.clone();
140
141 cx.spawn(async move |cx| {
142 let buffer = cx
143 .update(|cx| {
144 project.update(cx, |project, cx| {
145 project.open_buffer(project_path.clone(), cx)
146 })
147 })?
148 .await?;
149
150 if buffer.read_with(cx, |buffer, _| {
151 buffer
152 .file()
153 .as_ref()
154 .is_none_or(|file| !file.disk_state().exists())
155 })? {
156 anyhow::bail!("{file_path} not found");
157 }
158
159 let result = buffer.read_with(cx, |buffer, _cx| {
160 let full_text = buffer.text();
161 let total_bytes = full_text.len();
162 let bytes_read = bytes_to_read.min(total_bytes);
163
164 let text = if bytes_read < total_bytes {
165 &full_text[..bytes_read]
166 } else {
167 &full_text
168 };
169
170 if bytes_read < total_bytes {
171 format!("{}\n\n(showing first {} of {} bytes)", text, bytes_read, total_bytes)
172 } else {
173 format!("{}\n\n(file has only {} bytes total)", text, total_bytes)
174 }
175 })?;
176
177 // Update the event stream with formatted content
178 let markdown = MarkdownCodeBlock {
179 tag: &file_path,
180 text: &result,
181 }
182 .to_string();
183
184 event_stream.update_fields(acp::ToolCallUpdateFields {
185 content: Some(vec![acp::ToolCallContent::Content {
186 content: markdown.into(),
187 }]),
188 ..Default::default()
189 });
190
191 Ok(result)
192 })
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::ToolCallEventStream;
200 use gpui::{TestAppContext, UpdateGlobal};
201 use project::{FakeFs, Project};
202 use serde_json::json;
203 use settings::SettingsStore;
204 use util::path;
205
206 #[gpui::test]
207 async fn test_head_tool_basic(cx: &mut TestAppContext) {
208 init_test(cx);
209
210 let fs = FakeFs::new(cx.executor());
211 fs.insert_tree(
212 path!("/root"),
213 json!({
214 "test.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10\nLine 11\nLine 12"
215 }),
216 )
217 .await;
218
219 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
220 let tool = Arc::new(HeadTool::new(project.clone()));
221
222 // Test reading first 20 bytes
223 let input = HeadToolInput {
224 path: "root/test.txt".to_string(),
225 bytes: 20,
226 };
227
228 let result = cx
229 .update(|cx| tool.clone().run(input, ToolCallEventStream::test().0, cx))
230 .await
231 .unwrap();
232
233 assert!(result.starts_with("Line 1\nLine 2\nLine 3"));
234 assert!(result.contains("showing first 20 of"));
235
236 // Test reading first 50 bytes
237 let input = HeadToolInput {
238 path: "root/test.txt".to_string(),
239 bytes: 50,
240 };
241
242 let result = cx
243 .update(|cx| tool.run(input, ToolCallEventStream::test().0, cx))
244 .await
245 .unwrap();
246
247 assert!(result.starts_with("Line 1\nLine 2"));
248 assert!(result.contains("showing first 50 of"));
249 }
250
251 #[gpui::test]
252 async fn test_head_tool_small_file(cx: &mut TestAppContext) {
253 init_test(cx);
254
255 let fs = FakeFs::new(cx.executor());
256 fs.insert_tree(
257 path!("/root"),
258 json!({
259 "small.txt": "Line 1\nLine 2\nLine 3"
260 }),
261 )
262 .await;
263
264 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
265 let tool = Arc::new(HeadTool::new(project));
266
267 // Request more bytes than exist
268 let input = HeadToolInput {
269 path: "root/small.txt".to_string(),
270 bytes: 1000,
271 };
272
273 let result = cx
274 .update(|cx| tool.run(input, ToolCallEventStream::test().0, cx))
275 .await
276 .unwrap();
277
278 assert!(result.contains("Line 1"));
279 assert!(result.contains("Line 2"));
280 assert!(result.contains("Line 3"));
281 assert!(result.contains("file has only"));
282 }
283
284 #[gpui::test]
285 async fn test_head_tool_nonexistent_file(cx: &mut TestAppContext) {
286 init_test(cx);
287
288 let fs = FakeFs::new(cx.executor());
289 fs.insert_tree(path!("/root"), json!({})).await;
290
291 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
292 let tool = Arc::new(HeadTool::new(project));
293
294 let input = HeadToolInput {
295 path: "root/nonexistent.txt".to_string(),
296 bytes: 1024,
297 };
298
299 let result = cx
300 .update(|cx| tool.run(input, ToolCallEventStream::test().0, cx))
301 .await;
302
303 assert!(result.is_err());
304 assert_eq!(
305 result.unwrap_err().to_string(),
306 "root/nonexistent.txt not found"
307 );
308 }
309
310 #[gpui::test]
311 async fn test_head_tool_security(cx: &mut TestAppContext) {
312 init_test(cx);
313
314 let fs = FakeFs::new(cx.executor());
315 fs.insert_tree(
316 path!("/"),
317 json!({
318 "project_root": {
319 "allowed.txt": "This is allowed",
320 ".secret": "SECRET_KEY=abc123",
321 "private.key": "private key content"
322 },
323 "outside": {
324 "sensitive.txt": "Outside project"
325 }
326 }),
327 )
328 .await;
329
330 cx.update(|cx| {
331 SettingsStore::update_global(cx, |store, cx| {
332 store.update_user_settings(cx, |settings| {
333 settings.project.worktree.file_scan_exclusions = Some(vec!["**/.secret".to_string()]);
334 settings.project.worktree.private_files = Some(vec!["**/*.key".to_string()].into());
335 });
336 });
337 });
338
339 let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
340 let tool = Arc::new(HeadTool::new(project));
341
342 // Reading allowed file should succeed
343 let result = cx
344 .update(|cx| {
345 tool.clone().run(
346 HeadToolInput {
347 path: "project_root/allowed.txt".to_string(),
348 bytes: 1024,
349 },
350 ToolCallEventStream::test().0,
351 cx,
352 )
353 })
354 .await;
355 assert!(result.is_ok());
356
357 // Reading excluded file should fail
358 let result = cx
359 .update(|cx| {
360 tool.clone().run(
361 HeadToolInput {
362 path: "project_root/.secret".to_string(),
363 bytes: 1024,
364 },
365 ToolCallEventStream::test().0,
366 cx,
367 )
368 })
369 .await;
370 assert!(result.is_err());
371
372 // Reading private file should fail
373 let result = cx
374 .update(|cx| {
375 tool.run(
376 HeadToolInput {
377 path: "project_root/private.key".to_string(),
378 bytes: 1024,
379 },
380 ToolCallEventStream::test().0,
381 cx,
382 )
383 })
384 .await;
385 assert!(result.is_err());
386 }
387
388 fn init_test(cx: &mut TestAppContext) {
389 cx.update(|cx| {
390 let settings_store = SettingsStore::test(cx);
391 cx.set_global(settings_store);
392 });
393 }
394}