1use std::sync::Arc;
2
3use anyhow::Result;
4use assistant_tool::{Tool, ToolWorkingSet};
5use collections::HashMap;
6use futures::future::Shared;
7use futures::FutureExt as _;
8use gpui::{App, SharedString, Task};
9use language_model::{
10 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
11 LanguageModelToolUseId, MessageContent, Role,
12};
13use scripting_tool::ScriptingTool;
14
15use crate::thread::MessageId;
16use crate::thread_store::SerializedMessage;
17
18#[derive(Debug)]
19pub struct ToolUse {
20 pub id: LanguageModelToolUseId,
21 pub name: SharedString,
22 pub ui_text: SharedString,
23 pub status: ToolUseStatus,
24 pub input: serde_json::Value,
25}
26
27#[derive(Debug, Clone)]
28pub enum ToolUseStatus {
29 NeedsConfirmation,
30 Pending,
31 Running,
32 Finished(SharedString),
33 Error(SharedString),
34}
35
36pub struct ToolUseState {
37 tools: Arc<ToolWorkingSet>,
38 tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
39 tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
40 tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
41 pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
42}
43
44impl ToolUseState {
45 pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
46 Self {
47 tools,
48 tool_uses_by_assistant_message: HashMap::default(),
49 tool_uses_by_user_message: HashMap::default(),
50 tool_results: HashMap::default(),
51 pending_tool_uses_by_id: HashMap::default(),
52 }
53 }
54
55 /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
56 ///
57 /// Accepts a function to filter the tools that should be used to populate the state.
58 pub fn from_serialized_messages(
59 tools: Arc<ToolWorkingSet>,
60 messages: &[SerializedMessage],
61 mut filter_by_tool_name: impl FnMut(&str) -> bool,
62 ) -> Self {
63 let mut this = Self::new(tools);
64 let mut tool_names_by_id = HashMap::default();
65
66 for message in messages {
67 match message.role {
68 Role::Assistant => {
69 if !message.tool_uses.is_empty() {
70 let tool_uses = message
71 .tool_uses
72 .iter()
73 .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
74 .map(|tool_use| LanguageModelToolUse {
75 id: tool_use.id.clone(),
76 name: tool_use.name.clone().into(),
77 input: tool_use.input.clone(),
78 })
79 .collect::<Vec<_>>();
80
81 tool_names_by_id.extend(
82 tool_uses
83 .iter()
84 .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
85 );
86
87 this.tool_uses_by_assistant_message
88 .insert(message.id, tool_uses);
89 }
90 }
91 Role::User => {
92 if !message.tool_results.is_empty() {
93 let tool_uses_by_user_message = this
94 .tool_uses_by_user_message
95 .entry(message.id)
96 .or_default();
97
98 for tool_result in &message.tool_results {
99 let tool_use_id = tool_result.tool_use_id.clone();
100 let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
101 log::warn!("no tool name found for tool use: {tool_use_id:?}");
102 continue;
103 };
104
105 if !(filter_by_tool_name)(tool_use.as_ref()) {
106 continue;
107 }
108
109 tool_uses_by_user_message.push(tool_use_id.clone());
110 this.tool_results.insert(
111 tool_use_id.clone(),
112 LanguageModelToolResult {
113 tool_use_id,
114 is_error: tool_result.is_error,
115 content: tool_result.content.clone(),
116 },
117 );
118 }
119 }
120 }
121 Role::System => {}
122 }
123 }
124
125 this
126 }
127
128 pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
129 let mut pending_tools = Vec::new();
130 for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
131 self.tool_results.insert(
132 tool_use_id.clone(),
133 LanguageModelToolResult {
134 tool_use_id,
135 content: "Tool canceled by user".into(),
136 is_error: true,
137 },
138 );
139 pending_tools.push(tool_use.clone());
140 }
141 pending_tools
142 }
143
144 pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
145 self.pending_tool_uses_by_id.values().collect()
146 }
147
148 pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
149 let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
150 return Vec::new();
151 };
152
153 let mut tool_uses = Vec::new();
154
155 for tool_use in tool_uses_for_message.iter() {
156 let tool_result = self.tool_results.get(&tool_use.id);
157
158 let status = (|| {
159 if let Some(tool_result) = tool_result {
160 return if tool_result.is_error {
161 ToolUseStatus::Error(tool_result.content.clone().into())
162 } else {
163 ToolUseStatus::Finished(tool_result.content.clone().into())
164 };
165 }
166
167 if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
168 match pending_tool_use.status {
169 PendingToolUseStatus::Idle => ToolUseStatus::Pending,
170 PendingToolUseStatus::NeedsConfirmation { .. } => {
171 ToolUseStatus::NeedsConfirmation
172 }
173 PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
174 PendingToolUseStatus::Error(ref err) => {
175 ToolUseStatus::Error(err.clone().into())
176 }
177 }
178 } else {
179 ToolUseStatus::Pending
180 }
181 })();
182
183 tool_uses.push(ToolUse {
184 id: tool_use.id.clone(),
185 name: tool_use.name.clone().into(),
186 ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
187 input: tool_use.input.clone(),
188 status,
189 })
190 }
191
192 tool_uses
193 }
194
195 pub fn tool_ui_label(
196 &self,
197 tool_name: &str,
198 input: &serde_json::Value,
199 cx: &App,
200 ) -> SharedString {
201 if let Some(tool) = self.tools.tool(tool_name, cx) {
202 tool.ui_text(input).into()
203 } else if tool_name == ScriptingTool::NAME {
204 "Run Lua Script".into()
205 } else {
206 "Unknown tool".into()
207 }
208 }
209
210 pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
211 let empty = Vec::new();
212
213 self.tool_uses_by_user_message
214 .get(&message_id)
215 .unwrap_or(&empty)
216 .iter()
217 .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
218 .collect()
219 }
220
221 pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
222 self.tool_uses_by_user_message
223 .get(&message_id)
224 .map_or(false, |results| !results.is_empty())
225 }
226
227 pub fn tool_result(
228 &self,
229 tool_use_id: &LanguageModelToolUseId,
230 ) -> Option<&LanguageModelToolResult> {
231 self.tool_results.get(tool_use_id)
232 }
233
234 pub fn request_tool_use(
235 &mut self,
236 assistant_message_id: MessageId,
237 tool_use: LanguageModelToolUse,
238 cx: &App,
239 ) {
240 self.tool_uses_by_assistant_message
241 .entry(assistant_message_id)
242 .or_default()
243 .push(tool_use.clone());
244
245 // The tool use is being requested by the Assistant, so we want to
246 // attach the tool results to the next user message.
247 let next_user_message_id = MessageId(assistant_message_id.0 + 1);
248 self.tool_uses_by_user_message
249 .entry(next_user_message_id)
250 .or_default()
251 .push(tool_use.id.clone());
252
253 self.pending_tool_uses_by_id.insert(
254 tool_use.id.clone(),
255 PendingToolUse {
256 assistant_message_id,
257 id: tool_use.id,
258 name: tool_use.name.clone(),
259 ui_text: self
260 .tool_ui_label(&tool_use.name, &tool_use.input, cx)
261 .into(),
262 input: tool_use.input,
263 status: PendingToolUseStatus::Idle,
264 },
265 );
266 }
267
268 pub fn run_pending_tool(
269 &mut self,
270 tool_use_id: LanguageModelToolUseId,
271 ui_text: SharedString,
272 task: Task<()>,
273 ) {
274 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
275 tool_use.ui_text = ui_text.into();
276 tool_use.status = PendingToolUseStatus::Running {
277 _task: task.shared(),
278 };
279 }
280 }
281
282 pub fn confirm_tool_use(
283 &mut self,
284 tool_use_id: LanguageModelToolUseId,
285 ui_text: impl Into<Arc<str>>,
286 input: serde_json::Value,
287 messages: Arc<Vec<LanguageModelRequestMessage>>,
288 tool_type: ToolType,
289 ) {
290 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
291 let ui_text = ui_text.into();
292 tool_use.ui_text = ui_text.clone();
293 let confirmation = Confirmation {
294 tool_use_id,
295 input,
296 messages,
297 tool_type,
298 ui_text,
299 };
300 tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
301 }
302 }
303
304 pub fn insert_tool_output(
305 &mut self,
306 tool_use_id: LanguageModelToolUseId,
307 output: Result<String>,
308 ) -> Option<PendingToolUse> {
309 match output {
310 Ok(tool_result) => {
311 self.tool_results.insert(
312 tool_use_id.clone(),
313 LanguageModelToolResult {
314 tool_use_id: tool_use_id.clone(),
315 content: tool_result.into(),
316 is_error: false,
317 },
318 );
319 self.pending_tool_uses_by_id.remove(&tool_use_id)
320 }
321 Err(err) => {
322 self.tool_results.insert(
323 tool_use_id.clone(),
324 LanguageModelToolResult {
325 tool_use_id: tool_use_id.clone(),
326 content: err.to_string().into(),
327 is_error: true,
328 },
329 );
330
331 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
332 tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
333 }
334
335 self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
336 }
337 }
338 }
339
340 pub fn attach_tool_uses(
341 &self,
342 message_id: MessageId,
343 request_message: &mut LanguageModelRequestMessage,
344 ) {
345 if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
346 for tool_use in tool_uses {
347 if self.tool_results.contains_key(&tool_use.id) {
348 // Do not send tool uses until they are completed
349 request_message
350 .content
351 .push(MessageContent::ToolUse(tool_use.clone()));
352 } else {
353 log::debug!(
354 "skipped tool use {:?} because it is still pending",
355 tool_use
356 );
357 }
358 }
359 }
360 }
361
362 pub fn attach_tool_results(
363 &self,
364 message_id: MessageId,
365 request_message: &mut LanguageModelRequestMessage,
366 ) {
367 if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
368 for tool_use_id in tool_uses {
369 if let Some(tool_result) = self.tool_results.get(tool_use_id) {
370 request_message.content.push(MessageContent::ToolResult(
371 LanguageModelToolResult {
372 tool_use_id: tool_use_id.clone(),
373 is_error: tool_result.is_error,
374 content: if tool_result.content.is_empty() {
375 // Surprisingly, the API fails if we return an empty string here.
376 // It thinks we are sending a tool use without a tool result.
377 "<Tool returned an empty string>".into()
378 } else {
379 tool_result.content.clone()
380 },
381 },
382 ));
383 }
384 }
385 }
386 }
387}
388
389#[derive(Debug, Clone)]
390pub struct PendingToolUse {
391 pub id: LanguageModelToolUseId,
392 /// The ID of the Assistant message in which the tool use was requested.
393 #[allow(unused)]
394 pub assistant_message_id: MessageId,
395 pub name: Arc<str>,
396 pub ui_text: Arc<str>,
397 pub input: serde_json::Value,
398 pub status: PendingToolUseStatus,
399}
400
401#[derive(Debug, Clone)]
402pub enum ToolType {
403 ScriptingTool,
404 NonScriptingTool(Arc<dyn Tool>),
405}
406
407#[derive(Debug, Clone)]
408pub struct Confirmation {
409 pub tool_use_id: LanguageModelToolUseId,
410 pub input: serde_json::Value,
411 pub ui_text: Arc<str>,
412 pub messages: Arc<Vec<LanguageModelRequestMessage>>,
413 pub tool_type: ToolType,
414}
415
416#[derive(Debug, Clone)]
417pub enum PendingToolUseStatus {
418 Idle,
419 NeedsConfirmation(Arc<Confirmation>),
420 Running { _task: Shared<Task<()>> },
421 Error(#[allow(unused)] Arc<str>),
422}
423
424impl PendingToolUseStatus {
425 pub fn is_idle(&self) -> bool {
426 matches!(self, PendingToolUseStatus::Idle)
427 }
428
429 pub fn is_error(&self) -> bool {
430 matches!(self, PendingToolUseStatus::Error(_))
431 }
432
433 pub fn needs_confirmation(&self) -> bool {
434 matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
435 }
436}