batch_tool.rs

  1use anyhow::{anyhow, Result};
  2use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
  3use futures::future::join_all;
  4use gpui::{App, AppContext, Entity, Task};
  5use language_model::LanguageModelRequestMessage;
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::sync::Arc;
 10use ui::IconName;
 11
 12#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 13pub struct ToolInvocation {
 14    /// The name of the tool to invoke
 15    pub name: String,
 16
 17    /// The input to the tool in JSON format
 18    pub input: serde_json::Value,
 19}
 20
 21#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 22pub struct BatchToolInput {
 23    /// The tool invocations to run as a batch. These tools will be run either sequentially
 24    /// or concurrently depending on the `run_tools_concurrently` flag.
 25    ///
 26    /// <example>
 27    /// Basic file operations (concurrent)
 28    ///
 29    /// ```json
 30    /// {
 31    ///   "invocations": [
 32    ///     {
 33    ///       "name": "read-file",
 34    ///       "input": {
 35    ///         "path": "src/main.rs"
 36    ///       }
 37    ///     },
 38    ///     {
 39    ///       "name": "list-directory",
 40    ///       "input": {
 41    ///         "path": "src/lib"
 42    ///       }
 43    ///     },
 44    ///     {
 45    ///       "name": "regex-search",
 46    ///       "input": {
 47    ///         "regex": "fn run\\("
 48    ///       }
 49    ///     }
 50    ///   ],
 51    ///   "run_tools_concurrently": true
 52    /// }
 53    /// ```
 54    /// </example>
 55    ///
 56    /// <example>
 57    /// Multiple find-replace operations on the same file (sequential)
 58    ///
 59    /// ```json
 60    /// {
 61    ///   "invocations": [
 62    ///     {
 63    ///       "name": "find-replace-file",
 64    ///       "input": {
 65    ///         "path": "src/config.rs",
 66    ///         "display_description": "Update default timeout value",
 67    ///         "find": "pub const DEFAULT_TIMEOUT: u64 = 30;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";",
 68    ///         "replace": "pub const DEFAULT_TIMEOUT: u64 = 60;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";"
 69    ///       }
 70    ///     },
 71    ///     {
 72    ///       "name": "find-replace-file",
 73    ///       "input": {
 74    ///         "path": "src/config.rs",
 75    ///         "display_description": "Update API endpoint URL",
 76    ///         "find": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";\n\npub const API_VERSION: &str = \"v1\";",
 77    ///         "replace": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.newdomain.com\";\n\npub const API_VERSION: &str = \"v1\";"
 78    ///       }
 79    ///     }
 80    ///   ],
 81    ///   "run_tools_concurrently": false
 82    /// }
 83    /// ```
 84    /// </example>
 85    ///
 86    /// <example>
 87    /// Searching and analyzing code (concurrent)
 88    ///
 89    /// ```json
 90    /// {
 91    ///   "invocations": [
 92    ///     {
 93    ///       "name": "regex-search",
 94    ///       "input": {
 95    ///         "regex": "impl Database"
 96    ///       }
 97    ///     },
 98    ///     {
 99    ///       "name": "path-search",
