Fix bash tool output (#28391)

Agus Zubiaga created

Change summary

crates/assistant_tools/src/bash_tool.rs | 272 +++++++++++++++++---------
1 file changed, 174 insertions(+), 98 deletions(-)

Detailed changes

crates/assistant_tools/src/bash_tool.rs 🔗

@@ -3,7 +3,7 @@ use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::{ActionLog, Tool};
 use futures::io::BufReader;
 use futures::{AsyncBufReadExt, AsyncReadExt};
-use gpui::{App, Entity, Task};
+use gpui::{App, AppContext, Entity, Task};
 use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 use project::Project;
 use schemars::JsonSchema;
@@ -123,108 +123,184 @@ impl Tool for BashTool {
             worktree.read(cx).abs_path()
         };
 
-        cx.spawn(async move |_| {
-            // Add 2>&1 to merge stderr into stdout for proper interleaving.
-            let command = format!("({}) 2>&1", input.command);
-
-            let mut cmd = new_smol_command("bash")
-                .arg("-c")
-                .arg(&command)
-                .current_dir(working_dir)
-                .stdout(std::process::Stdio::piped())
-                .spawn()
-                .context("Failed to execute bash command")?;
-
-            // Capture stdout with a limit
-            let stdout = cmd.stdout.take().unwrap();
-            let mut reader = BufReader::new(stdout);
-
-            const MESSAGE_1: &str = "Command output too long. The first ";
-            const MESSAGE_2: &str = " bytes:\n\n";
-            const ERR_MESSAGE_1: &str = "Command failed with exit code ";
-            const ERR_MESSAGE_2: &str = "\n\n";
-
-            const STDOUT_LIMIT: usize = 8192;
-
-            const LIMIT: usize = STDOUT_LIMIT
-                - (MESSAGE_1.len()
-                    + (STDOUT_LIMIT.ilog10() as usize + 1) // byte count
-                    + MESSAGE_2.len()
-                    + ERR_MESSAGE_1.len()
-                    + 3 // status code
-                    + ERR_MESSAGE_2.len());
-
-            // Read one more byte to determine whether the output was truncated
-            let mut buffer = vec![0; LIMIT + 1];
-            let mut bytes_read = 0;
-
-            // Read until we reach the limit
-            loop {
-                let read = reader.read(&mut buffer).await?;
-                if read == 0 {
-                    break;
-                }
+        cx.background_spawn(run_command_limited(working_dir, input.command))
+    }
+}
 
-                bytes_read += read;
-                if bytes_read > LIMIT {
-                    bytes_read = LIMIT + 1;
-                    break;
-                }
-            }
+const LIMIT: usize = 16 * 1024;
+
+async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
+    // Add 2>&1 to merge stderr into stdout for proper interleaving.
+    let command = format!("({}) 2>&1", command);
+
+    let mut cmd = new_smol_command("bash")
+        .arg("-c")
+        .arg(&command)
+        .current_dir(working_dir)
+        .stdout(std::process::Stdio::piped())
+        .spawn()
+        .context("Failed to execute bash command")?;
+
+    // Capture stdout with a limit
+    let stdout = cmd.stdout.take().unwrap();
+    let mut reader = BufReader::new(stdout);
+
+    // Read one more byte to determine whether the output was truncated
+    let mut buffer = vec![0; LIMIT + 1];
+    let mut bytes_read = 0;
+
+    // Read until we reach the limit
+    loop {
+        let read = reader.read(&mut buffer[bytes_read..]).await?;
+        if read == 0 {
+            break;
+        }
 
-            // Repeatedly fill the output reader's buffer without copying it.
-            loop {
-                let skipped_bytes = reader.fill_buf().await?;
-                if skipped_bytes.is_empty() {
-                    break;
-                }
-                let skipped_bytes_len = skipped_bytes.len();
-                reader.consume_unpin(skipped_bytes_len);
-            }
+        bytes_read += read;
+        if bytes_read > LIMIT {
+            bytes_read = LIMIT + 1;
+            break;
+        }
+    }
 
-            let output_bytes = &buffer[..bytes_read];
-
-            // Let the process continue running
-            let status = cmd.status().await.context("Failed to get command status")?;
-
-            let output_string = if bytes_read > LIMIT {
-                // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
-                // multi-byte characters.
-                let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
-                let output_string = String::from_utf8_lossy(
-                    &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())],
-                );
-
-                format!(
-                    "{}{}{}{}",
-                    MESSAGE_1,
-                    output_string.len(),
-                    MESSAGE_2,
-                    output_string
-                )
-            } else {
-                String::from_utf8_lossy(&output_bytes).into()
-            };
+    // Repeatedly fill the output reader's buffer without copying it.
+    loop {
+        let skipped_bytes = reader.fill_buf().await?;
+        if skipped_bytes.is_empty() {
+            break;
+        }
+        let skipped_bytes_len = skipped_bytes.len();
+        reader.consume_unpin(skipped_bytes_len);
+    }
 
