tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, Entity, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
 11    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14use util::truncate_lines_to_byte_limit;
 15
 16use crate::thread::{MessageId, PromptId, ThreadId};
 17use crate::thread_store::SerializedMessage;
 18
 19#[derive(Debug)]
 20pub struct ToolUse {
 21    pub id: LanguageModelToolUseId,
 22    pub name: SharedString,
 23    pub ui_text: SharedString,
 24    pub status: ToolUseStatus,
 25    pub input: serde_json::Value,
 26    pub icon: ui::IconName,
 27    pub needs_confirmation: bool,
 28}
 29
 30pub const USING_TOOL_MARKER: &str = "<using_tool>";
 31
 32pub struct ToolUseState {
 33    tools: Entity<ToolWorkingSet>,
 34    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 35    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 36    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 37    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 38    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 39    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 40}
 41
 42impl ToolUseState {
 43    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 44        Self {
 45            tools,
 46            tool_uses_by_assistant_message: HashMap::default(),
 47            tool_uses_by_user_message: HashMap::default(),
 48            tool_results: HashMap::default(),
 49            pending_tool_uses_by_id: HashMap::default(),
 50            tool_result_cards: HashMap::default(),
 51            tool_use_metadata_by_id: HashMap::default(),
 52        }
 53    }
 54
 55    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 56    ///
 57    /// Accepts a function to filter the tools that should be used to populate the state.
 58    pub fn from_serialized_messages(
 59        tools: Entity<ToolWorkingSet>,
 60        messages: &[SerializedMessage],
 61        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 62    ) -> Self {
 63        let mut this = Self::new(tools);
 64        let mut tool_names_by_id = HashMap::default();
 65
 66        for message in messages {
 67            match message.role {
 68                Role::Assistant => {
 69                    if !message.tool_uses.is_empty() {
 70                        let tool_uses = message
 71                            .tool_uses
 72                            .iter()
 73                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 74                            .map(|tool_use| LanguageModelToolUse {
 75                                id: tool_use.id.clone(),
 76                                name: tool_use.name.clone().into(),
 77                                input: tool_use.input.clone(),
 78                            })
 79                            .collect::<Vec<_>>();
 80
 81                        tool_names_by_id.extend(
 82                            tool_uses
 83                                .iter()
 84                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 85                        );
 86
 87                        this.tool_uses_by_assistant_message
 88                            .insert(message.id, tool_uses);
 89                    }
 90                }
 91                Role::User => {
 92                    if !message.tool_results.is_empty() {
 93                        let tool_uses_by_user_message = this
 94                            .tool_uses_by_user_message
 95                            .entry(message.id)
 96                            .or_default();
 97
 98                        for tool_result in &message.tool_results {
 99                            let tool_use_id = tool_result.tool_use_id.clone();
100                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
101                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
102                                continue;
103                            };
104
105                            if !(filter_by_tool_name)(tool_use.as_ref()) {
106                                continue;
107                            }
108
109                            tool_uses_by_user_message.push(tool_use_id.clone());
110                            this.tool_results.insert(
111                                tool_use_id.clone(),
112                                LanguageModelToolResult {
113                                    tool_use_id,
114                                    tool_name: tool_use.clone(),
115                                    is_error: tool_result.is_error,
116                                    content: tool_result.content.clone(),
117                                },
118                            );
119                        }
120                    }
121                }
122                Role::System => {}
123            }
124        }
125
126        this
127    }
128
129    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
130        let mut pending_tools = Vec::new();
131        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
132            self.tool_results.insert(
133                tool_use_id.clone(),
134                LanguageModelToolResult {
135                    tool_use_id,
136                    tool_name: tool_use.name.clone(),
137                    content: "Tool canceled by user".into(),
138                    is_error: true,
139                },
140            );
141            pending_tools.push(tool_use.clone());
142        }
143        pending_tools
144    }
145
146    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
147        self.pending_tool_uses_by_id.values().collect()
148    }
149
150    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
151        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
152            return Vec::new();
153        };
154
155        let mut tool_uses = Vec::new();
156
157        for tool_use in tool_uses_for_message.iter() {
158            let tool_result = self.tool_results.get(&tool_use.id);
159
160            let status = (|| {
161                if let Some(tool_result) = tool_result {
162                    return if tool_result.is_error {
163                        ToolUseStatus::Error(tool_result.content.clone().into())
164                    } else {
165                        ToolUseStatus::Finished(tool_result.content.clone().into())
166                    };
167                }
168
169                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
170                    match pending_tool_use.status {
171                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
172                        PendingToolUseStatus::NeedsConfirmation { .. } => {
173                            ToolUseStatus::NeedsConfirmation
174                        }
175                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
176                        PendingToolUseStatus::Error(ref err) => {
177                            ToolUseStatus::Error(err.clone().into())
178                        }
179                    }
180                } else {
181                    ToolUseStatus::Pending
182                }
183            })();
184
185            let (icon, needs_confirmation) =
186                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
187                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
188                } else {
189                    (IconName::Cog, false)
190                };
191
192            tool_uses.push(ToolUse {
193                id: tool_use.id.clone(),
194                name: tool_use.name.clone().into(),
195                ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
196                input: tool_use.input.clone(),
197                status,
198                icon,
199                needs_confirmation,
200            })
201        }
202
203        tool_uses
204    }
205
206    pub fn tool_ui_label(
207        &self,
208        tool_name: &str,
209        input: &serde_json::Value,
210        cx: &App,
211    ) -> SharedString {
212        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
213            tool.ui_text(input).into()
214        } else {
215            format!("Unknown tool {tool_name:?}").into()
216        }
217    }
218
219    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
220        let empty = Vec::new();
221
222        self.tool_uses_by_user_message
223            .get(&message_id)
224            .unwrap_or(&empty)
225            .iter()
226            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
227            .collect()
228    }
229
230    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
231        self.tool_uses_by_user_message
232            .get(&message_id)
233            .map_or(false, |results| !results.is_empty())
234    }
235
236    pub fn tool_result(
237        &self,
238        tool_use_id: &LanguageModelToolUseId,
239    ) -> Option<&LanguageModelToolResult> {
240        self.tool_results.get(tool_use_id)
241    }
242
243    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
244        self.tool_result_cards.get(tool_use_id)
245    }
246
247    pub fn insert_tool_result_card(
248        &mut self,
249        tool_use_id: LanguageModelToolUseId,
250        card: AnyToolCard,
251    ) {
252        self.tool_result_cards.insert(tool_use_id, card);
253    }
254
255    pub fn request_tool_use(
256        &mut self,
257        assistant_message_id: MessageId,
258        tool_use: LanguageModelToolUse,
259        metadata: ToolUseMetadata,
260        cx: &App,
261    ) {
262        self.tool_uses_by_assistant_message
263            .entry(assistant_message_id)
264            .or_default()
265            .push(tool_use.clone());
266
267        self.tool_use_metadata_by_id
268            .insert(tool_use.id.clone(), metadata);
269
270        // The tool use is being requested by the Assistant, so we want to
271        // attach the tool results to the next user message.
272        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
273        self.tool_uses_by_user_message
274            .entry(next_user_message_id)
275            .or_default()
276            .push(tool_use.id.clone());
277
278        self.pending_tool_uses_by_id.insert(
279            tool_use.id.clone(),
280            PendingToolUse {
281                assistant_message_id,
282                id: tool_use.id,
283                name: tool_use.name.clone(),
284                ui_text: self
285                    .tool_ui_label(&tool_use.name, &tool_use.input, cx)
286                    .into(),
287                input: tool_use.input,
288                status: PendingToolUseStatus::Idle,
289            },
290        );
291    }
292
293    pub fn run_pending_tool(
294        &mut self,
295        tool_use_id: LanguageModelToolUseId,
296        ui_text: SharedString,
297        task: Task<()>,
298    ) {
299        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
300            tool_use.ui_text = ui_text.into();
301            tool_use.status = PendingToolUseStatus::Running {
302                _task: task.shared(),
303            };
304        }
305    }
306
307    pub fn confirm_tool_use(
308        &mut self,
309        tool_use_id: LanguageModelToolUseId,
310        ui_text: impl Into<Arc<str>>,
311        input: serde_json::Value,
312        messages: Arc<Vec<LanguageModelRequestMessage>>,
313        tool: Arc<dyn Tool>,
314    ) {
315        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
316            let ui_text = ui_text.into();
317            tool_use.ui_text = ui_text.clone();
318            let confirmation = Confirmation {
319                tool_use_id,
320                input,
321                messages,
322                tool,
323                ui_text,
324            };
325            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
326        }
327    }
328
329    pub fn insert_tool_output(
330        &mut self,
331        tool_use_id: LanguageModelToolUseId,
332        tool_name: Arc<str>,
333        output: Result<String>,
334        cx: &App,
335    ) -> Option<PendingToolUse> {
336        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
337
338        telemetry::event!(
339            "Agent Tool Finished",
340            model = metadata
341                .as_ref()
342                .map(|metadata| metadata.model.telemetry_id()),
343            model_provider = metadata
344                .as_ref()
345                .map(|metadata| metadata.model.provider_id().to_string()),
346            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
347            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
348            tool_name,
349            success = output.is_ok()
350        );
351
352        match output {
353            Ok(tool_result) => {
354                let model_registry = LanguageModelRegistry::read_global(cx);
355
356                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
357
358                // Protect from clearly large output
359                let tool_output_limit = model_registry
360                    .default_model()
361                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
362                    .unwrap_or(usize::MAX);
363
364                let tool_result = if tool_result.len() <= tool_output_limit {
365                    tool_result
366                } else {
367                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
368
369                    format!(
370                        "Tool result too long. The first {} bytes:\n\n{}",
371                        truncated.len(),
372                        truncated
373                    )
374                };
375
376                self.tool_results.insert(
377                    tool_use_id.clone(),
378                    LanguageModelToolResult {
379                        tool_use_id: tool_use_id.clone(),
380                        tool_name,
381                        content: tool_result.into(),
382                        is_error: false,
383                    },
384                );
385                self.pending_tool_uses_by_id.remove(&tool_use_id)
386            }
387            Err(err) => {
388                self.tool_results.insert(
389                    tool_use_id.clone(),
390                    LanguageModelToolResult {
391                        tool_use_id: tool_use_id.clone(),
392                        tool_name,
393                        content: err.to_string().into(),
394                        is_error: true,
395                    },
396                );
397
398                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
399                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
400                }
401
402                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
403            }
404        }
405    }
406
407    pub fn attach_tool_uses(
408        &self,
409        message_id: MessageId,
410        request_message: &mut LanguageModelRequestMessage,
411    ) {
412        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
413            let mut found_tool_use = false;
414
415            for tool_use in tool_uses {
416                if self.tool_results.contains_key(&tool_use.id) {
417                    if !found_tool_use {
418                        // The API fails if a message contains a tool use without any (non-whitespace) text around it
419                        match request_message.content.last_mut() {
420                            Some(MessageContent::Text(txt)) => {
421                                if txt.is_empty() {
422                                    txt.push_str(USING_TOOL_MARKER);
423                                }
424                            }
425                            None | Some(_) => {
426                                request_message
427                                    .content
428                                    .push(MessageContent::Text(USING_TOOL_MARKER.into()));
429                            }
430                        };
431                    }
432
433                    found_tool_use = true;
434
435                    // Do not send tool uses until they are completed
436                    request_message
437                        .content
438                        .push(MessageContent::ToolUse(tool_use.clone()));
439                } else {
440                    log::debug!(
441                        "skipped tool use {:?} because it is still pending",
442                        tool_use
443                    );
444                }
445            }
446        }
447    }
448
449    pub fn attach_tool_results(
450        &self,
451        message_id: MessageId,
452        request_message: &mut LanguageModelRequestMessage,
453    ) {
454        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
455            for tool_use_id in tool_uses {
456                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
457                    request_message.content.push(MessageContent::ToolResult(
458                        LanguageModelToolResult {
459                            tool_use_id: tool_use_id.clone(),
460                            tool_name: tool_result.tool_name.clone(),
461                            is_error: tool_result.is_error,
462                            content: if tool_result.content.is_empty() {
463                                // Surprisingly, the API fails if we return an empty string here.
464                                // It thinks we are sending a tool use without a tool result.
465                                "<Tool returned an empty string>".into()
466                            } else {
467                                tool_result.content.clone()
468                            },
469                        },
470                    ));
471                }
472            }
473        }
474    }
475}
476
477#[derive(Debug, Clone)]
478pub struct PendingToolUse {
479    pub id: LanguageModelToolUseId,
480    /// The ID of the Assistant message in which the tool use was requested.
481    #[allow(unused)]
482    pub assistant_message_id: MessageId,
483    pub name: Arc<str>,
484    pub ui_text: Arc<str>,
485    pub input: serde_json::Value,
486    pub status: PendingToolUseStatus,
487}
488
489#[derive(Debug, Clone)]
490pub struct Confirmation {
491    pub tool_use_id: LanguageModelToolUseId,
492    pub input: serde_json::Value,
493    pub ui_text: Arc<str>,
494    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
495    pub tool: Arc<dyn Tool>,
496}
497
498#[derive(Debug, Clone)]
499pub enum PendingToolUseStatus {
500    Idle,
501    NeedsConfirmation(Arc<Confirmation>),
502    Running { _task: Shared<Task<()>> },
503    Error(#[allow(unused)] Arc<str>),
504}
505
506impl PendingToolUseStatus {
507    pub fn is_idle(&self) -> bool {
508        matches!(self, PendingToolUseStatus::Idle)
509    }
510
511    pub fn is_error(&self) -> bool {
512        matches!(self, PendingToolUseStatus::Error(_))
513    }
514
515    pub fn needs_confirmation(&self) -> bool {
516        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
517    }
518}
519
520#[derive(Clone)]
521pub struct ToolUseMetadata {
522    pub model: Arc<dyn LanguageModel>,
523    pub thread_id: ThreadId,
524    pub prompt_id: PromptId,
525}