1use crate::schema::json_schema_for;
2use anyhow::{Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
4use futures::future::join_all;
5use gpui::{App, AppContext, Entity, Task};
6use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
7use project::Project;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use ui::IconName;
12
13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
14pub struct ToolInvocation {
15 /// The name of the tool to invoke
16 pub name: String,
17
18 /// The input to the tool in JSON format
19 pub input: serde_json::Value,
20}
21
22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
23pub struct BatchToolInput {
24 /// The tool invocations to run as a batch. These tools will be run either sequentially
25 /// or concurrently depending on the `run_tools_concurrently` flag.
26 ///
27 /// <example>
28 /// Basic file operations (concurrent)
29 ///
30 /// ```json
31 /// {
32 /// "invocations": [
33 /// {
34 /// "name": "read-file",
35 /// "input": {
36 /// "path": "src/main.rs"
37 /// }
38 /// },
39 /// {
40 /// "name": "list-directory",
41 /// "input": {
42 /// "path": "src/lib"
43 /// }
44 /// },
45 /// {
46 /// "name": "regex-search",
47 /// "input": {
48 /// "regex": "fn run\\("
49 /// }
50 /// }
51 /// ],
52 /// "run_tools_concurrently": true
53 /// }
54 /// ```
55 /// </example>
56 ///
57 /// <example>
58 /// Multiple find-replace operations on the same file (sequential)
59 ///
60 /// ```json
61 /// {
62 /// "invocations": [
63 /// {
64 /// "name": "find-replace-file",
65 /// "input": {
66 /// "path": "src/config.rs",
67 /// "display_description": "Update default timeout value",
68 /// "find": "pub const DEFAULT_TIMEOUT: u64 = 30;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";",
69 /// "replace": "pub const DEFAULT_TIMEOUT: u64 = 60;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";"
70 /// }
71 /// },
72 /// {
73 /// "name": "find-replace-file",
74 /// "input": {
75 /// "path": "src/config.rs",
76 /// "display_description": "Update API endpoint URL",
77 /// "find": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";\n\npub const API_VERSION: &str = \"v1\";",
78 /// "replace": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.newdomain.com\";\n\npub const API_VERSION: &str = \"v1\";"
79 /// }
80 /// }
81 /// ],
82 /// "run_tools_concurrently": false
83 /// }
84 /// ```
85 /// </example>
86 ///
87 /// <example>
88 /// Searching and analyzing code (concurrent)
89 ///
90 /// ```json
91 /// {
92 /// "invocations": [
93 /// {
94 /// "name": "regex-search",
95 /// "input": {
96 /// "regex": "impl Database"
97 /// }
98 /// },
99 /// {
100 /// "name": "path-search",
101 /// "input": {
102 /// "glob": "**/*test*.rs"
103 /// }
104 /// }
105 /// ],
106 /// "run_tools_concurrently": true
107 /// }
108 /// ```
109 /// </example>
110 ///
111 /// <example>
112 /// Multi-file refactoring (concurrent)
113 ///
114 /// ```json
115 /// {
116 /// "invocations": [
117 /// {
118 /// "name": "find-replace-file",
119 /// "input": {
120 /// "path": "src/models/user.rs",
121 /// "display_description": "Add email field to User struct",
122 /// "find": "pub struct User {\n pub id: u64,\n pub username: String,\n pub created_at: DateTime<Utc>,\n}",
123 /// "replace": "pub struct User {\n pub id: u64,\n pub username: String,\n pub email: String,\n pub created_at: DateTime<Utc>,\n}"
124 /// }
125 /// },
126 /// {
127 /// "name": "find-replace-file",
128 /// "input": {
129 /// "path": "src/db/queries.rs",
130 /// "display_description": "Update user insertion query",
131 /// "find": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n conn.execute(\n \"INSERT INTO users (id, username, created_at) VALUES ($1, $2, $3)\",\n &[&user.id, &user.username, &user.created_at],\n ).await?;\n \n Ok(())\n}",
132 /// "replace": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n conn.execute(\n \"INSERT INTO users (id, username, email, created_at) VALUES ($1, $2, $3, $4)\",\n &[&user.id, &user.username, &user.email, &user.created_at],\n ).await?;\n \n Ok(())\n}"
133 /// }
134 /// }
135 /// ],
136 /// "run_tools_concurrently": true
137 /// }
138 /// ```
139 /// </example>
140 pub invocations: Vec<ToolInvocation>,
141
142 /// Whether to run the tools in this batch concurrently. If this is false (the default), the tools will run sequentially.
143 #[serde(default)]
144 pub run_tools_concurrently: bool,
145}
146
147pub struct BatchTool;
148
149impl Tool for BatchTool {
150 fn name(&self) -> String {
151 "batch-tool".into()
152 }
153
154 fn needs_confirmation(&self) -> bool {
155 true
156 }
157
158 fn description(&self) -> String {
159 include_str!("./batch_tool/description.md").into()
160 }
161
162 fn icon(&self) -> IconName {
163 IconName::Cog
164 }
165
166 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
167 json_schema_for::<BatchToolInput>(format)
168 }
169
170 fn ui_text(&self, input: &serde_json::Value) -> String {
171 match serde_json::from_value::<BatchToolInput>(input.clone()) {
172 Ok(input) => {
173 let count = input.invocations.len();
174 let mode = if input.run_tools_concurrently {
175 "concurrently"
176 } else {
177 "sequentially"
178 };
179
180 let first_tool_name = input
181 .invocations
182 .first()
183 .map(|inv| inv.name.clone())
184 .unwrap_or_default();
185
186 let all_same = input
187 .invocations
188 .iter()
189 .all(|invocation| invocation.name == first_tool_name);
190
191 if all_same {
192 format!(
193 "Run `{}` {} times {}",
194 first_tool_name,
195 input.invocations.len(),
196 mode
197 )
198 } else {
199 format!("Run {} tools {}", count, mode)
200 }
201 }
202 Err(_) => "Batch tools".to_string(),
203 }
204 }
205
206 fn run(
207 self: Arc<Self>,
208 input: serde_json::Value,
209 messages: &[LanguageModelRequestMessage],
210 project: Entity<Project>,
211 action_log: Entity<ActionLog>,
212 cx: &mut App,
213 ) -> Task<Result<String>> {
214 let input = match serde_json::from_value::<BatchToolInput>(input) {
215 Ok(input) => input,
216 Err(err) => return Task::ready(Err(anyhow!(err))),
217 };
218
219 if input.invocations.is_empty() {
220 return Task::ready(Err(anyhow!("No tool invocations provided")));
221 }
222
223 let run_tools_concurrently = input.run_tools_concurrently;
224
225 let foreground_task = {
226 let working_set = ToolWorkingSet::default();
227 let invocations = input.invocations;
228 let messages = messages.to_vec();
229
230 cx.spawn(async move |cx| {
231 let mut tasks = Vec::new();
232 let mut tool_names = Vec::new();
233
234 for invocation in invocations {
235 let tool_name = invocation.name.clone();
236 tool_names.push(tool_name.clone());
237
238 let tool = cx
239 .update(|cx| working_set.tool(&tool_name, cx))
240 .map_err(|err| {
241 anyhow!("Failed to look up tool '{}': {}", tool_name, err)
242 })?;
243
244 let Some(tool) = tool else {
245 return Err(anyhow!("Tool '{}' not found", tool_name));
246 };
247
248 let project = project.clone();
249 let action_log = action_log.clone();
250 let messages = messages.clone();
251 let task = cx
252 .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
253 .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
254
255 tasks.push(task);
256 }
257
258 Ok((tasks, tool_names))
259 })
260 };
261
262 cx.background_spawn(async move {
263 let (tasks, tool_names) = foreground_task.await?;
264 let mut results = Vec::with_capacity(tasks.len());
265
266 if run_tools_concurrently {
267 results.extend(join_all(tasks).await)
268 } else {
269 for task in tasks {
270 results.push(task.await);
271 }
272 };
273
274 let mut formatted_results = String::new();
275 let mut error_occurred = false;
276
277 for (i, result) in results.into_iter().enumerate() {
278 let tool_name = &tool_names[i];
279
280 match result {
281 Ok(output) => {
282 formatted_results
283 .push_str(&format!("Tool '{}' result:\n{}\n\n", tool_name, output));
284 }
285 Err(err) => {
286 error_occurred = true;
287 formatted_results
288 .push_str(&format!("Tool '{}' error: {}\n\n", tool_name, err));
289 }
290 }
291 }
292
293 if error_occurred {
294 formatted_results
295 .push_str("Note: Some tool invocations failed. See individual results above.");
296 }
297
298 Ok(formatted_results.trim().to_string())
299 })
300 }
301}