100    ///       "input": {
101    ///         "glob": "**/*test*.rs"
102    ///       }
103    ///     }
104    ///   ],
105    ///   "run_tools_concurrently": true
106    /// }
107    /// ```
108    /// </example>
109    ///
110    /// <example>
111    /// Multi-file refactoring (concurrent)
112    ///
113    /// ```json
114    /// {
115    ///   "invocations": [
116    ///     {
117    ///       "name": "find-replace-file",
118    ///       "input": {
119    ///         "path": "src/models/user.rs",
120    ///         "display_description": "Add email field to User struct",
121    ///         "find": "pub struct User {\n    pub id: u64,\n    pub username: String,\n    pub created_at: DateTime<Utc>,\n}",
122    ///         "replace": "pub struct User {\n    pub id: u64,\n    pub username: String,\n    pub email: String,\n    pub created_at: DateTime<Utc>,\n}"
123    ///       }
124    ///     },
125    ///     {
126    ///       "name": "find-replace-file",
127    ///       "input": {
128    ///         "path": "src/db/queries.rs",
129    ///         "display_description": "Update user insertion query",
130    ///         "find": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n    conn.execute(\n        \"INSERT INTO users (id, username, created_at) VALUES ($1, $2, $3)\",\n        &[&user.id, &user.username, &user.created_at],\n    ).await?;\n    \n    Ok(())\n}",
131    ///         "replace": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n    conn.execute(\n        \"INSERT INTO users (id, username, email, created_at) VALUES ($1, $2, $3, $4)\",\n        &[&user.id, &user.username, &user.email, &user.created_at],\n    ).await?;\n    \n    Ok(())\n}"
132    ///       }
133    ///     }
134    ///   ],
135    ///   "run_tools_concurrently": true
136    /// }
137    /// ```
138    /// </example>
139    pub invocations: Vec<ToolInvocation>,
140
141    /// Whether to run the tools in this batch concurrently. If this is false (the default), the tools will run sequentially.
142    #[serde(default)]
143    pub run_tools_concurrently: bool,
144}
145
146pub struct BatchTool;
147
148impl Tool for BatchTool {
149    fn name(&self) -> String {
150        "batch-tool".into()
151    }
152
153    fn needs_confirmation(&self) -> bool {
154        true
155    }
156
157    fn description(&self) -> String {
158        include_str!("./batch_tool/description.md").into()
159    }
160
161    fn icon(&self) -> IconName {
162        IconName::Cog
163    }
164
165    fn input_schema(&self) -> serde_json::Value {
166        let schema = schemars::schema_for!(BatchToolInput);
167        serde_json::to_value(&schema).unwrap()
168    }
169
170    fn ui_text(&self, input: &serde_json::Value) -> String {
171        match serde_json::from_value::<BatchToolInput>(input.clone()) {
172            Ok(input) => {
173                let count = input.invocations.len();
174                let mode = if input.run_tools_concurrently {
175                    "concurrently"
176                } else {
177                    "sequentially"
178                };
179
180                let first_tool_name = input
181                    .invocations
182                    .first()
183                    .map(|inv| inv.name.clone())
184                    .unwrap_or_default();
185
186                let all_same = input
187                    .invocations
188                    .iter()
189                    .all(|invocation| invocation.name == first_tool_name);
190
191                if all_same {
192                    format!(
193                        "Run `{}` {} times {}",
194                        first_tool_name,
195                        input.invocations.len(),
196                        mode
197                    )
198                } else {
199                    format!("Run {} tools {}", count, mode)
200                }
201            }
202            Err(_) => "Batch tools".to_string(),
203        }
204    }
205
206    fn run(
207        self: Arc<Self>,
208        input: serde_json::Value,
209        messages: &[LanguageModelRequestMessage],
210        project: Entity<Project>,
211        action_log: Entity<ActionLog>,
212        cx: &mut App,
213    ) -> Task<Result<String>> {
214        let input = match serde_json::from_value::<BatchToolInput>(input) {
215            Ok(input) => input,
216            Err(err) => return Task::ready(Err(anyhow!(err))),
217        };
218
219        if input.invocations.is_empty() {
220            return Task::ready(Err(anyhow!("No tool invocations provided")));
221        }
222
223        let run_tools_concurrently = input.run_tools_concurrently;
224
225        let foreground_task = {
226            let working_set = ToolWorkingSet::default();
227            let invocations = input.invocations;
228            let messages = messages.to_vec();
229
230            cx.spawn(async move |cx| {
231                let mut tasks = Vec::new();
232                let mut tool_names = Vec::new();
233
234                for invocation in invocations {
235                    let tool_name = invocation.name.clone();
236                    tool_names.push(tool_name.clone());
237
238                    let tool = cx
239                        .update(|cx| working_set.tool(&tool_name, cx))
240                        .map_err(|err| {
241                            anyhow!("Failed to look up tool '{}': {}", tool_name, err)
242                        })?;
243
244                    let Some(tool) = tool else {
245                        return Err(anyhow!("Tool '{}' not found", tool_name));
246                    };
247
248                    let project = project.clone();
249                    let action_log = action_log.clone();
250                    let messages = messages.clone();
251                    let task = cx
252                        .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
253                        .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
254
255                    tasks.push(task);
256                }
257
258                Ok((tasks, tool_names))
259            })
260        };
261
262        cx.background_spawn(async move {
263            let (tasks, tool_names) = foreground_task.await?;
264            let mut results = Vec::with_capacity(tasks.len());
265
266            if run_tools_concurrently {
267                results.extend(join_all(tasks).await)
268            } else {
269                for task in tasks {
270                    results.push(task.await);
271                }
272            };
273
274            let mut formatted_results = String::new();
275            let mut error_occurred = false;
276
277            for (i, result) in results.into_iter().enumerate() {
278                let tool_name = &tool_names[i];
279
280                match result {
281                    Ok(output) => {
282                        formatted_results
283                            .push_str(&format!("Tool '{}' result:\n{}\n\n", tool_name, output));
284                    }
285                    Err(err) => {
286                        error_occurred = true;
287                        formatted_results
288                            .push_str(&format!("Tool '{}' error: {}\n\n", tool_name, err));
289                    }
290                }
291            }
292
293            if error_occurred {
294                formatted_results
295                    .push_str("Note: Some tool invocations failed. See individual results above.");
296            }
297
298            Ok(formatted_results.trim().to_string())
299        })
300    }
301}