1use crate::schema::json_schema_for;
2use anyhow::{Context as _, Result, anyhow};
3use assistant_tool::{ActionLog, Tool};
4use futures::io::BufReader;
5use futures::{AsyncBufReadExt, AsyncReadExt};
6use gpui::{App, AppContext, Entity, Task};
7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
8use project::Project;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use std::path::Path;
12use std::sync::Arc;
13use ui::IconName;
14use util::command::new_smol_command;
15use util::markdown::MarkdownString;
16
17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
18pub struct BashToolInput {
19 /// The bash one-liner command to execute.
20 command: String,
21 /// Working directory for the command. This must be one of the root directories of the project.
22 cd: String,
23}
24
25pub struct BashTool;
26
27impl Tool for BashTool {
28 fn name(&self) -> String {
29 "bash".to_string()
30 }
31
32 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
33 true
34 }
35
36 fn description(&self) -> String {
37 include_str!("./bash_tool/description.md").to_string()
38 }
39
40 fn icon(&self) -> IconName {
41 IconName::Terminal
42 }
43
44 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
45 json_schema_for::<BashToolInput>(format)
46 }
47
48 fn ui_text(&self, input: &serde_json::Value) -> String {
49 match serde_json::from_value::<BashToolInput>(input.clone()) {
50 Ok(input) => {
51 let mut lines = input.command.lines();
52 let first_line = lines.next().unwrap_or_default();
53 let remaining_line_count = lines.count();
54 match remaining_line_count {
55 0 => MarkdownString::inline_code(&first_line).0,
56 1 => {
57 MarkdownString::inline_code(&format!(
58 "{} - {} more line",
59 first_line, remaining_line_count
60 ))
61 .0
62 }
63 n => {
64 MarkdownString::inline_code(&format!("{} - {} more lines", first_line, n)).0
65 }
66 }
67 }
68 Err(_) => "Run bash command".to_string(),
69 }
70 }
71
72 fn run(
73 self: Arc<Self>,
74 input: serde_json::Value,
75 _messages: &[LanguageModelRequestMessage],
76 project: Entity<Project>,
77 _action_log: Entity<ActionLog>,
78 cx: &mut App,
79 ) -> Task<Result<String>> {
80 let input: BashToolInput = match serde_json::from_value(input) {
81 Ok(input) => input,
82 Err(err) => return Task::ready(Err(anyhow!(err))),
83 };
84
85 let project = project.read(cx);
86 let input_path = Path::new(&input.cd);
87 let working_dir = if input.cd == "." {
88 // Accept "." as meaning "the one worktree" if we only have one worktree.
89 let mut worktrees = project.worktrees(cx);
90
91 let only_worktree = match worktrees.next() {
92 Some(worktree) => worktree,
93 None => return Task::ready(Err(anyhow!("No worktrees found in the project"))),
94 };
95
96 if worktrees.next().is_some() {
97 return Task::ready(Err(anyhow!(
98 "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
99 )));
100 }
101
102 only_worktree.read(cx).abs_path()
103 } else if input_path.is_absolute() {
104 // Absolute paths are allowed, but only if they're in one of the project's worktrees.
105 if !project
106 .worktrees(cx)
107 .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
108 {
109 return Task::ready(Err(anyhow!(
110 "The absolute path must be within one of the project's worktrees"
111 )));
112 }
113
114 input_path.into()
115 } else {
116 let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
117 return Task::ready(Err(anyhow!(
118 "`cd` directory {} not found in the project",
119 &input.cd
120 )));
121 };
122
123 worktree.read(cx).abs_path()
124 };
125
126 cx.background_spawn(run_command_limited(working_dir, input.command))
127 }
128}
129
130const LIMIT: usize = 16 * 1024;
131
132async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
133 // Add 2>&1 to merge stderr into stdout for proper interleaving.
134 let command = format!("({}) 2>&1", command);
135
136 let mut cmd = new_smol_command("bash")
137 .arg("-c")
138 .arg(&command)
139 .current_dir(working_dir)
140 .stdout(std::process::Stdio::piped())
141 .spawn()
142 .context("Failed to execute bash command")?;
143
144 // Capture stdout with a limit
145 let stdout = cmd.stdout.take().unwrap();
146 let mut reader = BufReader::new(stdout);
147
148 // Read one more byte to determine whether the output was truncated
149 let mut buffer = vec![0; LIMIT + 1];
150 let mut bytes_read = 0;
151
152 // Read until we reach the limit
153 loop {
154 let read = reader.read(&mut buffer[bytes_read..]).await?;
155 if read == 0 {
156 break;
157 }
158
159 bytes_read += read;
160 if bytes_read > LIMIT {
161 bytes_read = LIMIT + 1;
162 break;
163 }
164 }
165
166 // Repeatedly fill the output reader's buffer without copying it.
167 loop {
168 let skipped_bytes = reader.fill_buf().await?;
169 if skipped_bytes.is_empty() {
170 break;
171 }
172 let skipped_bytes_len = skipped_bytes.len();
173 reader.consume_unpin(skipped_bytes_len);
174 }
175
176 let output_bytes = &buffer[..bytes_read.min(LIMIT)];
177
178 let status = cmd.status().await.context("Failed to get command status")?;
179
180 let output_string = if bytes_read > LIMIT {
181 // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
182 // multi-byte characters.
183 let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
184 let until_last_line = &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())];
185 let output_string = String::from_utf8_lossy(until_last_line);
186
187 format!(
188 "Command output too long. The first {} bytes:\n\n{}",
189 output_string.len(),
190 output_block(&output_string),
191 )
192 } else {
193 output_block(&String::from_utf8_lossy(&output_bytes))
194 };
195
196 let output_with_status = if status.success() {
197 if output_string.is_empty() {
198 "Command executed successfully.".to_string()
199 } else {
200 output_string.to_string()
201 }
202 } else {
203 format!(
204 "Command failed with exit code {}\n\n{}",
205 status.code().unwrap_or(-1),
206 output_string,
207 )
208 };
209
210 Ok(output_with_status)
211}
212
213fn output_block(output: &str) -> String {
214 format!(
215 "```\n{}{}```",
216 output,
217 if output.ends_with('\n') { "" } else { "\n" }
218 )
219}
220
221#[cfg(test)]
222#[cfg(not(windows))]
223mod tests {
224 use gpui::TestAppContext;
225
226 use super::*;
227
228 #[gpui::test]
229 async fn test_run_command_simple(cx: &mut TestAppContext) {
230 cx.executor().allow_parking();
231
232 let result =
233 run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
234
235 assert!(result.is_ok());
236 assert_eq!(result.unwrap(), "```\nHello, World!\n```");
237 }
238
239 #[gpui::test]
240 async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
241 cx.executor().allow_parking();
242
243 let command =
244 "echo 'stdout 1' && echo 'stderr 1' >&2 && echo 'stdout 2' && echo 'stderr 2' >&2";
245 let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
246
247 assert!(result.is_ok());
248 assert_eq!(
249 result.unwrap(),
250 "```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
251 );
252 }
253
254 #[gpui::test]
255 async fn test_multiple_output_reads(cx: &mut TestAppContext) {
256 cx.executor().allow_parking();
257
258 // Command with multiple outputs that might require multiple reads
259 let result = run_command_limited(
260 Path::new(".").into(),
261 "echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
262 )
263 .await;
264
265 assert!(result.is_ok());
266 assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
267 }
268
269 #[gpui::test]
270 async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
271 cx.executor().allow_parking();
272
273 let cmd = format!("echo '{}';", "X".repeat(LIMIT * 2));
274
275 let result = run_command_limited(Path::new(".").into(), cmd).await;
276
277 assert!(result.is_ok());
278 let output = result.unwrap();
279
280 let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
281 let content_end = output.rfind("\n```").unwrap_or(output.len());
282 let content_length = content_end - content_start;
283
284 // Output should be exactly the limit
285 assert_eq!(content_length, LIMIT);
286 }
287
288 #[gpui::test]
289 async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
290 cx.executor().allow_parking();
291
292 let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
293 let result = run_command_limited(Path::new(".").into(), cmd).await;
294
295 assert!(result.is_ok());
296 let output = result.unwrap();
297
298 assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
299
300 let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
301 let content_end = output.rfind("\n```").unwrap_or(output.len());
302 let content_length = content_end - content_start;
303
304 assert!(content_length <= LIMIT);
305 }
306}