tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
  5use collections::HashMap;
  6use futures::FutureExt as _;
  7use futures::future::Shared;
  8use gpui::{App, Entity, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
 11    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12};
 13use ui::IconName;
 14use util::truncate_lines_to_byte_limit;
 15
 16use crate::thread::{MessageId, PromptId, ThreadId};
 17use crate::thread_store::SerializedMessage;
 18
 19#[derive(Debug)]
 20pub struct ToolUse {
 21    pub id: LanguageModelToolUseId,
 22    pub name: SharedString,
 23    pub ui_text: SharedString,
 24    pub status: ToolUseStatus,
 25    pub input: serde_json::Value,
 26    pub icon: ui::IconName,
 27    pub needs_confirmation: bool,
 28}
 29
 30pub const USING_TOOL_MARKER: &str = "<using_tool>";
 31
 32pub struct ToolUseState {
 33    tools: Entity<ToolWorkingSet>,
 34    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 35    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
 36    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 37    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 38    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 39    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 40}
 41
 42impl ToolUseState {
 43    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 44        Self {
 45            tools,
 46            tool_uses_by_assistant_message: HashMap::default(),
 47            tool_uses_by_user_message: HashMap::default(),
 48            tool_results: HashMap::default(),
 49            pending_tool_uses_by_id: HashMap::default(),
 50            tool_result_cards: HashMap::default(),
 51            tool_use_metadata_by_id: HashMap::default(),
 52        }
 53    }
 54
 55    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 56    ///
 57    /// Accepts a function to filter the tools that should be used to populate the state.
 58    pub fn from_serialized_messages(
 59        tools: Entity<ToolWorkingSet>,
 60        messages: &[SerializedMessage],
 61        mut filter_by_tool_name: impl FnMut(&str) -> bool,
 62    ) -> Self {
 63        let mut this = Self::new(tools);
 64        let mut tool_names_by_id = HashMap::default();
 65
 66        for message in messages {
 67            match message.role {
 68                Role::Assistant => {
 69                    if !message.tool_uses.is_empty() {
 70                        let tool_uses = message
 71                            .tool_uses
 72                            .iter()
 73                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
 74                            .map(|tool_use| LanguageModelToolUse {
 75                                id: tool_use.id.clone(),
 76                                name: tool_use.name.clone().into(),
 77                                input: tool_use.input.clone(),
 78                                is_input_complete: true,
 79                            })
 80                            .collect::<Vec<_>>();
 81
 82                        tool_names_by_id.extend(
 83                            tool_uses
 84                                .iter()
 85                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 86                        );
 87
 88                        this.tool_uses_by_assistant_message
 89                            .insert(message.id, tool_uses);
 90                    }
 91                }
 92                Role::User => {
 93                    if !message.tool_results.is_empty() {
 94                        let tool_uses_by_user_message = this
 95                            .tool_uses_by_user_message
 96                            .entry(message.id)
 97                            .or_default();
 98
 99                        for tool_result in &message.tool_results {
100                            let tool_use_id = tool_result.tool_use_id.clone();
101                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
102                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
103                                continue;
104                            };
105
106                            if !(filter_by_tool_name)(tool_use.as_ref()) {
107                                continue;
108                            }
109
110                            tool_uses_by_user_message.push(tool_use_id.clone());
111                            this.tool_results.insert(
112                                tool_use_id.clone(),
113                                LanguageModelToolResult {
114                                    tool_use_id,
115                                    tool_name: tool_use.clone(),
116                                    is_error: tool_result.is_error,
117                                    content: tool_result.content.clone(),
118                                },
119                            );
120                        }
121                    }
122                }
123                Role::System => {}
124            }
125        }
126
127        this
128    }
129
130    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
131        let mut pending_tools = Vec::new();
132        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
133            self.tool_results.insert(
134                tool_use_id.clone(),
135                LanguageModelToolResult {
136                    tool_use_id,
137                    tool_name: tool_use.name.clone(),
138                    content: "Tool canceled by user".into(),
139                    is_error: true,
140                },
141            );
142            pending_tools.push(tool_use.clone());
143        }
144        pending_tools
145    }
146
147    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
148        self.pending_tool_uses_by_id.values().collect()
149    }
150
151    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
152        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
153            return Vec::new();
154        };
155
156        let mut tool_uses = Vec::new();
157
158        for tool_use in tool_uses_for_message.iter() {
159            let tool_result = self.tool_results.get(&tool_use.id);
160
161            let status = (|| {
162                if let Some(tool_result) = tool_result {
163                    return if tool_result.is_error {
164                        ToolUseStatus::Error(tool_result.content.clone().into())
165                    } else {
166                        ToolUseStatus::Finished(tool_result.content.clone().into())
167                    };
168                }
169
170                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
171                    match pending_tool_use.status {
172                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
173                        PendingToolUseStatus::NeedsConfirmation { .. } => {
174                            ToolUseStatus::NeedsConfirmation
175                        }
176                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
177                        PendingToolUseStatus::Error(ref err) => {
178                            ToolUseStatus::Error(err.clone().into())
179                        }
180                        PendingToolUseStatus::InputStillStreaming => {
181                            ToolUseStatus::InputStillStreaming
182                        }
183                    }
184                } else {
185                    ToolUseStatus::Pending
186                }
187            })();
188
189            let (icon, needs_confirmation) =
190                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
191                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
192                } else {
193                    (IconName::Cog, false)
194                };
195
196            tool_uses.push(ToolUse {
197                id: tool_use.id.clone(),
198                name: tool_use.name.clone().into(),
199                ui_text: self.tool_ui_label(
200                    &tool_use.name,
201                    &tool_use.input,
202                    tool_use.is_input_complete,
203                    cx,
204                ),
205                input: tool_use.input.clone(),
206                status,
207                icon,
208                needs_confirmation,
209            })
210        }
211
212        tool_uses
213    }
214
215    pub fn tool_ui_label(
216        &self,
217        tool_name: &str,
218        input: &serde_json::Value,
219        is_input_complete: bool,
220        cx: &App,
221    ) -> SharedString {
222        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
223            if is_input_complete {
224                tool.ui_text(input).into()
225            } else {
226                tool.still_streaming_ui_text(input).into()
227            }
228        } else {
229            format!("Unknown tool {tool_name:?}").into()
230        }
231    }
232
233    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
234        let empty = Vec::new();
235
236        self.tool_uses_by_user_message
237            .get(&message_id)
238            .unwrap_or(&empty)
239            .iter()
240            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
241            .collect()
242    }
243
244    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
245        self.tool_uses_by_user_message
246            .get(&message_id)
247            .map_or(false, |results| !results.is_empty())
248    }
249
250    pub fn tool_result(
251        &self,
252        tool_use_id: &LanguageModelToolUseId,
253    ) -> Option<&LanguageModelToolResult> {
254        self.tool_results.get(tool_use_id)
255    }
256
257    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
258        self.tool_result_cards.get(tool_use_id)
259    }
260
261    pub fn insert_tool_result_card(
262        &mut self,
263        tool_use_id: LanguageModelToolUseId,
264        card: AnyToolCard,
265    ) {
266        self.tool_result_cards.insert(tool_use_id, card);
267    }
268
269    pub fn request_tool_use(
270        &mut self,
271        assistant_message_id: MessageId,
272        tool_use: LanguageModelToolUse,
273        metadata: ToolUseMetadata,
274        cx: &App,
275    ) -> Arc<str> {
276        let tool_uses = self
277            .tool_uses_by_assistant_message
278            .entry(assistant_message_id)
279            .or_default();
280
281        let mut existing_tool_use_found = false;
282
283        for existing_tool_use in tool_uses.iter_mut() {
284            if existing_tool_use.id == tool_use.id {
285                *existing_tool_use = tool_use.clone();
286                existing_tool_use_found = true;
287            }
288        }
289
290        if !existing_tool_use_found {
291            tool_uses.push(tool_use.clone());
292        }
293
294        let status = if tool_use.is_input_complete {
295            self.tool_use_metadata_by_id
296                .insert(tool_use.id.clone(), metadata);
297
298            // The tool use is being requested by the Assistant, so we want to
299            // attach the tool results to the next user message.
300            let next_user_message_id = MessageId(assistant_message_id.0 + 1);
301            self.tool_uses_by_user_message
302                .entry(next_user_message_id)
303                .or_default()
304                .push(tool_use.id.clone());
305
306            PendingToolUseStatus::Idle
307        } else {
308            PendingToolUseStatus::InputStillStreaming
309        };
310
311        let ui_text: Arc<str> = self
312            .tool_ui_label(
313                &tool_use.name,
314                &tool_use.input,
315                tool_use.is_input_complete,
316                cx,
317            )
318            .into();
319
320        self.pending_tool_uses_by_id.insert(
321            tool_use.id.clone(),
322            PendingToolUse {
323                assistant_message_id,
324                id: tool_use.id,
325                name: tool_use.name.clone(),
326                ui_text: ui_text.clone(),
327                input: tool_use.input,
328                status,
329            },
330        );
331
332        ui_text
333    }
334
335    pub fn run_pending_tool(
336        &mut self,
337        tool_use_id: LanguageModelToolUseId,
338        ui_text: SharedString,
339        task: Task<()>,
340    ) {
341        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
342            tool_use.ui_text = ui_text.into();
343            tool_use.status = PendingToolUseStatus::Running {
344                _task: task.shared(),
345            };
346        }
347    }
348
349    pub fn confirm_tool_use(
350        &mut self,
351        tool_use_id: LanguageModelToolUseId,
352        ui_text: impl Into<Arc<str>>,
353        input: serde_json::Value,
354        messages: Arc<Vec<LanguageModelRequestMessage>>,
355        tool: Arc<dyn Tool>,
356    ) {
357        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
358            let ui_text = ui_text.into();
359            tool_use.ui_text = ui_text.clone();
360            let confirmation = Confirmation {
361                tool_use_id,
362                input,
363                messages,
364                tool,
365                ui_text,
366            };
367            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
368        }
369    }
370
371    pub fn insert_tool_output(
372        &mut self,
373        tool_use_id: LanguageModelToolUseId,
374        tool_name: Arc<str>,
375        output: Result<String>,
376        cx: &App,
377    ) -> Option<PendingToolUse> {
378        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
379
380        telemetry::event!(
381            "Agent Tool Finished",
382            model = metadata
383                .as_ref()
384                .map(|metadata| metadata.model.telemetry_id()),
385            model_provider = metadata
386                .as_ref()
387                .map(|metadata| metadata.model.provider_id().to_string()),
388            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
389            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
390            tool_name,
391            success = output.is_ok()
392        );
393
394        match output {
395            Ok(tool_result) => {
396                let model_registry = LanguageModelRegistry::read_global(cx);
397
398                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
399
400                // Protect from clearly large output
401                let tool_output_limit = model_registry
402                    .default_model()
403                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
404                    .unwrap_or(usize::MAX);
405
406                let tool_result = if tool_result.len() <= tool_output_limit {
407                    tool_result
408                } else {
409                    let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
410
411                    format!(
412                        "Tool result too long. The first {} bytes:\n\n{}",
413                        truncated.len(),
414                        truncated
415                    )
416                };
417
418                self.tool_results.insert(
419                    tool_use_id.clone(),
420                    LanguageModelToolResult {
421                        tool_use_id: tool_use_id.clone(),
422                        tool_name,
423                        content: tool_result.into(),
424                        is_error: false,
425                    },
426                );
427                self.pending_tool_uses_by_id.remove(&tool_use_id)
428            }
429            Err(err) => {
430                self.tool_results.insert(
431                    tool_use_id.clone(),
432                    LanguageModelToolResult {
433                        tool_use_id: tool_use_id.clone(),
434                        tool_name,
435                        content: err.to_string().into(),
436                        is_error: true,
437                    },
438                );
439
440                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
441                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
442                }
443
444                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
445            }
446        }
447    }
448
449    pub fn attach_tool_uses(
450        &self,
451        message_id: MessageId,
452        request_message: &mut LanguageModelRequestMessage,
453    ) {
454        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
455            let mut found_tool_use = false;
456
457            for tool_use in tool_uses {
458                if self.tool_results.contains_key(&tool_use.id) {
459                    if !found_tool_use {
460                        // The API fails if a message contains a tool use without any (non-whitespace) text around it
461                        match request_message.content.last_mut() {
462                            Some(MessageContent::Text(txt)) => {
463                                if txt.is_empty() {
464                                    txt.push_str(USING_TOOL_MARKER);
465                                }
466                            }
467                            None | Some(_) => {
468                                request_message
469                                    .content
470                                    .push(MessageContent::Text(USING_TOOL_MARKER.into()));
471                            }
472                        };
473                    }
474
475                    found_tool_use = true;
476
477                    // Do not send tool uses until they are completed
478                    request_message
479                        .content
480                        .push(MessageContent::ToolUse(tool_use.clone()));
481                } else {
482                    log::debug!(
483                        "skipped tool use {:?} because it is still pending",
484                        tool_use
485                    );
486                }
487            }
488        }
489    }
490
491    pub fn attach_tool_results(
492        &self,
493        message_id: MessageId,
494        request_message: &mut LanguageModelRequestMessage,
495    ) {
496        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
497            for tool_use_id in tool_uses {
498                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
499                    request_message.content.push(MessageContent::ToolResult(
500                        LanguageModelToolResult {
501                            tool_use_id: tool_use_id.clone(),
502                            tool_name: tool_result.tool_name.clone(),
503                            is_error: tool_result.is_error,
504                            content: if tool_result.content.is_empty() {
505                                // Surprisingly, the API fails if we return an empty string here.
506                                // It thinks we are sending a tool use without a tool result.
507                                "<Tool returned an empty string>".into()
508                            } else {
509                                tool_result.content.clone()
510                            },
511                        },
512                    ));
513                }
514            }
515        }
516    }
517}
518
519#[derive(Debug, Clone)]
520pub struct PendingToolUse {
521    pub id: LanguageModelToolUseId,
522    /// The ID of the Assistant message in which the tool use was requested.
523    #[allow(unused)]
524    pub assistant_message_id: MessageId,
525    pub name: Arc<str>,
526    pub ui_text: Arc<str>,
527    pub input: serde_json::Value,
528    pub status: PendingToolUseStatus,
529}
530
531#[derive(Debug, Clone)]
532pub struct Confirmation {
533    pub tool_use_id: LanguageModelToolUseId,
534    pub input: serde_json::Value,
535    pub ui_text: Arc<str>,
536    pub messages: Arc<Vec<LanguageModelRequestMessage>>,
537    pub tool: Arc<dyn Tool>,
538}
539
540#[derive(Debug, Clone)]
541pub enum PendingToolUseStatus {
542    InputStillStreaming,
543    Idle,
544    NeedsConfirmation(Arc<Confirmation>),
545    Running { _task: Shared<Task<()>> },
546    Error(#[allow(unused)] Arc<str>),
547}
548
549impl PendingToolUseStatus {
550    pub fn is_idle(&self) -> bool {
551        matches!(self, PendingToolUseStatus::Idle)
552    }
553
554    pub fn is_error(&self) -> bool {
555        matches!(self, PendingToolUseStatus::Error(_))
556    }
557
558    pub fn needs_confirmation(&self) -> bool {
559        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
560    }
561}
562
563#[derive(Clone)]
564pub struct ToolUseMetadata {
565    pub model: Arc<dyn LanguageModel>,
566    pub thread_id: ThreadId,
567    pub prompt_id: PromptId,
568}