tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{
  5    AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
  6};
  7use collections::HashMap;
  8use futures::FutureExt as _;
  9use futures::future::Shared;
 10use gpui::{App, Entity, SharedString, Task};
 11use language_model::{
 12    ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
 13    LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
 14};
 15use project::Project;
 16use ui::{IconName, Window};
 17use util::truncate_lines_to_byte_limit;
 18
 19use crate::thread::{MessageId, PromptId, ThreadId};
 20use crate::thread_store::SerializedMessage;
 21
 22#[derive(Debug)]
 23pub struct ToolUse {
 24    pub id: LanguageModelToolUseId,
 25    pub name: SharedString,
 26    pub ui_text: SharedString,
 27    pub status: ToolUseStatus,
 28    pub input: serde_json::Value,
 29    pub icon: ui::IconName,
 30    pub needs_confirmation: bool,
 31}
 32
 33pub struct ToolUseState {
 34    tools: Entity<ToolWorkingSet>,
 35    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 36    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 37    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 38    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 39    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 40}
 41
 42impl ToolUseState {
 43    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 44        Self {
 45            tools,
 46            tool_uses_by_assistant_message: HashMap::default(),
 47            tool_results: HashMap::default(),
 48            pending_tool_uses_by_id: HashMap::default(),
 49            tool_result_cards: HashMap::default(),
 50            tool_use_metadata_by_id: HashMap::default(),
 51        }
 52    }
 53
 54    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 55    ///
 56    /// Accepts a function to filter the tools that should be used to populate the state.
 57    ///
 58    /// If `window` is `None` (e.g., when in headless mode or when running evals),
 59    /// tool cards won't be deserialized
 60    pub fn from_serialized_messages(
 61        tools: Entity<ToolWorkingSet>,
 62        messages: &[SerializedMessage],
 63        project: Entity<Project>,
 64        window: Option<&mut Window>, // None in headless mode
 65        cx: &mut App,
 66    ) -> Self {
 67        let mut this = Self::new(tools);
 68        let mut tool_names_by_id = HashMap::default();
 69        let mut window = window;
 70
 71        for message in messages {
 72            match message.role {
 73                Role::Assistant => {
 74                    if !message.tool_uses.is_empty() {
 75                        let tool_uses = message
 76                            .tool_uses
 77                            .iter()
 78                            .map(|tool_use| LanguageModelToolUse {
 79                                id: tool_use.id.clone(),
 80                                name: tool_use.name.clone().into(),
 81                                raw_input: tool_use.input.to_string(),
 82                                input: tool_use.input.clone(),
 83                                is_input_complete: true,
 84                            })
 85                            .collect::<Vec<_>>();
 86
 87                        tool_names_by_id.extend(
 88                            tool_uses
 89                                .iter()
 90                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 91                        );
 92
 93                        this.tool_uses_by_assistant_message
 94                            .insert(message.id, tool_uses);
 95
 96                        for tool_result in &message.tool_results {
 97                            let tool_use_id = tool_result.tool_use_id.clone();
 98                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 99                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
100                                continue;
101                            };
102
103                            this.tool_results.insert(
104                                tool_use_id.clone(),
105                                LanguageModelToolResult {
106                                    tool_use_id: tool_use_id.clone(),
107                                    tool_name: tool_use.clone(),
108                                    is_error: tool_result.is_error,
109                                    content: tool_result.content.clone(),
110                                    output: tool_result.output.clone(),
111                                },
112                            );
113
114                            if let Some(window) = &mut window {
115                                if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
116                                    if let Some(output) = tool_result.output.clone() {
117                                        if let Some(card) = tool.deserialize_card(
118                                            output,
119                                            project.clone(),
120                                            window,
121                                            cx,
122                                        ) {
123                                            this.tool_result_cards.insert(tool_use_id, card);
124                                        }
125                                    }
126                                }
127                            }
128                        }
129                    }
130                }
131                Role::System | Role::User => {}
132            }
133        }
134
135        this
136    }
137
138    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
139        let mut cancelled_tool_uses = Vec::new();
140        self.pending_tool_uses_by_id
141            .retain(|tool_use_id, tool_use| {
142                if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
143                    return true;
144                }
145
146                let content = "Tool canceled by user".into();
147                self.tool_results.insert(
148                    tool_use_id.clone(),
149                    LanguageModelToolResult {
150                        tool_use_id: tool_use_id.clone(),
151                        tool_name: tool_use.name.clone(),
152                        content,
153                        output: None,
154                        is_error: true,
155                    },
156                );
157                cancelled_tool_uses.push(tool_use.clone());
158                false
159            });
160        cancelled_tool_uses
161    }
162
163    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
164        self.pending_tool_uses_by_id.values().collect()
165    }
166
167    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
168        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
169            return Vec::new();
170        };
171
172        let mut tool_uses = Vec::new();
173
174        for tool_use in tool_uses_for_message.iter() {
175            let tool_result = self.tool_results.get(&tool_use.id);
176
177            let status = (|| {
178                if let Some(tool_result) = tool_result {
179                    let content = tool_result
180                        .content
181                        .to_str()
182                        .map(|str| str.to_owned().into())
183                        .unwrap_or_default();
184
185                    return if tool_result.is_error {
186                        ToolUseStatus::Error(content)
187                    } else {
188                        ToolUseStatus::Finished(content)
189                    };
190                }
191
192                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
193                    match pending_tool_use.status {
194                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
195                        PendingToolUseStatus::NeedsConfirmation { .. } => {
196                            ToolUseStatus::NeedsConfirmation
197                        }
198                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
199                        PendingToolUseStatus::Error(ref err) => {
200                            ToolUseStatus::Error(err.clone().into())
201                        }
202                        PendingToolUseStatus::InputStillStreaming => {
203                            ToolUseStatus::InputStillStreaming
204                        }
205                    }
206                } else {
207                    ToolUseStatus::Pending
208                }
209            })();
210
211            let (icon, needs_confirmation) =
212                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
213                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
214                } else {
215                    (IconName::Cog, false)
216                };
217
218            tool_uses.push(ToolUse {
219                id: tool_use.id.clone(),
220                name: tool_use.name.clone().into(),
221                ui_text: self.tool_ui_label(
222                    &tool_use.name,
223                    &tool_use.input,
224                    tool_use.is_input_complete,
225                    cx,
226                ),
227                input: tool_use.input.clone(),
228                status,
229                icon,
230                needs_confirmation,
231            })
232        }
233
234        tool_uses
235    }
236
237    pub fn tool_ui_label(
238        &self,
239        tool_name: &str,
240        input: &serde_json::Value,
241        is_input_complete: bool,
242        cx: &App,
243    ) -> SharedString {
244        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
245            if is_input_complete {
246                tool.ui_text(input).into()
247            } else {
248                tool.still_streaming_ui_text(input).into()
249            }
250        } else {
251            format!("Unknown tool {tool_name:?}").into()
252        }
253    }
254
255    pub fn tool_results_for_message(
256        &self,
257        assistant_message_id: MessageId,
258    ) -> Vec<&LanguageModelToolResult> {
259        let Some(tool_uses) = self
260            .tool_uses_by_assistant_message
261            .get(&assistant_message_id)
262        else {
263            return Vec::new();
264        };
265
266        tool_uses
267            .iter()
268            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
269            .collect()
270    }
271
272    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
273        self.tool_uses_by_assistant_message
274            .get(&assistant_message_id)
275            .map_or(false, |results| !results.is_empty())
276    }
277
278    pub fn tool_result(
279        &self,
280        tool_use_id: &LanguageModelToolUseId,
281    ) -> Option<&LanguageModelToolResult> {
282        self.tool_results.get(tool_use_id)
283    }
284
285    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
286        self.tool_result_cards.get(tool_use_id)
287    }
288
289    pub fn insert_tool_result_card(
290        &mut self,
291        tool_use_id: LanguageModelToolUseId,
292        card: AnyToolCard,
293    ) {
294        self.tool_result_cards.insert(tool_use_id, card);
295    }
296
297    pub fn request_tool_use(
298        &mut self,
299        assistant_message_id: MessageId,
300        tool_use: LanguageModelToolUse,
301        metadata: ToolUseMetadata,
302        cx: &App,
303    ) -> Arc<str> {
304        let tool_uses = self
305            .tool_uses_by_assistant_message
306            .entry(assistant_message_id)
307            .or_default();
308
309        let mut existing_tool_use_found = false;
310
311        for existing_tool_use in tool_uses.iter_mut() {
312            if existing_tool_use.id == tool_use.id {
313                *existing_tool_use = tool_use.clone();
314                existing_tool_use_found = true;
315            }
316        }
317
318        if !existing_tool_use_found {
319            tool_uses.push(tool_use.clone());
320        }
321
322        let status = if tool_use.is_input_complete {
323            self.tool_use_metadata_by_id
324                .insert(tool_use.id.clone(), metadata);
325
326            PendingToolUseStatus::Idle
327        } else {
328            PendingToolUseStatus::InputStillStreaming
329        };
330
331        let ui_text: Arc<str> = self
332            .tool_ui_label(
333                &tool_use.name,
334                &tool_use.input,
335                tool_use.is_input_complete,
336                cx,
337            )
338            .into();
339
340        let may_perform_edits = self
341            .tools
342            .read(cx)
343            .tool(&tool_use.name, cx)
344            .is_some_and(|tool| tool.may_perform_edits());
345
346        self.pending_tool_uses_by_id.insert(
347            tool_use.id.clone(),
348            PendingToolUse {
349                assistant_message_id,
350                id: tool_use.id,
351                name: tool_use.name.clone(),
352                ui_text: ui_text.clone(),
353                input: tool_use.input,
354                may_perform_edits,
355                status,
356            },
357        );
358
359        ui_text
360    }
361
362    pub fn run_pending_tool(
363        &mut self,
364        tool_use_id: LanguageModelToolUseId,
365        ui_text: SharedString,
366        task: Task<()>,
367    ) {
368        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
369            tool_use.ui_text = ui_text.into();
370            tool_use.status = PendingToolUseStatus::Running {
371                _task: task.shared(),
372            };
373        }
374    }
375
376    pub fn confirm_tool_use(
377        &mut self,
378        tool_use_id: LanguageModelToolUseId,
379        ui_text: impl Into<Arc<str>>,
380        input: serde_json::Value,
381        request: Arc<LanguageModelRequest>,
382        tool: Arc<dyn Tool>,
383    ) {
384        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
385            let ui_text = ui_text.into();
386            tool_use.ui_text = ui_text.clone();
387            let confirmation = Confirmation {
388                tool_use_id,
389                input,
390                request,
391                tool,
392                ui_text,
393            };
394            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
395        }
396    }
397
398    pub fn insert_tool_output(
399        &mut self,
400        tool_use_id: LanguageModelToolUseId,
401        tool_name: Arc<str>,
402        output: Result<ToolResultOutput>,
403        configured_model: Option<&ConfiguredModel>,
404    ) -> Option<PendingToolUse> {
405        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
406
407        telemetry::event!(
408            "Agent Tool Finished",
409            model = metadata
410                .as_ref()
411                .map(|metadata| metadata.model.telemetry_id()),
412            model_provider = metadata
413                .as_ref()
414                .map(|metadata| metadata.model.provider_id().to_string()),
415            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
416            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
417            tool_name,
418            success = output.is_ok()
419        );
420
421        match output {
422            Ok(output) => {
423                let tool_result = output.content;
424                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
425
426                let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
427
428                // Protect from overly large output
429                let tool_output_limit = configured_model
430                    .map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
431                    .unwrap_or(usize::MAX);
432
433                let content = match tool_result {
434                    ToolResultContent::Text(text) => {
435                        let text = if text.len() < tool_output_limit {
436                            text
437                        } else {
438                            let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
439                            format!(
440                                "Tool result too long. The first {} bytes:\n\n{}",
441                                truncated.len(),
442                                truncated
443                            )
444                        };
445                        LanguageModelToolResultContent::Text(text.into())
446                    }
447                    ToolResultContent::Image(language_model_image) => {
448                        if language_model_image.estimate_tokens() < tool_output_limit {
449                            LanguageModelToolResultContent::Image(language_model_image)
450                        } else {
451                            self.tool_results.insert(
452                                tool_use_id.clone(),
453                                LanguageModelToolResult {
454                                    tool_use_id: tool_use_id.clone(),
455                                    tool_name,
456                                    content: "Tool responded with an image that would exceeded the remaining tokens".into(),
457                                    is_error: true,
458                                    output: None,
459                                },
460                            );
461
462                            return old_use;
463                        }
464                    }
465                };
466
467                self.tool_results.insert(
468                    tool_use_id.clone(),
469                    LanguageModelToolResult {
470                        tool_use_id: tool_use_id.clone(),
471                        tool_name,
472                        content,
473                        is_error: false,
474                        output: output.output,
475                    },
476                );
477
478                old_use
479            }
480            Err(err) => {
481                self.tool_results.insert(
482                    tool_use_id.clone(),
483                    LanguageModelToolResult {
484                        tool_use_id: tool_use_id.clone(),
485                        tool_name,
486                        content: LanguageModelToolResultContent::Text(err.to_string().into()),
487                        is_error: true,
488                        output: None,
489                    },
490                );
491
492                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
493                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
494                }
495
496                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
497            }
498        }
499    }
500
501    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
502        self.tool_uses_by_assistant_message
503            .contains_key(&assistant_message_id)
504    }
505
506    pub fn tool_results(
507        &self,
508        assistant_message_id: MessageId,
509    ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
510        self.tool_uses_by_assistant_message
511            .get(&assistant_message_id)
512            .into_iter()
513            .flatten()
514            .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
515    }
516}
517
518#[derive(Debug, Clone)]
519pub struct PendingToolUse {
520    pub id: LanguageModelToolUseId,
521    /// The ID of the Assistant message in which the tool use was requested.
522    #[allow(unused)]
523    pub assistant_message_id: MessageId,
524    pub name: Arc<str>,
525    pub ui_text: Arc<str>,
526    pub input: serde_json::Value,
527    pub status: PendingToolUseStatus,
528    pub may_perform_edits: bool,
529}
530
531#[derive(Debug, Clone)]
532pub struct Confirmation {
533    pub tool_use_id: LanguageModelToolUseId,
534    pub input: serde_json::Value,
535    pub ui_text: Arc<str>,
536    pub request: Arc<LanguageModelRequest>,
537    pub tool: Arc<dyn Tool>,
538}
539
540#[derive(Debug, Clone)]
541pub enum PendingToolUseStatus {
542    InputStillStreaming,
543    Idle,
544    NeedsConfirmation(Arc<Confirmation>),
545    Running { _task: Shared<Task<()>> },
546    Error(#[allow(unused)] Arc<str>),
547}
548
549impl PendingToolUseStatus {
550    pub fn is_idle(&self) -> bool {
551        matches!(self, PendingToolUseStatus::Idle)
552    }
553
554    pub fn is_error(&self) -> bool {
555        matches!(self, PendingToolUseStatus::Error(_))
556    }
557
558    pub fn needs_confirmation(&self) -> bool {
559        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
560    }
561}
562
563#[derive(Clone)]
564pub struct ToolUseMetadata {
565    pub model: Arc<dyn LanguageModel>,
566    pub thread_id: ThreadId,
567    pub prompt_id: PromptId,
568}