@@ -3,7 +3,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncReadExt};
-use gpui::{App, Entity, Task};
+use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -123,108 +123,184 @@ impl Tool for BashTool {
worktree.read(cx).abs_path()
};
- cx.spawn(async move |_| {
- // Add 2>&1 to merge stderr into stdout for proper interleaving.
- let command = format!("({}) 2>&1", input.command);
-
- let mut cmd = new_smol_command("bash")
- .arg("-c")
- .arg(&command)
- .current_dir(working_dir)
- .stdout(std::process::Stdio::piped())
- .spawn()
- .context("Failed to execute bash command")?;
-
- // Capture stdout with a limit
- let stdout = cmd.stdout.take().unwrap();
- let mut reader = BufReader::new(stdout);
-
- const MESSAGE_1: &str = "Command output too long. The first ";
- const MESSAGE_2: &str = " bytes:\n\n";
- const ERR_MESSAGE_1: &str = "Command failed with exit code ";
- const ERR_MESSAGE_2: &str = "\n\n";
-
- const STDOUT_LIMIT: usize = 8192;
-
- const LIMIT: usize = STDOUT_LIMIT
- - (MESSAGE_1.len()
- + (STDOUT_LIMIT.ilog10() as usize + 1) // byte count
- + MESSAGE_2.len()
- + ERR_MESSAGE_1.len()
- + 3 // status code
- + ERR_MESSAGE_2.len());
-
- // Read one more byte to determine whether the output was truncated
- let mut buffer = vec![0; LIMIT + 1];
- let mut bytes_read = 0;
-
- // Read until we reach the limit
- loop {
- let read = reader.read(&mut buffer).await?;
- if read == 0 {
- break;
- }
+ cx.background_spawn(run_command_limited(working_dir, input.command))
+ }
+}
- bytes_read += read;
- if bytes_read > LIMIT {
- bytes_read = LIMIT + 1;
- break;
- }
- }
+const LIMIT: usize = 16 * 1024;
+
+async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
+ // Add 2>&1 to merge stderr into stdout for proper interleaving.
+ let command = format!("({}) 2>&1", command);
+
+ let mut cmd = new_smol_command("bash")
+ .arg("-c")
+ .arg(&command)
+ .current_dir(working_dir)
+ .stdout(std::process::Stdio::piped())
+ .spawn()
+ .context("Failed to execute bash command")?;
+
+ // Capture stdout with a limit
+ let stdout = cmd.stdout.take().unwrap();
+ let mut reader = BufReader::new(stdout);
+
+ // Read one more byte to determine whether the output was truncated
+ let mut buffer = vec![0; LIMIT + 1];
+ let mut bytes_read = 0;
+
+ // Read until we reach the limit
+ loop {
+ let read = reader.read(&mut buffer[bytes_read..]).await?;
+ if read == 0 {
+ break;
+ }
- // Repeatedly fill the output reader's buffer without copying it.
- loop {
- let skipped_bytes = reader.fill_buf().await?;
- if skipped_bytes.is_empty() {
- break;
- }
- let skipped_bytes_len = skipped_bytes.len();
- reader.consume_unpin(skipped_bytes_len);
- }
+ bytes_read += read;
+ if bytes_read > LIMIT {
+ bytes_read = LIMIT + 1;
+ break;
+ }
+ }
- let output_bytes = &buffer[..bytes_read];
-
- // Let the process continue running
- let status = cmd.status().await.context("Failed to get command status")?;
-
- let output_string = if bytes_read > LIMIT {
- // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
- // multi-byte characters.
- let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
- let output_string = String::from_utf8_lossy(
- &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())],
- );
-
- format!(
- "{}{}{}{}",
- MESSAGE_1,
- output_string.len(),
- MESSAGE_2,
- output_string
- )
- } else {
- String::from_utf8_lossy(&output_bytes).into()
- };
+ // Repeatedly fill the output reader's buffer without copying it.
+ loop {
+ let skipped_bytes = reader.fill_buf().await?;
+ if skipped_bytes.is_empty() {
+ break;
+ }
+ let skipped_bytes_len = skipped_bytes.len();
+ reader.consume_unpin(skipped_bytes_len);
+ }
- let output_with_status = if status.success() {
- if output_string.is_empty() {
- "Command executed successfully.".to_string()
- } else {
- output_string.to_string()
- }
- } else {
- format!(
- "{}{}{}{}",
- ERR_MESSAGE_1,
- status.code().unwrap_or(-1),
- ERR_MESSAGE_2,
- output_string,
- )
- };
+ let output_bytes = &buffer[..bytes_read.min(LIMIT)];
+
+ let status = cmd.status().await.context("Failed to get command status")?;
+
+ let output_string = if bytes_read > LIMIT {
+ // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
+ // multi-byte characters.
+ let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
+ let until_last_line = &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())];
+ let output_string = String::from_utf8_lossy(until_last_line);
+
+ format!(
+ "Command output too long. The first {} bytes:\n\n{}",
+ output_string.len(),
+ output_block(&output_string),
+ )
+ } else {
+ output_block(&String::from_utf8_lossy(&output_bytes))
+ };
+
+ let output_with_status = if status.success() {
+ if output_string.is_empty() {
+ "Command executed successfully.".to_string()
+ } else {
+ output_string.to_string()
+ }
+ } else {
+ format!(
+ "Command failed with exit code {}\n\n{}",
+ status.code().unwrap_or(-1),
+ output_string,
+ )
+ };
+
+ Ok(output_with_status)
+}
+
+fn output_block(output: &str) -> String {
+ format!(
+ "```\n{}{}```",
+ output,
+ if output.ends_with('\n') { "" } else { "\n" }
+ )
+}
+
+#[cfg(test)]
+#[cfg(not(windows))]
+mod tests {
+ use gpui::TestAppContext;
+
+ use super::*;
+
+ #[gpui::test]
+ async fn test_run_command_simple(cx: &mut TestAppContext) {
+ cx.executor().allow_parking();
+
+ let result =
+ run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
+
+ assert!(result.is_ok());
+ assert_eq!(result.unwrap(), "```\nHello, World!\n```");
+ }
+
+ #[gpui::test]
+ async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
+ cx.executor().allow_parking();
+
+ let command =
+ "echo 'stdout 1' && echo 'stderr 1' >&2 && echo 'stdout 2' && echo 'stderr 2' >&2";
+ let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
+
+ assert!(result.is_ok());
+ assert_eq!(
+ result.unwrap(),
+ "```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_multiple_output_reads(cx: &mut TestAppContext) {
+ cx.executor().allow_parking();
+
+ // Command with multiple outputs that might require multiple reads
+ let result = run_command_limited(
+ Path::new(".").into(),
+ "echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
+ )
+ .await;
+
+ assert!(result.is_ok());
+ assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
+ }
+
+ #[gpui::test]
+ async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
+ cx.executor().allow_parking();
+
+ let cmd = format!("echo '{}';", "X".repeat(LIMIT * 2));
+
+ let result = run_command_limited(Path::new(".").into(), cmd).await;
+
+ assert!(result.is_ok());
+ let output = result.unwrap();
+
+ let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
+ let content_end = output.rfind("\n```").unwrap_or(output.len());
+ let content_length = content_end - content_start;
+
+ // Output should be exactly the limit
+ assert_eq!(content_length, LIMIT);
+ }
+
+ #[gpui::test]
+ async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
+ cx.executor().allow_parking();
+
+ let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
+ let result = run_command_limited(Path::new(".").into(), cmd).await;
+
+ assert!(result.is_ok());
+ let output = result.unwrap();
+
+ assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
- debug_assert!(output_with_status.len() <= STDOUT_LIMIT);
+ let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
+ let content_end = output.rfind("\n```").unwrap_or(output.len());
+ let content_length = content_end - content_start;
- Ok(output_with_status)
- })
+ assert!(content_length <= LIMIT);
}
}