batch_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
  4use futures::future::join_all;
  5use gpui::{App, AppContext, Entity, Task};
  6use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  7use project::Project;
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use std::sync::Arc;
 11use ui::IconName;
 12
 13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 14pub struct ToolInvocation {
 15    /// The name of the tool to invoke
 16    pub name: String,
 17
 18    /// The input to the tool in JSON format
 19    pub input: serde_json::Value,
 20}
 21
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct BatchToolInput {
 24    /// The tool invocations to run as a batch. These tools will be run either sequentially
 25    /// or concurrently depending on the `run_tools_concurrently` flag.
 26    ///
 27    /// <example>
 28    /// Basic file operations (concurrent)
 29    ///
 30    /// ```json
 31    /// {
 32    ///   "invocations": [
 33    ///     {
 34    ///       "name": "read_file",
 35    ///       "input": {
 36    ///         "path": "src/main.rs"
 37    ///       }
 38    ///     },
 39    ///     {
 40    ///       "name": "list_directory",
 41    ///       "input": {
 42    ///         "path": "src/lib"
 43    ///       }
 44    ///     },
 45    ///     {
 46    ///       "name": "regex_search",
 47    ///       "input": {
 48    ///         "regex": "fn run\\("
 49    ///       }
 50    ///     }
 51    ///   ],
 52    ///   "run_tools_concurrently": true
 53    /// }
 54    /// ```
 55    /// </example>
 56    ///
 57    /// <example>
 58    /// Multiple find-replace operations on the same file (sequential)
 59    ///
 60    /// ```json
 61    /// {
 62    ///   "invocations": [
 63    ///     {
 64    ///       "name": "find_replace_file",
 65    ///       "input": {
 66    ///         "path": "src/config.rs",
 67    ///         "display_description": "Update default timeout value",
 68    ///         "find": "pub const DEFAULT_TIMEOUT: u64 = 30;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";",
 69    ///         "replace": "pub const DEFAULT_TIMEOUT: u64 = 60;\n\npub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";"
 70    ///       }
 71    ///     },
 72    ///     {
 73    ///       "name": "find_replace_file",
 74    ///       "input": {
 75    ///         "path": "src/config.rs",
 76    ///         "display_description": "Update API endpoint URL",
 77    ///         "find": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.example.com\";\n\npub const API_VERSION: &str = \"v1\";",
 78    ///         "replace": "pub const MAX_RETRIES: u32 = 3;\n\npub const SERVER_URL: &str = \"https://api.newdomain.com\";\n\npub const API_VERSION: &str = \"v1\";"
 79    ///       }
 80    ///     }
 81    ///   ],
 82    ///   "run_tools_concurrently": false
 83    /// }
 84    /// ```
 85    /// </example>
 86    ///
 87    /// <example>
 88    /// Searching and analyzing code (concurrent)
 89    ///
 90    /// ```json
 91    /// {
 92    ///   "invocations": [
 93    ///     {
 94    ///       "name": "regex_search",
 95    ///       "input": {
 96    ///         "regex": "impl Database"
 97    ///       }
 98    ///     },
 99    ///     {
100    ///       "name": "path_search",
101    ///       "input": {
102    ///         "glob": "**/*test*.rs"
103    ///       }
104    ///     }
105    ///   ],
106    ///   "run_tools_concurrently": true
107    /// }
108    /// ```
109    /// </example>
110    ///
111    /// <example>
112    /// Multi-file refactoring (concurrent)
113    ///
114    /// ```json
115    /// {
116    ///   "invocations": [
117    ///     {
118    ///       "name": "find_replace_file",
119    ///       "input": {
120    ///         "path": "src/models/user.rs",
121    ///         "display_description": "Add email field to User struct",
122    ///         "find": "pub struct User {\n    pub id: u64,\n    pub username: String,\n    pub created_at: DateTime<Utc>,\n}",
123    ///         "replace": "pub struct User {\n    pub id: u64,\n    pub username: String,\n    pub email: String,\n    pub created_at: DateTime<Utc>,\n}"
124    ///       }
125    ///     },
126    ///     {
127    ///       "name": "find_replace_file",
128    ///       "input": {
129    ///         "path": "src/db/queries.rs",
130    ///         "display_description": "Update user insertion query",
131    ///         "find": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n    conn.execute(\n        \"INSERT INTO users (id, username, created_at) VALUES ($1, $2, $3)\",\n        &[&user.id, &user.username, &user.created_at],\n    ).await?;\n    \n    Ok(())\n}",
132    ///         "replace": "pub async fn insert_user(conn: &mut Connection, user: &User) -> Result<(), DbError> {\n    conn.execute(\n        \"INSERT INTO users (id, username, email, created_at) VALUES ($1, $2, $3, $4)\",\n        &[&user.id, &user.username, &user.email, &user.created_at],\n    ).await?;\n    \n    Ok(())\n}"
133    ///       }
134    ///     }
135    ///   ],
136    ///   "run_tools_concurrently": true
137    /// }
138    /// ```
139    /// </example>
140    pub invocations: Vec<ToolInvocation>,
141
142    /// Whether to run the tools in this batch concurrently. If this is false (the default), the tools will run sequentially.
143    #[serde(default)]
144    pub run_tools_concurrently: bool,
145}
146
147pub struct BatchTool;
148
149impl Tool for BatchTool {
150    fn name(&self) -> String {
151        "batch_tool".into()
152    }
153
154    fn needs_confirmation(&self) -> bool {
155        true
156    }
157
158    fn description(&self) -> String {
159        include_str!("./batch_tool/description.md").into()
160    }
161
162    fn icon(&self) -> IconName {
163        IconName::Cog
164    }
165
166    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
167        json_schema_for::<BatchToolInput>(format)
168    }
169
170    fn ui_text(&self, input: &serde_json::Value) -> String {
171        match serde_json::from_value::<BatchToolInput>(input.clone()) {
172            Ok(input) => {
173                let count = input.invocations.len();
174                let mode = if input.run_tools_concurrently {
175                    "concurrently"
176                } else {
177                    "sequentially"
178                };
179
180                let first_tool_name = input
181                    .invocations
182                    .first()
183                    .map(|inv| inv.name.clone())
184                    .unwrap_or_default();
185
186                let all_same = input
187                    .invocations
188                    .iter()
189                    .all(|invocation| invocation.name == first_tool_name);
190
191                if all_same {
192                    format!(
193                        "Run `{}` {} times {}",
194                        first_tool_name,
195                        input.invocations.len(),
196                        mode
197                    )
198                } else {
199                    format!("Run {} tools {}", count, mode)
200                }
201            }
202            Err(_) => "Batch tools".to_string(),
203        }
204    }
205
206    fn run(
207        self: Arc<Self>,
208        input: serde_json::Value,
209        messages: &[LanguageModelRequestMessage],
210        project: Entity<Project>,
211        action_log: Entity<ActionLog>,
212        cx: &mut App,
213    ) -> Task<Result<String>> {
214        let input = match serde_json::from_value::<BatchToolInput>(input) {
215            Ok(input) => input,
216            Err(err) => return Task::ready(Err(anyhow!(err))),
217        };
218
219        if input.invocations.is_empty() {
220            return Task::ready(Err(anyhow!("No tool invocations provided")));
221        }
222
223        let run_tools_concurrently = input.run_tools_concurrently;
224
225        let foreground_task = {
226            let working_set = ToolWorkingSet::default();
227            let invocations = input.invocations;
228            let messages = messages.to_vec();
229
230            cx.spawn(async move |cx| {
231                let mut tasks = Vec::new();
232                let mut tool_names = Vec::new();
233
234                for invocation in invocations {
235                    let tool_name = invocation.name.clone();
236                    tool_names.push(tool_name.clone());
237
238                    let tool = cx
239                        .update(|cx| working_set.tool(&tool_name, cx))
240                        .map_err(|err| {
241                            anyhow!("Failed to look up tool '{}': {}", tool_name, err)
242                        })?;
243
244                    let Some(tool) = tool else {
245                        return Err(anyhow!("Tool '{}' not found", tool_name));
246                    };
247
248                    let project = project.clone();
249                    let action_log = action_log.clone();
250                    let messages = messages.clone();
251                    let task = cx
252                        .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
253                        .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
254
255                    tasks.push(task);
256                }
257
258                Ok((tasks, tool_names))
259            })
260        };
261
262        cx.background_spawn(async move {
263            let (tasks, tool_names) = foreground_task.await?;
264            let mut results = Vec::with_capacity(tasks.len());
265
266            if run_tools_concurrently {
267                results.extend(join_all(tasks).await)
268            } else {
269                for task in tasks {
270                    results.push(task.await);
271                }
272            };
273
274            let mut formatted_results = String::new();
275            let mut error_occurred = false;
276
277            for (i, result) in results.into_iter().enumerate() {
278                let tool_name = &tool_names[i];
279
280                match result {
281                    Ok(output) => {
282                        formatted_results
283                            .push_str(&format!("Tool '{}' result:\n{}\n\n", tool_name, output));
284                    }
285                    Err(err) => {
286                        error_occurred = true;
287                        formatted_results
288                            .push_str(&format!("Tool '{}' error: {}\n\n", tool_name, err));
289                    }
290                }
291            }
292
293            if error_occurred {
294                formatted_results
295                    .push_str("Note: Some tool invocations failed. See individual results above.");
296            }
297
298            Ok(formatted_results.trim().to_string())
299        })
300    }
301}