-            let output_with_status = if status.success() {
-                if output_string.is_empty() {
-                    "Command executed successfully.".to_string()
-                } else {
-                    output_string.to_string()
-                }
-            } else {
-                format!(
-                    "{}{}{}{}",
-                    ERR_MESSAGE_1,
-                    status.code().unwrap_or(-1),
-                    ERR_MESSAGE_2,
-                    output_string,
-                )
-            };
+    let output_bytes = &buffer[..bytes_read.min(LIMIT)];
+
+    let status = cmd.status().await.context("Failed to get command status")?;
+
+    let output_string = if bytes_read > LIMIT {
+        // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
+        // multi-byte characters.
+        let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
+        let until_last_line = &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())];
+        let output_string = String::from_utf8_lossy(until_last_line);
+
+        format!(
+            "Command output too long. The first {} bytes:\n\n{}",
+            output_string.len(),
+            output_block(&output_string),
+        )
+    } else {
+        output_block(&String::from_utf8_lossy(&output_bytes))
+    };
+
+    let output_with_status = if status.success() {
+        if output_string.is_empty() {
+            "Command executed successfully.".to_string()
+        } else {
+            output_string.to_string()
+        }
+    } else {
+        format!(
+            "Command failed with exit code {}\n\n{}",
+            status.code().unwrap_or(-1),
+            output_string,
+        )
+    };
+
+    Ok(output_with_status)
+}
+
+fn output_block(output: &str) -> String {
+    format!(
+        "```\n{}{}```",
+        output,
+        if output.ends_with('\n') { "" } else { "\n" }
+    )
+}
+
+#[cfg(test)]
+#[cfg(not(windows))]
+mod tests {
+    use gpui::TestAppContext;
+
+    use super::*;
+
+    #[gpui::test]
+    async fn test_run_command_simple(cx: &mut TestAppContext) {
+        cx.executor().allow_parking();
+
+        let result =
+            run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
+
+        assert!(result.is_ok());
+        assert_eq!(result.unwrap(), "```\nHello, World!\n```");
+    }
+
+    #[gpui::test]
+    async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
+        cx.executor().allow_parking();
+
+        let command =
+            "echo 'stdout 1' && echo 'stderr 1' >&2 && echo 'stdout 2' && echo 'stderr 2' >&2";
+        let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
+
+        assert!(result.is_ok());
+        assert_eq!(
+            result.unwrap(),
+            "```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_multiple_output_reads(cx: &mut TestAppContext) {
+        cx.executor().allow_parking();
+
+        // Command with multiple outputs that might require multiple reads
+        let result = run_command_limited(
+            Path::new(".").into(),
+            "echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
+        )
+        .await;
+
+        assert!(result.is_ok());
+        assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
+    }
+
+    #[gpui::test]
+    async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
+        cx.executor().allow_parking();
+
+        let cmd = format!("echo '{}';", "X".repeat(LIMIT * 2));
+
+        let result = run_command_limited(Path::new(".").into(), cmd).await;
+
+        assert!(result.is_ok());
+        let output = result.unwrap();
+
+        let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
+        let content_end = output.rfind("\n```").unwrap_or(output.len());
+        let content_length = content_end - content_start;
+
+        // Output should be exactly the limit
+        assert_eq!(content_length, LIMIT);
+    }
+
+    #[gpui::test]
+    async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
+        cx.executor().allow_parking();
+
+        let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
+        let result = run_command_limited(Path::new(".").into(), cmd).await;
+
+        assert!(result.is_ok());
+        let output = result.unwrap();
+
+        assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
 
-            debug_assert!(output_with_status.len() <= STDOUT_LIMIT);
+        let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
+        let content_end = output.rfind("\n```").unwrap_or(output.len());
+        let content_length = content_end - content_start;
 
-            Ok(output_with_status)
-        })
+        assert!(content_length <= LIMIT);
     }
 }