tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{Tool, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::future::Shared;
  7use futures::FutureExt as _;
  8use gpui::{App, SharedString, Task};
  9use language_model::{
 10    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
 11    LanguageModelToolUseId, MessageContent, Role,
 12};
 13use scripting_tool::ScriptingTool;
 14
 15use crate::thread::MessageId;
 16use crate::thread_store::SerializedMessage;
 17
 18#[derive(Debug)]
 19pub struct ToolUse {
 20    pub id: LanguageModelToolUseId,
 21    pub name: SharedString,
 22    pub ui_text: SharedString,
 23    pub status: ToolUseStatus,
 24    pub input: serde_json::Value,
 25}
 26
 27#[derive(Debug, Clone)]
 28pub enum ToolUseStatus {
 29    NeedsConfirmation,
 30    Pending,
 31    Running,
 32    Finished(SharedString),
 33    Error(SharedString),
 34}
 35
 36pub struct ToolUseState {
 37    tools: Arc<ToolWorkingSet>,
 38    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 39    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 40    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 41    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 42}
 43
 44impl ToolUseState {
 45    pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
 46        Self {
 47            tools,
 48            tool_uses_by_assistant_message: HashMap::default(),
 49            tool_uses_by_user_message: HashMap::default(),
 50            tool_results: HashMap::default(),
 51            pending_tool_uses_by_id: HashMap::default(),
 52        }
 53    }
 54
 55    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 56    ///
 57    /// Accepts a function to filter the tools that should be used to populate the state.
 58    pub fn from_serialized_messages(
 59        tools: Arc<ToolWorkingSet>,
 60        messages: &[SerializedMessage],
 61        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 62    ) -> Self {
 63        let mut this = Self::new(tools);
 64        let mut tool_names_by_id = HashMap::default();
 65
 66        for message in messages {
 67            match message.role {
 68                Role::Assistant => {
 69                    if !message.tool_uses.is_empty() {
 70                        let tool_uses = message
 71                            .tool_uses
 72                            .iter()
 73                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 74                            .map(|tool_use| LanguageModelToolUse {
 75                                id: tool_use.id.clone(),
 76                                name: tool_use.name.clone().into(),
 77                                input: tool_use.input.clone(),
 78                            })
 79                            .collect::<Vec<_>>();
 80
 81                        tool_names_by_id.extend(
 82                            tool_uses
 83                                .iter()
 84                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 85                        );
 86
 87                        this.tool_uses_by_assistant_message
 88                            .insert(message.id, tool_uses);
 89                    }
 90                }
 91                Role::User => {
 92                    if !message.tool_results.is_empty() {
 93                        let tool_uses_by_user_message = this
 94                            .tool_uses_by_user_message
 95                            .entry(message.id)
 96                            .or_default();
 97
 98                        for tool_result in &message.tool_results {
 99                            let tool_use_id = tool_result.tool_use_id.clone();
100                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
101                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
102                                continue;
103                            };
104
105                            if !(filter_by_tool_name)(tool_use.as_ref()) {
106                                continue;
107                            }
108
109                            tool_uses_by_user_message.push(tool_use_id.clone());
110                            this.tool_results.insert(
111                                tool_use_id.clone(),
112                                LanguageModelToolResult {
113                                    tool_use_id,
114                                    is_error: tool_result.is_error,
115                                    content: tool_result.content.clone(),
116                                },
117                            );
118                        }
119                    }
120                }
121                Role::System => {}
122            }
123        }
124
125        this
126    }
127
128    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
129        let mut pending_tools = Vec::new();
130        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
131            self.tool_results.insert(
132                tool_use_id.clone(),
133                LanguageModelToolResult {
134                    tool_use_id,
135                    content: "Tool canceled by user".into(),
136                    is_error: true,
137                },
138            );
139            pending_tools.push(tool_use.clone());
140        }
141        pending_tools
142    }
143
144    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
145        self.pending_tool_uses_by_id.values().collect()
146    }
147
148    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
149        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
150            return Vec::new();
151        };
152
153        let mut tool_uses = Vec::new();
154
155        for tool_use in tool_uses_for_message.iter() {
156            let tool_result = self.tool_results.get(&tool_use.id);
157
158            let status = (|| {
159                if let Some(tool_result) = tool_result {
160                    return if tool_result.is_error {
161                        ToolUseStatus::Error(tool_result.content.clone().into())
162                    } else {
163                        ToolUseStatus::Finished(tool_result.content.clone().into())
164                    };
165                }
166
167                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
168                    match pending_tool_use.status {
169                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
170                        PendingToolUseStatus::NeedsConfirmation { .. } => {
171                            ToolUseStatus::NeedsConfirmation
172                        }
173                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
174                        PendingToolUseStatus::Error(ref err) => {
175                            ToolUseStatus::Error(err.clone().into())
176                        }
177                    }
178                } else {
179                    ToolUseStatus::Pending
180                }
181            })();
182
183            tool_uses.push(ToolUse {
184                id: tool_use.id.clone(),
185                name: tool_use.name.clone().into(),
186                ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
187                input: tool_use.input.clone(),
188                status,
189            })
190        }
191
192        tool_uses
193    }
194
195    pub fn tool_ui_label(
196        &self,
197        tool_name: &str,
198        input: &serde_json::Value,
199        cx: &App,
200    ) -> SharedString {
201        if let Some(tool) = self.tools.tool(tool_name, cx) {
202            tool.ui_text(input).into()
203        } else if tool_name == ScriptingTool::NAME {
204            "Run Lua Script".into()
205        } else {
206            "Unknown tool".into()
207        }
208    }
209
210    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
211        let empty = Vec::new();
212
213        self.tool_uses_by_user_message
214            .get(&message_id)
215            .unwrap_or(&empty)
216            .iter()
217            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
218            .collect()
219    }
220
221    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
222        self.tool_uses_by_user_message
223            .get(&message_id)
224            .map_or(false, |results| !results.is_empty())
225    }
226
227    pub fn tool_result(
228        &self,
229        tool_use_id: &LanguageModelToolUseId,
230    ) -> Option<&LanguageModelToolResult> {
231        self.tool_results.get(tool_use_id)
232    }
233
234    pub fn request_tool_use(
235        &mut self,
236        assistant_message_id: MessageId,
237        tool_use: LanguageModelToolUse,
238        cx: &App,
239    ) {
240        self.tool_uses_by_assistant_message
241            .entry(assistant_message_id)
242            .or_default()
243            .push(tool_use.clone());
244
245        // The tool use is being requested by the Assistant, so we want to
246        // attach the tool results to the next user message.
247        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
248        self.tool_uses_by_user_message
249            .entry(next_user_message_id)
250            .or_default()
251            .push(tool_use.id.clone());
252
253        self.pending_tool_uses_by_id.insert(
254            tool_use.id.clone(),
255            PendingToolUse {
256                assistant_message_id,
257                id: tool_use.id,
258                name: tool_use.name.clone(),
259                ui_text: self
260                    .tool_ui_label(&tool_use.name, &tool_use.input, cx)
261                    .into(),
262                input: tool_use.input,
263                status: PendingToolUseStatus::Idle,
264            },
265        );
266    }
267
268    pub fn run_pending_tool(
269        &mut self,
270        tool_use_id: LanguageModelToolUseId,
271        ui_text: SharedString,
272        task: Task<()>,
273    ) {
274        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
275            tool_use.ui_text = ui_text.into();
276            tool_use.status = PendingToolUseStatus::Running {
277                _task: task.shared(),
278            };
279        }
280    }
281
282    pub fn confirm_tool_use(
283        &mut self,
284        tool_use_id: LanguageModelToolUseId,
285        ui_text: impl Into<Arc<str>>,
286        input: serde_json::Value,
287        messages: Arc<Vec<LanguageModelRequestMessage>>,
288        tool_type: ToolType,
289    ) {
290        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
291            let ui_text = ui_text.into();
292            tool_use.ui_text = ui_text.clone();
293            let confirmation = Confirmation {
294                tool_use_id,
295                input,
296                messages,
297                tool_type,
298                ui_text,
299            };
300            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
301        }
302    }
303
304    pub fn insert_tool_output(
305        &mut self,
306        tool_use_id: LanguageModelToolUseId,
307        output: Result<String>,
308    ) -> Option<PendingToolUse> {
309        match output {
310            Ok(tool_result) => {
311                self.tool_results.insert(
312                    tool_use_id.clone(),
313                    LanguageModelToolResult {
314                        tool_use_id: tool_use_id.clone(),
315                        content: tool_result.into(),
316                        is_error: false,
317                    },
318                );
319                self.pending_tool_uses_by_id.remove(&tool_use_id)
320            }
321            Err(err) => {
322                self.tool_results.insert(
323                    tool_use_id.clone(),
324                    LanguageModelToolResult {
325                        tool_use_id: tool_use_id.clone(),
326                        content: err.to_string().into(),
327                        is_error: true,
328                    },
329                );
330
331                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
332                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
333                }
334
335                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
336            }
337        }
338    }
339
340    pub fn attach_tool_uses(
341        &self,
342        message_id: MessageId,
343        request_message: &mut LanguageModelRequestMessage,
344    ) {
345        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
346            for tool_use in tool_uses {
347                if self.tool_results.contains_key(&tool_use.id) {
348                    // Do not send tool uses until they are completed
349                    request_message
350                        .content
351                        .push(MessageContent::ToolUse(tool_use.clone()));
352                } else {
353                    log::debug!(
354                        "skipped tool use {:?} because it is still pending",
355                        tool_use
356                    );
357                }
358            }
359        }
360    }
361
362    pub fn attach_tool_results(
363        &self,
364        message_id: MessageId,
365        request_message: &mut LanguageModelRequestMessage,
366    ) {
367        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
368            for tool_use_id in tool_uses {
369                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
370                    request_message.content.push(MessageContent::ToolResult(
371                        LanguageModelToolResult {
372                            tool_use_id: tool_use_id.clone(),
373                            is_error: tool_result.is_error,
374                            content: if tool_result.content.is_empty() {
375                                // Surprisingly, the API fails if we return an empty string here.
376                                // It thinks we are sending a tool use without a tool result.
377                                "<Tool returned an empty string>".into()
378                            } else {
379                                tool_result.content.clone()
380                            },
381                        },
382                    ));
383                }
384            }
385        }
386    }
387}
388
389#[derive(Debug, Clone)]
390pub struct PendingToolUse {
391    pub id: LanguageModelToolUseId,
392    /// The ID of the Assistant message in which the tool use was requested.
393    #[allow(unused)]
394    pub assistant_message_id: MessageId,
395    pub name: Arc<str>,
396    pub ui_text: Arc<str>,
397    pub input: serde_json::Value,
398    pub status: PendingToolUseStatus,
399}
400
401#[derive(Debug, Clone)]
402pub enum ToolType {
403    ScriptingTool,
404    NonScriptingTool(Arc<dyn Tool>),
405}
406
407#[derive(Debug, Clone)]
408pub struct Confirmation {
409    pub tool_use_id: LanguageModelToolUseId,
410    pub input: serde_json::Value,
411    pub ui_text: Arc<str>,
412    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
413    pub tool_type: ToolType,
414}
415
416#[derive(Debug, Clone)]
417pub enum PendingToolUseStatus {
418    Idle,
419    NeedsConfirmation(Arc<Confirmation>),
420    Running { _task: Shared<Task<()>> },
421    Error(#[allow(unused)] Arc<str>),
422}
423
424impl PendingToolUseStatus {
425    pub fn is_idle(&self) -> bool {
426        matches!(self, PendingToolUseStatus::Idle)
427    }
428
429    pub fn is_error(&self) -> bool {
430        matches!(self, PendingToolUseStatus::Error(_))
431    }
432
433    pub fn needs_confirmation(&self) -> bool {
434        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
435    }
436}