bash_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::{ActionLog, Tool};
  4use futures::io::BufReader;
  5use futures::{AsyncBufReadExt, AsyncReadExt};
  6use gpui::{App, Entity, Task};
  7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use std::path::Path;
 12use std::sync::Arc;
 13use ui::IconName;
 14use util::command::new_smol_command;
 15use util::markdown::MarkdownString;
 16
 17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 18pub struct BashToolInput {
 19    /// The bash command to execute as a one-liner.
 20    command: String,
 21    /// Working directory for the command. This must be one of the root directories of the project.
 22    cd: String,
 23}
 24
 25pub struct BashTool;
 26
 27impl Tool for BashTool {
 28    fn name(&self) -> String {
 29        "bash".to_string()
 30    }
 31
 32    fn needs_confirmation(&self) -> bool {
 33        true
 34    }
 35
 36    fn description(&self) -> String {
 37        include_str!("./bash_tool/description.md").to_string()
 38    }
 39
 40    fn icon(&self) -> IconName {
 41        IconName::Terminal
 42    }
 43
 44    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
 45        json_schema_for::<BashToolInput>(format)
 46    }
 47
 48    fn ui_text(&self, input: &serde_json::Value) -> String {
 49        match serde_json::from_value::<BashToolInput>(input.clone()) {
 50            Ok(input) => {
 51                let mut lines = input.command.lines();
 52                let first_line = lines.next().unwrap_or_default();
 53                let remaining_line_count = lines.count();
 54                match remaining_line_count {
 55                    0 => MarkdownString::inline_code(&first_line).0,
 56                    1 => {
 57                        MarkdownString::inline_code(&format!(
 58                            "{} - {} more line",
 59                            first_line, remaining_line_count
 60                        ))
 61                        .0
 62                    }
 63                    n => {
 64                        MarkdownString::inline_code(&format!("{} - {} more lines", first_line, n)).0
 65                    }
 66                }
 67            }
 68            Err(_) => "Run bash command".to_string(),
 69        }
 70    }
 71
 72    fn run(
 73        self: Arc<Self>,
 74        input: serde_json::Value,
 75        _messages: &[LanguageModelRequestMessage],
 76        project: Entity<Project>,
 77        _action_log: Entity<ActionLog>,
 78        cx: &mut App,
 79    ) -> Task<Result<String>> {
 80        let input: BashToolInput = match serde_json::from_value(input) {
 81            Ok(input) => input,
 82            Err(err) => return Task::ready(Err(anyhow!(err))),
 83        };
 84
 85        let project = project.read(cx);
 86        let input_path = Path::new(&input.cd);
 87        let working_dir = if input.cd == "." {
 88            // Accept "." as meaning "the one worktree" if we only have one worktree.
 89            let mut worktrees = project.worktrees(cx);
 90
 91            let only_worktree = match worktrees.next() {
 92                Some(worktree) => worktree,
 93                None => return Task::ready(Err(anyhow!("No worktrees found in the project"))),
 94            };
 95
 96            if worktrees.next().is_some() {
 97                return Task::ready(Err(anyhow!(
 98                    "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
 99                )));
100            }
101
102            only_worktree.read(cx).abs_path()
103        } else if input_path.is_absolute() {
104            // Absolute paths are allowed, but only if they're in one of the project's worktrees.
105            if !project
106                .worktrees(cx)
107                .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
108            {
109                return Task::ready(Err(anyhow!(
110                    "The absolute path must be within one of the project's worktrees"
111                )));
112            }
113
114            input_path.into()
115        } else {
116            let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
117                return Task::ready(Err(anyhow!(
118                    "`cd` directory {} not found in the project",
119                    &input.cd
120                )));
121            };
122
123            worktree.read(cx).abs_path()
124        };
125
126        cx.spawn(async move |_| {
127            // Add 2>&1 to merge stderr into stdout for proper interleaving.
128            let command = format!("({}) 2>&1", input.command);
129
130            let mut cmd = new_smol_command("bash")
131                .arg("-c")
132                .arg(&command)
133                .current_dir(working_dir)
134                .stdout(std::process::Stdio::piped())
135                .spawn()
136                .context("Failed to execute bash command")?;
137
138            // Capture stdout with a limit
139            let stdout = cmd.stdout.take().unwrap();
140            let mut reader = BufReader::new(stdout);
141
142            const MESSAGE_1: &str = "Command output too long. The first ";
143            const MESSAGE_2: &str = " bytes:\n\n";
144            const ERR_MESSAGE_1: &str = "Command failed with exit code ";
145            const ERR_MESSAGE_2: &str = "\n\n";
146
147            const STDOUT_LIMIT: usize = 8192;
148
149            const LIMIT: usize = STDOUT_LIMIT
150                - (MESSAGE_1.len()
151                    + (STDOUT_LIMIT.ilog10() as usize + 1) // byte count
152                    + MESSAGE_2.len()
153                    + ERR_MESSAGE_1.len()
154                    + 3 // status code
155                    + ERR_MESSAGE_2.len());
156
157            // Read one more byte to determine whether the output was truncated
158            let mut buffer = vec![0; LIMIT + 1];
159            let bytes_read = reader.read(&mut buffer).await?;
160
161            // Repeatedly fill the output reader's buffer without copying it.
162            loop {
163                let skipped_bytes = reader.fill_buf().await?;
164                if skipped_bytes.is_empty() {
165                    break;
166                }
167                let skipped_bytes_len = skipped_bytes.len();
168                reader.consume_unpin(skipped_bytes_len);
169            }
170
171            let output_bytes = &buffer[..bytes_read];
172
173            // Let the process continue running
174            let status = cmd.status().await.context("Failed to get command status")?;
175
176            let output_string = if bytes_read > LIMIT {
177                // Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
178                // multi-byte characters.
179                let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
180                let output_string = String::from_utf8_lossy(
181                    &output_bytes[..last_line_ix.unwrap_or(output_bytes.len())],
182                );
183
184                format!(
185                    "{}{}{}{}",
186                    MESSAGE_1,
187                    output_string.len(),
188                    MESSAGE_2,
189                    output_string
190                )
191            } else {
192                String::from_utf8_lossy(&output_bytes).into()
193            };
194
195            let output_with_status = if status.success() {
196                if output_string.is_empty() {
197                    "Command executed successfully.".to_string()
198                } else {
199                    output_string.to_string()
200                }
201            } else {
202                format!(
203                    "{}{}{}{}",
204                    ERR_MESSAGE_1,
205                    status.code().unwrap_or(-1),
206                    ERR_MESSAGE_2,
207                    output_string,
208                )
209            };
210
211            debug_assert!(output_with_status.len() <= STDOUT_LIMIT);
212
213            Ok(output_with_status)
214        })
215    }
216}