diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 6bdd92e7855c838acaf5d0e1a7e1a76415bdf9e7..fd3192bdf043294224bce443cedafd963067579e 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1487,6 +1487,7 @@ impl Thread { tool_use_id.clone(), tool_name, output, + cx, ); cx.emit(ThreadEvent::ToolFinished { @@ -1831,7 +1832,7 @@ impl Thread { )); self.tool_use - .insert_tool_output(tool_use_id.clone(), tool_name, err); + .insert_tool_output(tool_use_id.clone(), tool_name, err, cx); cx.emit(ThreadEvent::ToolFinished { tool_use_id, diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 5dd9e119c3830e12aeac9dca605c603b8db533ec..b71c0348c39188bebb59e9eaf36481f425ca1323 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -7,10 +7,11 @@ use futures::FutureExt as _; use futures::future::Shared; use gpui::{App, SharedString, Task}; use language_model::{ - LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, - LanguageModelToolUseId, MessageContent, Role, + LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult, + LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, }; use ui::IconName; +use util::truncate_lines_to_byte_limit; use crate::thread::MessageId; use crate::thread_store::SerializedMessage; @@ -331,9 +332,32 @@ impl ToolUseState { tool_use_id: LanguageModelToolUseId, tool_name: Arc, output: Result, + cx: &App, ) -> Option { match output { Ok(tool_result) => { + let model_registry = LanguageModelRegistry::read_global(cx); + + const BYTES_PER_TOKEN_ESTIMATE: usize = 3; + + // Protect from clearly large output + let tool_output_limit = model_registry + .default_model() + .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE) + .unwrap_or(usize::MAX); + + let tool_result = if tool_result.len() <= tool_output_limit { + tool_result + } else { + let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit); + + format!( + "Tool result too long. The first {} bytes:\n\n{}", + truncated.len(), + truncated + ) + }; + self.tool_results.insert( tool_use_id.clone(), LanguageModelToolResult { diff --git a/crates/assistant_tools/src/bash_tool.rs b/crates/assistant_tools/src/bash_tool.rs index f504fb61c35a9d5072f5d2e001f344547a39abb7..2dd66cea8bb8eab8fadf2b92953a8d612cb2e1a5 100644 --- a/crates/assistant_tools/src/bash_tool.rs +++ b/crates/assistant_tools/src/bash_tool.rs @@ -1,6 +1,8 @@ use crate::schema::json_schema_for; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{ActionLog, Tool}; +use futures::io::BufReader; +use futures::{AsyncBufReadExt, AsyncReadExt}; use gpui::{App, Entity, Task}; use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use project::Project; @@ -125,29 +127,90 @@ impl Tool for BashTool { // Add 2>&1 to merge stderr into stdout for proper interleaving. let command = format!("({}) 2>&1", input.command); - let output = new_smol_command("bash") + let mut cmd = new_smol_command("bash") .arg("-c") .arg(&command) .current_dir(working_dir) - .output() - .await + .stdout(std::process::Stdio::piped()) + .spawn() .context("Failed to execute bash command")?; - let output_string = String::from_utf8_lossy(&output.stdout).to_string(); + // Capture stdout with a limit + let stdout = cmd.stdout.take().unwrap(); + let mut reader = BufReader::new(stdout); + + const MESSAGE_1: &str = "Command output too long. The first "; + const MESSAGE_2: &str = " bytes:\n\n"; + const ERR_MESSAGE_1: &str = "Command failed with exit code "; + const ERR_MESSAGE_2: &str = "\n\n"; + + const STDOUT_LIMIT: usize = 8192; + + const LIMIT: usize = STDOUT_LIMIT + - (MESSAGE_1.len() + + (STDOUT_LIMIT.ilog10() as usize + 1) // byte count + + MESSAGE_2.len() + + ERR_MESSAGE_1.len() + + 3 // status code + + ERR_MESSAGE_2.len()); + + // Read one more byte to determine whether the output was truncated + let mut buffer = vec![0; LIMIT + 1]; + let bytes_read = reader.read(&mut buffer).await?; + + // Repeatedly fill the output reader's buffer without copying it. + loop { + let skipped_bytes = reader.fill_buf().await?; + if skipped_bytes.is_empty() { + break; + } + let skipped_bytes_len = skipped_bytes.len(); + reader.consume_unpin(skipped_bytes_len); + } + + let output_bytes = &buffer[..bytes_read]; + + // Let the process continue running + let status = cmd.status().await.context("Failed to get command status")?; + + let output_string = if bytes_read > LIMIT { + // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in + // multi-byte characters. + let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n'); + let output_string = String::from_utf8_lossy( + &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())], + ); + + format!( + "{}{}{}{}", + MESSAGE_1, + output_string.len(), + MESSAGE_2, + output_string + ) + } else { + String::from_utf8_lossy(&output_bytes).into() + }; - if output.status.success() { + let output_with_status = if status.success() { if output_string.is_empty() { - Ok("Command executed successfully.".to_string()) + "Command executed successfully.".to_string() } else { - Ok(output_string) + output_string.to_string() } } else { - Ok(format!( - "Command failed with exit code {}\n{}", - output.status.code().unwrap_or(-1), - &output_string - )) - } + format!( + "{}{}{}{}", + ERR_MESSAGE_1, + status.code().unwrap_or(-1), + ERR_MESSAGE_2, + output_string, + ) + }; + + debug_assert!(output_with_status.len() <= STDOUT_LIMIT); + + Ok(output_with_status) }) } } diff --git a/crates/util/src/util.rs b/crates/util/src/util.rs index e1917de3037dde6e960566c97bee71a1b9641375..c6cf114c296a5d584813ba498626b4ba23dff8c0 100644 --- a/crates/util/src/util.rs +++ b/crates/util/src/util.rs @@ -145,6 +145,66 @@ pub fn truncate_lines_and_trailoff(s: &str, max_lines: usize) -> String { } } +/// Truncates the string at a character boundary, such that the result is less than `max_bytes` in +/// length. +pub fn truncate_to_byte_limit(s: &str, max_bytes: usize) -> &str { + if s.len() < max_bytes { + return s; + } + + for i in (0..max_bytes).rev() { + if s.is_char_boundary(i) { + return &s[..i]; + } + } + + "" +} + +/// Takes a prefix of complete lines which fit within the byte limit. If the first line is longer +/// than the limit, truncates at a character boundary. +pub fn truncate_lines_to_byte_limit(s: &str, max_bytes: usize) -> &str { + if s.len() < max_bytes { + return s; + } + + for i in (0..max_bytes).rev() { + if s.is_char_boundary(i) { + if s.as_bytes()[i] == b'\n' { + // Since the i-th character is \n, valid to slice at i + 1. + return &s[..i + 1]; + } + } + } + + truncate_to_byte_limit(s, max_bytes) +} + +#[test] +fn test_truncate_lines_to_byte_limit() { + let text = "Line 1\nLine 2\nLine 3\nLine 4"; + + // Limit that includes all lines + assert_eq!(truncate_lines_to_byte_limit(text, 100), text); + + // Exactly the first line + assert_eq!(truncate_lines_to_byte_limit(text, 7), "Line 1\n"); + + // Limit between lines + assert_eq!(truncate_lines_to_byte_limit(text, 13), "Line 1\n"); + assert_eq!(truncate_lines_to_byte_limit(text, 20), "Line 1\nLine 2\n"); + + // Limit before first newline + assert_eq!(truncate_lines_to_byte_limit(text, 6), "Line "); + + // Test with non-ASCII characters + let text_utf8 = "Line 1\nLíne 2\nLine 3"; + assert_eq!( + truncate_lines_to_byte_limit(text_utf8, 15), + "Line 1\nLíne 2\n" + ); +} + pub fn post_inc + AddAssign + Copy>(value: &mut T) -> T { let prev = *value; *value += T::from(1);