tool_use.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::{
  5    AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
  6};
  7use collections::HashMap;
  8use futures::FutureExt as _;
  9use futures::future::Shared;
 10use gpui::{App, Entity, SharedString, Task};
 11use language_model::{
 12    ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
 13    LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
 14};
 15use project::Project;
 16use ui::{IconName, Window};
 17use util::truncate_lines_to_byte_limit;
 18
 19use crate::thread::{MessageId, PromptId, ThreadId};
 20use crate::thread_store::SerializedMessage;
 21
 22#[derive(Debug)]
 23pub struct ToolUse {
 24    pub id: LanguageModelToolUseId,
 25    pub name: SharedString,
 26    pub ui_text: SharedString,
 27    pub status: ToolUseStatus,
 28    pub input: serde_json::Value,
 29    pub icon: ui::IconName,
 30    pub needs_confirmation: bool,
 31}
 32
 33pub struct ToolUseState {
 34    tools: Entity<ToolWorkingSet>,
 35    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 36    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
 37    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 38    tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
 39    tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
 40}
 41
 42impl ToolUseState {
 43    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
 44        Self {
 45            tools,
 46            tool_uses_by_assistant_message: HashMap::default(),
 47            tool_results: HashMap::default(),
 48            pending_tool_uses_by_id: HashMap::default(),
 49            tool_result_cards: HashMap::default(),
 50            tool_use_metadata_by_id: HashMap::default(),
 51        }
 52    }
 53
 54    /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
 55    ///
 56    /// Accepts a function to filter the tools that should be used to populate the state.
 57    ///
 58    /// If `window` is `None` (e.g., when in headless mode or when running evals),
 59    /// tool cards won't be deserialized
 60    pub fn from_serialized_messages(
 61        tools: Entity<ToolWorkingSet>,
 62        messages: &[SerializedMessage],
 63        project: Entity<Project>,
 64        window: Option<&mut Window>, // None in headless mode
 65        cx: &mut App,
 66    ) -> Self {
 67        let mut this = Self::new(tools);
 68        let mut tool_names_by_id = HashMap::default();
 69        let mut window = window;
 70
 71        for message in messages {
 72            match message.role {
 73                Role::Assistant => {
 74                    if !message.tool_uses.is_empty() {
 75                        let tool_uses = message
 76                            .tool_uses
 77                            .iter()
 78                            .map(|tool_use| LanguageModelToolUse {
 79                                id: tool_use.id.clone(),
 80                                name: tool_use.name.clone().into(),
 81                                raw_input: tool_use.input.to_string(),
 82                                input: tool_use.input.clone(),
 83                                is_input_complete: true,
 84                            })
 85                            .collect::<Vec<_>>();
 86
 87                        tool_names_by_id.extend(
 88                            tool_uses
 89                                .iter()
 90                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
 91                        );
 92
 93                        this.tool_uses_by_assistant_message
 94                            .insert(message.id, tool_uses);
 95
 96                        for tool_result in &message.tool_results {
 97                            let tool_use_id = tool_result.tool_use_id.clone();
 98                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
 99                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
100                                continue;
101                            };
102
103                            this.tool_results.insert(
104                                tool_use_id.clone(),
105                                LanguageModelToolResult {
106                                    tool_use_id: tool_use_id.clone(),
107                                    tool_name: tool_use.clone(),
108                                    is_error: tool_result.is_error,
109                                    content: tool_result.content.clone(),
110                                    output: tool_result.output.clone(),
111                                },
112                            );
113
114                            if let Some(window) = &mut window {
115                                if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
116                                    if let Some(output) = tool_result.output.clone() {
117                                        if let Some(card) = tool.deserialize_card(
118                                            output,
119                                            project.clone(),
120                                            window,
121                                            cx,
122                                        ) {
123                                            this.tool_result_cards.insert(tool_use_id, card);
124                                        }
125                                    }
126                                }
127                            }
128                        }
129                    }
130                }
131                Role::System | Role::User => {}
132            }
133        }
134
135        this
136    }
137
138    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
139        let mut cancelled_tool_uses = Vec::new();
140        self.pending_tool_uses_by_id
141            .retain(|tool_use_id, tool_use| {
142                if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
143                    return true;
144                }
145
146                let content = "Tool canceled by user".into();
147                self.tool_results.insert(
148                    tool_use_id.clone(),
149                    LanguageModelToolResult {
150                        tool_use_id: tool_use_id.clone(),
151                        tool_name: tool_use.name.clone(),
152                        content,
153                        output: None,
154                        is_error: true,
155                    },
156                );
157                cancelled_tool_uses.push(tool_use.clone());
158                false
159            });
160        cancelled_tool_uses
161    }
162
163    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
164        self.pending_tool_uses_by_id.values().collect()
165    }
166
167    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
168        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
169            return Vec::new();
170        };
171
172        let mut tool_uses = Vec::new();
173
174        for tool_use in tool_uses_for_message.iter() {
175            let tool_result = self.tool_results.get(&tool_use.id);
176
177            let status = (|| {
178                if let Some(tool_result) = tool_result {
179                    let content = tool_result
180                        .content
181                        .to_str()
182                        .map(|str| str.to_owned().into())
183                        .unwrap_or_default();
184
185                    return if tool_result.is_error {
186                        ToolUseStatus::Error(content)
187                    } else {
188                        ToolUseStatus::Finished(content)
189                    };
190                }
191
192                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
193                    match pending_tool_use.status {
194                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
195                        PendingToolUseStatus::NeedsConfirmation { .. } => {
196                            ToolUseStatus::NeedsConfirmation
197                        }
198                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
199                        PendingToolUseStatus::Error(ref err) => {
200                            ToolUseStatus::Error(err.clone().into())
201                        }
202                        PendingToolUseStatus::InputStillStreaming => {
203                            ToolUseStatus::InputStillStreaming
204                        }
205                    }
206                } else {
207                    ToolUseStatus::Pending
208                }
209            })();
210
211            let (icon, needs_confirmation) =
212                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
213                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
214                } else {
215                    (IconName::Cog, false)
216                };
217
218            tool_uses.push(ToolUse {
219                id: tool_use.id.clone(),
220                name: tool_use.name.clone().into(),
221                ui_text: self.tool_ui_label(
222                    &tool_use.name,
223                    &tool_use.input,
224                    tool_use.is_input_complete,
225                    cx,
226                ),
227                input: tool_use.input.clone(),
228                status,
229                icon,
230                needs_confirmation,
231            })
232        }
233
234        tool_uses
235    }
236
237    pub fn tool_ui_label(
238        &self,
239        tool_name: &str,
240        input: &serde_json::Value,
241        is_input_complete: bool,
242        cx: &App,
243    ) -> SharedString {
244        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
245            if is_input_complete {
246                tool.ui_text(input).into()
247            } else {
248                tool.still_streaming_ui_text(input).into()
249            }
250        } else {
251            format!("Unknown tool {tool_name:?}").into()
252        }
253    }
254
255    pub fn tool_results_for_message(
256        &self,
257        assistant_message_id: MessageId,
258    ) -> Vec<&LanguageModelToolResult> {
259        let Some(tool_uses) = self
260            .tool_uses_by_assistant_message
261            .get(&assistant_message_id)
262        else {
263            return Vec::new();
264        };
265
266        tool_uses
267            .iter()
268            .filter_map(|tool_use| self.tool_results.get(&tool_use.id))
269            .collect()
270    }
271
272    pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
273        self.tool_uses_by_assistant_message
274            .get(&assistant_message_id)
275            .map_or(false, |results| !results.is_empty())
276    }
277
278    pub fn tool_result(
279        &self,
280        tool_use_id: &LanguageModelToolUseId,
281    ) -> Option<&LanguageModelToolResult> {
282        self.tool_results.get(tool_use_id)
283    }
284
285    pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
286        self.tool_result_cards.get(tool_use_id)
287    }
288
289    pub fn insert_tool_result_card(
290        &mut self,
291        tool_use_id: LanguageModelToolUseId,
292        card: AnyToolCard,
293    ) {
294        self.tool_result_cards.insert(tool_use_id, card);
295    }
296
297    pub fn request_tool_use(
298        &mut self,
299        assistant_message_id: MessageId,
300        tool_use: LanguageModelToolUse,
301        metadata: ToolUseMetadata,
302        cx: &App,
303    ) -> Arc<str> {
304        let tool_uses = self
305            .tool_uses_by_assistant_message
306            .entry(assistant_message_id)
307            .or_default();
308
309        let mut existing_tool_use_found = false;
310
311        for existing_tool_use in tool_uses.iter_mut() {
312            if existing_tool_use.id == tool_use.id {
313                *existing_tool_use = tool_use.clone();
314                existing_tool_use_found = true;
315            }
316        }
317
318        if !existing_tool_use_found {
319            tool_uses.push(tool_use.clone());
320        }
321
322        let status = if tool_use.is_input_complete {
323            self.tool_use_metadata_by_id
324                .insert(tool_use.id.clone(), metadata);
325
326            PendingToolUseStatus::Idle
327        } else {
328            PendingToolUseStatus::InputStillStreaming
329        };
330
331        let ui_text: Arc<str> = self
332            .tool_ui_label(
333                &tool_use.name,
334                &tool_use.input,
335                tool_use.is_input_complete,
336                cx,
337            )
338            .into();
339
340        self.pending_tool_uses_by_id.insert(
341            tool_use.id.clone(),
342            PendingToolUse {
343                assistant_message_id,
344                id: tool_use.id,
345                name: tool_use.name.clone(),
346                ui_text: ui_text.clone(),
347                input: tool_use.input,
348                status,
349            },
350        );
351
352        ui_text
353    }
354
355    pub fn run_pending_tool(
356        &mut self,
357        tool_use_id: LanguageModelToolUseId,
358        ui_text: SharedString,
359        task: Task<()>,
360    ) {
361        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
362            tool_use.ui_text = ui_text.into();
363            tool_use.status = PendingToolUseStatus::Running {
364                _task: task.shared(),
365            };
366        }
367    }
368
369    pub fn confirm_tool_use(
370        &mut self,
371        tool_use_id: LanguageModelToolUseId,
372        ui_text: impl Into<Arc<str>>,
373        input: serde_json::Value,
374        request: Arc<LanguageModelRequest>,
375        tool: Arc<dyn Tool>,
376    ) {
377        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
378            let ui_text = ui_text.into();
379            tool_use.ui_text = ui_text.clone();
380            let confirmation = Confirmation {
381                tool_use_id,
382                input,
383                request,
384                tool,
385                ui_text,
386            };
387            tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
388        }
389    }
390
391    pub fn insert_tool_output(
392        &mut self,
393        tool_use_id: LanguageModelToolUseId,
394        tool_name: Arc<str>,
395        output: Result<ToolResultOutput>,
396        configured_model: Option<&ConfiguredModel>,
397    ) -> Option<PendingToolUse> {
398        let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
399
400        telemetry::event!(
401            "Agent Tool Finished",
402            model = metadata
403                .as_ref()
404                .map(|metadata| metadata.model.telemetry_id()),
405            model_provider = metadata
406                .as_ref()
407                .map(|metadata| metadata.model.provider_id().to_string()),
408            thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
409            prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
410            tool_name,
411            success = output.is_ok()
412        );
413
414        match output {
415            Ok(output) => {
416                let tool_result = output.content;
417                const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
418
419                let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
420
421                // Protect from overly large output
422                let tool_output_limit = configured_model
423                    .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
424                    .unwrap_or(usize::MAX);
425
426                let content = match tool_result {
427                    ToolResultContent::Text(text) => {
428                        let text = if text.len() < tool_output_limit {
429                            text
430                        } else {
431                            let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
432                            format!(
433                                "Tool result too long. The first {} bytes:\n\n{}",
434                                truncated.len(),
435                                truncated
436                            )
437                        };
438                        LanguageModelToolResultContent::Text(text.into())
439                    }
440                    ToolResultContent::Image(language_model_image) => {
441                        if language_model_image.estimate_tokens() < tool_output_limit {
442                            LanguageModelToolResultContent::Image(language_model_image)
443                        } else {
444                            self.tool_results.insert(
445                                tool_use_id.clone(),
446                                LanguageModelToolResult {
447                                    tool_use_id: tool_use_id.clone(),
448                                    tool_name,
449                                    content: "Tool responded with an image that would exceeded the remaining tokens".into(),
450                                    is_error: true,
451                                    output: None,
452                                },
453                            );
454
455                            return old_use;
456                        }
457                    }
458                };
459
460                self.tool_results.insert(
461                    tool_use_id.clone(),
462                    LanguageModelToolResult {
463                        tool_use_id: tool_use_id.clone(),
464                        tool_name,
465                        content,
466                        is_error: false,
467                        output: output.output,
468                    },
469                );
470
471                old_use
472            }
473            Err(err) => {
474                self.tool_results.insert(
475                    tool_use_id.clone(),
476                    LanguageModelToolResult {
477                        tool_use_id: tool_use_id.clone(),
478                        tool_name,
479                        content: LanguageModelToolResultContent::Text(err.to_string().into()),
480                        is_error: true,
481                        output: None,
482                    },
483                );
484
485                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
486                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
487                }
488
489                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
490            }
491        }
492    }
493
494    pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
495        self.tool_uses_by_assistant_message
496            .contains_key(&assistant_message_id)
497    }
498
499    pub fn tool_results(
500        &self,
501        assistant_message_id: MessageId,
502    ) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
503        self.tool_uses_by_assistant_message
504            .get(&assistant_message_id)
505            .into_iter()
506            .flatten()
507            .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
508    }
509}
510
511#[derive(Debug, Clone)]
512pub struct PendingToolUse {
513    pub id: LanguageModelToolUseId,
514    /// The ID of the Assistant message in which the tool use was requested.
515    #[allow(unused)]
516    pub assistant_message_id: MessageId,
517    pub name: Arc<str>,
518    pub ui_text: Arc<str>,
519    pub input: serde_json::Value,
520    pub status: PendingToolUseStatus,
521}
522
523#[derive(Debug, Clone)]
524pub struct Confirmation {
525    pub tool_use_id: LanguageModelToolUseId,
526    pub input: serde_json::Value,
527    pub ui_text: Arc<str>,
528    pub request: Arc<LanguageModelRequest>,
529    pub tool: Arc<dyn Tool>,
530}
531
532#[derive(Debug, Clone)]
533pub enum PendingToolUseStatus {
534    InputStillStreaming,
535    Idle,
536    NeedsConfirmation(Arc<Confirmation>),
537    Running { _task: Shared<Task<()>> },
538    Error(#[allow(unused)] Arc<str>),
539}
540
541impl PendingToolUseStatus {
542    pub fn is_idle(&self) -> bool {
543        matches!(self, PendingToolUseStatus::Idle)
544    }
545
546    pub fn is_error(&self) -> bool {
547        matches!(self, PendingToolUseStatus::Error(_))
548    }
549
550    pub fn needs_confirmation(&self) -> bool {
551        matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
552    }
553}
554
555#[derive(Clone)]
556pub struct ToolUseMetadata {
557    pub model: Arc<dyn LanguageModel>,
558    pub thread_id: ThreadId,
559    pub prompt_id: PromptId,
560}