tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{Tool, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::future::Shared;
  7use futures::FutureExt as _;
  8use gpui::{App, SharedString, Task};
  9use language_model::{
 10    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
 11    LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14
 15use crate::thread::MessageId;
 16use crate::thread_store::SerializedMessage;
 17
 18#[derive(Debug)]
 19pub struct ToolUse {
 20    pub id: LanguageModelToolUseId,
 21    pub name: SharedString,
 22    pub ui_text: SharedString,
 23    pub status: ToolUseStatus,
 24    pub input: serde_json::Value,
 25    pub icon: ui::IconName,
 26    pub needs_confirmation: bool,
 27}
 28
 29#[derive(Debug, Clone)]
 30pub enum ToolUseStatus {
 31    NeedsConfirmation,
 32    Pending,
 33    Running,
 34    Finished(SharedString),
 35    Error(SharedString),
 36}
 37
 38pub struct ToolUseState {
 39    tools: Arc<ToolWorkingSet>,
 40    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 41    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 42    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 43    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 44}
 45
 46impl ToolUseState {
 47    pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
 48        Self {
 49            tools,
 50            tool_uses_by_assistant_message: HashMap::default(),
 51            tool_uses_by_user_message: HashMap::default(),
 52            tool_results: HashMap::default(),
 53            pending_tool_uses_by_id: HashMap::default(),
 54        }
 55    }
 56
 57    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 58    ///
 59    /// Accepts a function to filter the tools that should be used to populate the state.
 60    pub fn from_serialized_messages(
 61        tools: Arc<ToolWorkingSet>,
 62        messages: &[SerializedMessage],
 63        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 64    ) -> Self {
 65        let mut this = Self::new(tools);
 66        let mut tool_names_by_id = HashMap::default();
 67
 68        for message in messages {
 69            match message.role {
 70                Role::Assistant => {
 71                    if !message.tool_uses.is_empty() {
 72                        let tool_uses = message
 73                            .tool_uses
 74                            .iter()
 75                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 76                            .map(|tool_use| LanguageModelToolUse {
 77                                id: tool_use.id.clone(),
 78                                name: tool_use.name.clone().into(),
 79                                input: tool_use.input.clone(),
 80                            })
 81                            .collect::<Vec<_>>();
 82
 83                        tool_names_by_id.extend(
 84                            tool_uses
 85                                .iter()
 86                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 87                        );
 88
 89                        this.tool_uses_by_assistant_message
 90                            .insert(message.id, tool_uses);
 91                    }
 92                }
 93                Role::User => {
 94                    if !message.tool_results.is_empty() {
 95                        let tool_uses_by_user_message = this
 96                            .tool_uses_by_user_message
 97                            .entry(message.id)
 98                            .or_default();
 99
100                        for tool_result in &message.tool_results {
101                            let tool_use_id = tool_result.tool_use_id.clone();
102                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
103                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
104                                continue;
105                            };
106
107                            if !(filter_by_tool_name)(tool_use.as_ref()) {
108                                continue;
109                            }
110
111                            tool_uses_by_user_message.push(tool_use_id.clone());
112                            this.tool_results.insert(
113                                tool_use_id.clone(),
114                                LanguageModelToolResult {
115                                    tool_use_id,
116                                    tool_name: tool_use.clone(),
117                                    is_error: tool_result.is_error,
118                                    content: tool_result.content.clone(),
119                                },
120                            );
121                        }
122                    }
123                }
124                Role::System => {}
125            }
126        }
127
128        this
129    }
130
131    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
132        let mut pending_tools = Vec::new();
133        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
134            self.tool_results.insert(
135                tool_use_id.clone(),
136                LanguageModelToolResult {
137                    tool_use_id,
138                    tool_name: tool_use.name.clone(),
139                    content: "Tool canceled by user".into(),
140                    is_error: true,
141                },
142            );
143            pending_tools.push(tool_use.clone());
144        }
145        pending_tools
146    }
147
148    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
149        self.pending_tool_uses_by_id.values().collect()
150    }
151
152    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
153        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
154            return Vec::new();
155        };
156
157        let mut tool_uses = Vec::new();
158
159        for tool_use in tool_uses_for_message.iter() {
160            let tool_result = self.tool_results.get(&tool_use.id);
161
162            let status = (|| {
163                if let Some(tool_result) = tool_result {
164                    return if tool_result.is_error {
165                        ToolUseStatus::Error(tool_result.content.clone().into())
166                    } else {
167                        ToolUseStatus::Finished(tool_result.content.clone().into())
168                    };
169                }
170
171                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
172                    match pending_tool_use.status {
173                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
174                        PendingToolUseStatus::NeedsConfirmation { .. } => {
175                            ToolUseStatus::NeedsConfirmation
176                        }
177                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
178                        PendingToolUseStatus::Error(ref err) => {
179                            ToolUseStatus::Error(err.clone().into())
180                        }
181                    }
182                } else {
183                    ToolUseStatus::Pending
184                }
185            })();
186
187            let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
188            {
189                (tool.icon(), tool.needs_confirmation())
190            } else {
191                (IconName::Cog, false)
192            };
193
194            tool_uses.push(ToolUse {
195                id: tool_use.id.clone(),
196                name: tool_use.name.clone().into(),
197                ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
198                input: tool_use.input.clone(),
199                status,
200                icon,
201                needs_confirmation,
202            })
203        }
204
205        tool_uses
206    }
207
208    pub fn tool_ui_label(
209        &self,
210        tool_name: &str,
211        input: &serde_json::Value,
212        cx: &App,
213    ) -> SharedString {
214        if let Some(tool) = self.tools.tool(tool_name, cx) {
215            tool.ui_text(input).into()
216        } else {
217            format!("Unknown tool {tool_name:?}").into()
218        }
219    }
220
221    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
222        let empty = Vec::new();
223
224        self.tool_uses_by_user_message
225            .get(&message_id)
226            .unwrap_or(&empty)
227            .iter()
228            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
229            .collect()
230    }
231
232    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
233        self.tool_uses_by_user_message
234            .get(&message_id)
235            .map_or(false, |results| !results.is_empty())
236    }
237
238    pub fn tool_result(
239        &self,
240        tool_use_id: &LanguageModelToolUseId,
241    ) -> Option<&LanguageModelToolResult> {
242        self.tool_results.get(tool_use_id)
243    }
244
245    pub fn request_tool_use(
246        &mut self,
247        assistant_message_id: MessageId,
248        tool_use: LanguageModelToolUse,
249        cx: &App,
250    ) {
251        self.tool_uses_by_assistant_message
252            .entry(assistant_message_id)
253            .or_default()
254            .push(tool_use.clone());
255
256        // The tool use is being requested by the Assistant, so we want to
257        // attach the tool results to the next user message.
258        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
259        self.tool_uses_by_user_message
260            .entry(next_user_message_id)
261            .or_default()
262            .push(tool_use.id.clone());
263
264        self.pending_tool_uses_by_id.insert(
265            tool_use.id.clone(),
266            PendingToolUse {
267                assistant_message_id,
268                id: tool_use.id,
269                name: tool_use.name.clone(),
270                ui_text: self
271                    .tool_ui_label(&tool_use.name, &tool_use.input, cx)
272                    .into(),
273                input: tool_use.input,
274                status: PendingToolUseStatus::Idle,
275            },
276        );
277    }
278
279    pub fn run_pending_tool(
280        &mut self,
281        tool_use_id: LanguageModelToolUseId,
282        ui_text: SharedString,
283        task: Task<()>,
284    ) {
285        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
286            tool_use.ui_text = ui_text.into();
287            tool_use.status = PendingToolUseStatus::Running {
288                _task: task.shared(),
289            };
290        }
291    }
292
293    pub fn confirm_tool_use(
294        &mut self,
295        tool_use_id: LanguageModelToolUseId,
296        ui_text: impl Into<Arc<str>>,
297        input: serde_json::Value,
298        messages: Arc<Vec<LanguageModelRequestMessage>>,
299        tool: Arc<dyn Tool>,
300    ) {
301        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
302            let ui_text = ui_text.into();
303            tool_use.ui_text = ui_text.clone();
304            let confirmation = Confirmation {
305                tool_use_id,
306                input,
307                messages,
308                tool,
309                ui_text,
310            };
311            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
312        }
313    }
314
315    pub fn insert_tool_output(
316        &mut self,
317        tool_use_id: LanguageModelToolUseId,
318        tool_name: Arc<str>,
319        output: Result<String>,
320    ) -> Option<PendingToolUse> {
321        match output {
322            Ok(tool_result) => {
323                self.tool_results.insert(
324                    tool_use_id.clone(),
325                    LanguageModelToolResult {
326                        tool_use_id: tool_use_id.clone(),
327                        tool_name,
328                        content: tool_result.into(),
329                        is_error: false,
330                    },
331                );
332                self.pending_tool_uses_by_id.remove(&tool_use_id)
333            }
334            Err(err) => {
335                self.tool_results.insert(
336                    tool_use_id.clone(),
337                    LanguageModelToolResult {
338                        tool_use_id: tool_use_id.clone(),
339                        tool_name,
340                        content: err.to_string().into(),
341                        is_error: true,
342                    },
343                );
344
345                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
346                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
347                }
348
349                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
350            }
351        }
352    }
353
354    pub fn attach_tool_uses(
355        &self,
356        message_id: MessageId,
357        request_message: &mut LanguageModelRequestMessage,
358    ) {
359        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
360            for tool_use in tool_uses {
361                if self.tool_results.contains_key(&tool_use.id) {
362                    // Do not send tool uses until they are completed
363                    request_message
364                        .content
365                        .push(MessageContent::ToolUse(tool_use.clone()));
366                } else {
367                    log::debug!(
368                        "skipped tool use {:?} because it is still pending",
369                        tool_use
370                    );
371                }
372            }
373        }
374    }
375
376    pub fn attach_tool_results(
377        &self,
378        message_id: MessageId,
379        request_message: &mut LanguageModelRequestMessage,
380    ) {
381        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
382            for tool_use_id in tool_uses {
383                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
384                    request_message.content.push(MessageContent::ToolResult(
385                        LanguageModelToolResult {
386                            tool_use_id: tool_use_id.clone(),
387                            tool_name: tool_result.tool_name.clone(),
388                            is_error: tool_result.is_error,
389                            content: if tool_result.content.is_empty() {
390                                // Surprisingly, the API fails if we return an empty string here.
391                                // It thinks we are sending a tool use without a tool result.
392                                "<Tool returned an empty string>".into()
393                            } else {
394                                tool_result.content.clone()
395                            },
396                        },
397                    ));
398                }
399            }
400        }
401    }
402}
403
404#[derive(Debug, Clone)]
405pub struct PendingToolUse {
406    pub id: LanguageModelToolUseId,
407    /// The ID of the Assistant message in which the tool use was requested.
408    #[allow(unused)]
409    pub assistant_message_id: MessageId,
410    pub name: Arc<str>,
411    pub ui_text: Arc<str>,
412    pub input: serde_json::Value,
413    pub status: PendingToolUseStatus,
414}
415
416#[derive(Debug, Clone)]
417pub struct Confirmation {
418    pub tool_use_id: LanguageModelToolUseId,
419    pub input: serde_json::Value,
420    pub ui_text: Arc<str>,
421    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
422    pub tool: Arc<dyn Tool>,
423}
424
425#[derive(Debug, Clone)]
426pub enum PendingToolUseStatus {
427    Idle,
428    NeedsConfirmation(Arc<Confirmation>),
429    Running { _task: Shared<Task<()>> },
430    Error(#[allow(unused)] Arc<str>),
431}
432
433impl PendingToolUseStatus {
434    pub fn is_idle(&self) -> bool {
435        matches!(self, PendingToolUseStatus::Idle)
436    }
437
438    pub fn is_error(&self) -> bool {
439        matches!(self, PendingToolUseStatus::Error(_))
440    }
441
442    pub fn needs_confirmation(&self) -> bool {
443        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
444    }
445}