1use anyhow::{anyhow, Result};
2use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
3use futures::future::join_all;
4use gpui::{App, AppContext, Entity, Task};
5use language_model::LanguageModelRequestMessage;
6use project::Project;
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use ui::IconName;
11
12#[derive(Debug, Serialize, Deserialize, JsonSchema)]
13pub struct ToolInvocation {
14 /// The name of the tool to invoke
15 pub name: String,
16
17 /// The input to the tool in JSON format
18 pub input: serde_json::Value,
19}
20
21#[derive(Debug, Serialize, Deserialize, JsonSchema)]
22pub struct BatchToolInput {
23 /// The tool invocations to run as a batch. These tools will be run either sequentially
24 /// or concurrently depending on the `run_tools_concurrently` flag.
25 ///
26 /// <example>
27 /// Basic file operations (concurrent)
28 ///
29 /// ```json
30 /// {
31 /// "invocations": [
32 /// {
33 /// "name": "read-file",
34 /// "input": {
35 /// "path": "src/main.rs"
36 /// }
37 /// },
38 /// {
39 /// "name": "list-directory",
40 /// "input": {
41 /// "path": "src/lib"
42 /// }
43 /// },
44 /// {
45 /// "name": "regex-search",
46 /// "input": {
47 /// "regex": "fn run\\("
48 /// }
49 /// }
50 /// ],
51 /// "run_tools_concurrently": true
52 /// }
53 /// ```
54 /// </example>
55 ///
56 /// <example>
57 /// Multiple find-replace operations on the same file (sequential)
58 ///
59 /// ```json
60 /// {
61 /// "invocations": [
62 /// {
63 /// "name": "find-replace-file",
64 /// "input": {
65 /// "path": "src/config.rs",
66 /// "display_description": "Update default timeout value",
67 /// "find": "pub const DEFAULT_TIMEOUT: u64 = 30;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";",
68 /// "replace": "pub const DEFAULT_TIMEOUT: u64 = 60;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";"
69 /// }
70 /// },
71 /// {
72 /// "name": "find-replace-file",
73 /// "input": {
74 /// "path": "src/config.rs",
75 /// "display_description": "Update API endpoint URL",
76 /// "find": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";\n\npub const API_VERSION: &str = \"v1\";",
77 /// "replace": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.newdomain.com\";\n\npub const API_VERSION: &str = \"v1\";"
78 /// }
79 /// }
80 /// ],
81 /// "run_tools_concurrently": false
82 /// }
83 /// ```
84 /// </example>
85 ///
86 /// <example>
87 /// Searching and analyzing code (concurrent)
88 ///
89 /// ```json
90 /// {
91 /// "invocations": [
92 /// {
93 /// "name": "regex-search",
94 /// "input": {
95 /// "regex": "impl Database"
96 /// }
97 /// },
98 /// {
99 /// "name": "path-search",
100 /// "input": {
101 /// "glob": "**/*test*.rs"
102 /// }
103 /// }
104 /// ],
105 /// "run_tools_concurrently": true
106 /// }
107 /// ```
108 /// </example>
109 ///
110 /// <example>
111 /// Multi-file refactoring (concurrent)
112 ///
113 /// ```json
114 /// {
115 /// "invocations": [
116 /// {
117 /// "name": "find-replace-file",
118 /// "input": {
119 /// "path": "src/models/user.rs",
120 /// "display_description": "Add email field to User struct",
121 /// "find": "pub struct User {\n pub id: u64,\n pub username: String,\n pub created_at: DateTime<Utc>,\n}",
122 /// "replace": "pub struct User {\n pub id: u64,\n pub username: String,\n pub email: String,\n pub created_at: DateTime<Utc>,\n}"
123 /// }
124 /// },
125 /// {
126 /// "name": "find-replace-file",
127 /// "input": {
128 /// "path": "src/db/queries.rs",
129 /// "display_description": "Update user insertion query",
130 /// "find": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n conn.execute(\n \"INSERT INTO users (id, username, created_at) VALUES ($1, $2, $3)\",\n &[&user.id, &user.username, &user.created_at],\n ).await?;\n \n Ok(())\n}",
131 /// "replace": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n conn.execute(\n \"INSERT INTO users (id, username, email, created_at) VALUES ($1, $2, $3, $4)\",\n &[&user.id, &user.username, &user.email, &user.created_at],\n ).await?;\n \n Ok(())\n}"
132 /// }
133 /// }
134 /// ],
135 /// "run_tools_concurrently": true
136 /// }
137 /// ```
138 /// </example>
139 pub invocations: Vec<ToolInvocation>,
140
141 /// Whether to run the tools in this batch concurrently. If this is false (the default), the tools will run sequentially.
142 #[serde(default)]
143 pub run_tools_concurrently: bool,
144}
145
146pub struct BatchTool;
147
148impl Tool for BatchTool {
149 fn name(&self) -> String {
150 "batch-tool".into()
151 }
152
153 fn needs_confirmation(&self) -> bool {
154 true
155 }
156
157 fn description(&self) -> String {
158 include_str!("./batch_tool/description.md").into()
159 }
160
161 fn icon(&self) -> IconName {
162 IconName::Cog
163 }
164
165 fn input_schema(&self) -> serde_json::Value {
166 let schema = schemars::schema_for!(BatchToolInput);
167 serde_json::to_value(&schema).unwrap()
168 }
169
170 fn ui_text(&self, input: &serde_json::Value) -> String {
171 match serde_json::from_value::<BatchToolInput>(input.clone()) {
172 Ok(input) => {
173 let count = input.invocations.len();
174 let mode = if input.run_tools_concurrently {
175 "concurrently"
176 } else {
177 "sequentially"
178 };
179
180 let first_tool_name = input
181 .invocations
182 .first()
183 .map(|inv| inv.name.clone())
184 .unwrap_or_default();
185
186 let all_same = input
187 .invocations
188 .iter()
189 .all(|invocation| invocation.name == first_tool_name);
190
191 if all_same {
192 format!(
193 "Run `{}` {} times {}",
194 first_tool_name,
195 input.invocations.len(),
196 mode
197 )
198 } else {
199 format!("Run {} tools {}", count, mode)
200 }
201 }
202 Err(_) => "Batch tools".to_string(),
203 }
204 }
205
206 fn run(
207 self: Arc<Self>,
208 input: serde_json::Value,
209 messages: &[LanguageModelRequestMessage],
210 project: Entity<Project>,
211 action_log: Entity<ActionLog>,
212 cx: &mut App,
213 ) -> Task<Result<String>> {
214 let input = match serde_json::from_value::<BatchToolInput>(input) {
215 Ok(input) => input,
216 Err(err) => return Task::ready(Err(anyhow!(err))),
217 };
218
219 if input.invocations.is_empty() {
220 return Task::ready(Err(anyhow!("No tool invocations provided")));
221 }
222
223 let run_tools_concurrently = input.run_tools_concurrently;
224
225 let foreground_task = {
226 let working_set = ToolWorkingSet::default();
227 let invocations = input.invocations;
228 let messages = messages.to_vec();
229
230 cx.spawn(async move |cx| {
231 let mut tasks = Vec::new();
232 let mut tool_names = Vec::new();
233
234 for invocation in invocations {
235 let tool_name = invocation.name.clone();
236 tool_names.push(tool_name.clone());
237
238 let tool = cx
239 .update(|cx| working_set.tool(&tool_name, cx))
240 .map_err(|err| {
241 anyhow!("Failed to look up tool '{}': {}", tool_name, err)
242 })?;
243
244 let Some(tool) = tool else {
245 return Err(anyhow!("Tool '{}' not found", tool_name));
246 };
247
248 let project = project.clone();
249 let action_log = action_log.clone();
250 let messages = messages.clone();
251 let task = cx
252 .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
253 .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
254
255 tasks.push(task);
256 }
257
258 Ok((tasks, tool_names))
259 })
260 };
261
262 cx.background_spawn(async move {
263 let (tasks, tool_names) = foreground_task.await?;
264 let mut results = Vec::with_capacity(tasks.len());
265
266 if run_tools_concurrently {
267 results.extend(join_all(tasks).await)
268 } else {
269 for task in tasks {
270 results.push(task.await);
271 }
272 };
273
274 let mut formatted_results = String::new();
275 let mut error_occurred = false;
276
277 for (i, result) in results.into_iter().enumerate() {
278 let tool_name = &tool_names[i];
279
280 match result {
281 Ok(output) => {
282 formatted_results
283 .push_str(&format!("Tool '{}' result:\n{}\n\n", tool_name, output));
284 }
285 Err(err) => {
286 error_occurred = true;
287 formatted_results
288 .push_str(&format!("Tool '{}' error: {}\n\n", tool_name, err));
289 }
290 }
291 }
292
293 if error_occurred {
294 formatted_results
295 .push_str("Note: Some tool invocations failed. See individual results above.");
296 }
297
298 Ok(formatted_results.trim().to_string())
299 })
300 }
301}