Detailed changes
@@ -1487,6 +1487,7 @@ impl Thread {
tool_use_id.clone(),
tool_name,
output,
+ cx,
);
cx.emit(ThreadEvent::ToolFinished {
@@ -1831,7 +1832,7 @@ impl Thread {
));
self.tool_use
- .insert_tool_output(tool_use_id.clone(), tool_name, err);
+ .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
@@ -7,10 +7,11 @@ use futures::FutureExt as _;
use futures::future::Shared;
use gpui::{App, SharedString, Task};
use language_model::{
- LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
- LanguageModelToolUseId, MessageContent, Role,
+ LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
+ LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
};
use ui::IconName;
+use util::truncate_lines_to_byte_limit;
use crate::thread::MessageId;
use crate::thread_store::SerializedMessage;
@@ -331,9 +332,32 @@ impl ToolUseState {
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
output: Result<String>,
+ cx: &App,
) -> Option<PendingToolUse> {
match output {
Ok(tool_result) => {
+ let model_registry = LanguageModelRegistry::read_global(cx);
+
+ const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
+
+ // Protect from clearly large output
+ let tool_output_limit = model_registry
+ .default_model()
+ .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
+ .unwrap_or(usize::MAX);
+
+ let tool_result = if tool_result.len() <= tool_output_limit {
+ tool_result
+ } else {
+ let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
+
+ format!(
+ "Tool result too long. The first {} bytes:\n\n{}",
+ truncated.len(),
+ truncated
+ )
+ };
+
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
@@ -1,6 +1,8 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
+use futures::io::BufReader;
+use futures::{AsyncBufReadExt, AsyncReadExt};
use gpui::{App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -125,29 +127,90 @@ impl Tool for BashTool {
// Add 2>&1 to merge stderr into stdout for proper interleaving.
let command = format!("({}) 2>&1", input.command);
- let output = new_smol_command("bash")
+ let mut cmd = new_smol_command("bash")
.arg("-c")
.arg(&command)
.current_dir(working_dir)
- .output()
- .await
+ .stdout(std::process::Stdio::piped())
+ .spawn()
.context("Failed to execute bash command")?;
- let output_string = String::from_utf8_lossy(&output.stdout).to_string();
+ // Capture stdout with a limit
+ let stdout = cmd.stdout.take().unwrap();
+ let mut reader = BufReader::new(stdout);
+
+ const MESSAGE_1: &str = "Command output too long. The first ";
+ const MESSAGE_2: &str = " bytes:\n\n";
+ const ERR_MESSAGE_1: &str = "Command failed with exit code ";
+ const ERR_MESSAGE_2: &str = "\n\n";
+
+ const STDOUT_LIMIT: usize = 8192;
+
+ const LIMIT: usize = STDOUT_LIMIT
+ - (MESSAGE_1.len()
+ + (STDOUT_LIMIT.ilog10() as usize + 1) // byte count
+ + MESSAGE_2.len()
+ + ERR_MESSAGE_1.len()
+ + 3 // status code
+ + ERR_MESSAGE_2.len());
+
+ // Read one more byte to determine whether the output was truncated
+ let mut buffer = vec![0; LIMIT + 1];
+ let bytes_read = reader.read(&mut buffer).await?;
+
+ // Repeatedly fill the output reader's buffer without copying it.
+ loop {
+ let skipped_bytes = reader.fill_buf().await?;
+ if skipped_bytes.is_empty() {
+ break;
+ }
+ let skipped_bytes_len = skipped_bytes.len();
+ reader.consume_unpin(skipped_bytes_len);
+ }
+
+ let output_bytes = &buffer[..bytes_read];
+
+ // Let the process continue running
+ let status = cmd.status().await.context("Failed to get command status")?;
+
+ let output_string = if bytes_read > LIMIT {
+ // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
+ // multi-byte characters.
+ let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
+ let output_string = String::from_utf8_lossy(
+ &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())],
+ );
+
+ format!(
+ "{}{}{}{}",
+ MESSAGE_1,
+ output_string.len(),
+ MESSAGE_2,
+ output_string
+ )
+ } else {
+ String::from_utf8_lossy(&output_bytes).into()
+ };
- if output.status.success() {
+ let output_with_status = if status.success() {
if output_string.is_empty() {
- Ok("Command executed successfully.".to_string())
+ "Command executed successfully.".to_string()
} else {
- Ok(output_string)
+ output_string.to_string()
}
} else {
- Ok(format!(
- "Command failed with exit code {}\n{}",
- output.status.code().unwrap_or(-1),
- &output_string
- ))
- }
+ format!(
+ "{}{}{}{}",
+ ERR_MESSAGE_1,
+ status.code().unwrap_or(-1),
+ ERR_MESSAGE_2,
+ output_string,
+ )
+ };
+
+ debug_assert!(output_with_status.len() <= STDOUT_LIMIT);
+
+ Ok(output_with_status)
})
}
}
@@ -145,6 +145,66 @@ pub fn truncate_lines_and_trailoff(s: &str, max_lines: usize) -> String {
}
}
+/// Truncates the string at a character boundary, such that the result is less than `max_bytes` in
+/// length.
+pub fn truncate_to_byte_limit(s: &str, max_bytes: usize) -> &str {
+ if s.len() < max_bytes {
+ return s;
+ }
+
+ for i in (0..max_bytes).rev() {
+ if s.is_char_boundary(i) {
+ return &s[..i];
+ }
+ }
+
+ ""
+}
+
+/// Takes a prefix of complete lines which fit within the byte limit. If the first line is longer
+/// than the limit, truncates at a character boundary.
+pub fn truncate_lines_to_byte_limit(s: &str, max_bytes: usize) -> &str {
+ if s.len() < max_bytes {
+ return s;
+ }
+
+ for i in (0..max_bytes).rev() {
+ if s.is_char_boundary(i) {
+ if s.as_bytes()[i] == b'\n' {
+ // Since the i-th character is \n, valid to slice at i + 1.
+ return &s[..i + 1];
+ }
+ }
+ }
+
+ truncate_to_byte_limit(s, max_bytes)
+}
+
+#[test]
+fn test_truncate_lines_to_byte_limit() {
+ let text = "Line 1\nLine 2\nLine 3\nLine 4";
+
+ // Limit that includes all lines
+ assert_eq!(truncate_lines_to_byte_limit(text, 100), text);
+
+ // Exactly the first line
+ assert_eq!(truncate_lines_to_byte_limit(text, 7), "Line 1\n");
+
+ // Limit between lines
+ assert_eq!(truncate_lines_to_byte_limit(text, 13), "Line 1\n");
+ assert_eq!(truncate_lines_to_byte_limit(text, 20), "Line 1\nLine 2\n");
+
+ // Limit before first newline
+ assert_eq!(truncate_lines_to_byte_limit(text, 6), "Line ");
+
+ // Test with non-ASCII characters
+ let text_utf8 = "Line 1\nLรญne 2\nLine 3";
+ assert_eq!(
+ truncate_lines_to_byte_limit(text_utf8, 15),
+ "Line 1\nLรญne 2\n"
+ );
+}
+
pub fn post_inc<T: From<u8> + AddAssign<T> + Copy>(value: &mut T) -> T {
let prev = *value;
*value += T::from(1);