agent: Truncate bash tool output (#28291)

Agus Zubiaga and Michael Sloan created

The bash tool will now truncate its output to 8192 bytes (or the last
newline before that).

We also added a global limit for any tool that produces a clearly large
output that wouldn't fit the context window.

Release Notes:

- agent: Truncate bash tool output

---------

Co-authored-by: Michael Sloan <mgsloan@gmail.com>

Change summary

crates/agent/src/thread.rs              |  3 
crates/agent/src/tool_use.rs            | 28 +++++++
crates/assistant_tools/src/bash_tool.rs | 89 +++++++++++++++++++++++---
crates/util/src/util.rs                 | 60 ++++++++++++++++++
4 files changed, 164 insertions(+), 16 deletions(-)

Detailed changes

crates/agent/src/thread.rs ๐Ÿ”—

@@ -1487,6 +1487,7 @@ impl Thread {
                             tool_use_id.clone(),
                             tool_name,
                             output,
+                            cx,
                         );
 
                         cx.emit(ThreadEvent::ToolFinished {
@@ -1831,7 +1832,7 @@ impl Thread {
         ));
 
         self.tool_use
-            .insert_tool_output(tool_use_id.clone(), tool_name, err);
+            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
 
         cx.emit(ThreadEvent::ToolFinished {
             tool_use_id,

crates/agent/src/tool_use.rs ๐Ÿ”—

@@ -7,10 +7,11 @@ use futures::FutureExt as _;
 use futures::future::Shared;
 use gpui::{App, SharedString, Task};
 use language_model::{
-    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
-    LanguageModelToolUseId, MessageContent, Role,
+    LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
+    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 };
 use ui::IconName;
+use util::truncate_lines_to_byte_limit;
 
 use crate::thread::MessageId;
 use crate::thread_store::SerializedMessage;
@@ -331,9 +332,32 @@ impl ToolUseState {
         tool_use_id: LanguageModelToolUseId,
         tool_name: Arc<str>,
         output: Result<String>,
+        cx: &App,
     ) -> Option<PendingToolUse> {
         match output {
             Ok(tool_result) => {
+                let model_registry = LanguageModelRegistry::read_global(cx);
+
+                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
+
+                // Protect from clearly large output
+                let tool_output_limit = model_registry
+                    .default_model()
+                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
+                    .unwrap_or(usize::MAX);
+
+                let tool_result = if tool_result.len() <= tool_output_limit {
+                    tool_result
+                } else {
+                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
+
+                    format!(
+                        "Tool result too long. The first {} bytes:\n\n{}",
+                        truncated.len(),
+                        truncated
+                    )
+                };
+
                 self.tool_results.insert(
                     tool_use_id.clone(),
                     LanguageModelToolResult {

crates/assistant_tools/src/bash_tool.rs ๐Ÿ”—

@@ -1,6 +1,8 @@
 use crate::schema::json_schema_for;
 use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::{ActionLog, Tool};
+use futures::io::BufReader;
+use futures::{AsyncBufReadExt, AsyncReadExt};
 use gpui::{App, Entity, Task};
 use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 use project::Project;
@@ -125,29 +127,90 @@ impl Tool for BashTool {
             // Add 2>&1 to merge stderr into stdout for proper interleaving.
             let command = format!("({}) 2>&1", input.command);
 
-            let output = new_smol_command("bash")
+            let mut cmd = new_smol_command("bash")
                 .arg("-c")
                 .arg(&command)
                 .current_dir(working_dir)
-                .output()
-                .await
+                .stdout(std::process::Stdio::piped())
+                .spawn()
                 .context("Failed to execute bash command")?;
 
-            let output_string = String::from_utf8_lossy(&output.stdout).to_string();
+            // Capture stdout with a limit
+            let stdout = cmd.stdout.take().unwrap();
+            let mut reader = BufReader::new(stdout);
+
+            const MESSAGE_1: &str = "Command output too long. The first ";
+            const MESSAGE_2: &str = " bytes:\n\n";
+            const ERR_MESSAGE_1: &str = "Command failed with exit code ";
+            const ERR_MESSAGE_2: &str = "\n\n";
+
+            const STDOUT_LIMIT: usize = 8192;
+
+            const LIMIT: usize = STDOUT_LIMIT
+                - (MESSAGE_1.len()
+                    + (STDOUT_LIMIT.ilog10() as usize + 1) // byte count
+                    + MESSAGE_2.len()
+                    + ERR_MESSAGE_1.len()
+                    + 3 // status code
+                    + ERR_MESSAGE_2.len());
+
+            // Read one more byte to determine whether the output was truncated
+            let mut buffer = vec![0; LIMIT + 1];
+            let bytes_read = reader.read(&mut buffer).await?;
+
+            // Repeatedly fill the output reader's buffer without copying it.
+            loop {
+                let skipped_bytes = reader.fill_buf().await?;
+                if skipped_bytes.is_empty() {
+                    break;
+                }
+                let skipped_bytes_len = skipped_bytes.len();
+                reader.consume_unpin(skipped_bytes_len);
+            }
+
+            let output_bytes = &buffer[..bytes_read];
+
+            // Let the process continue running
+            let status = cmd.status().await.context("Failed to get command status")?;
+
+            let output_string = if bytes_read > LIMIT {
+                // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
+                // multi-byte characters.
+                let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
+                let output_string = String::from_utf8_lossy(
+                    &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())],
+                );
+
+                format!(
+                    "{}{}{}{}",
+                    MESSAGE_1,
+                    output_string.len(),
+                    MESSAGE_2,
+                    output_string
+                )
+            } else {
+                String::from_utf8_lossy(&output_bytes).into()
+            };
 
-            if output.status.success() {
+            let output_with_status = if status.success() {
                 if output_string.is_empty() {
-                    Ok("Command executed successfully.".to_string())
+                    "Command executed successfully.".to_string()
                 } else {
-                    Ok(output_string)
+                    output_string.to_string()
                 }
             } else {
-                Ok(format!(
-                    "Command failed with exit code {}\n{}",
-                    output.status.code().unwrap_or(-1),
-                    &output_string
-                ))
-            }
+                format!(
+                    "{}{}{}{}",
+                    ERR_MESSAGE_1,
+                    status.code().unwrap_or(-1),
+                    ERR_MESSAGE_2,
+                    output_string,
+                )
+            };
+
+            debug_assert!(output_with_status.len() <= STDOUT_LIMIT);
+
+            Ok(output_with_status)
         })
     }
 }

crates/util/src/util.rs ๐Ÿ”—

@@ -145,6 +145,66 @@ pub fn truncate_lines_and_trailoff(s: &str, max_lines: usize) -> String {
     }
 }
 
+/// Truncates the string at a character boundary, such that the result is less than `max_bytes` in
+/// length.
+pub fn truncate_to_byte_limit(s: &str, max_bytes: usize) -> &str {
+    if s.len() < max_bytes {
+        return s;
+    }
+
+    for i in (0..max_bytes).rev() {
+        if s.is_char_boundary(i) {
+            return &s[..i];
+        }
+    }
+
+    ""
+}
+
+/// Takes a prefix of complete lines which fit within the byte limit. If the first line is longer
+/// than the limit, truncates at a character boundary.
+pub fn truncate_lines_to_byte_limit(s: &str, max_bytes: usize) -> &str {
+    if s.len() < max_bytes {
+        return s;
+    }
+
+    for i in (0..max_bytes).rev() {
+        if s.is_char_boundary(i) {
+            if s.as_bytes()[i] == b'\n' {
+                // Since the i-th character is \n, valid to slice at i + 1.
+                return &s[..i + 1];
+            }
+        }
+    }
+
+    truncate_to_byte_limit(s, max_bytes)
+}
+
+#[test]
+fn test_truncate_lines_to_byte_limit() {
+    let text = "Line 1\nLine 2\nLine 3\nLine 4";
+
+    // Limit that includes all lines
+    assert_eq!(truncate_lines_to_byte_limit(text, 100), text);
+
+    // Exactly the first line
+    assert_eq!(truncate_lines_to_byte_limit(text, 7), "Line 1\n");
+
+    // Limit between lines
+    assert_eq!(truncate_lines_to_byte_limit(text, 13), "Line 1\n");
+    assert_eq!(truncate_lines_to_byte_limit(text, 20), "Line 1\nLine 2\n");
+
+    // Limit before first newline
+    assert_eq!(truncate_lines_to_byte_limit(text, 6), "Line ");
+
+    // Test with non-ASCII characters
+    let text_utf8 = "Line 1\nLรญne 2\nLine 3";
+    assert_eq!(
+        truncate_lines_to_byte_limit(text_utf8, 15),
+        "Line 1\nLรญne 2\n"
+    );
+}
+
 pub fn post_inc<T: From<u8> + AddAssign<T> + Copy>(value: &mut T) -> T {
     let prev = *value;
     *value += T::from(1);