bash_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::{ActionLog, Tool};
  4use gpui::{App, Entity, Task};
  5use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::path::Path;
 10use std::sync::Arc;
 11use ui::IconName;
 12use util::command::new_smol_command;
 13use util::markdown::MarkdownString;
 14
 15#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 16pub struct BashToolInput {
 17    /// The bash command to execute as a one-liner.
 18    command: String,
 19    /// Working directory for the command. This must be one of the root directories of the project.
 20    cd: String,
 21}
 22
 23pub struct BashTool;
 24
 25impl Tool for BashTool {
 26    fn name(&self) -> String {
 27        "bash".to_string()
 28    }
 29
 30    fn needs_confirmation(&self) -> bool {
 31        true
 32    }
 33
 34    fn description(&self) -> String {
 35        include_str!("./bash_tool/description.md").to_string()
 36    }
 37
 38    fn icon(&self) -> IconName {
 39        IconName::Terminal
 40    }
 41
 42    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
 43        json_schema_for::<BashToolInput>(format)
 44    }
 45
 46    fn ui_text(&self, input: &serde_json::Value) -> String {
 47        match serde_json::from_value::<BashToolInput>(input.clone()) {
 48            Ok(input) => {
 49                let mut lines = input.command.lines();
 50                let first_line = lines.next().unwrap_or_default();
 51                let remaining_line_count = lines.count();
 52                match remaining_line_count {
 53                    0 => MarkdownString::inline_code(&first_line).0,
 54                    1 => {
 55                        MarkdownString::inline_code(&format!(
 56                            "{} - {} more line",
 57                            first_line, remaining_line_count
 58                        ))
 59                        .0
 60                    }
 61                    n => {
 62                        MarkdownString::inline_code(&format!("{} - {} more lines", first_line, n)).0
 63                    }
 64                }
 65            }
 66            Err(_) => "Run bash command".to_string(),
 67        }
 68    }
 69
 70    fn run(
 71        self: Arc<Self>,
 72        input: serde_json::Value,
 73        _messages: &[LanguageModelRequestMessage],
 74        project: Entity<Project>,
 75        _action_log: Entity<ActionLog>,
 76        cx: &mut App,
 77    ) -> Task<Result<String>> {
 78        let input: BashToolInput = match serde_json::from_value(input) {
 79            Ok(input) => input,
 80            Err(err) => return Task::ready(Err(anyhow!(err))),
 81        };
 82
 83        let project = project.read(cx);
 84        let input_path = Path::new(&input.cd);
 85        let working_dir = if input.cd == "." {
 86            // Accept "." as meaning "the one worktree" if we only have one worktree.
 87            let mut worktrees = project.worktrees(cx);
 88
 89            let only_worktree = match worktrees.next() {
 90                Some(worktree) => worktree,
 91                None => return Task::ready(Err(anyhow!("No worktrees found in the project"))),
 92            };
 93
 94            if worktrees.next().is_some() {
 95                return Task::ready(Err(anyhow!(
 96                    "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
 97                )));
 98            }
 99
100            only_worktree.read(cx).abs_path()
101        } else if input_path.is_absolute() {
102            // Absolute paths are allowed, but only if they're in one of the project's worktrees.
103            if !project
104                .worktrees(cx)
105                .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
106            {
107                return Task::ready(Err(anyhow!(
108                    "The absolute path must be within one of the project's worktrees"
109                )));
110            }
111
112            input_path.into()
113        } else {
114            let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
115                return Task::ready(Err(anyhow!(
116                    "`cd` directory {} not found in the project",
117                    &input.cd
118                )));
119            };
120
121            worktree.read(cx).abs_path()
122        };
123
124        cx.spawn(async move |_| {
125            // Add 2>&1 to merge stderr into stdout for proper interleaving.
126            let command = format!("({}) 2>&1", input.command);
127
128            let output = new_smol_command("bash")
129                .arg("-c")
130                .arg(&command)
131                .current_dir(working_dir)
132                .output()
133                .await
134                .context("Failed to execute bash command")?;
135
136            let output_string = String::from_utf8_lossy(&output.stdout).to_string();
137
138            if output.status.success() {
139                if output_string.is_empty() {
140                    Ok("Command executed successfully.".to_string())
141                } else {
142                    Ok(output_string)
143                }
144            } else {
145                Ok(format!(
146                    "Command failed with exit code {}\n{}",
147                    output.status.code().unwrap_or(-1),
148                    &output_string
149                ))
150            }
151        })
152    }
153}