1use crate::schema::json_schema_for;
2use anyhow::{Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult, ToolWorkingSet};
4use futures::future::join_all;
5use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
6use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
7use project::Project;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use ui::IconName;
12
13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
14pub struct ToolInvocation {
15 /// The name of the tool to invoke
16 pub name: String,
17
18 /// The input to the tool in JSON format
19 pub input: serde_json::Value,
20}
21
22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
23pub struct BatchToolInput {
24 /// The tool invocations to run as a batch. These tools will be run either sequentially
25 /// or concurrently depending on the `run_tools_concurrently` flag.
26 ///
27 /// <example>
28 /// Basic file operations (concurrent)
29 ///
30 /// ```json
31 /// {
32 /// "invocations": [
33 /// {
34 /// "name": "read_file",
35 /// "input": {
36 /// "path": "src/main.rs"
37 /// }
38 /// },
39 /// {
40 /// "name": "list_directory",
41 /// "input": {
42 /// "path": "src/lib"
43 /// }
44 /// },
45 /// {
46 /// "name": "grep",
47 /// "input": {
48 /// "regex": "fn run\\("
49 /// }
50 /// }
51 /// ],
52 /// "run_tools_concurrently": true
53 /// }
54 /// ```
55 /// </example>
56 ///
57 /// <example>
58 /// Multiple find-replace operations on the same file (sequential)
59 ///
60 /// ```json
61 /// {
62 /// "invocations": [
63 /// {
64 /// "name": "find_replace_file",
65 /// "input": {
66 /// "path": "src/config.rs",
67 /// "display_description": "Update default timeout value",
68 /// "find": "pub const DEFAULT_TIMEOUT: u64 = 30;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";",
69 /// "replace": "pub const DEFAULT_TIMEOUT: u64 = 60;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";"
70 /// }
71 /// },
72 /// {
73 /// "name": "find_replace_file",
74 /// "input": {
75 /// "path": "src/config.rs",
76 /// "display_description": "Update API endpoint URL",
77 /// "find": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";\n\npub const API_VERSION: &str = \"v1\";",
78 /// "replace": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.newdomain.com\";\n\npub const API_VERSION: &str = \"v1\";"
79 /// }
80 /// }
81 /// ],
82 /// "run_tools_concurrently": false
83 /// }
84 /// ```
85 /// </example>
86 ///
87 /// <example>
88 /// Searching and analyzing code (concurrent)
89 ///
90 /// ```json
91 /// {
92 /// "invocations": [
93 /// {
94 /// "name": "grep",
95 /// "input": {
96 /// "regex": "impl Database"
97 /// }
98 /// },
99 /// {
100 /// "name": "find_path",
101 /// "input": {
102 /// "glob": "**/*test*.rs"
103 /// }
104 /// }
105 /// ],
106 /// "run_tools_concurrently": true
107 /// }
108 /// ```
109 /// </example>
110 ///
111 /// <example>
112 /// Multi-file refactoring (concurrent)
113 ///
114 /// ```json
115 /// {
116 /// "invocations": [
117 /// {
118 /// "name": "find_replace_file",
119 /// "input": {
120 /// "path": "src/models/user.rs",
121 /// "display_description": "Add email field to User struct",
122 /// "find": "pub struct User {\n pub id: u64,\n pub username: String,\n pub created_at: DateTime<Utc>,\n}",
123 /// "replace": "pub struct User {\n pub id: u64,\n pub username: String,\n pub email: String,\n pub created_at: DateTime<Utc>,\n}"
124 /// }
125 /// },
126 /// {
127 /// "name": "find_replace_file",
128 /// "input": {
129 /// "path": "src/db/queries.rs",
130 /// "display_description": "Update user insertion query",
131 /// "find": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n conn.execute(\n \"INSERT INTO users (id, username, created_at) VALUES ($1, $2, $3)\",\n &[&user.id, &user.username, &user.created_at],\n ).await?;\n \n Ok(())\n}",
132 /// "replace": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n conn.execute(\n \"INSERT INTO users (id, username, email, created_at) VALUES ($1, $2, $3, $4)\",\n &[&user.id, &user.username, &user.email, &user.created_at],\n ).await?;\n \n Ok(())\n}"
133 /// }
134 /// }
135 /// ],
136 /// "run_tools_concurrently": true
137 /// }
138 /// ```
139 /// </example>
140 pub invocations: Vec<ToolInvocation>,
141
142 /// Whether to run the tools in this batch concurrently. If this is false (the default), the tools will run sequentially.
143 #[serde(default)]
144 pub run_tools_concurrently: bool,
145}
146
147pub struct BatchTool;
148
149impl Tool for BatchTool {
150 fn name(&self) -> String {
151 "batch_tool".into()
152 }
153
154 fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool {
155 serde_json::from_value::<BatchToolInput>(input.clone())
156 .map(|input| {
157 let working_set = ToolWorkingSet::default();
158 input.invocations.iter().any(|invocation| {
159 working_set
160 .tool(&invocation.name, cx)
161 .map_or(false, |tool| tool.needs_confirmation(&invocation.input, cx))
162 })
163 })
164 .unwrap_or(false)
165 }
166
167 fn description(&self) -> String {
168 include_str!("./batch_tool/description.md").into()
169 }
170
171 fn icon(&self) -> IconName {
172 IconName::Cog
173 }
174
175 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
176 json_schema_for::<BatchToolInput>(format)
177 }
178
179 fn ui_text(&self, input: &serde_json::Value) -> String {
180 match serde_json::from_value::<BatchToolInput>(input.clone()) {
181 Ok(input) => {
182 let count = input.invocations.len();
183 let mode = if input.run_tools_concurrently {
184 "concurrently"
185 } else {
186 "sequentially"
187 };
188
189 let first_tool_name = input
190 .invocations
191 .first()
192 .map(|inv| inv.name.clone())
193 .unwrap_or_default();
194
195 let all_same = input
196 .invocations
197 .iter()
198 .all(|invocation| invocation.name == first_tool_name);
199
200 if all_same {
201 format!(
202 "Run `{}` {} times {}",
203 first_tool_name,
204 input.invocations.len(),
205 mode
206 )
207 } else {
208 format!("Run {} tools {}", count, mode)
209 }
210 }
211 Err(_) => "Batch tools".to_string(),
212 }
213 }
214
215 fn run(
216 self: Arc<Self>,
217 input: serde_json::Value,
218 messages: &[LanguageModelRequestMessage],
219 project: Entity<Project>,
220 action_log: Entity<ActionLog>,
221 window: Option<AnyWindowHandle>,
222 cx: &mut App,
223 ) -> ToolResult {
224 let input = match serde_json::from_value::<BatchToolInput>(input) {
225 Ok(input) => input,
226 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
227 };
228
229 if input.invocations.is_empty() {
230 return Task::ready(Err(anyhow!("No tool invocations provided"))).into();
231 }
232
233 let run_tools_concurrently = input.run_tools_concurrently;
234
235 let foreground_task = {
236 let working_set = ToolWorkingSet::default();
237 let invocations = input.invocations;
238 let messages = messages.to_vec();
239
240 cx.spawn(async move |cx| {
241 let mut tasks = Vec::new();
242 let mut tool_names = Vec::new();
243
244 for invocation in invocations {
245 let tool_name = invocation.name.clone();
246 tool_names.push(tool_name.clone());
247
248 let tool = cx
249 .update(|cx| working_set.tool(&tool_name, cx))
250 .map_err(|err| {
251 anyhow!("Failed to look up tool '{}': {}", tool_name, err)
252 })?;
253
254 let Some(tool) = tool else {
255 return Err(anyhow!("Tool '{}' not found", tool_name));
256 };
257
258 let project = project.clone();
259 let action_log = action_log.clone();
260 let messages = messages.clone();
261 let tool_result = cx
262 .update(|cx| {
263 tool.run(invocation.input, &messages, project, action_log, window, cx)
264 })
265 .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
266
267 tasks.push(tool_result.output);
268 }
269
270 Ok((tasks, tool_names))
271 })
272 };
273
274 cx.background_spawn(async move {
275 let (tasks, tool_names) = foreground_task.await?;
276 let mut results = Vec::with_capacity(tasks.len());
277
278 if run_tools_concurrently {
279 results.extend(join_all(tasks).await)
280 } else {
281 for task in tasks {
282 results.push(task.await);
283 }
284 };
285
286 let mut formatted_results = String::new();
287 let mut error_occurred = false;
288
289 for (i, result) in results.into_iter().enumerate() {
290 let tool_name = &tool_names[i];
291
292 match result {
293 Ok(output) => {
294 formatted_results
295 .push_str(&format!("Tool '{}' result:\n{}\n\n", tool_name, output));
296 }
297 Err(err) => {
298 error_occurred = true;
299 formatted_results
300 .push_str(&format!("Tool '{}' error: {}\n\n", tool_name, err));
301 }
302 }
303 }
304
305 if error_occurred {
306 formatted_results
307 .push_str("Note: Some tool invocations failed. See individual results above.");
308 }
309
310 Ok(formatted_results.trim().to_string())
311 })
312 .into()
313 }
314}