bash_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Context as _, Result, anyhow};
  3use assistant_tool::{ActionLog, Tool};
  4use gpui::{App, Entity, Task};
  5use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::path::Path;
 10use std::sync::Arc;
 11use ui::IconName;
 12use util::command::new_smol_command;
 13use util::markdown::MarkdownString;
 14
 15#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 16pub struct BashToolInput {
 17    /// The bash command to execute as a one-liner.
 18    command: String,
 19    /// Working directory for the command. This must be one of the root directories of the project.
 20    cd: String,
 21}
 22
 23pub struct BashTool;
 24
 25impl Tool for BashTool {
 26    fn name(&self) -> String {
 27        "bash".to_string()
 28    }
 29
 30    fn needs_confirmation(&self) -> bool {
 31        true
 32    }
 33
 34    fn description(&self) -> String {
 35        include_str!("./bash_tool/description.md").to_string()
 36    }
 37
 38    fn icon(&self) -> IconName {
 39        IconName::Terminal
 40    }
 41
 42    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
 43        json_schema_for::<BashToolInput>(format)
 44    }
 45
 46    fn ui_text(&self, input: &serde_json::Value) -> String {
 47        match serde_json::from_value::<BashToolInput>(input.clone()) {
 48            Ok(input) => {
 49                if input.command.contains('\n') {
 50                    MarkdownString::code_block("bash", &input.command).0
 51                } else {
 52                    MarkdownString::inline_code(&input.command).0
 53                }
 54            }
 55            Err(_) => "Run bash command".to_string(),
 56        }
 57    }
 58
 59    fn run(
 60        self: Arc<Self>,
 61        input: serde_json::Value,
 62        _messages: &[LanguageModelRequestMessage],
 63        project: Entity<Project>,
 64        _action_log: Entity<ActionLog>,
 65        cx: &mut App,
 66    ) -> Task<Result<String>> {
 67        let input: BashToolInput = match serde_json::from_value(input) {
 68            Ok(input) => input,
 69            Err(err) => return Task::ready(Err(anyhow!(err))),
 70        };
 71
 72        let project = project.read(cx);
 73        let input_path = Path::new(&input.cd);
 74        let working_dir = if input.cd == "." {
 75            // Accept "." as meaning "the one worktree" if we only have one worktree.
 76            let mut worktrees = project.worktrees(cx);
 77
 78            let only_worktree = match worktrees.next() {
 79                Some(worktree) => worktree,
 80                None => return Task::ready(Err(anyhow!("No worktrees found in the project"))),
 81            };
 82
 83            if worktrees.next().is_some() {
 84                return Task::ready(Err(anyhow!(
 85                    "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
 86                )));
 87            }
 88
 89            only_worktree.read(cx).abs_path()
 90        } else if input_path.is_absolute() {
 91            // Absolute paths are allowed, but only if they're in one of the project's worktrees.
 92            if !project
 93                .worktrees(cx)
 94                .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
 95            {
 96                return Task::ready(Err(anyhow!(
 97                    "The absolute path must be within one of the project's worktrees"
 98                )));
 99            }
100
101            input_path.into()
102        } else {
103            let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
104                return Task::ready(Err(anyhow!(
105                    "`cd` directory {} not found in the project",
106                    &input.cd
107                )));
108            };
109
110            worktree.read(cx).abs_path()
111        };
112
113        cx.spawn(async move |_| {
114            // Add 2>&1 to merge stderr into stdout for proper interleaving.
115            let command = format!("({}) 2>&1", input.command);
116
117            let output = new_smol_command("bash")
118                .arg("-c")
119                .arg(&command)
120                .current_dir(working_dir)
121                .output()
122                .await
123                .context("Failed to execute bash command")?;
124
125            let output_string = String::from_utf8_lossy(&output.stdout).to_string();
126
127            if output.status.success() {
128                if output_string.is_empty() {
129                    Ok("Command executed successfully.".to_string())
130                } else {
131                    Ok(output_string)
132                }
133            } else {
134                Ok(format!(
135                    "Command failed with exit code {}\n{}",
136                    output.status.code().unwrap_or(-1),
137                    &output_string
138                ))
139            }
140        })
141    }
142}