bash_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::{ActionLog, Tool};
  4use futures::io::BufReader;
  5use futures::{AsyncBufReadExt, AsyncReadExt};
  6use gpui::{App, Entity, Task};
  7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use std::path::Path;
 12use std::sync::Arc;
 13use ui::IconName;
 14use util::command::new_smol_command;
 15use util::markdown::MarkdownString;
 16
 17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 18pub struct BashToolInput {
 19    /// The bash one-liner command to execute.
 20    command: String,
 21    /// Working directory for the command. This must be one of the root directories of the project.
 22    cd: String,
 23}
 24
 25pub struct BashTool;
 26
 27impl Tool for BashTool {
 28    fn name(&self) -> String {
 29        "bash".to_string()
 30    }
 31
 32    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 33        true
 34    }
 35
 36    fn description(&self) -> String {
 37        include_str!("./bash_tool/description.md").to_string()
 38    }
 39
 40    fn icon(&self) -> IconName {
 41        IconName::Terminal
 42    }
 43
 44    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
 45        json_schema_for::<BashToolInput>(format)
 46    }
 47
 48    fn ui_text(&self, input: &serde_json::Value) -> String {
 49        match serde_json::from_value::<BashToolInput>(input.clone()) {
 50            Ok(input) => {
 51                let mut lines = input.command.lines();
 52                let first_line = lines.next().unwrap_or_default();
 53                let remaining_line_count = lines.count();
 54                match remaining_line_count {
 55                    0 => MarkdownString::inline_code(&first_line).0,
 56                    1 => {
 57                        MarkdownString::inline_code(&format!(
 58                            "{} - {} more line",
 59                            first_line, remaining_line_count
 60                        ))
 61                        .0
 62                    }
 63                    n => {
 64                        MarkdownString::inline_code(&format!("{} - {} more lines", first_line, n)).0
 65                    }
 66                }
 67            }
 68            Err(_) => "Run bash command".to_string(),
 69        }
 70    }
 71
 72    fn run(
 73        self: Arc<Self>,
 74        input: serde_json::Value,
 75        _messages: &[LanguageModelRequestMessage],
 76        project: Entity<Project>,
 77        _action_log: Entity<ActionLog>,
 78        cx: &mut App,
 79    ) -> Task<Result<String>> {
 80        let input: BashToolInput = match serde_json::from_value(input) {
 81            Ok(input) => input,
 82            Err(err) => return Task::ready(Err(anyhow!(err))),
 83        };
 84
 85        let project = project.read(cx);
 86        let input_path = Path::new(&input.cd);
 87        let working_dir = if input.cd == "." {
 88            // Accept "." as meaning "the one worktree" if we only have one worktree.
 89            let mut worktrees = project.worktrees(cx);
 90
 91            let only_worktree = match worktrees.next() {
 92                Some(worktree) => worktree,
 93                None => return Task::ready(Err(anyhow!("No worktrees found in the project"))),
 94            };
 95
 96            if worktrees.next().is_some() {
 97                return Task::ready(Err(anyhow!(
 98                    "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
 99                )));
100            }
101
102            only_worktree.read(cx).abs_path()
103        } else if input_path.is_absolute() {
104            // Absolute paths are allowed, but only if they're in one of the project's worktrees.
105            if !project
106                .worktrees(cx)
107                .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
108            {
109                return Task::ready(Err(anyhow!(
110                    "The absolute path must be within one of the project's worktrees"
111                )));
112            }
113
114            input_path.into()
115        } else {
116            let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
117                return Task::ready(Err(anyhow!(
118                    "`cd` directory {} not found in the project",
119                    &input.cd
120                )));
121            };
122
123            worktree.read(cx).abs_path()
124        };
125
126        cx.spawn(async move |_| {
127            // Add 2>&1 to merge stderr into stdout for proper interleaving.
128            let command = format!("({}) 2>&1", input.command);
129
130            let mut cmd = new_smol_command("bash")
131                .arg("-c")
132                .arg(&command)
133                .current_dir(working_dir)
134                .stdout(std::process::Stdio::piped())
135                .spawn()
136                .context("Failed to execute bash command")?;
137
138            // Capture stdout with a limit
139            let stdout = cmd.stdout.take().unwrap();
140            let mut reader = BufReader::new(stdout);
141
142            const MESSAGE_1: &str = "Command output too long. The first ";
143            const MESSAGE_2: &str = " bytes:\n\n";
144            const ERR_MESSAGE_1: &str = "Command failed with exit code ";
145            const ERR_MESSAGE_2: &str = "\n\n";
146
147            const STDOUT_LIMIT: usize = 8192;
148
149            const LIMIT: usize = STDOUT_LIMIT
150                - (MESSAGE_1.len()
151                    + (STDOUT_LIMIT.ilog10() as usize + 1) // byte count
152                    + MESSAGE_2.len()
153                    + ERR_MESSAGE_1.len()
154                    + 3 // status code
155                    + ERR_MESSAGE_2.len());
156
157            // Read one more byte to determine whether the output was truncated
158            let mut buffer = vec![0; LIMIT + 1];
159            let mut bytes_read = 0;
160
161            // Read until we reach the limit
162            loop {
163                let read = reader.read(&mut buffer).await?;
164                if read == 0 {
165                    break;
166                }
167
168                bytes_read += read;
169                if bytes_read > LIMIT {
170                    bytes_read = LIMIT + 1;
171                    break;
172                }
173            }
174
175            // Repeatedly fill the output reader's buffer without copying it.
176            loop {
177                let skipped_bytes = reader.fill_buf().await?;
178                if skipped_bytes.is_empty() {
179                    break;
180                }
181                let skipped_bytes_len = skipped_bytes.len();
182                reader.consume_unpin(skipped_bytes_len);
183            }
184
185            let output_bytes = &buffer[..bytes_read];
186
187            // Let the process continue running
188            let status = cmd.status().await.context("Failed to get command status")?;
189
190            let output_string = if bytes_read > LIMIT {
191                // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
192                // multi-byte characters.
193                let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
194                let output_string = String::from_utf8_lossy(
195                    &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())],
196                );
197
198                format!(
199                    "{}{}{}{}",
200                    MESSAGE_1,
201                    output_string.len(),
202                    MESSAGE_2,
203                    output_string
204                )
205            } else {
206                String::from_utf8_lossy(&output_bytes).into()
207            };
208
209            let output_with_status = if status.success() {
210                if output_string.is_empty() {
211                    "Command executed successfully.".to_string()
212                } else {
213                    output_string.to_string()
214                }
215            } else {
216                format!(
217                    "{}{}{}{}",
218                    ERR_MESSAGE_1,
219                    status.code().unwrap_or(-1),
220                    ERR_MESSAGE_2,
221                    output_string,
222                )
223            };
224
225            debug_assert!(output_with_status.len() <= STDOUT_LIMIT);
226
227            Ok(output_with_status)
228        })
229    }
230}