@@ -0,0 +1,301 @@
+use anyhow::{anyhow, Result};
+use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
+use futures::future::join_all;
+use gpui::{App, AppContext, Entity, Task};
+use language_model::LanguageModelRequestMessage;
+use project::Project;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use std::sync::Arc;
+use ui::IconName;
+
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+pub struct ToolInvocation {
+ /// The name of the tool to invoke
+ pub name: String,
+
+ /// The input to the tool in JSON format
+ pub input: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+pub struct BatchToolInput {
+ /// The tool invocations to run as a batch. These tools will be run either sequentially
+ /// or concurrently depending on the `run_tools_concurrently` flag.
+ ///
+ /// <example>
+ /// Basic file operations (concurrent)
+ ///
+ /// ```json
+ /// {
+ /// "invocations": [
+ /// {
+ /// "name": "read-file",
+ /// "input": {
+ /// "path": "src/main.rs"
+ /// }
+ /// },
+ /// {
+ /// "name": "list-directory",
+ /// "input": {
+ /// "path": "src/lib"
+ /// }
+ /// },
+ /// {
+ /// "name": "regex-search",
+ /// "input": {
+ /// "regex": "fn run\\("
+ /// }
+ /// }
+ /// ],
+ /// "run_tools_concurrently": true
+ /// }
+ /// ```
+ /// </example>
+ ///
+ /// <example>
+ /// Multiple find-replace operations on the same file (sequential)
+ ///
+ /// ```json
+ /// {
+ /// "invocations": [
+ /// {
+ /// "name": "find-replace-file",
+ /// "input": {
+ /// "path": "src/config.rs",
+ /// "display_description": "Update default timeout value",
+ /// "find": "pub const DEFAULT_TIMEOUT: u64 = 30;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";",
+ /// "replace": "pub const DEFAULT_TIMEOUT: u64 = 60;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";"
+ /// }
+ /// },
+ /// {
+ /// "name": "find-replace-file",
+ /// "input": {
+ /// "path": "src/config.rs",
+ /// "display_description": "Update API endpoint URL",
+ /// "find": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";\n\npub const API_VERSION: &str = \"v1\";",
+ /// "replace": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.newdomain.com\";\n\npub const API_VERSION: &str = \"v1\";"
+ /// }
+ /// }
+ /// ],
+ /// "run_tools_concurrently": false
+ /// }
+ /// ```
+ /// </example>
+ ///
+ /// <example>
+ /// Searching and analyzing code (concurrent)
+ ///
+ /// ```json
+ /// {
+ /// "invocations": [
+ /// {
+ /// "name": "regex-search",
+ /// "input": {
+ /// "regex": "impl Database"
+ /// }
+ /// },
+ /// {
+ /// "name": "path-search",
+ /// "input": {
+ /// "glob": "**/*test*.rs"
+ /// }
+ /// }
+ /// ],
+ /// "run_tools_concurrently": true
+ /// }
+ /// ```
+ /// </example>
+ ///
+ /// <example>
+ /// Multi-file refactoring (concurrent)
+ ///
+ /// ```json
+ /// {
+ /// "invocations": [
+ /// {
+ /// "name": "find-replace-file",
+ /// "input": {
+ /// "path": "src/models/user.rs",
+ /// "display_description": "Add email field to User struct",
+ /// "find": "pub struct User {\n pub id: u64,\n pub username: String,\n pub created_at: DateTime<Utc>,\n}",
+ /// "replace": "pub struct User {\n pub id: u64,\n pub username: String,\n pub email: String,\n pub created_at: DateTime<Utc>,\n}"
+ /// }
+ /// },
+ /// {
+ /// "name": "find-replace-file",
+ /// "input": {
+ /// "path": "src/db/queries.rs",
+ /// "display_description": "Update user insertion query",
+ /// "find": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n conn.execute(\n \"INSERT INTO users (id, username, created_at) VALUES ($1, $2, $3)\",\n &[&user.id, &user.username, &user.created_at],\n ).await?;\n \n Ok(())\n}",
+ /// "replace": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n conn.execute(\n \"INSERT INTO users (id, username, email, created_at) VALUES ($1, $2, $3, $4)\",\n &[&user.id, &user.username, &user.email, &user.created_at],\n ).await?;\n \n Ok(())\n}"
+ /// }
+ /// }
+ /// ],
+ /// "run_tools_concurrently": true
+ /// }
+ /// ```
+ /// </example>
+ pub invocations: Vec<ToolInvocation>,
+
+ /// Whether to run the tools in this batch concurrently. If this is false (the default), the tools will run sequentially.
+ #[serde(default)]
+ pub run_tools_concurrently: bool,
+}
+
+pub struct BatchTool;
+
+impl Tool for BatchTool {
+ fn name(&self) -> String {
+ "batch-tool".into()
+ }
+
+ fn needs_confirmation(&self) -> bool {
+ true
+ }
+
+ fn description(&self) -> String {
+ include_str!("./batch_tool/description.md").into()
+ }
+
+ fn icon(&self) -> IconName {
+ IconName::Cog
+ }
+
+ fn input_schema(&self) -> serde_json::Value {
+ let schema = schemars::schema_for!(BatchToolInput);
+ serde_json::to_value(&schema).unwrap()
+ }
+
+ fn ui_text(&self, input: &serde_json::Value) -> String {
+ match serde_json::from_value::<BatchToolInput>(input.clone()) {
+ Ok(input) => {
+ let count = input.invocations.len();
+ let mode = if input.run_tools_concurrently {
+ "concurrently"
+ } else {
+ "sequentially"
+ };
+
+ let first_tool_name = input
+ .invocations
+ .first()
+ .map(|inv| inv.name.clone())
+ .unwrap_or_default();
+
+ let all_same = input
+ .invocations
+ .iter()
+ .all(|invocation| invocation.name == first_tool_name);
+
+ if all_same {
+ format!(
+ "Run `{}` {} times {}",
+ first_tool_name,
+ input.invocations.len(),
+ mode
+ )
+ } else {
+ format!("Run {} tools {}", count, mode)
+ }
+ }
+ Err(_) => "Batch tools".to_string(),
+ }
+ }
+
+ fn run(
+ self: Arc<Self>,
+ input: serde_json::Value,
+ messages: &[LanguageModelRequestMessage],
+ project: Entity<Project>,
+ action_log: Entity<ActionLog>,
+ cx: &mut App,
+ ) -> Task<Result<String>> {
+ let input = match serde_json::from_value::<BatchToolInput>(input) {
+ Ok(input) => input,
+ Err(err) => return Task::ready(Err(anyhow!(err))),
+ };
+
+ if input.invocations.is_empty() {
+ return Task::ready(Err(anyhow!("No tool invocations provided")));
+ }
+
+ let run_tools_concurrently = input.run_tools_concurrently;
+
+ let foreground_task = {
+ let working_set = ToolWorkingSet::default();
+ let invocations = input.invocations;
+ let messages = messages.to_vec();
+
+ cx.spawn(async move |cx| {
+ let mut tasks = Vec::new();
+ let mut tool_names = Vec::new();
+
+ for invocation in invocations {
+ let tool_name = invocation.name.clone();
+ tool_names.push(tool_name.clone());
+
+ let tool = cx
+ .update(|cx| working_set.tool(&tool_name, cx))
+ .map_err(|err| {
+ anyhow!("Failed to look up tool '{}': {}", tool_name, err)
+ })?;
+
+ let Some(tool) = tool else {
+ return Err(anyhow!("Tool '{}' not found", tool_name));
+ };
+
+ let project = project.clone();
+ let action_log = action_log.clone();
+ let messages = messages.clone();
+ let task = cx
+ .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
+ .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
+
+ tasks.push(task);
+ }
+
+ Ok((tasks, tool_names))
+ })
+ };
+
+ cx.background_spawn(async move {
+ let (tasks, tool_names) = foreground_task.await?;
+ let mut results = Vec::with_capacity(tasks.len());
+
+ if run_tools_concurrently {
+ results.extend(join_all(tasks).await)
+ } else {
+ for task in tasks {
+ results.push(task.await);
+ }
+ };
+
+ let mut formatted_results = String::new();
+ let mut error_occurred = false;
+
+ for (i, result) in results.into_iter().enumerate() {
+ let tool_name = &tool_names[i];
+
+ match result {
+ Ok(output) => {
+ formatted_results
+ .push_str(&format!("Tool '{}' result:\n{}\n\n", tool_name, output));
+ }
+ Err(err) => {
+ error_occurred = true;
+ formatted_results
+ .push_str(&format!("Tool '{}' error: {}\n\n", tool_name, err));
+ }
+ }
+ }
+
+ if error_occurred {
+ formatted_results
+ .push_str("Note: Some tool invocations failed. See individual results above.");
+ }
+
+ Ok(formatted_results.trim().to_string())
+ })
+ }
+}