bash_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::{ActionLog, Tool};
  4use futures::io::BufReader;
  5use futures::{AsyncBufReadExt, AsyncReadExt};
  6use gpui::{App, AppContext, Entity, Task};
  7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use std::path::Path;
 12use std::sync::Arc;
 13use ui::IconName;
 14use util::command::new_smol_command;
 15use util::markdown::MarkdownString;
 16
 17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 18pub struct BashToolInput {
 19    /// The bash one-liner command to execute.
 20    command: String,
 21    /// Working directory for the command. This must be one of the root directories of the project.
 22    cd: String,
 23}
 24
 25pub struct BashTool;
 26
 27impl Tool for BashTool {
 28    fn name(&self) -> String {
 29        "bash".to_string()
 30    }
 31
 32    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 33        true
 34    }
 35
 36    fn description(&self) -> String {
 37        include_str!("./bash_tool/description.md").to_string()
 38    }
 39
 40    fn icon(&self) -> IconName {
 41        IconName::Terminal
 42    }
 43
 44    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
 45        json_schema_for::<BashToolInput>(format)
 46    }
 47
 48    fn ui_text(&self, input: &serde_json::Value) -> String {
 49        match serde_json::from_value::<BashToolInput>(input.clone()) {
 50            Ok(input) => {
 51                let mut lines = input.command.lines();
 52                let first_line = lines.next().unwrap_or_default();
 53                let remaining_line_count = lines.count();
 54                match remaining_line_count {
 55                    0 => MarkdownString::inline_code(&first_line).0,
 56                    1 => {
 57                        MarkdownString::inline_code(&format!(
 58                            "{} - {} more line",
 59                            first_line, remaining_line_count
 60                        ))
 61                        .0
 62                    }
 63                    n => {
 64                        MarkdownString::inline_code(&format!("{} - {} more lines", first_line, n)).0
 65                    }
 66                }
 67            }
 68            Err(_) => "Run bash command".to_string(),
 69        }
 70    }
 71
 72    fn run(
 73        self: Arc<Self>,
 74        input: serde_json::Value,
 75        _messages: &[LanguageModelRequestMessage],
 76        project: Entity<Project>,
 77        _action_log: Entity<ActionLog>,
 78        cx: &mut App,
 79    ) -> Task<Result<String>> {
 80        let input: BashToolInput = match serde_json::from_value(input) {
 81            Ok(input) => input,
 82            Err(err) => return Task::ready(Err(anyhow!(err))),
 83        };
 84
 85        let project = project.read(cx);
 86        let input_path = Path::new(&input.cd);
 87        let working_dir = if input.cd == "." {
 88            // Accept "." as meaning "the one worktree" if we only have one worktree.
 89            let mut worktrees = project.worktrees(cx);
 90
 91            let only_worktree = match worktrees.next() {
 92                Some(worktree) => worktree,
 93                None => return Task::ready(Err(anyhow!("No worktrees found in the project"))),
 94            };
 95
 96            if worktrees.next().is_some() {
 97                return Task::ready(Err(anyhow!(
 98                    "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
 99                )));
100            }
101
102            only_worktree.read(cx).abs_path()
103        } else if input_path.is_absolute() {
104            // Absolute paths are allowed, but only if they're in one of the project's worktrees.
105            if !project
106                .worktrees(cx)
107                .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
108            {
109                return Task::ready(Err(anyhow!(
110                    "The absolute path must be within one of the project's worktrees"
111                )));
112            }
113
114            input_path.into()
115        } else {
116            let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
117                return Task::ready(Err(anyhow!(
118                    "`cd` directory {} not found in the project",
119                    &input.cd
120                )));
121            };
122
123            worktree.read(cx).abs_path()
124        };
125
126        cx.background_spawn(run_command_limited(working_dir, input.command))
127    }
128}
129
130const LIMIT: usize = 16 * 1024;
131
132async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
133    // Add 2>&1 to merge stderr into stdout for proper interleaving.
134    let command = format!("({}) 2>&1", command);
135
136    let mut cmd = new_smol_command("bash")
137        .arg("-c")
138        .arg(&command)
139        .current_dir(working_dir)
140        .stdout(std::process::Stdio::piped())
141        .spawn()
142        .context("Failed to execute bash command")?;
143
144    // Capture stdout with a limit
145    let stdout = cmd.stdout.take().unwrap();
146    let mut reader = BufReader::new(stdout);
147
148    // Read one more byte to determine whether the output was truncated
149    let mut buffer = vec![0; LIMIT + 1];
150    let mut bytes_read = 0;
151
152    // Read until we reach the limit
153    loop {
154        let read = reader.read(&mut buffer[bytes_read..]).await?;
155        if read == 0 {
156            break;
157        }
158
159        bytes_read += read;
160        if bytes_read > LIMIT {
161            bytes_read = LIMIT + 1;
162            break;
163        }
164    }
165
166    // Repeatedly fill the output reader's buffer without copying it.
167    loop {
168        let skipped_bytes = reader.fill_buf().await?;
169        if skipped_bytes.is_empty() {
170            break;
171        }
172        let skipped_bytes_len = skipped_bytes.len();
173        reader.consume_unpin(skipped_bytes_len);
174    }
175
176    let output_bytes = &buffer[..bytes_read.min(LIMIT)];
177
178    let status = cmd.status().await.context("Failed to get command status")?;
179
180    let output_string = if bytes_read > LIMIT {
181        // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
182        // multi-byte characters.
183        let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
184        let until_last_line = &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())];
185        let output_string = String::from_utf8_lossy(until_last_line);
186
187        format!(
188            "Command output too long. The first {} bytes:\n\n{}",
189            output_string.len(),
190            output_block(&output_string),
191        )
192    } else {
193        output_block(&String::from_utf8_lossy(&output_bytes))
194    };
195
196    let output_with_status = if status.success() {
197        if output_string.is_empty() {
198            "Command executed successfully.".to_string()
199        } else {
200            output_string.to_string()
201        }
202    } else {
203        format!(
204            "Command failed with exit code {}\n\n{}",
205            status.code().unwrap_or(-1),
206            output_string,
207        )
208    };
209
210    Ok(output_with_status)
211}
212
213fn output_block(output: &str) -> String {
214    format!(
215        "```\n{}{}```",
216        output,
217        if output.ends_with('\n') { "" } else { "\n" }
218    )
219}
220
221#[cfg(test)]
222#[cfg(not(windows))]
223mod tests {
224    use gpui::TestAppContext;
225
226    use super::*;
227
228    #[gpui::test]
229    async fn test_run_command_simple(cx: &mut TestAppContext) {
230        cx.executor().allow_parking();
231
232        let result =
233            run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
234
235        assert!(result.is_ok());
236        assert_eq!(result.unwrap(), "```\nHello, World!\n```");
237    }
238
239    #[gpui::test]
240    async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
241        cx.executor().allow_parking();
242
243        let command =
244            "echo 'stdout 1' && echo 'stderr 1' >&2 && echo 'stdout 2' && echo 'stderr 2' >&2";
245        let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
246
247        assert!(result.is_ok());
248        assert_eq!(
249            result.unwrap(),
250            "```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
251        );
252    }
253
254    #[gpui::test]
255    async fn test_multiple_output_reads(cx: &mut TestAppContext) {
256        cx.executor().allow_parking();
257
258        // Command with multiple outputs that might require multiple reads
259        let result = run_command_limited(
260            Path::new(".").into(),
261            "echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
262        )
263        .await;
264
265        assert!(result.is_ok());
266        assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
267    }
268
269    #[gpui::test]
270    async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
271        cx.executor().allow_parking();
272
273        let cmd = format!("echo '{}';", "X".repeat(LIMIT * 2));
274
275        let result = run_command_limited(Path::new(".").into(), cmd).await;
276
277        assert!(result.is_ok());
278        let output = result.unwrap();
279
280        let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
281        let content_end = output.rfind("\n```").unwrap_or(output.len());
282        let content_length = content_end - content_start;
283
284        // Output should be exactly the limit
285        assert_eq!(content_length, LIMIT);
286    }
287
288    #[gpui::test]
289    async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
290        cx.executor().allow_parking();
291
292        let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
293        let result = run_command_limited(Path::new(".").into(), cmd).await;
294
295        assert!(result.is_ok());
296        let output = result.unwrap();
297
298        assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
299
300        let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
301        let content_end = output.rfind("\n```").unwrap_or(output.len());
302        let content_length = content_end - content_start;
303
304        assert!(content_length <= LIMIT);
305    }
306}