tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, Entity, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
 11    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14use util::truncate_lines_to_byte_limit;
 15
 16use crate::thread::{MessageId, PromptId, ThreadId};
 17use crate::thread_store::SerializedMessage;
 18
 19#[derive(Debug)]
 20pub struct ToolUse {
 21    pub id: LanguageModelToolUseId,
 22    pub name: SharedString,
 23    pub ui_text: SharedString,
 24    pub status: ToolUseStatus,
 25    pub input: serde_json::Value,
 26    pub icon: ui::IconName,
 27    pub needs_confirmation: bool,
 28}
 29
 30pub struct ToolUseState {
 31    tools: Entity<ToolWorkingSet>,
 32    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 33    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 34    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 35    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 36    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 37    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 38}
 39
 40impl ToolUseState {
 41    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 42        Self {
 43            tools,
 44            tool_uses_by_assistant_message: HashMap::default(),
 45            tool_uses_by_user_message: HashMap::default(),
 46            tool_results: HashMap::default(),
 47            pending_tool_uses_by_id: HashMap::default(),
 48            tool_result_cards: HashMap::default(),
 49            tool_use_metadata_by_id: HashMap::default(),
 50        }
 51    }
 52
 53    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 54    ///
 55    /// Accepts a function to filter the tools that should be used to populate the state.
 56    pub fn from_serialized_messages(
 57        tools: Entity<ToolWorkingSet>,
 58        messages: &[SerializedMessage],
 59        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 60    ) -> Self {
 61        let mut this = Self::new(tools);
 62        let mut tool_names_by_id = HashMap::default();
 63
 64        for message in messages {
 65            match message.role {
 66                Role::Assistant => {
 67                    if !message.tool_uses.is_empty() {
 68                        let tool_uses = message
 69                            .tool_uses
 70                            .iter()
 71                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 72                            .map(|tool_use| LanguageModelToolUse {
 73                                id: tool_use.id.clone(),
 74                                name: tool_use.name.clone().into(),
 75                                input: tool_use.input.clone(),
 76                            })
 77                            .collect::<Vec<_>>();
 78
 79                        tool_names_by_id.extend(
 80                            tool_uses
 81                                .iter()
 82                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 83                        );
 84
 85                        this.tool_uses_by_assistant_message
 86                            .insert(message.id, tool_uses);
 87                    }
 88                }
 89                Role::User => {
 90                    if !message.tool_results.is_empty() {
 91                        let tool_uses_by_user_message = this
 92                            .tool_uses_by_user_message
 93                            .entry(message.id)
 94                            .or_default();
 95
 96                        for tool_result in &message.tool_results {
 97                            let tool_use_id = tool_result.tool_use_id.clone();
 98                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 99                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
100                                continue;
101                            };
102
103                            if !(filter_by_tool_name)(tool_use.as_ref()) {
104                                continue;
105                            }
106
107                            tool_uses_by_user_message.push(tool_use_id.clone());
108                            this.tool_results.insert(
109                                tool_use_id.clone(),
110                                LanguageModelToolResult {
111                                    tool_use_id,
112                                    tool_name: tool_use.clone(),
113                                    is_error: tool_result.is_error,
114                                    content: tool_result.content.clone(),
115                                },
116                            );
117                        }
118                    }
119                }
120                Role::System => {}
121            }
122        }
123
124        this
125    }
126
127    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
128        let mut pending_tools = Vec::new();
129        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
130            self.tool_results.insert(
131                tool_use_id.clone(),
132                LanguageModelToolResult {
133                    tool_use_id,
134                    tool_name: tool_use.name.clone(),
135                    content: "Tool canceled by user".into(),
136                    is_error: true,
137                },
138            );
139            pending_tools.push(tool_use.clone());
140        }
141        pending_tools
142    }
143
144    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
145        self.pending_tool_uses_by_id.values().collect()
146    }
147
148    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
149        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
150            return Vec::new();
151        };
152
153        let mut tool_uses = Vec::new();
154
155        for tool_use in tool_uses_for_message.iter() {
156            let tool_result = self.tool_results.get(&tool_use.id);
157
158            let status = (|| {
159                if let Some(tool_result) = tool_result {
160                    return if tool_result.is_error {
161                        ToolUseStatus::Error(tool_result.content.clone().into())
162                    } else {
163                        ToolUseStatus::Finished(tool_result.content.clone().into())
164                    };
165                }
166
167                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
168                    match pending_tool_use.status {
169                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
170                        PendingToolUseStatus::NeedsConfirmation { .. } => {
171                            ToolUseStatus::NeedsConfirmation
172                        }
173                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
174                        PendingToolUseStatus::Error(ref err) => {
175                            ToolUseStatus::Error(err.clone().into())
176                        }
177                    }
178                } else {
179                    ToolUseStatus::Pending
180                }
181            })();
182
183            let (icon, needs_confirmation) =
184                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
185                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
186                } else {
187                    (IconName::Cog, false)
188                };
189
190            tool_uses.push(ToolUse {
191                id: tool_use.id.clone(),
192                name: tool_use.name.clone().into(),
193                ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
194                input: tool_use.input.clone(),
195                status,
196                icon,
197                needs_confirmation,
198            })
199        }
200
201        tool_uses
202    }
203
204    pub fn tool_ui_label(
205        &self,
206        tool_name: &str,
207        input: &serde_json::Value,
208        cx: &App,
209    ) -> SharedString {
210        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
211            tool.ui_text(input).into()
212        } else {
213            format!("Unknown tool {tool_name:?}").into()
214        }
215    }
216
217    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
218        let empty = Vec::new();
219
220        self.tool_uses_by_user_message
221            .get(&message_id)
222            .unwrap_or(&empty)
223            .iter()
224            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
225            .collect()
226    }
227
228    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
229        self.tool_uses_by_user_message
230            .get(&message_id)
231            .map_or(false, |results| !results.is_empty())
232    }
233
234    pub fn tool_result(
235        &self,
236        tool_use_id: &LanguageModelToolUseId,
237    ) -> Option<&LanguageModelToolResult> {
238        self.tool_results.get(tool_use_id)
239    }
240
241    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
242        self.tool_result_cards.get(tool_use_id)
243    }
244
245    pub fn insert_tool_result_card(
246        &mut self,
247        tool_use_id: LanguageModelToolUseId,
248        card: AnyToolCard,
249    ) {
250        self.tool_result_cards.insert(tool_use_id, card);
251    }
252
253    pub fn request_tool_use(
254        &mut self,
255        assistant_message_id: MessageId,
256        tool_use: LanguageModelToolUse,
257        metadata: ToolUseMetadata,
258        cx: &App,
259    ) {
260        self.tool_uses_by_assistant_message
261            .entry(assistant_message_id)
262            .or_default()
263            .push(tool_use.clone());
264
265        self.tool_use_metadata_by_id
266            .insert(tool_use.id.clone(), metadata);
267
268        // The tool use is being requested by the Assistant, so we want to
269        // attach the tool results to the next user message.
270        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
271        self.tool_uses_by_user_message
272            .entry(next_user_message_id)
273            .or_default()
274            .push(tool_use.id.clone());
275
276        self.pending_tool_uses_by_id.insert(
277            tool_use.id.clone(),
278            PendingToolUse {
279                assistant_message_id,
280                id: tool_use.id,
281                name: tool_use.name.clone(),
282                ui_text: self
283                    .tool_ui_label(&tool_use.name, &tool_use.input, cx)
284                    .into(),
285                input: tool_use.input,
286                status: PendingToolUseStatus::Idle,
287            },
288        );
289    }
290
291    pub fn run_pending_tool(
292        &mut self,
293        tool_use_id: LanguageModelToolUseId,
294        ui_text: SharedString,
295        task: Task<()>,
296    ) {
297        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
298            tool_use.ui_text = ui_text.into();
299            tool_use.status = PendingToolUseStatus::Running {
300                _task: task.shared(),
301            };
302        }
303    }
304
305    pub fn confirm_tool_use(
306        &mut self,
307        tool_use_id: LanguageModelToolUseId,
308        ui_text: impl Into<Arc<str>>,
309        input: serde_json::Value,
310        messages: Arc<Vec<LanguageModelRequestMessage>>,
311        tool: Arc<dyn Tool>,
312    ) {
313        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
314            let ui_text = ui_text.into();
315            tool_use.ui_text = ui_text.clone();
316            let confirmation = Confirmation {
317                tool_use_id,
318                input,
319                messages,
320                tool,
321                ui_text,
322            };
323            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
324        }
325    }
326
327    pub fn insert_tool_output(
328        &mut self,
329        tool_use_id: LanguageModelToolUseId,
330        tool_name: Arc<str>,
331        output: Result<String>,
332        cx: &App,
333    ) -> Option<PendingToolUse> {
334        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
335
336        telemetry::event!(
337            "Agent Tool Finished",
338            model = metadata
339                .as_ref()
340                .map(|metadata| metadata.model.telemetry_id()),
341            model_provider = metadata
342                .as_ref()
343                .map(|metadata| metadata.model.provider_id().to_string()),
344            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
345            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
346            tool_name,
347            success = output.is_ok()
348        );
349
350        match output {
351            Ok(tool_result) => {
352                let model_registry = LanguageModelRegistry::read_global(cx);
353
354                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
355
356                // Protect from clearly large output
357                let tool_output_limit = model_registry
358                    .default_model()
359                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
360                    .unwrap_or(usize::MAX);
361
362                let tool_result = if tool_result.len() <= tool_output_limit {
363                    tool_result
364                } else {
365                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
366
367                    format!(
368                        "Tool result too long. The first {} bytes:\n\n{}",
369                        truncated.len(),
370                        truncated
371                    )
372                };
373
374                self.tool_results.insert(
375                    tool_use_id.clone(),
376                    LanguageModelToolResult {
377                        tool_use_id: tool_use_id.clone(),
378                        tool_name,
379                        content: tool_result.into(),
380                        is_error: false,
381                    },
382                );
383                self.pending_tool_uses_by_id.remove(&tool_use_id)
384            }
385            Err(err) => {
386                self.tool_results.insert(
387                    tool_use_id.clone(),
388                    LanguageModelToolResult {
389                        tool_use_id: tool_use_id.clone(),
390                        tool_name,
391                        content: err.to_string().into(),
392                        is_error: true,
393                    },
394                );
395
396                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
397                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
398                }
399
400                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
401            }
402        }
403    }
404
405    pub fn attach_tool_uses(
406        &self,
407        message_id: MessageId,
408        request_message: &mut LanguageModelRequestMessage,
409    ) {
410        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
411            for tool_use in tool_uses {
412                if self.tool_results.contains_key(&tool_use.id) {
413                    // Do not send tool uses until they are completed
414                    request_message
415                        .content
416                        .push(MessageContent::ToolUse(tool_use.clone()));
417                } else {
418                    log::debug!(
419                        "skipped tool use {:?} because it is still pending",
420                        tool_use
421                    );
422                }
423            }
424        }
425    }
426
427    pub fn attach_tool_results(
428        &self,
429        message_id: MessageId,
430        request_message: &mut LanguageModelRequestMessage,
431    ) {
432        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
433            for tool_use_id in tool_uses {
434                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
435                    request_message.content.push(MessageContent::ToolResult(
436                        LanguageModelToolResult {
437                            tool_use_id: tool_use_id.clone(),
438                            tool_name: tool_result.tool_name.clone(),
439                            is_error: tool_result.is_error,
440                            content: if tool_result.content.is_empty() {
441                                // Surprisingly, the API fails if we return an empty string here.
442                                // It thinks we are sending a tool use without a tool result.
443                                "<Tool returned an empty string>".into()
444                            } else {
445                                tool_result.content.clone()
446                            },
447                        },
448                    ));
449                }
450            }
451        }
452    }
453}
454
455#[derive(Debug, Clone)]
456pub struct PendingToolUse {
457    pub id: LanguageModelToolUseId,
458    /// The ID of the Assistant message in which the tool use was requested.
459    #[allow(unused)]
460    pub assistant_message_id: MessageId,
461    pub name: Arc<str>,
462    pub ui_text: Arc<str>,
463    pub input: serde_json::Value,
464    pub status: PendingToolUseStatus,
465}
466
467#[derive(Debug, Clone)]
468pub struct Confirmation {
469    pub tool_use_id: LanguageModelToolUseId,
470    pub input: serde_json::Value,
471    pub ui_text: Arc<str>,
472    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
473    pub tool: Arc<dyn Tool>,
474}
475
476#[derive(Debug, Clone)]
477pub enum PendingToolUseStatus {
478    Idle,
479    NeedsConfirmation(Arc<Confirmation>),
480    Running { _task: Shared<Task<()>> },
481    Error(#[allow(unused)] Arc<str>),
482}
483
484impl PendingToolUseStatus {
485    pub fn is_idle(&self) -> bool {
486        matches!(self, PendingToolUseStatus::Idle)
487    }
488
489    pub fn is_error(&self) -> bool {
490        matches!(self, PendingToolUseStatus::Error(_))
491    }
492
493    pub fn needs_confirmation(&self) -> bool {
494        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
495    }
496}
497
498#[derive(Clone)]
499pub struct ToolUseMetadata {
500    pub model: Arc<dyn LanguageModel>,
501    pub thread_id: ThreadId,
502    pub prompt_id: PromptId,
503}