batch_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult, ToolWorkingSet};
  4use futures::future::join_all;
  5use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
  6use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  7use project::Project;
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use std::sync::Arc;
 11use ui::IconName;
 12
 13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 14pub struct ToolInvocation {
 15    /// The name of the tool to invoke
 16    pub name: String,
 17
 18    /// The input to the tool in JSON format
 19    pub input: serde_json::Value,
 20}
 21
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct BatchToolInput {
 24    /// The tool invocations to run as a batch. These tools will be run either sequentially
 25    /// or concurrently depending on the `run_tools_concurrently` flag.
 26    ///
 27    /// <example>
 28    /// Basic file operations (concurrent)
 29    ///
 30    /// ```json
 31    /// {
 32    ///   "invocations": [
 33    ///     {
 34    ///       "name": "read_file",
 35    ///       "input": {
 36    ///         "path": "src/main.rs"
 37    ///       }
 38    ///     },
 39    ///     {
 40    ///       "name": "list_directory",
 41    ///       "input": {
 42    ///         "path": "src/lib"
 43    ///       }
 44    ///     },
 45    ///     {
 46    ///       "name": "grep",
 47    ///       "input": {
 48    ///         "regex": "fn run\\("
 49    ///       }
 50    ///     }
 51    ///   ],
 52    ///   "run_tools_concurrently": true
 53    /// }
 54    /// ```
 55    /// </example>
 56    ///
 57    /// <example>
 58    /// Multiple find-replace operations on the same file (sequential)
 59    ///
 60    /// ```json
 61    /// {
 62    ///   "invocations": [
 63    ///     {
 64    ///       "name": "find_replace_file",
 65    ///       "input": {
 66    ///         "path": "src/config.rs",
 67    ///         "display_description": "Update default timeout value",
 68    ///         "find": "pub const DEFAULT_TIMEOUT: u64 = 30;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";",
 69    ///         "replace": "pub const DEFAULT_TIMEOUT: u64 = 60;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";"
 70    ///       }
 71    ///     },
 72    ///     {
 73    ///       "name": "find_replace_file",
 74    ///       "input": {
 75    ///         "path": "src/config.rs",
 76    ///         "display_description": "Update API endpoint URL",
 77    ///         "find": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";\n\npub const API_VERSION: &str = \"v1\";",
 78    ///         "replace": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.newdomain.com\";\n\npub const API_VERSION: &str = \"v1\";"
 79    ///       }
 80    ///     }
 81    ///   ],
 82    ///   "run_tools_concurrently": false
 83    /// }
 84    /// ```
 85    /// </example>
 86    ///
 87    /// <example>
 88    /// Searching and analyzing code (concurrent)
 89    ///
 90    /// ```json
 91    /// {
 92    ///   "invocations": [
 93    ///     {
 94    ///       "name": "grep",
 95    ///       "input": {
 96    ///         "regex": "impl Database"
 97    ///       }
 98    ///     },
 99    ///     {
100    ///       "name": "find_path",
101    ///       "input": {
102    ///         "glob": "**/*test*.rs"
103    ///       }
104    ///     }
105    ///   ],
106    ///   "run_tools_concurrently": true
107    /// }
108    /// ```
109    /// </example>
110    ///
111    /// <example>
112    /// Multi-file refactoring (concurrent)
113    ///
114    /// ```json
115    /// {
116    ///   "invocations": [
117    ///     {
118    ///       "name": "find_replace_file",
119    ///       "input": {
120    ///         "path": "src/models/user.rs",
121    ///         "display_description": "Add email field to User struct",
122    ///         "find": "pub struct User {\n    pub id: u64,\n    pub username: String,\n    pub created_at: DateTime<Utc>,\n}",
123    ///         "replace": "pub struct User {\n    pub id: u64,\n    pub username: String,\n    pub email: String,\n    pub created_at: DateTime<Utc>,\n}"
124    ///       }
125    ///     },
126    ///     {
127    ///       "name": "find_replace_file",
128    ///       "input": {
129    ///         "path": "src/db/queries.rs",
130    ///         "display_description": "Update user insertion query",
131    ///         "find": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n    conn.execute(\n        \"INSERT INTO users (id, username, created_at) VALUES ($1, $2, $3)\",\n        &[&user.id, &user.username, &user.created_at],\n    ).await?;\n    \n    Ok(())\n}",
132    ///         "replace": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n    conn.execute(\n        \"INSERT INTO users (id, username, email, created_at) VALUES ($1, $2, $3, $4)\",\n        &[&user.id, &user.username, &user.email, &user.created_at],\n    ).await?;\n    \n    Ok(())\n}"
133    ///       }
134    ///     }
135    ///   ],
136    ///   "run_tools_concurrently": true
137    /// }
138    /// ```
139    /// </example>
140    pub invocations: Vec<ToolInvocation>,
141
142    /// Whether to run the tools in this batch concurrently. If this is false (the default), the tools will run sequentially.
143    #[serde(default)]
144    pub run_tools_concurrently: bool,
145}
146
147pub struct BatchTool;
148
149impl Tool for BatchTool {
150    fn name(&self) -> String {
151        "batch_tool".into()
152    }
153
154    fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
155        serde_json::from_value::<BatchToolInput>(input.clone())
156            .map(|input| {
157                let working_set = ToolWorkingSet::default();
158                input.invocations.iter().any(|invocation| {
159                    working_set
160                        .tool(&invocation.name, cx)
161                        .map_or(false, |tool| tool.needs_confirmation(&invocation.input, cx))
162                })
163            })
164            .unwrap_or(false)
165    }
166
167    fn description(&self) -> String {
168        include_str!("./batch_tool/description.md").into()
169    }
170
171    fn icon(&self) -> IconName {
172        IconName::Cog
173    }
174
175    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
176        json_schema_for::<BatchToolInput>(format)
177    }
178
179    fn ui_text(&self, input: &serde_json::Value) -> String {
180        match serde_json::from_value::<BatchToolInput>(input.clone()) {
181            Ok(input) => {
182                let count = input.invocations.len();
183                let mode = if input.run_tools_concurrently {
184                    "concurrently"
185                } else {
186                    "sequentially"
187                };
188
189                let first_tool_name = input
190                    .invocations
191                    .first()
192                    .map(|inv| inv.name.clone())
193                    .unwrap_or_default();
194
195                let all_same = input
196                    .invocations
197                    .iter()
198                    .all(|invocation| invocation.name == first_tool_name);
199
200                if all_same {
201                    format!(
202                        "Run `{}` {} times {}",
203                        first_tool_name,
204                        input.invocations.len(),
205                        mode
206                    )
207                } else {
208                    format!("Run {} tools {}", count, mode)
209                }
210            }
211            Err(_) => "Batch tools".to_string(),
212        }
213    }
214
215    fn run(
216        self: Arc<Self>,
217        input: serde_json::Value,
218        messages: &[LanguageModelRequestMessage],
219        project: Entity<Project>,
220        action_log: Entity<ActionLog>,
221        window: Option<AnyWindowHandle>,
222        cx: &mut App,
223    ) -> ToolResult {
224        let input = match serde_json::from_value::<BatchToolInput>(input) {
225            Ok(input) => input,
226            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
227        };
228
229        if input.invocations.is_empty() {
230            return Task::ready(Err(anyhow!("No tool invocations provided"))).into();
231        }
232
233        let run_tools_concurrently = input.run_tools_concurrently;
234
235        let foreground_task = {
236            let working_set = ToolWorkingSet::default();
237            let invocations = input.invocations;
238            let messages = messages.to_vec();
239
240            cx.spawn(async move |cx| {
241                let mut tasks = Vec::new();
242                let mut tool_names = Vec::new();
243
244                for invocation in invocations {
245                    let tool_name = invocation.name.clone();
246                    tool_names.push(tool_name.clone());
247
248                    let tool = cx
249                        .update(|cx| working_set.tool(&tool_name, cx))
250                        .map_err(|err| {
251                            anyhow!("Failed to look up tool '{}': {}", tool_name, err)
252                        })?;
253
254                    let Some(tool) = tool else {
255                        return Err(anyhow!("Tool '{}' not found", tool_name));
256                    };
257
258                    let project = project.clone();
259                    let action_log = action_log.clone();
260                    let messages = messages.clone();
261                    let tool_result = cx
262                        .update(|cx| {
263                            tool.run(invocation.input, &messages, project, action_log, window, cx)
264                        })
265                        .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
266
267                    tasks.push(tool_result.output);
268                }
269
270                Ok((tasks, tool_names))
271            })
272        };
273
274        cx.background_spawn(async move {
275            let (tasks, tool_names) = foreground_task.await?;
276            let mut results = Vec::with_capacity(tasks.len());
277
278            if run_tools_concurrently {
279                results.extend(join_all(tasks).await)
280            } else {
281                for task in tasks {
282                    results.push(task.await);
283                }
284            };
285
286            let mut formatted_results = String::new();
287            let mut error_occurred = false;
288
289            for (i, result) in results.into_iter().enumerate() {
290                let tool_name = &tool_names[i];
291
292                match result {
293                    Ok(output) => {
294                        formatted_results
295                            .push_str(&format!("Tool '{}' result:\n{}\n\n", tool_name, output));
296                    }
297                    Err(err) => {
298                        error_occurred = true;
299                        formatted_results
300                            .push_str(&format!("Tool '{}' error: {}\n\n", tool_name, err));
301                    }
302                }
303            }
304
305            if error_occurred {
306                formatted_results
307                    .push_str("Note: Some tool invocations failed. See individual results above.");
308            }
309
310            Ok(formatted_results.trim().to_string())
311        })
312        .into()
313    }
314}