tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, Entity, SharedString, Task};
  9use language_model::{
 10    LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
 11    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14use util::truncate_lines_to_byte_limit;
 15
 16use crate::thread::MessageId;
 17use crate::thread_store::SerializedMessage;
 18
 19#[derive(Debug)]
 20pub struct ToolUse {
 21    pub id: LanguageModelToolUseId,
 22    pub name: SharedString,
 23    pub ui_text: SharedString,
 24    pub status: ToolUseStatus,
 25    pub input: serde_json::Value,
 26    pub icon: ui::IconName,
 27    pub needs_confirmation: bool,
 28}
 29
 30pub const USING_TOOL_MARKER: &str = "<using_tool>";
 31
 32pub struct ToolUseState {
 33    tools: Entity<ToolWorkingSet>,
 34    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 35    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 36    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 37    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 38    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 39}
 40
 41impl ToolUseState {
 42    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 43        Self {
 44            tools,
 45            tool_uses_by_assistant_message: HashMap::default(),
 46            tool_uses_by_user_message: HashMap::default(),
 47            tool_results: HashMap::default(),
 48            pending_tool_uses_by_id: HashMap::default(),
 49            tool_result_cards: HashMap::default(),
 50        }
 51    }
 52
 53    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 54    ///
 55    /// Accepts a function to filter the tools that should be used to populate the state.
 56    pub fn from_serialized_messages(
 57        tools: Entity<ToolWorkingSet>,
 58        messages: &[SerializedMessage],
 59        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 60    ) -> Self {
 61        let mut this = Self::new(tools);
 62        let mut tool_names_by_id = HashMap::default();
 63
 64        for message in messages {
 65            match message.role {
 66                Role::Assistant => {
 67                    if !message.tool_uses.is_empty() {
 68                        let tool_uses = message
 69                            .tool_uses
 70                            .iter()
 71                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 72                            .map(|tool_use| LanguageModelToolUse {
 73                                id: tool_use.id.clone(),
 74                                name: tool_use.name.clone().into(),
 75                                input: tool_use.input.clone(),
 76                            })
 77                            .collect::<Vec<_>>();
 78
 79                        tool_names_by_id.extend(
 80                            tool_uses
 81                                .iter()
 82                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 83                        );
 84
 85                        this.tool_uses_by_assistant_message
 86                            .insert(message.id, tool_uses);
 87                    }
 88                }
 89                Role::User => {
 90                    if !message.tool_results.is_empty() {
 91                        let tool_uses_by_user_message = this
 92                            .tool_uses_by_user_message
 93                            .entry(message.id)
 94                            .or_default();
 95
 96                        for tool_result in &message.tool_results {
 97                            let tool_use_id = tool_result.tool_use_id.clone();
 98                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 99                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
100                                continue;
101                            };
102
103                            if !(filter_by_tool_name)(tool_use.as_ref()) {
104                                continue;
105                            }
106
107                            tool_uses_by_user_message.push(tool_use_id.clone());
108                            this.tool_results.insert(
109                                tool_use_id.clone(),
110                                LanguageModelToolResult {
111                                    tool_use_id,
112                                    tool_name: tool_use.clone(),
113                                    is_error: tool_result.is_error,
114                                    content: tool_result.content.clone(),
115                                },
116                            );
117                        }
118                    }
119                }
120                Role::System => {}
121            }
122        }
123
124        this
125    }
126
127    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
128        let mut pending_tools = Vec::new();
129        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
130            self.tool_results.insert(
131                tool_use_id.clone(),
132                LanguageModelToolResult {
133                    tool_use_id,
134                    tool_name: tool_use.name.clone(),
135                    content: "Tool canceled by user".into(),
136                    is_error: true,
137                },
138            );
139            pending_tools.push(tool_use.clone());
140        }
141        pending_tools
142    }
143
144    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
145        self.pending_tool_uses_by_id.values().collect()
146    }
147
148    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
149        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
150            return Vec::new();
151        };
152
153        let mut tool_uses = Vec::new();
154
155        for tool_use in tool_uses_for_message.iter() {
156            let tool_result = self.tool_results.get(&tool_use.id);
157
158            let status = (|| {
159                if let Some(tool_result) = tool_result {
160                    return if tool_result.is_error {
161                        ToolUseStatus::Error(tool_result.content.clone().into())
162                    } else {
163                        ToolUseStatus::Finished(tool_result.content.clone().into())
164                    };
165                }
166
167                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
168                    match pending_tool_use.status {
169                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
170                        PendingToolUseStatus::NeedsConfirmation { .. } => {
171                            ToolUseStatus::NeedsConfirmation
172                        }
173                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
174                        PendingToolUseStatus::Error(ref err) => {
175                            ToolUseStatus::Error(err.clone().into())
176                        }
177                    }
178                } else {
179                    ToolUseStatus::Pending
180                }
181            })();
182
183            let (icon, needs_confirmation) =
184                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
185                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
186                } else {
187                    (IconName::Cog, false)
188                };
189
190            tool_uses.push(ToolUse {
191                id: tool_use.id.clone(),
192                name: tool_use.name.clone().into(),
193                ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
194                input: tool_use.input.clone(),
195                status,
196                icon,
197                needs_confirmation,
198            })
199        }
200
201        tool_uses
202    }
203
204    pub fn tool_ui_label(
205        &self,
206        tool_name: &str,
207        input: &serde_json::Value,
208        cx: &App,
209    ) -> SharedString {
210        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
211            tool.ui_text(input).into()
212        } else {
213            format!("Unknown tool {tool_name:?}").into()
214        }
215    }
216
217    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
218        let empty = Vec::new();
219
220        self.tool_uses_by_user_message
221            .get(&message_id)
222            .unwrap_or(&empty)
223            .iter()
224            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
225            .collect()
226    }
227
228    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
229        self.tool_uses_by_user_message
230            .get(&message_id)
231            .map_or(false, |results| !results.is_empty())
232    }
233
234    pub fn tool_result(
235        &self,
236        tool_use_id: &LanguageModelToolUseId,
237    ) -> Option<&LanguageModelToolResult> {
238        self.tool_results.get(tool_use_id)
239    }
240
241    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
242        self.tool_result_cards.get(tool_use_id)
243    }
244
245    pub fn insert_tool_result_card(
246        &mut self,
247        tool_use_id: LanguageModelToolUseId,
248        card: AnyToolCard,
249    ) {
250        self.tool_result_cards.insert(tool_use_id, card);
251    }
252
253    pub fn request_tool_use(
254        &mut self,
255        assistant_message_id: MessageId,
256        tool_use: LanguageModelToolUse,
257        cx: &App,
258    ) {
259        self.tool_uses_by_assistant_message
260            .entry(assistant_message_id)
261            .or_default()
262            .push(tool_use.clone());
263
264        // The tool use is being requested by the Assistant, so we want to
265        // attach the tool results to the next user message.
266        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
267        self.tool_uses_by_user_message
268            .entry(next_user_message_id)
269            .or_default()
270            .push(tool_use.id.clone());
271
272        self.pending_tool_uses_by_id.insert(
273            tool_use.id.clone(),
274            PendingToolUse {
275                assistant_message_id,
276                id: tool_use.id,
277                name: tool_use.name.clone(),
278                ui_text: self
279                    .tool_ui_label(&tool_use.name, &tool_use.input, cx)
280                    .into(),
281                input: tool_use.input,
282                status: PendingToolUseStatus::Idle,
283            },
284        );
285    }
286
287    pub fn run_pending_tool(
288        &mut self,
289        tool_use_id: LanguageModelToolUseId,
290        ui_text: SharedString,
291        task: Task<()>,
292    ) {
293        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
294            tool_use.ui_text = ui_text.into();
295            tool_use.status = PendingToolUseStatus::Running {
296                _task: task.shared(),
297            };
298        }
299    }
300
301    pub fn confirm_tool_use(
302        &mut self,
303        tool_use_id: LanguageModelToolUseId,
304        ui_text: impl Into<Arc<str>>,
305        input: serde_json::Value,
306        messages: Arc<Vec<LanguageModelRequestMessage>>,
307        tool: Arc<dyn Tool>,
308    ) {
309        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
310            let ui_text = ui_text.into();
311            tool_use.ui_text = ui_text.clone();
312            let confirmation = Confirmation {
313                tool_use_id,
314                input,
315                messages,
316                tool,
317                ui_text,
318            };
319            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
320        }
321    }
322
323    pub fn insert_tool_output(
324        &mut self,
325        tool_use_id: LanguageModelToolUseId,
326        tool_name: Arc<str>,
327        output: Result<String>,
328        cx: &App,
329    ) -> Option<PendingToolUse> {
330        telemetry::event!("Agent Tool Finished", tool_name, success = output.is_ok());
331
332        match output {
333            Ok(tool_result) => {
334                let model_registry = LanguageModelRegistry::read_global(cx);
335
336                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
337
338                // Protect from clearly large output
339                let tool_output_limit = model_registry
340                    .default_model()
341                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
342                    .unwrap_or(usize::MAX);
343
344                let tool_result = if tool_result.len() <= tool_output_limit {
345                    tool_result
346                } else {
347                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
348
349                    format!(
350                        "Tool result too long. The first {} bytes:\n\n{}",
351                        truncated.len(),
352                        truncated
353                    )
354                };
355
356                self.tool_results.insert(
357                    tool_use_id.clone(),
358                    LanguageModelToolResult {
359                        tool_use_id: tool_use_id.clone(),
360                        tool_name,
361                        content: tool_result.into(),
362                        is_error: false,
363                    },
364                );
365                self.pending_tool_uses_by_id.remove(&tool_use_id)
366            }
367            Err(err) => {
368                self.tool_results.insert(
369                    tool_use_id.clone(),
370                    LanguageModelToolResult {
371                        tool_use_id: tool_use_id.clone(),
372                        tool_name,
373                        content: err.to_string().into(),
374                        is_error: true,
375                    },
376                );
377
378                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
379                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
380                }
381
382                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
383            }
384        }
385    }
386
387    pub fn attach_tool_uses(
388        &self,
389        message_id: MessageId,
390        request_message: &mut LanguageModelRequestMessage,
391    ) {
392        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
393            let mut found_tool_use = false;
394
395            for tool_use in tool_uses {
396                if self.tool_results.contains_key(&tool_use.id) {
397                    if !found_tool_use {
398                        // The API fails if a message contains a tool use without any (non-whitespace) text around it
399                        match request_message.content.last_mut() {
400                            Some(MessageContent::Text(txt)) => {
401                                if txt.is_empty() {
402                                    txt.push_str(USING_TOOL_MARKER);
403                                }
404                            }
405                            None | Some(_) => {
406                                request_message
407                                    .content
408                                    .push(MessageContent::Text(USING_TOOL_MARKER.into()));
409                            }
410                        };
411                    }
412
413                    found_tool_use = true;
414
415                    // Do not send tool uses until they are completed
416                    request_message
417                        .content
418                        .push(MessageContent::ToolUse(tool_use.clone()));
419                } else {
420                    log::debug!(
421                        "skipped tool use {:?} because it is still pending",
422                        tool_use
423                    );
424                }
425            }
426        }
427    }
428
429    pub fn attach_tool_results(
430        &self,
431        message_id: MessageId,
432        request_message: &mut LanguageModelRequestMessage,
433    ) {
434        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
435            for tool_use_id in tool_uses {
436                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
437                    request_message.content.push(MessageContent::ToolResult(
438                        LanguageModelToolResult {
439                            tool_use_id: tool_use_id.clone(),
440                            tool_name: tool_result.tool_name.clone(),
441                            is_error: tool_result.is_error,
442                            content: if tool_result.content.is_empty() {
443                                // Surprisingly, the API fails if we return an empty string here.
444                                // It thinks we are sending a tool use without a tool result.
445                                "<Tool returned an empty string>".into()
446                            } else {
447                                tool_result.content.clone()
448                            },
449                        },
450                    ));
451                }
452            }
453        }
454    }
455}
456
457#[derive(Debug, Clone)]
458pub struct PendingToolUse {
459    pub id: LanguageModelToolUseId,
460    /// The ID of the Assistant message in which the tool use was requested.
461    #[allow(unused)]
462    pub assistant_message_id: MessageId,
463    pub name: Arc<str>,
464    pub ui_text: Arc<str>,
465    pub input: serde_json::Value,
466    pub status: PendingToolUseStatus,
467}
468
469#[derive(Debug, Clone)]
470pub struct Confirmation {
471    pub tool_use_id: LanguageModelToolUseId,
472    pub input: serde_json::Value,
473    pub ui_text: Arc<str>,
474    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
475    pub tool: Arc<dyn Tool>,
476}
477
478#[derive(Debug, Clone)]
479pub enum PendingToolUseStatus {
480    Idle,
481    NeedsConfirmation(Arc<Confirmation>),
482    Running { _task: Shared<Task<()>> },
483    Error(#[allow(unused)] Arc<str>),
484}
485
486impl PendingToolUseStatus {
487    pub fn is_idle(&self) -> bool {
488        matches!(self, PendingToolUseStatus::Idle)
489    }
490
491    pub fn is_error(&self) -> bool {
492        matches!(self, PendingToolUseStatus::Error(_))
493    }
494
495    pub fn needs_confirmation(&self) -> bool {
496        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
497    }
498}