tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{Tool, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, SharedString, Task};
  9use language_model::{
 10    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
 11    LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14
 15use crate::thread::MessageId;
 16use crate::thread_store::SerializedMessage;
 17
 18#[derive(Debug)]
 19pub struct ToolUse {
 20    pub id: LanguageModelToolUseId,
 21    pub name: SharedString,
 22    pub ui_text: SharedString,
 23    pub status: ToolUseStatus,
 24    pub input: serde_json::Value,
 25    pub icon: ui::IconName,
 26    pub needs_confirmation: bool,
 27}
 28
 29#[derive(Debug, Clone)]
 30pub enum ToolUseStatus {
 31    NeedsConfirmation,
 32    Pending,
 33    Running,
 34    Finished(SharedString),
 35    Error(SharedString),
 36}
 37
 38pub struct ToolUseState {
 39    tools: Arc<ToolWorkingSet>,
 40    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 41    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 42    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 43    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 44}
 45
 46pub const USING_TOOL_MARKER: &str = "<using_tool>";
 47
 48impl ToolUseState {
 49    pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
 50        Self {
 51            tools,
 52            tool_uses_by_assistant_message: HashMap::default(),
 53            tool_uses_by_user_message: HashMap::default(),
 54            tool_results: HashMap::default(),
 55            pending_tool_uses_by_id: HashMap::default(),
 56        }
 57    }
 58
 59    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 60    ///
 61    /// Accepts a function to filter the tools that should be used to populate the state.
 62    pub fn from_serialized_messages(
 63        tools: Arc<ToolWorkingSet>,
 64        messages: &[SerializedMessage],
 65        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 66    ) -> Self {
 67        let mut this = Self::new(tools);
 68        let mut tool_names_by_id = HashMap::default();
 69
 70        for message in messages {
 71            match message.role {
 72                Role::Assistant => {
 73                    if !message.tool_uses.is_empty() {
 74                        let tool_uses = message
 75                            .tool_uses
 76                            .iter()
 77                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 78                            .map(|tool_use| LanguageModelToolUse {
 79                                id: tool_use.id.clone(),
 80                                name: tool_use.name.clone().into(),
 81                                input: tool_use.input.clone(),
 82                            })
 83                            .collect::<Vec<_>>();
 84
 85                        tool_names_by_id.extend(
 86                            tool_uses
 87                                .iter()
 88                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 89                        );
 90
 91                        this.tool_uses_by_assistant_message
 92                            .insert(message.id, tool_uses);
 93                    }
 94                }
 95                Role::User => {
 96                    if !message.tool_results.is_empty() {
 97                        let tool_uses_by_user_message = this
 98                            .tool_uses_by_user_message
 99                            .entry(message.id)
100                            .or_default();
101
102                        for tool_result in &message.tool_results {
103                            let tool_use_id = tool_result.tool_use_id.clone();
104                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
105                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
106                                continue;
107                            };
108
109                            if !(filter_by_tool_name)(tool_use.as_ref()) {
110                                continue;
111                            }
112
113                            tool_uses_by_user_message.push(tool_use_id.clone());
114                            this.tool_results.insert(
115                                tool_use_id.clone(),
116                                LanguageModelToolResult {
117                                    tool_use_id,
118                                    tool_name: tool_use.clone(),
119                                    is_error: tool_result.is_error,
120                                    content: tool_result.content.clone(),
121                                },
122                            );
123                        }
124                    }
125                }
126                Role::System => {}
127            }
128        }
129
130        this
131    }
132
133    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
134        let mut pending_tools = Vec::new();
135        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
136            self.tool_results.insert(
137                tool_use_id.clone(),
138                LanguageModelToolResult {
139                    tool_use_id,
140                    tool_name: tool_use.name.clone(),
141                    content: "Tool canceled by user".into(),
142                    is_error: true,
143                },
144            );
145            pending_tools.push(tool_use.clone());
146        }
147        pending_tools
148    }
149
150    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
151        self.pending_tool_uses_by_id.values().collect()
152    }
153
154    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
155        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
156            return Vec::new();
157        };
158
159        let mut tool_uses = Vec::new();
160
161        for tool_use in tool_uses_for_message.iter() {
162            let tool_result = self.tool_results.get(&tool_use.id);
163
164            let status = (|| {
165                if let Some(tool_result) = tool_result {
166                    return if tool_result.is_error {
167                        ToolUseStatus::Error(tool_result.content.clone().into())
168                    } else {
169                        ToolUseStatus::Finished(tool_result.content.clone().into())
170                    };
171                }
172
173                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
174                    match pending_tool_use.status {
175                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
176                        PendingToolUseStatus::NeedsConfirmation { .. } => {
177                            ToolUseStatus::NeedsConfirmation
178                        }
179                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
180                        PendingToolUseStatus::Error(ref err) => {
181                            ToolUseStatus::Error(err.clone().into())
182                        }
183                    }
184                } else {
185                    ToolUseStatus::Pending
186                }
187            })();
188
189            let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
190            {
191                (tool.icon(), tool.needs_confirmation())
192            } else {
193                (IconName::Cog, false)
194            };
195
196            tool_uses.push(ToolUse {
197                id: tool_use.id.clone(),
198                name: tool_use.name.clone().into(),
199                ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
200                input: tool_use.input.clone(),
201                status,
202                icon,
203                needs_confirmation,
204            })
205        }
206
207        tool_uses
208    }
209
210    pub fn tool_ui_label(
211        &self,
212        tool_name: &str,
213        input: &serde_json::Value,
214        cx: &App,
215    ) -> SharedString {
216        if let Some(tool) = self.tools.tool(tool_name, cx) {
217            tool.ui_text(input).into()
218        } else {
219            format!("Unknown tool {tool_name:?}").into()
220        }
221    }
222
223    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
224        let empty = Vec::new();
225
226        self.tool_uses_by_user_message
227            .get(&message_id)
228            .unwrap_or(&empty)
229            .iter()
230            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
231            .collect()
232    }
233
234    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
235        self.tool_uses_by_user_message
236            .get(&message_id)
237            .map_or(false, |results| !results.is_empty())
238    }
239
240    pub fn tool_result(
241        &self,
242        tool_use_id: &LanguageModelToolUseId,
243    ) -> Option<&LanguageModelToolResult> {
244        self.tool_results.get(tool_use_id)
245    }
246
247    pub fn request_tool_use(
248        &mut self,
249        assistant_message_id: MessageId,
250        tool_use: LanguageModelToolUse,
251        cx: &App,
252    ) {
253        self.tool_uses_by_assistant_message
254            .entry(assistant_message_id)
255            .or_default()
256            .push(tool_use.clone());
257
258        // The tool use is being requested by the Assistant, so we want to
259        // attach the tool results to the next user message.
260        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
261        self.tool_uses_by_user_message
262            .entry(next_user_message_id)
263            .or_default()
264            .push(tool_use.id.clone());
265
266        self.pending_tool_uses_by_id.insert(
267            tool_use.id.clone(),
268            PendingToolUse {
269                assistant_message_id,
270                id: tool_use.id,
271                name: tool_use.name.clone(),
272                ui_text: self
273                    .tool_ui_label(&tool_use.name, &tool_use.input, cx)
274                    .into(),
275                input: tool_use.input,
276                status: PendingToolUseStatus::Idle,
277            },
278        );
279    }
280
281    pub fn run_pending_tool(
282        &mut self,
283        tool_use_id: LanguageModelToolUseId,
284        ui_text: SharedString,
285        task: Task<()>,
286    ) {
287        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
288            tool_use.ui_text = ui_text.into();
289            tool_use.status = PendingToolUseStatus::Running {
290                _task: task.shared(),
291            };
292        }
293    }
294
295    pub fn confirm_tool_use(
296        &mut self,
297        tool_use_id: LanguageModelToolUseId,
298        ui_text: impl Into<Arc<str>>,
299        input: serde_json::Value,
300        messages: Arc<Vec<LanguageModelRequestMessage>>,
301        tool: Arc<dyn Tool>,
302    ) {
303        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
304            let ui_text = ui_text.into();
305            tool_use.ui_text = ui_text.clone();
306            let confirmation = Confirmation {
307                tool_use_id,
308                input,
309                messages,
310                tool,
311                ui_text,
312            };
313            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
314        }
315    }
316
317    pub fn insert_tool_output(
318        &mut self,
319        tool_use_id: LanguageModelToolUseId,
320        tool_name: Arc<str>,
321        output: Result<String>,
322    ) -> Option<PendingToolUse> {
323        match output {
324            Ok(tool_result) => {
325                self.tool_results.insert(
326                    tool_use_id.clone(),
327                    LanguageModelToolResult {
328                        tool_use_id: tool_use_id.clone(),
329                        tool_name,
330                        content: tool_result.into(),
331                        is_error: false,
332                    },
333                );
334                self.pending_tool_uses_by_id.remove(&tool_use_id)
335            }
336            Err(err) => {
337                self.tool_results.insert(
338                    tool_use_id.clone(),
339                    LanguageModelToolResult {
340                        tool_use_id: tool_use_id.clone(),
341                        tool_name,
342                        content: err.to_string().into(),
343                        is_error: true,
344                    },
345                );
346
347                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
348                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
349                }
350
351                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
352            }
353        }
354    }
355
356    pub fn attach_tool_uses(
357        &self,
358        message_id: MessageId,
359        request_message: &mut LanguageModelRequestMessage,
360    ) {
361        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
362            let mut found_tool_use = false;
363
364            for tool_use in tool_uses {
365                if self.tool_results.contains_key(&tool_use.id) {
366                    if !found_tool_use {
367                        // The API fails if a message contains a tool use without any (non-whitespace) text around it
368                        match request_message.content.last_mut() {
369                            Some(MessageContent::Text(txt)) => {
370                                if txt.is_empty() {
371                                    txt.push_str(USING_TOOL_MARKER);
372                                }
373                            }
374                            None | Some(_) => {
375                                request_message
376                                    .content
377                                    .push(MessageContent::Text(USING_TOOL_MARKER.into()));
378                            }
379                        };
380                    }
381
382                    found_tool_use = true;
383
384                    // Do not send tool uses until they are completed
385                    request_message
386                        .content
387                        .push(MessageContent::ToolUse(tool_use.clone()));
388                } else {
389                    log::debug!(
390                        "skipped tool use {:?} because it is still pending",
391                        tool_use
392                    );
393                }
394            }
395        }
396    }
397
398    pub fn attach_tool_results(
399        &self,
400        message_id: MessageId,
401        request_message: &mut LanguageModelRequestMessage,
402    ) {
403        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
404            for tool_use_id in tool_uses {
405                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
406                    request_message.content.push(MessageContent::ToolResult(
407                        LanguageModelToolResult {
408                            tool_use_id: tool_use_id.clone(),
409                            tool_name: tool_result.tool_name.clone(),
410                            is_error: tool_result.is_error,
411                            content: if tool_result.content.is_empty() {
412                                // Surprisingly, the API fails if we return an empty string here.
413                                // It thinks we are sending a tool use without a tool result.
414                                "<Tool returned an empty string>".into()
415                            } else {
416                                tool_result.content.clone()
417                            },
418                        },
419                    ));
420                }
421            }
422        }
423    }
424}
425
426#[derive(Debug, Clone)]
427pub struct PendingToolUse {
428    pub id: LanguageModelToolUseId,
429    /// The ID of the Assistant message in which the tool use was requested.
430    #[allow(unused)]
431    pub assistant_message_id: MessageId,
432    pub name: Arc<str>,
433    pub ui_text: Arc<str>,
434    pub input: serde_json::Value,
435    pub status: PendingToolUseStatus,
436}
437
438#[derive(Debug, Clone)]
439pub struct Confirmation {
440    pub tool_use_id: LanguageModelToolUseId,
441    pub input: serde_json::Value,
442    pub ui_text: Arc<str>,
443    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
444    pub tool: Arc<dyn Tool>,
445}
446
447#[derive(Debug, Clone)]
448pub enum PendingToolUseStatus {
449    Idle,
450    NeedsConfirmation(Arc<Confirmation>),
451    Running { _task: Shared<Task<()>> },
452    Error(#[allow(unused)] Arc<str>),
453}
454
455impl PendingToolUseStatus {
456    pub fn is_idle(&self) -> bool {
457        matches!(self, PendingToolUseStatus::Idle)
458    }
459
460    pub fn is_error(&self) -> bool {
461        matches!(self, PendingToolUseStatus::Error(_))
462    }
463
464    pub fn needs_confirmation(&self) -> bool {
465        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
466    }
467}